Browse Source

feat(dnn): add uint16 support

GitOrigin-RevId: f4c4b1c7b9
release-1.2
Megvii Engine Team 4 years ago
parent
commit
7066ad5ba6
6 changed files with 16 additions and 0 deletions
  1. +5
    -0
      dnn/include/megdnn/dtype.h
  2. +1
    -0
      dnn/test/common/dtype.cpp
  3. +1
    -0
      src/core/include/megbrain/dtype.h
  4. +6
    -0
      src/opr/impl/loop/forward.cpp
  5. +2
    -0
      src/opr/impl/loop/impl.cpp
  6. +1
    -0
      src/serialization/impl/dtype.fbs

+ 5
- 0
dnn/include/megdnn/dtype.h View File

@@ -53,6 +53,7 @@ namespace megdnn {
MEGDNN_INC_FLOAT16(cb(BFloat16)) \
cb(UintB4) \
cb(Bool) \
cb(Uint16) \

/*!
* \brief iterate through each full byte dtype
@@ -67,6 +68,7 @@ namespace megdnn {
MEGDNN_INC_FLOAT16(cb(Float16)) \
MEGDNN_INC_FLOAT16(cb(BFloat16)) \
cb(Bool) \
cb(Uint16) \

/*!
* \brief iterate through each fractional byte dtype
@@ -353,6 +355,7 @@ typedef int16_t dt_int16;
typedef int8_t dt_int8;
typedef uint8_t dt_uint8;
typedef bool dt_bool;
typedef uint16_t dt_uint16;
MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;)
MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)

@@ -381,6 +384,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)
BFloat16 = 11,
#endif
Bool = 12,
Uint16 = 13,
#define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE,
#define D(_name) _name,
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D)
@@ -713,6 +717,7 @@ MEGDNN_DEF_DT(Int16, dt_int16, INT, SIGNED, INT16_MIN, INT16_MAX);
MEGDNN_DEF_DT(Int8, dt_int8, INT, SIGNED, INT8_MIN, INT8_MAX);
MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX);
MEGDNN_DEF_DT(Bool, dt_bool, BOOL, UNSIGNED, false, true);
MEGDNN_DEF_DT(Uint16, dt_uint16, INT, UNSIGNED, 0, UINT16_MAX);
MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED,
std::numeric_limits<dt_float16>::lowest(),
std::numeric_limits<dt_float16>::max()));


+ 1
- 0
dnn/test/common/dtype.cpp View File

@@ -26,6 +26,7 @@ TEST(TestDType, SizeCheck) {
ASSERT_EQ(static_cast<size_t>(2), ::megdnn::dtype::IntB4().size(3));
ASSERT_EQ(static_cast<size_t>(2), ::megdnn::dtype::IntB4().size(4));
ASSERT_EQ(static_cast<size_t>(3), ::megdnn::dtype::IntB4().size(5));
ASSERT_EQ(static_cast<size_t>(2), ::megdnn::dtype::Uint16().size(1));
ASSERT_EQ(static_cast<size_t>(2),
::megdnn::dtype::Quantized4Asymm(1.0f, static_cast<uint8_t>(12))
.size(3));


+ 1
- 0
src/core/include/megbrain/dtype.h View File

@@ -28,6 +28,7 @@ using ::megdnn::dt_quint8;
using ::megdnn::dt_qint8;
using ::megdnn::dt_qint32;
using ::megdnn::dt_bool;
using ::megdnn::dt_uint16;
using ::megdnn::DType;
using ::megdnn::DTypeEnum;
using ::megdnn::DTypeTrait;


+ 6
- 0
src/opr/impl/loop/forward.cpp View File

@@ -370,6 +370,12 @@ cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const {
iv += contain_eq;
return std::max(iv, 0);
}
case DTypeEnum::Uint16:
{
auto iv = val.ptr<dt_uint16>()[0];
iv += contain_eq;
return std::max<int>(iv, 0);
}
case DTypeEnum::Float32:
#if !MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16:


+ 2
- 0
src/opr/impl/loop/impl.cpp View File

@@ -249,6 +249,8 @@ MGB_DEFINE_OPR_CLASS(LoopImpl::DescImplBase::LoopCondManager::GetCondOpr,
break;
case DTypeEnum::Bool:
break;
case DTypeEnum::Uint16:
break;
#define cb(_dt) \
case DTypeEnum::_dt: \
break;


+ 1
- 0
src/serialization/impl/dtype.fbs View File

@@ -22,6 +22,7 @@ enum DTypeEnum : byte {
QuantizedS16,
BFloat16,
Bool,
Uint16,
}

table LinearQuantizationParam {


Loading…
Cancel
Save