GitOrigin-RevId: 0aa3753f37
release-1.5
@@ -505,9 +505,9 @@ class DType { | |||
return std::numeric_limits<size_t>::max() >> m_trait->size_log; | |||
} | |||
bool is_low_bit() const { | |||
return m_trait->low_bit != 0; | |||
} | |||
size_t low_bit() const { return m_trait->low_bit; } | |||
bool is_low_bit() const { return low_bit() != 0; } | |||
/*! | |||
* \brief size of this data type, in bytes | |||
@@ -20,12 +20,15 @@ namespace megdnn { | |||
enum class TensorFormat::Type { | |||
DEFAULT = 0, //!< see DefaultTensorFormat | |||
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat | |||
FOURBITS_ALIGNED_TO_BYTE = 2, //!< | |||
}; | |||
class TensorFormat::ImplBase { | |||
public: | |||
using Type = TensorFormat::Type; | |||
virtual void assert_valid(const TensorLayout& layout) const = 0; | |||
virtual size_t init_contiguous_stride(TensorLayout& layout) const = 0; | |||
virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; | |||
@@ -63,6 +66,8 @@ public: | |||
DefaultTensorFormat() : ImplBase(TYPE) {} | |||
void assert_valid(const TensorLayout& layout) const override; | |||
size_t init_contiguous_stride(TensorLayout& layout) const override; | |||
/*! | |||
@@ -180,11 +185,11 @@ public: | |||
*/ | |||
size_t image_width(const TensorLayout& layout) const; | |||
//! raise exception if preconditions violated | |||
void assert_valid(const TensorLayout& layout) const; | |||
size_t image_row_pitch(const TensorLayout& layout) const; | |||
//! raise exception if preconditions violated | |||
void assert_valid(const TensorLayout& layout) const override; | |||
//! span for image must include the padding at the last row | |||
TensorLayout::Span span_spec(const TensorLayout& layout) const override; | |||
@@ -197,31 +202,48 @@ public: | |||
}; | |||
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | |||
///*! | |||
// * \brief used for tensors with lowbit data type | |||
// * | |||
// * \p SIZE_NBITS is the size in bits of element of the tensor. | |||
// * | |||
// */ | |||
//template <size_t SIZE_NBITS_> | |||
//class LowbitTensorFormat : public TensorFormat::ImplBase { | |||
// static constexpr size_t SIZE_NBITS = SIZE_NBITS_; | |||
// size_t m_align_size_in_bits; | |||
// | |||
//protected: //? | |||
// LowbitTensorFormat(Type type, size_t m_align_size_in_bits); | |||
// | |||
//public: | |||
// size_t align_size_in_bits() const { | |||
// return m_align_size_in_bits; | |||
// } | |||
// | |||
// std::string to_string() const override; | |||
// | |||
// void serialize_append( | |||
// | |||
// | |||
//}; | |||
/*! | |||
* \brief used for tensors storing lowbit data | |||
* | |||
* \p SIZE_NBITS is the size in bits of element of the tensor. | |||
* | |||
*/ | |||
template <size_t SIZE_NBITS_> | |||
class LowbitsTensorFormatBase : public TensorFormat::ImplBase { | |||
static constexpr size_t SIZE_NBITS = SIZE_NBITS_; | |||
size_t m_align_size_in_bits, m_align_size_in_elements; | |||
protected: //? | |||
LowbitsTensorFormatBase(Type type, size_t align_size_in_bits); | |||
virtual ~LowbitsTensorFormatBase() = default; | |||
public: | |||
size_t align_size_in_bits() const { return m_align_size_in_bits; } | |||
std::string to_string() const override; | |||
//! raise exception if given layout is illegal | |||
void assert_valid(const TensorLayout& layout) const; | |||
void serialize_append(std::string& result) const override; | |||
//! span for lowbit tensor must include the padding at the innermost | |||
//! dimemsion that make lowbit tensor be aligned to bytes | |||
TensorLayout::Span span_spec(const TensorLayout& layout) const override; | |||
size_t init_contiguous_stride(TensorLayout& layout) const override; | |||
bool is_contiguous_spec(const TensorLayout& layout) const override; | |||
TensorLayout collapse_contiguous_spec( | |||
const TensorLayout& layout) const override; | |||
protected: | |||
struct SerializePack { | |||
uint8_t align_size_in_bits; | |||
}; | |||
}; | |||
using FourBitsAlignedToBytesTensorFormatBase = LowbitsTensorFormatBase<4>; | |||
} // namespace detail | |||
/*! | |||
@@ -270,6 +292,34 @@ private: | |||
TYPE, align_axis, align_size_in_elements, vendor_type) {} | |||
}; | |||
/*! | |||
* \brief Tensor for storing 4bit data that requires stride corresponding to | |||
* non-innermost dimension to be aligned to bytes, and pack 2 elems into a byte | |||
*/ | |||
class FourBitsAlignedToBytesTensorFormat final | |||
: public detail::FourBitsAlignedToBytesTensorFormatBase { | |||
public: | |||
static constexpr Type TYPE = Type::FOURBITS_ALIGNED_TO_BYTE; | |||
static TensorFormat make(size_t align_size_in_bits); | |||
static TensorFormat deserialize(const Handle* handle, const void* buf, | |||
size_t size); | |||
static bool is_valid_layout(const TensorLayout& layout) { | |||
if (layout.format.type() == TYPE) { | |||
layout.format.as_impl<FourBitsAlignedToBytesTensorFormat>() | |||
.assert_valid(layout); | |||
return true; | |||
} | |||
return false; | |||
} | |||
private: | |||
FourBitsAlignedToBytesTensorFormat(size_t align_size_in_bits) | |||
: detail::FourBitsAlignedToBytesTensorFormatBase( | |||
TYPE, align_size_in_bits) {} | |||
}; | |||
} // namespace megdnn | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
@@ -201,7 +201,15 @@ TensorLayout::TensorLayout(DType dtype_, Format format_) | |||
: dtype{dtype_}, format{format_} {} | |||
TensorLayout::TensorLayout(const TensorShape& shape, DType dtype) | |||
: TensorLayout(shape, dtype, DefaultTensorFormat::make()) {} | |||
: TensorShape(shape), dtype{dtype} { | |||
if (dtype.low_bit() == 4_z) { | |||
format = FourBitsAlignedToBytesTensorFormat::make(8_z); | |||
} else { | |||
megdnn_assert(!dtype.is_low_bit(), "Unsupported data type(%s)", | |||
dtype.name()); | |||
format = DefaultTensorFormat::make(); | |||
} | |||
} | |||
TensorLayout::TensorLayout(const TensorShape& shape, DType dtype, | |||
TensorFormat format_) | |||
@@ -35,6 +35,9 @@ TensorFormat TensorFormat::deserialize(const std::string& bin, | |||
case Type::IMAGE2D_PACK4: | |||
return Image2DPack4TensorFormat::deserialize( | |||
handle, type + 1, bin.size() - sizeof(Type)); | |||
case Type::FOURBITS_ALIGNED_TO_BYTE: | |||
return FourBitsAlignedToBytesTensorFormat::deserialize( | |||
handle, type + 1, bin.size() - sizeof(Type)); | |||
default: | |||
megdnn_throw("invalid tensor format type in deserialize"); | |||
} | |||
@@ -67,7 +70,15 @@ bool TensorFormat::is_default() const { | |||
} | |||
/* ===================== DefaultFormat ===================== */ | |||
void DefaultTensorFormat::assert_valid(const TensorLayout& layout) const { | |||
megdnn_assert( | |||
!layout.dtype.valid() || !layout.dtype.is_low_bit(), | |||
"DefaultTensorFormat does not support low-bits tensor(dtype:%s)", | |||
layout.dtype.name()); | |||
} | |||
size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const { | |||
assert_valid(layout); | |||
if (!layout.ndim) | |||
return 0; | |||
megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM); | |||
@@ -81,11 +92,13 @@ size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const { | |||
} | |||
bool DefaultTensorFormat::is_contiguous_spec(const TensorLayout& layout) const { | |||
assert_valid(layout); | |||
return layout.is_physical_contiguous(); | |||
} | |||
TensorLayout DefaultTensorFormat::collapse_contiguous_spec( | |||
const TensorLayout& layout) const { | |||
assert_valid(layout); | |||
megdnn_assert(layout.ndim); | |||
TensorLayout res{layout}; | |||
@@ -126,6 +139,7 @@ TensorLayout DefaultTensorFormat::collapse_contiguous_spec( | |||
TensorLayout::Span DefaultTensorFormat::span_spec( | |||
const TensorLayout& layout) const { | |||
assert_valid(layout); | |||
if (layout.ndim == 0) | |||
return {0, 0, 0, 0}; | |||
@@ -146,9 +160,6 @@ TensorLayout::Span DefaultTensorFormat::span_spec( | |||
++high_elem; | |||
ptrdiff_t low_byte; | |||
if (low_elem < 0) { | |||
megdnn_assert(!layout.dtype.is_low_bit(), | |||
"tensors with low-bit dytes shouldn't have negative " | |||
"strides"); | |||
low_byte = low_elem * layout.dtype.size(); | |||
} else { | |||
low_byte = 0; | |||
@@ -422,12 +433,151 @@ TensorLayout Image2DPackedTensorFormatBase<PIXEL_SIZE>::collapse_contiguous_spec | |||
return res; | |||
} | |||
namespace megdnn { | |||
namespace detail { | |||
template class Image2DPackedTensorFormatBase<4>; | |||
} // namespace detail | |||
} // namespace megdnn | |||
/* =============== FourBitsAlignedToBytesTensorFormatBase ============== */ | |||
template <size_t SIZE_NBITS> | |||
LowbitsTensorFormatBase<SIZE_NBITS>::LowbitsTensorFormatBase( | |||
Type type, size_t align_size_in_bits) | |||
: ImplBase(type), m_align_size_in_bits(align_size_in_bits) { | |||
megdnn_assert(!(m_align_size_in_bits % SIZE_NBITS), | |||
"align size(%zu) must be a multiple of element size(%zu)", | |||
m_align_size_in_bits, SIZE_NBITS); | |||
m_align_size_in_elements = m_align_size_in_bits / SIZE_NBITS; | |||
} | |||
template <size_t SIZE_NBITS> | |||
std::string LowbitsTensorFormatBase<SIZE_NBITS>::to_string() const { | |||
return ssprintf("LOWBITS{%zu,%zu}", SIZE_NBITS, m_align_size_in_bits); | |||
} | |||
template <size_t SIZE_NBITS> | |||
void LowbitsTensorFormatBase<SIZE_NBITS>::assert_valid( | |||
const TensorLayout& layout) const { | |||
megdnn_assert(layout.dtype.valid() && layout.dtype.is_low_bit() && | |||
layout.dtype.low_bit() == SIZE_NBITS); | |||
bool has_dim_unity_stride = false; | |||
for (int i = layout.ndim - 1; i >= 0; --i) { | |||
if (!has_dim_unity_stride && layout.stride[i] == 1) | |||
has_dim_unity_stride = true; | |||
megdnn_assert( | |||
layout.stride[i] >= 0 && | |||
(layout.stride[i] % m_align_size_in_elements == 0 || | |||
layout.stride[i] == 1), | |||
"bad stride: %zu", layout.stride[i]); | |||
} | |||
megdnn_assert(has_dim_unity_stride, "innermost dim not contiguous"); | |||
} | |||
template <size_t SIZE_NBITS> | |||
void LowbitsTensorFormatBase<SIZE_NBITS>::serialize_append( | |||
std::string& result) const { | |||
SerializePack pack; | |||
pack.align_size_in_bits = m_align_size_in_bits; | |||
megdnn_assert(pack.align_size_in_bits == | |||
m_align_size_in_bits); // detect overflow; | |||
result.append(reinterpret_cast<char*>(&pack), sizeof(pack)); | |||
} | |||
template <size_t SIZE_NBITS> | |||
TensorLayout::Span LowbitsTensorFormatBase<SIZE_NBITS>::span_spec( | |||
const TensorLayout& layout) const { | |||
assert_valid(layout); | |||
if (layout.ndim == 0) | |||
return {0, 0, 0, 0}; | |||
size_t high_elem = 0; | |||
for (size_t i = 0; i < layout.ndim; ++i) { | |||
auto shape_val = layout.shape[i]; | |||
if (!shape_val) { | |||
return {0, 0, 0, 0}; | |||
} | |||
auto stride_val = layout.stride[i]; | |||
megdnn_assert(stride_val >= 0, | |||
"lowbit tensors shouldn't have negative strides"); | |||
high_elem += (shape_val - 1) * stride_val; | |||
} | |||
++high_elem; | |||
size_t high_byte = layout.dtype.size(high_elem); | |||
return TensorLayout::Span(0, 0, high_elem, high_byte); | |||
} | |||
template <size_t SIZE_NBITS> | |||
size_t LowbitsTensorFormatBase<SIZE_NBITS>::init_contiguous_stride( | |||
TensorLayout& layout) const { | |||
if (!layout.ndim) | |||
return 0; | |||
megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM); | |||
size_t accum = 1; | |||
SafeMultiplies<size_t> mul; | |||
for (size_t i = layout.ndim; i; --i) { | |||
layout.stride[i - 1] = accum; | |||
auto multiplier = layout.shape[i - 1]; | |||
if (i == layout.ndim) | |||
multiplier = round_up(multiplier, m_align_size_in_elements); | |||
accum = mul(accum, multiplier); | |||
} | |||
return accum; | |||
} | |||
template <size_t SIZE_NBITS> | |||
bool LowbitsTensorFormatBase<SIZE_NBITS>::is_contiguous_spec( | |||
const TensorLayout& layout) const { | |||
assert_valid(layout); | |||
ptrdiff_t expected = 1; | |||
for (int i = static_cast<int>(layout.ndim) - 1; i >= 0; --i) { | |||
if (layout.shape[i] != 1 && layout.stride[i] != expected) | |||
return false; | |||
auto multiplier = layout.shape[i]; | |||
if (i == layout.ndim - 1) | |||
multiplier = round_up(multiplier, m_align_size_in_elements); | |||
expected *= multiplier; | |||
} | |||
return expected != 0; | |||
} | |||
template <size_t SIZE_NBITS> | |||
TensorLayout LowbitsTensorFormatBase<SIZE_NBITS>::collapse_contiguous_spec( | |||
const TensorLayout& layout) const { | |||
assert_valid(layout); | |||
TensorLayout res{layout}; | |||
for (int i = static_cast<int>(res.ndim) - 1; i >= 0; --i) { | |||
if (!res.shape[i]) { | |||
// empty tensor | |||
res.ndim = 1; | |||
res.shape[0] = 0; | |||
res.stride[0] = 1; | |||
return res; | |||
} | |||
if (res.shape[i] == 1) { | |||
res.remove_axis_inplace(i); | |||
} | |||
} | |||
megdnn_assert(res.ndim && res.shape[res.ndim - 1]); | |||
for (int i = static_cast<int>(res.ndim) - 2; i >= 0; --i) { | |||
megdnn_assert(res.shape[i]); | |||
if (res.stride[i] == | |||
res.stride[i + 1] * static_cast<ptrdiff_t>(res.shape[i + 1])) { | |||
res.shape[i] *= res.shape[i + 1]; | |||
res.stride[i] = res.stride[i + 1]; | |||
res.remove_axis_inplace(i + 1); | |||
} | |||
} | |||
return res; | |||
} | |||
namespace megdnn { | |||
namespace detail { | |||
template class LowbitsTensorFormatBase<4>; | |||
} // namespace detail | |||
} // namespace megdnn | |||
/* ===================== Image2DPack4TensorFormat ===================== */ | |||
TensorFormat Image2DPack4TensorFormat::make_raw( | |||
size_t align_axis, size_t align_size_in_elements, | |||
@@ -466,4 +616,29 @@ TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const { | |||
return make_raw(axis, align_size_in_elements(), vendor()); | |||
} | |||
/* ===================== FourBitsAlignedToBytesTensorFormat | |||
* ===================== */ | |||
TensorFormat FourBitsAlignedToBytesTensorFormat::make( | |||
size_t align_size_in_bits) { | |||
static std::mutex mtx; | |||
static std::unordered_map< | |||
uint32_t, std::unique_ptr<FourBitsAlignedToBytesTensorFormat>> | |||
cache; | |||
megdnn_assert(!(align_size_in_bits % 4)); | |||
MEGDNN_LOCK_GUARD(mtx); | |||
auto&& ptr = cache[static_cast<uint32_t>(align_size_in_bits)]; | |||
if (!ptr) { | |||
ptr.reset(new FourBitsAlignedToBytesTensorFormat{align_size_in_bits}); | |||
} | |||
return impl_to_tensor_format(ptr.get()); | |||
} | |||
TensorFormat FourBitsAlignedToBytesTensorFormat::deserialize(const Handle*, | |||
const void* buf, | |||
size_t size) { | |||
megdnn_assert(size == sizeof(SerializePack)); | |||
auto pack = *static_cast<const SerializePack*>(buf); | |||
return make(pack.align_size_in_bits); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -128,9 +128,11 @@ public: | |||
for (size_t i = 0; i < shapes.size(); ++i) { | |||
DType dt = (m_dtype.find(i) != m_dtype.end() ? m_dtype[i] | |||
: dtype::Float32()); | |||
TensorFormat fmt = | |||
(m_fmt.find(i) != m_fmt.end() ? m_fmt[i] : TensorFormat{}); | |||
layouts[i] = TensorLayout(shapes[i], dt, fmt); | |||
if (m_fmt.find(i) == m_fmt.end()) { | |||
layouts[i] = TensorLayout(shapes[i], dt); | |||
layouts[i].init_contiguous_stride(); | |||
} else | |||
layouts[i] = TensorLayout(shapes[i], dt, m_fmt[i]); | |||
} | |||
return layouts; | |||
} | |||
@@ -302,4 +302,50 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) { | |||
} | |||
} | |||
TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS) { | |||
TensorLayout layout{{16, 32, 7, 7}, dtype::QuantizedS4{1.2f}}; | |||
layout.init_contiguous_stride(); | |||
ASSERT_EQ(layout.stride[0], 1792); | |||
ASSERT_EQ(layout.stride[1], 56); | |||
ASSERT_EQ(layout.stride[2], 8); | |||
ASSERT_EQ(layout.stride[3], 1); | |||
auto span = layout.span(); | |||
ASSERT_EQ(0, span.low_elem); | |||
ASSERT_EQ(28671, span.high_elem); | |||
ASSERT_EQ(0, span.low_byte); | |||
ASSERT_EQ(14336, span.high_byte); | |||
EXPECT_EQ(make_layout({3584, 7}, {8, 1}, dtype::QuantizedS4{1.2f}), | |||
layout.collapse_contiguous()); | |||
layout = make_layout({16, 32, 7, 7}, {1792, 56, 8, 1}, | |||
dtype::QuantizedS4{1.3f}); | |||
layout.format = FourBitsAlignedToBytesTensorFormat::make(8_z); | |||
EXPECT_TRUE(layout.is_contiguous()); | |||
layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}}; | |||
layout = layout.broadcast({16, 32, 7, 7}); | |||
EXPECT_EQ(make_layout({16, 32, 49}, {0, 1, 0}, dtype::QuantizedS4{1.2}), | |||
layout.collapse_contiguous()); | |||
layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}}; | |||
layout.init_contiguous_stride(); | |||
layout = layout.broadcast({16, 32, 7, 7}); | |||
ASSERT_THROW(layout.span(), MegDNNError); | |||
} | |||
TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS_VALID) { | |||
ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS4{1.2f}, | |||
DefaultTensorFormat::make()), | |||
MegDNNError); | |||
ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS32{1.2f}, | |||
FourBitsAlignedToBytesTensorFormat::make(8_z)) | |||
.span(), | |||
MegDNNError); | |||
ASSERT_THROW(TensorLayout({16, 32, 7, 7}, dtype::IntB2{}, | |||
FourBitsAlignedToBytesTensorFormat::make(8_z)) | |||
.span(), | |||
MegDNNError); | |||
} | |||
// vim: syntax=cpp.doxygen |