Browse Source

feat(dnn/common): add tensor format for low-bits tensor layout

GitOrigin-RevId: 0aa3753f37
release-1.5
Megvii Engine Team 4 years ago
parent
commit
91d6160769
6 changed files with 319 additions and 38 deletions
  1. +3
    -3
      dnn/include/megdnn/dtype.h
  2. +78
    -28
      dnn/include/megdnn/tensor_format.h
  3. +9
    -1
      dnn/src/common/basic_types.cpp
  4. +178
    -3
      dnn/src/common/tensor_format.cpp
  5. +5
    -3
      dnn/test/common/checker.h
  6. +46
    -0
      dnn/test/common/test_basic_types.cpp

+ 3
- 3
dnn/include/megdnn/dtype.h View File

@@ -505,9 +505,9 @@ class DType {
return std::numeric_limits<size_t>::max() >> m_trait->size_log; return std::numeric_limits<size_t>::max() >> m_trait->size_log;
} }


bool is_low_bit() const {
return m_trait->low_bit != 0;
}
size_t low_bit() const { return m_trait->low_bit; }
bool is_low_bit() const { return low_bit() != 0; }


/*! /*!
* \brief size of this data type, in bytes * \brief size of this data type, in bytes


+ 78
- 28
dnn/include/megdnn/tensor_format.h View File

@@ -20,12 +20,15 @@ namespace megdnn {
enum class TensorFormat::Type { enum class TensorFormat::Type {
DEFAULT = 0, //!< see DefaultTensorFormat DEFAULT = 0, //!< see DefaultTensorFormat
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat
FOURBITS_ALIGNED_TO_BYTE = 2, //!<
}; };


class TensorFormat::ImplBase { class TensorFormat::ImplBase {
public: public:
using Type = TensorFormat::Type; using Type = TensorFormat::Type;


virtual void assert_valid(const TensorLayout& layout) const = 0;

virtual size_t init_contiguous_stride(TensorLayout& layout) const = 0; virtual size_t init_contiguous_stride(TensorLayout& layout) const = 0;


virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0;
@@ -63,6 +66,8 @@ public:


DefaultTensorFormat() : ImplBase(TYPE) {} DefaultTensorFormat() : ImplBase(TYPE) {}


void assert_valid(const TensorLayout& layout) const override;

size_t init_contiguous_stride(TensorLayout& layout) const override; size_t init_contiguous_stride(TensorLayout& layout) const override;


/*! /*!
@@ -180,11 +185,11 @@ public:
*/ */
size_t image_width(const TensorLayout& layout) const; size_t image_width(const TensorLayout& layout) const;


//! raise exception if preconditions violated
void assert_valid(const TensorLayout& layout) const;

size_t image_row_pitch(const TensorLayout& layout) const; size_t image_row_pitch(const TensorLayout& layout) const;


//! raise exception if preconditions violated
void assert_valid(const TensorLayout& layout) const override;

//! span for image must include the padding at the last row //! span for image must include the padding at the last row
TensorLayout::Span span_spec(const TensorLayout& layout) const override; TensorLayout::Span span_spec(const TensorLayout& layout) const override;


@@ -197,31 +202,48 @@ public:
}; };
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>;


///*!
// * \brief used for tensors with lowbit data type
// *
// * \p SIZE_NBITS is the size in bits of element of the tensor.
// *
// */
//template <size_t SIZE_NBITS_>
//class LowbitTensorFormat : public TensorFormat::ImplBase {
// static constexpr size_t SIZE_NBITS = SIZE_NBITS_;
// size_t m_align_size_in_bits;
//
//protected: //?
// LowbitTensorFormat(Type type, size_t m_align_size_in_bits);
//
//public:
// size_t align_size_in_bits() const {
// return m_align_size_in_bits;
// }
//
// std::string to_string() const override;
//
// void serialize_append(
//
//
//};
/*!
* \brief used for tensors storing lowbit data
*
* \p SIZE_NBITS is the size in bits of element of the tensor.
*
*/
template <size_t SIZE_NBITS_>
class LowbitsTensorFormatBase : public TensorFormat::ImplBase {
static constexpr size_t SIZE_NBITS = SIZE_NBITS_;
size_t m_align_size_in_bits, m_align_size_in_elements;

protected: //?
LowbitsTensorFormatBase(Type type, size_t align_size_in_bits);

virtual ~LowbitsTensorFormatBase() = default;

public:
size_t align_size_in_bits() const { return m_align_size_in_bits; }

std::string to_string() const override;

//! raise exception if given layout is illegal
void assert_valid(const TensorLayout& layout) const;

void serialize_append(std::string& result) const override;

//! span for lowbit tensor must include the padding at the innermost
//! dimemsion that make lowbit tensor be aligned to bytes
TensorLayout::Span span_spec(const TensorLayout& layout) const override;

size_t init_contiguous_stride(TensorLayout& layout) const override;

bool is_contiguous_spec(const TensorLayout& layout) const override;

TensorLayout collapse_contiguous_spec(
const TensorLayout& layout) const override;
protected:
struct SerializePack {
uint8_t align_size_in_bits;
};
};
using FourBitsAlignedToBytesTensorFormatBase = LowbitsTensorFormatBase<4>;
} // namespace detail } // namespace detail


/*! /*!
@@ -270,6 +292,34 @@ private:
TYPE, align_axis, align_size_in_elements, vendor_type) {} TYPE, align_axis, align_size_in_elements, vendor_type) {}
}; };


/*!
* \brief Tensor for storing 4bit data that requires stride corresponding to
* non-innermost dimension to be aligned to bytes, and pack 2 elems into a byte
*/
class FourBitsAlignedToBytesTensorFormat final
: public detail::FourBitsAlignedToBytesTensorFormatBase {
public:
static constexpr Type TYPE = Type::FOURBITS_ALIGNED_TO_BYTE;

static TensorFormat make(size_t align_size_in_bits);

static TensorFormat deserialize(const Handle* handle, const void* buf,
size_t size);

static bool is_valid_layout(const TensorLayout& layout) {
if (layout.format.type() == TYPE) {
layout.format.as_impl<FourBitsAlignedToBytesTensorFormat>()
.assert_valid(layout);
return true;
}
return false;
}

private:
FourBitsAlignedToBytesTensorFormat(size_t align_size_in_bits)
: detail::FourBitsAlignedToBytesTensorFormatBase(
TYPE, align_size_in_bits) {}
};
} // namespace megdnn } // namespace megdnn


#include "megdnn/internal/visibility_epilogue.h" #include "megdnn/internal/visibility_epilogue.h"


+ 9
- 1
dnn/src/common/basic_types.cpp View File

@@ -201,7 +201,15 @@ TensorLayout::TensorLayout(DType dtype_, Format format_)
: dtype{dtype_}, format{format_} {} : dtype{dtype_}, format{format_} {}


TensorLayout::TensorLayout(const TensorShape& shape, DType dtype) TensorLayout::TensorLayout(const TensorShape& shape, DType dtype)
: TensorLayout(shape, dtype, DefaultTensorFormat::make()) {}
: TensorShape(shape), dtype{dtype} {
if (dtype.low_bit() == 4_z) {
format = FourBitsAlignedToBytesTensorFormat::make(8_z);
} else {
megdnn_assert(!dtype.is_low_bit(), "Unsupported data type(%s)",
dtype.name());
format = DefaultTensorFormat::make();
}
}


TensorLayout::TensorLayout(const TensorShape& shape, DType dtype, TensorLayout::TensorLayout(const TensorShape& shape, DType dtype,
TensorFormat format_) TensorFormat format_)


+ 178
- 3
dnn/src/common/tensor_format.cpp View File

@@ -35,6 +35,9 @@ TensorFormat TensorFormat::deserialize(const std::string& bin,
case Type::IMAGE2D_PACK4: case Type::IMAGE2D_PACK4:
return Image2DPack4TensorFormat::deserialize( return Image2DPack4TensorFormat::deserialize(
handle, type + 1, bin.size() - sizeof(Type)); handle, type + 1, bin.size() - sizeof(Type));
case Type::FOURBITS_ALIGNED_TO_BYTE:
return FourBitsAlignedToBytesTensorFormat::deserialize(
handle, type + 1, bin.size() - sizeof(Type));
default: default:
megdnn_throw("invalid tensor format type in deserialize"); megdnn_throw("invalid tensor format type in deserialize");
} }
@@ -67,7 +70,15 @@ bool TensorFormat::is_default() const {
} }


/* ===================== DefaultFormat ===================== */ /* ===================== DefaultFormat ===================== */
void DefaultTensorFormat::assert_valid(const TensorLayout& layout) const {
megdnn_assert(
!layout.dtype.valid() || !layout.dtype.is_low_bit(),
"DefaultTensorFormat does not support low-bits tensor(dtype:%s)",
layout.dtype.name());
}

size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const { size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const {
assert_valid(layout);
if (!layout.ndim) if (!layout.ndim)
return 0; return 0;
megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM); megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM);
@@ -81,11 +92,13 @@ size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const {
} }


bool DefaultTensorFormat::is_contiguous_spec(const TensorLayout& layout) const { bool DefaultTensorFormat::is_contiguous_spec(const TensorLayout& layout) const {
assert_valid(layout);
return layout.is_physical_contiguous(); return layout.is_physical_contiguous();
} }


TensorLayout DefaultTensorFormat::collapse_contiguous_spec( TensorLayout DefaultTensorFormat::collapse_contiguous_spec(
const TensorLayout& layout) const { const TensorLayout& layout) const {
assert_valid(layout);
megdnn_assert(layout.ndim); megdnn_assert(layout.ndim);
TensorLayout res{layout}; TensorLayout res{layout};


@@ -126,6 +139,7 @@ TensorLayout DefaultTensorFormat::collapse_contiguous_spec(


TensorLayout::Span DefaultTensorFormat::span_spec( TensorLayout::Span DefaultTensorFormat::span_spec(
const TensorLayout& layout) const { const TensorLayout& layout) const {
assert_valid(layout);
if (layout.ndim == 0) if (layout.ndim == 0)
return {0, 0, 0, 0}; return {0, 0, 0, 0};


@@ -146,9 +160,6 @@ TensorLayout::Span DefaultTensorFormat::span_spec(
++high_elem; ++high_elem;
ptrdiff_t low_byte; ptrdiff_t low_byte;
if (low_elem < 0) { if (low_elem < 0) {
megdnn_assert(!layout.dtype.is_low_bit(),
"tensors with low-bit dytes shouldn't have negative "
"strides");
low_byte = low_elem * layout.dtype.size(); low_byte = low_elem * layout.dtype.size();
} else { } else {
low_byte = 0; low_byte = 0;
@@ -422,12 +433,151 @@ TensorLayout Image2DPackedTensorFormatBase<PIXEL_SIZE>::collapse_contiguous_spec
return res; return res;
} }



namespace megdnn { namespace megdnn {
namespace detail { namespace detail {
template class Image2DPackedTensorFormatBase<4>; template class Image2DPackedTensorFormatBase<4>;
} // namespace detail } // namespace detail
} // namespace megdnn } // namespace megdnn


/* =============== FourBitsAlignedToBytesTensorFormatBase ============== */
template <size_t SIZE_NBITS>
LowbitsTensorFormatBase<SIZE_NBITS>::LowbitsTensorFormatBase(
Type type, size_t align_size_in_bits)
: ImplBase(type), m_align_size_in_bits(align_size_in_bits) {
megdnn_assert(!(m_align_size_in_bits % SIZE_NBITS),
"align size(%zu) must be a multiple of element size(%zu)",
m_align_size_in_bits, SIZE_NBITS);
m_align_size_in_elements = m_align_size_in_bits / SIZE_NBITS;
}

template <size_t SIZE_NBITS>
std::string LowbitsTensorFormatBase<SIZE_NBITS>::to_string() const {
return ssprintf("LOWBITS{%zu,%zu}", SIZE_NBITS, m_align_size_in_bits);
}

template <size_t SIZE_NBITS>
void LowbitsTensorFormatBase<SIZE_NBITS>::assert_valid(
const TensorLayout& layout) const {
megdnn_assert(layout.dtype.valid() && layout.dtype.is_low_bit() &&
layout.dtype.low_bit() == SIZE_NBITS);
bool has_dim_unity_stride = false;
for (int i = layout.ndim - 1; i >= 0; --i) {
if (!has_dim_unity_stride && layout.stride[i] == 1)
has_dim_unity_stride = true;
megdnn_assert(
layout.stride[i] >= 0 &&
(layout.stride[i] % m_align_size_in_elements == 0 ||
layout.stride[i] == 1),
"bad stride: %zu", layout.stride[i]);
}
megdnn_assert(has_dim_unity_stride, "innermost dim not contiguous");
}

template <size_t SIZE_NBITS>
void LowbitsTensorFormatBase<SIZE_NBITS>::serialize_append(
std::string& result) const {
SerializePack pack;
pack.align_size_in_bits = m_align_size_in_bits;
megdnn_assert(pack.align_size_in_bits ==
m_align_size_in_bits); // detect overflow;
result.append(reinterpret_cast<char*>(&pack), sizeof(pack));
}

template <size_t SIZE_NBITS>
TensorLayout::Span LowbitsTensorFormatBase<SIZE_NBITS>::span_spec(
const TensorLayout& layout) const {
assert_valid(layout);
if (layout.ndim == 0)
return {0, 0, 0, 0};

size_t high_elem = 0;
for (size_t i = 0; i < layout.ndim; ++i) {
auto shape_val = layout.shape[i];
if (!shape_val) {
return {0, 0, 0, 0};
}
auto stride_val = layout.stride[i];
megdnn_assert(stride_val >= 0,
"lowbit tensors shouldn't have negative strides");
high_elem += (shape_val - 1) * stride_val;
}
++high_elem;
size_t high_byte = layout.dtype.size(high_elem);
return TensorLayout::Span(0, 0, high_elem, high_byte);
}

template <size_t SIZE_NBITS>
size_t LowbitsTensorFormatBase<SIZE_NBITS>::init_contiguous_stride(
TensorLayout& layout) const {
if (!layout.ndim)
return 0;
megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM);
size_t accum = 1;
SafeMultiplies<size_t> mul;
for (size_t i = layout.ndim; i; --i) {
layout.stride[i - 1] = accum;
auto multiplier = layout.shape[i - 1];
if (i == layout.ndim)
multiplier = round_up(multiplier, m_align_size_in_elements);
accum = mul(accum, multiplier);
}
return accum;
}

template <size_t SIZE_NBITS>
bool LowbitsTensorFormatBase<SIZE_NBITS>::is_contiguous_spec(
const TensorLayout& layout) const {
assert_valid(layout);
ptrdiff_t expected = 1;
for (int i = static_cast<int>(layout.ndim) - 1; i >= 0; --i) {
if (layout.shape[i] != 1 && layout.stride[i] != expected)
return false;
auto multiplier = layout.shape[i];
if (i == layout.ndim - 1)
multiplier = round_up(multiplier, m_align_size_in_elements);
expected *= multiplier;
}
return expected != 0;
}

template <size_t SIZE_NBITS>
TensorLayout LowbitsTensorFormatBase<SIZE_NBITS>::collapse_contiguous_spec(
const TensorLayout& layout) const {
assert_valid(layout);
TensorLayout res{layout};
for (int i = static_cast<int>(res.ndim) - 1; i >= 0; --i) {
if (!res.shape[i]) {
// empty tensor
res.ndim = 1;
res.shape[0] = 0;
res.stride[0] = 1;
return res;
}
if (res.shape[i] == 1) {
res.remove_axis_inplace(i);
}
}

megdnn_assert(res.ndim && res.shape[res.ndim - 1]);
for (int i = static_cast<int>(res.ndim) - 2; i >= 0; --i) {
megdnn_assert(res.shape[i]);
if (res.stride[i] ==
res.stride[i + 1] * static_cast<ptrdiff_t>(res.shape[i + 1])) {
res.shape[i] *= res.shape[i + 1];
res.stride[i] = res.stride[i + 1];
res.remove_axis_inplace(i + 1);
}
}
return res;
}

namespace megdnn {
namespace detail {
template class LowbitsTensorFormatBase<4>;
} // namespace detail
} // namespace megdnn

/* ===================== Image2DPack4TensorFormat ===================== */ /* ===================== Image2DPack4TensorFormat ===================== */
TensorFormat Image2DPack4TensorFormat::make_raw( TensorFormat Image2DPack4TensorFormat::make_raw(
size_t align_axis, size_t align_size_in_elements, size_t align_axis, size_t align_size_in_elements,
@@ -466,4 +616,29 @@ TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const {
return make_raw(axis, align_size_in_elements(), vendor()); return make_raw(axis, align_size_in_elements(), vendor());
} }


/* ===================== FourBitsAlignedToBytesTensorFormat
* ===================== */
TensorFormat FourBitsAlignedToBytesTensorFormat::make(
size_t align_size_in_bits) {
static std::mutex mtx;
static std::unordered_map<
uint32_t, std::unique_ptr<FourBitsAlignedToBytesTensorFormat>>
cache;
megdnn_assert(!(align_size_in_bits % 4));
MEGDNN_LOCK_GUARD(mtx);
auto&& ptr = cache[static_cast<uint32_t>(align_size_in_bits)];
if (!ptr) {
ptr.reset(new FourBitsAlignedToBytesTensorFormat{align_size_in_bits});
}
return impl_to_tensor_format(ptr.get());
}

TensorFormat FourBitsAlignedToBytesTensorFormat::deserialize(const Handle*,
const void* buf,
size_t size) {
megdnn_assert(size == sizeof(SerializePack));
auto pack = *static_cast<const SerializePack*>(buf);
return make(pack.align_size_in_bits);
}

// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 5
- 3
dnn/test/common/checker.h View File

@@ -128,9 +128,11 @@ public:
for (size_t i = 0; i < shapes.size(); ++i) { for (size_t i = 0; i < shapes.size(); ++i) {
DType dt = (m_dtype.find(i) != m_dtype.end() ? m_dtype[i] DType dt = (m_dtype.find(i) != m_dtype.end() ? m_dtype[i]
: dtype::Float32()); : dtype::Float32());
TensorFormat fmt =
(m_fmt.find(i) != m_fmt.end() ? m_fmt[i] : TensorFormat{});
layouts[i] = TensorLayout(shapes[i], dt, fmt);
if (m_fmt.find(i) == m_fmt.end()) {
layouts[i] = TensorLayout(shapes[i], dt);
layouts[i].init_contiguous_stride();
} else
layouts[i] = TensorLayout(shapes[i], dt, m_fmt[i]);
} }
return layouts; return layouts;
} }


+ 46
- 0
dnn/test/common/test_basic_types.cpp View File

@@ -302,4 +302,50 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) {
} }
} }


TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS) {
TensorLayout layout{{16, 32, 7, 7}, dtype::QuantizedS4{1.2f}};
layout.init_contiguous_stride();
ASSERT_EQ(layout.stride[0], 1792);
ASSERT_EQ(layout.stride[1], 56);
ASSERT_EQ(layout.stride[2], 8);
ASSERT_EQ(layout.stride[3], 1);
auto span = layout.span();
ASSERT_EQ(0, span.low_elem);
ASSERT_EQ(28671, span.high_elem);
ASSERT_EQ(0, span.low_byte);
ASSERT_EQ(14336, span.high_byte);
EXPECT_EQ(make_layout({3584, 7}, {8, 1}, dtype::QuantizedS4{1.2f}),
layout.collapse_contiguous());


layout = make_layout({16, 32, 7, 7}, {1792, 56, 8, 1},
dtype::QuantizedS4{1.3f});
layout.format = FourBitsAlignedToBytesTensorFormat::make(8_z);
EXPECT_TRUE(layout.is_contiguous());

layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}};
layout = layout.broadcast({16, 32, 7, 7});
EXPECT_EQ(make_layout({16, 32, 49}, {0, 1, 0}, dtype::QuantizedS4{1.2}),
layout.collapse_contiguous());

layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}};
layout.init_contiguous_stride();
layout = layout.broadcast({16, 32, 7, 7});
ASSERT_THROW(layout.span(), MegDNNError);
}

TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS_VALID) {
ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS4{1.2f},
DefaultTensorFormat::make()),
MegDNNError);
ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS32{1.2f},
FourBitsAlignedToBytesTensorFormat::make(8_z))
.span(),
MegDNNError);
ASSERT_THROW(TensorLayout({16, 32, 7, 7}, dtype::IntB2{},
FourBitsAlignedToBytesTensorFormat::make(8_z))
.span(),
MegDNNError);
}

// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

Loading…
Cancel
Save