GitOrigin-RevId: 0aa3753f37
release-1.5
@@ -505,9 +505,9 @@ class DType { | |||||
return std::numeric_limits<size_t>::max() >> m_trait->size_log; | return std::numeric_limits<size_t>::max() >> m_trait->size_log; | ||||
} | } | ||||
bool is_low_bit() const { | |||||
return m_trait->low_bit != 0; | |||||
} | |||||
size_t low_bit() const { return m_trait->low_bit; } | |||||
bool is_low_bit() const { return low_bit() != 0; } | |||||
/*! | /*! | ||||
* \brief size of this data type, in bytes | * \brief size of this data type, in bytes | ||||
@@ -20,12 +20,15 @@ namespace megdnn { | |||||
enum class TensorFormat::Type { | enum class TensorFormat::Type { | ||||
DEFAULT = 0, //!< see DefaultTensorFormat | DEFAULT = 0, //!< see DefaultTensorFormat | ||||
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat | IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat | ||||
FOURBITS_ALIGNED_TO_BYTE = 2, //!< | |||||
}; | }; | ||||
class TensorFormat::ImplBase { | class TensorFormat::ImplBase { | ||||
public: | public: | ||||
using Type = TensorFormat::Type; | using Type = TensorFormat::Type; | ||||
virtual void assert_valid(const TensorLayout& layout) const = 0; | |||||
virtual size_t init_contiguous_stride(TensorLayout& layout) const = 0; | virtual size_t init_contiguous_stride(TensorLayout& layout) const = 0; | ||||
virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; | virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; | ||||
@@ -63,6 +66,8 @@ public: | |||||
DefaultTensorFormat() : ImplBase(TYPE) {} | DefaultTensorFormat() : ImplBase(TYPE) {} | ||||
void assert_valid(const TensorLayout& layout) const override; | |||||
size_t init_contiguous_stride(TensorLayout& layout) const override; | size_t init_contiguous_stride(TensorLayout& layout) const override; | ||||
/*! | /*! | ||||
@@ -180,11 +185,11 @@ public: | |||||
*/ | */ | ||||
size_t image_width(const TensorLayout& layout) const; | size_t image_width(const TensorLayout& layout) const; | ||||
//! raise exception if preconditions violated | |||||
void assert_valid(const TensorLayout& layout) const; | |||||
size_t image_row_pitch(const TensorLayout& layout) const; | size_t image_row_pitch(const TensorLayout& layout) const; | ||||
//! raise exception if preconditions violated | |||||
void assert_valid(const TensorLayout& layout) const override; | |||||
//! span for image must include the padding at the last row | //! span for image must include the padding at the last row | ||||
TensorLayout::Span span_spec(const TensorLayout& layout) const override; | TensorLayout::Span span_spec(const TensorLayout& layout) const override; | ||||
@@ -197,31 +202,48 @@ public: | |||||
}; | }; | ||||
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | ||||
///*! | |||||
// * \brief used for tensors with lowbit data type | |||||
// * | |||||
// * \p SIZE_NBITS is the size in bits of element of the tensor. | |||||
// * | |||||
// */ | |||||
//template <size_t SIZE_NBITS_> | |||||
//class LowbitTensorFormat : public TensorFormat::ImplBase { | |||||
// static constexpr size_t SIZE_NBITS = SIZE_NBITS_; | |||||
// size_t m_align_size_in_bits; | |||||
// | |||||
//protected: //? | |||||
// LowbitTensorFormat(Type type, size_t m_align_size_in_bits); | |||||
// | |||||
//public: | |||||
// size_t align_size_in_bits() const { | |||||
// return m_align_size_in_bits; | |||||
// } | |||||
// | |||||
// std::string to_string() const override; | |||||
// | |||||
// void serialize_append( | |||||
// | |||||
// | |||||
//}; | |||||
/*! | |||||
* \brief used for tensors storing lowbit data | |||||
* | |||||
* \p SIZE_NBITS is the size in bits of element of the tensor. | |||||
* | |||||
*/ | |||||
template <size_t SIZE_NBITS_> | |||||
class LowbitsTensorFormatBase : public TensorFormat::ImplBase { | |||||
static constexpr size_t SIZE_NBITS = SIZE_NBITS_; | |||||
size_t m_align_size_in_bits, m_align_size_in_elements; | |||||
protected: //? | |||||
LowbitsTensorFormatBase(Type type, size_t align_size_in_bits); | |||||
virtual ~LowbitsTensorFormatBase() = default; | |||||
public: | |||||
size_t align_size_in_bits() const { return m_align_size_in_bits; } | |||||
std::string to_string() const override; | |||||
//! raise exception if given layout is illegal | |||||
void assert_valid(const TensorLayout& layout) const; | |||||
void serialize_append(std::string& result) const override; | |||||
//! span for lowbit tensor must include the padding at the innermost | |||||
//! dimemsion that make lowbit tensor be aligned to bytes | |||||
TensorLayout::Span span_spec(const TensorLayout& layout) const override; | |||||
size_t init_contiguous_stride(TensorLayout& layout) const override; | |||||
bool is_contiguous_spec(const TensorLayout& layout) const override; | |||||
TensorLayout collapse_contiguous_spec( | |||||
const TensorLayout& layout) const override; | |||||
protected: | |||||
struct SerializePack { | |||||
uint8_t align_size_in_bits; | |||||
}; | |||||
}; | |||||
using FourBitsAlignedToBytesTensorFormatBase = LowbitsTensorFormatBase<4>; | |||||
} // namespace detail | } // namespace detail | ||||
/*! | /*! | ||||
@@ -270,6 +292,34 @@ private: | |||||
TYPE, align_axis, align_size_in_elements, vendor_type) {} | TYPE, align_axis, align_size_in_elements, vendor_type) {} | ||||
}; | }; | ||||
/*! | |||||
* \brief Tensor for storing 4bit data that requires stride corresponding to | |||||
* non-innermost dimension to be aligned to bytes, and pack 2 elems into a byte | |||||
*/ | |||||
class FourBitsAlignedToBytesTensorFormat final | |||||
: public detail::FourBitsAlignedToBytesTensorFormatBase { | |||||
public: | |||||
static constexpr Type TYPE = Type::FOURBITS_ALIGNED_TO_BYTE; | |||||
static TensorFormat make(size_t align_size_in_bits); | |||||
static TensorFormat deserialize(const Handle* handle, const void* buf, | |||||
size_t size); | |||||
static bool is_valid_layout(const TensorLayout& layout) { | |||||
if (layout.format.type() == TYPE) { | |||||
layout.format.as_impl<FourBitsAlignedToBytesTensorFormat>() | |||||
.assert_valid(layout); | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
private: | |||||
FourBitsAlignedToBytesTensorFormat(size_t align_size_in_bits) | |||||
: detail::FourBitsAlignedToBytesTensorFormatBase( | |||||
TYPE, align_size_in_bits) {} | |||||
}; | |||||
} // namespace megdnn | } // namespace megdnn | ||||
#include "megdnn/internal/visibility_epilogue.h" | #include "megdnn/internal/visibility_epilogue.h" | ||||
@@ -201,7 +201,15 @@ TensorLayout::TensorLayout(DType dtype_, Format format_) | |||||
: dtype{dtype_}, format{format_} {} | : dtype{dtype_}, format{format_} {} | ||||
TensorLayout::TensorLayout(const TensorShape& shape, DType dtype) | TensorLayout::TensorLayout(const TensorShape& shape, DType dtype) | ||||
: TensorLayout(shape, dtype, DefaultTensorFormat::make()) {} | |||||
: TensorShape(shape), dtype{dtype} { | |||||
if (dtype.low_bit() == 4_z) { | |||||
format = FourBitsAlignedToBytesTensorFormat::make(8_z); | |||||
} else { | |||||
megdnn_assert(!dtype.is_low_bit(), "Unsupported data type(%s)", | |||||
dtype.name()); | |||||
format = DefaultTensorFormat::make(); | |||||
} | |||||
} | |||||
TensorLayout::TensorLayout(const TensorShape& shape, DType dtype, | TensorLayout::TensorLayout(const TensorShape& shape, DType dtype, | ||||
TensorFormat format_) | TensorFormat format_) | ||||
@@ -35,6 +35,9 @@ TensorFormat TensorFormat::deserialize(const std::string& bin, | |||||
case Type::IMAGE2D_PACK4: | case Type::IMAGE2D_PACK4: | ||||
return Image2DPack4TensorFormat::deserialize( | return Image2DPack4TensorFormat::deserialize( | ||||
handle, type + 1, bin.size() - sizeof(Type)); | handle, type + 1, bin.size() - sizeof(Type)); | ||||
case Type::FOURBITS_ALIGNED_TO_BYTE: | |||||
return FourBitsAlignedToBytesTensorFormat::deserialize( | |||||
handle, type + 1, bin.size() - sizeof(Type)); | |||||
default: | default: | ||||
megdnn_throw("invalid tensor format type in deserialize"); | megdnn_throw("invalid tensor format type in deserialize"); | ||||
} | } | ||||
@@ -67,7 +70,15 @@ bool TensorFormat::is_default() const { | |||||
} | } | ||||
/* ===================== DefaultFormat ===================== */ | /* ===================== DefaultFormat ===================== */ | ||||
void DefaultTensorFormat::assert_valid(const TensorLayout& layout) const { | |||||
megdnn_assert( | |||||
!layout.dtype.valid() || !layout.dtype.is_low_bit(), | |||||
"DefaultTensorFormat does not support low-bits tensor(dtype:%s)", | |||||
layout.dtype.name()); | |||||
} | |||||
size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const { | size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const { | ||||
assert_valid(layout); | |||||
if (!layout.ndim) | if (!layout.ndim) | ||||
return 0; | return 0; | ||||
megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM); | megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM); | ||||
@@ -81,11 +92,13 @@ size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const { | |||||
} | } | ||||
bool DefaultTensorFormat::is_contiguous_spec(const TensorLayout& layout) const { | bool DefaultTensorFormat::is_contiguous_spec(const TensorLayout& layout) const { | ||||
assert_valid(layout); | |||||
return layout.is_physical_contiguous(); | return layout.is_physical_contiguous(); | ||||
} | } | ||||
TensorLayout DefaultTensorFormat::collapse_contiguous_spec( | TensorLayout DefaultTensorFormat::collapse_contiguous_spec( | ||||
const TensorLayout& layout) const { | const TensorLayout& layout) const { | ||||
assert_valid(layout); | |||||
megdnn_assert(layout.ndim); | megdnn_assert(layout.ndim); | ||||
TensorLayout res{layout}; | TensorLayout res{layout}; | ||||
@@ -126,6 +139,7 @@ TensorLayout DefaultTensorFormat::collapse_contiguous_spec( | |||||
TensorLayout::Span DefaultTensorFormat::span_spec( | TensorLayout::Span DefaultTensorFormat::span_spec( | ||||
const TensorLayout& layout) const { | const TensorLayout& layout) const { | ||||
assert_valid(layout); | |||||
if (layout.ndim == 0) | if (layout.ndim == 0) | ||||
return {0, 0, 0, 0}; | return {0, 0, 0, 0}; | ||||
@@ -146,9 +160,6 @@ TensorLayout::Span DefaultTensorFormat::span_spec( | |||||
++high_elem; | ++high_elem; | ||||
ptrdiff_t low_byte; | ptrdiff_t low_byte; | ||||
if (low_elem < 0) { | if (low_elem < 0) { | ||||
megdnn_assert(!layout.dtype.is_low_bit(), | |||||
"tensors with low-bit dytes shouldn't have negative " | |||||
"strides"); | |||||
low_byte = low_elem * layout.dtype.size(); | low_byte = low_elem * layout.dtype.size(); | ||||
} else { | } else { | ||||
low_byte = 0; | low_byte = 0; | ||||
@@ -422,12 +433,151 @@ TensorLayout Image2DPackedTensorFormatBase<PIXEL_SIZE>::collapse_contiguous_spec | |||||
return res; | return res; | ||||
} | } | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace detail { | namespace detail { | ||||
template class Image2DPackedTensorFormatBase<4>; | template class Image2DPackedTensorFormatBase<4>; | ||||
} // namespace detail | } // namespace detail | ||||
} // namespace megdnn | } // namespace megdnn | ||||
/* =============== FourBitsAlignedToBytesTensorFormatBase ============== */ | |||||
template <size_t SIZE_NBITS> | |||||
LowbitsTensorFormatBase<SIZE_NBITS>::LowbitsTensorFormatBase( | |||||
Type type, size_t align_size_in_bits) | |||||
: ImplBase(type), m_align_size_in_bits(align_size_in_bits) { | |||||
megdnn_assert(!(m_align_size_in_bits % SIZE_NBITS), | |||||
"align size(%zu) must be a multiple of element size(%zu)", | |||||
m_align_size_in_bits, SIZE_NBITS); | |||||
m_align_size_in_elements = m_align_size_in_bits / SIZE_NBITS; | |||||
} | |||||
template <size_t SIZE_NBITS> | |||||
std::string LowbitsTensorFormatBase<SIZE_NBITS>::to_string() const { | |||||
return ssprintf("LOWBITS{%zu,%zu}", SIZE_NBITS, m_align_size_in_bits); | |||||
} | |||||
template <size_t SIZE_NBITS> | |||||
void LowbitsTensorFormatBase<SIZE_NBITS>::assert_valid( | |||||
const TensorLayout& layout) const { | |||||
megdnn_assert(layout.dtype.valid() && layout.dtype.is_low_bit() && | |||||
layout.dtype.low_bit() == SIZE_NBITS); | |||||
bool has_dim_unity_stride = false; | |||||
for (int i = layout.ndim - 1; i >= 0; --i) { | |||||
if (!has_dim_unity_stride && layout.stride[i] == 1) | |||||
has_dim_unity_stride = true; | |||||
megdnn_assert( | |||||
layout.stride[i] >= 0 && | |||||
(layout.stride[i] % m_align_size_in_elements == 0 || | |||||
layout.stride[i] == 1), | |||||
"bad stride: %zu", layout.stride[i]); | |||||
} | |||||
megdnn_assert(has_dim_unity_stride, "innermost dim not contiguous"); | |||||
} | |||||
template <size_t SIZE_NBITS> | |||||
void LowbitsTensorFormatBase<SIZE_NBITS>::serialize_append( | |||||
std::string& result) const { | |||||
SerializePack pack; | |||||
pack.align_size_in_bits = m_align_size_in_bits; | |||||
megdnn_assert(pack.align_size_in_bits == | |||||
m_align_size_in_bits); // detect overflow; | |||||
result.append(reinterpret_cast<char*>(&pack), sizeof(pack)); | |||||
} | |||||
template <size_t SIZE_NBITS> | |||||
TensorLayout::Span LowbitsTensorFormatBase<SIZE_NBITS>::span_spec( | |||||
const TensorLayout& layout) const { | |||||
assert_valid(layout); | |||||
if (layout.ndim == 0) | |||||
return {0, 0, 0, 0}; | |||||
size_t high_elem = 0; | |||||
for (size_t i = 0; i < layout.ndim; ++i) { | |||||
auto shape_val = layout.shape[i]; | |||||
if (!shape_val) { | |||||
return {0, 0, 0, 0}; | |||||
} | |||||
auto stride_val = layout.stride[i]; | |||||
megdnn_assert(stride_val >= 0, | |||||
"lowbit tensors shouldn't have negative strides"); | |||||
high_elem += (shape_val - 1) * stride_val; | |||||
} | |||||
++high_elem; | |||||
size_t high_byte = layout.dtype.size(high_elem); | |||||
return TensorLayout::Span(0, 0, high_elem, high_byte); | |||||
} | |||||
template <size_t SIZE_NBITS> | |||||
size_t LowbitsTensorFormatBase<SIZE_NBITS>::init_contiguous_stride( | |||||
TensorLayout& layout) const { | |||||
if (!layout.ndim) | |||||
return 0; | |||||
megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM); | |||||
size_t accum = 1; | |||||
SafeMultiplies<size_t> mul; | |||||
for (size_t i = layout.ndim; i; --i) { | |||||
layout.stride[i - 1] = accum; | |||||
auto multiplier = layout.shape[i - 1]; | |||||
if (i == layout.ndim) | |||||
multiplier = round_up(multiplier, m_align_size_in_elements); | |||||
accum = mul(accum, multiplier); | |||||
} | |||||
return accum; | |||||
} | |||||
template <size_t SIZE_NBITS> | |||||
bool LowbitsTensorFormatBase<SIZE_NBITS>::is_contiguous_spec( | |||||
const TensorLayout& layout) const { | |||||
assert_valid(layout); | |||||
ptrdiff_t expected = 1; | |||||
for (int i = static_cast<int>(layout.ndim) - 1; i >= 0; --i) { | |||||
if (layout.shape[i] != 1 && layout.stride[i] != expected) | |||||
return false; | |||||
auto multiplier = layout.shape[i]; | |||||
if (i == layout.ndim - 1) | |||||
multiplier = round_up(multiplier, m_align_size_in_elements); | |||||
expected *= multiplier; | |||||
} | |||||
return expected != 0; | |||||
} | |||||
template <size_t SIZE_NBITS> | |||||
TensorLayout LowbitsTensorFormatBase<SIZE_NBITS>::collapse_contiguous_spec( | |||||
const TensorLayout& layout) const { | |||||
assert_valid(layout); | |||||
TensorLayout res{layout}; | |||||
for (int i = static_cast<int>(res.ndim) - 1; i >= 0; --i) { | |||||
if (!res.shape[i]) { | |||||
// empty tensor | |||||
res.ndim = 1; | |||||
res.shape[0] = 0; | |||||
res.stride[0] = 1; | |||||
return res; | |||||
} | |||||
if (res.shape[i] == 1) { | |||||
res.remove_axis_inplace(i); | |||||
} | |||||
} | |||||
megdnn_assert(res.ndim && res.shape[res.ndim - 1]); | |||||
for (int i = static_cast<int>(res.ndim) - 2; i >= 0; --i) { | |||||
megdnn_assert(res.shape[i]); | |||||
if (res.stride[i] == | |||||
res.stride[i + 1] * static_cast<ptrdiff_t>(res.shape[i + 1])) { | |||||
res.shape[i] *= res.shape[i + 1]; | |||||
res.stride[i] = res.stride[i + 1]; | |||||
res.remove_axis_inplace(i + 1); | |||||
} | |||||
} | |||||
return res; | |||||
} | |||||
namespace megdnn { | |||||
namespace detail { | |||||
template class LowbitsTensorFormatBase<4>; | |||||
} // namespace detail | |||||
} // namespace megdnn | |||||
/* ===================== Image2DPack4TensorFormat ===================== */ | /* ===================== Image2DPack4TensorFormat ===================== */ | ||||
TensorFormat Image2DPack4TensorFormat::make_raw( | TensorFormat Image2DPack4TensorFormat::make_raw( | ||||
size_t align_axis, size_t align_size_in_elements, | size_t align_axis, size_t align_size_in_elements, | ||||
@@ -466,4 +616,29 @@ TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const { | |||||
return make_raw(axis, align_size_in_elements(), vendor()); | return make_raw(axis, align_size_in_elements(), vendor()); | ||||
} | } | ||||
/* ===================== FourBitsAlignedToBytesTensorFormat | |||||
* ===================== */ | |||||
TensorFormat FourBitsAlignedToBytesTensorFormat::make( | |||||
size_t align_size_in_bits) { | |||||
static std::mutex mtx; | |||||
static std::unordered_map< | |||||
uint32_t, std::unique_ptr<FourBitsAlignedToBytesTensorFormat>> | |||||
cache; | |||||
megdnn_assert(!(align_size_in_bits % 4)); | |||||
MEGDNN_LOCK_GUARD(mtx); | |||||
auto&& ptr = cache[static_cast<uint32_t>(align_size_in_bits)]; | |||||
if (!ptr) { | |||||
ptr.reset(new FourBitsAlignedToBytesTensorFormat{align_size_in_bits}); | |||||
} | |||||
return impl_to_tensor_format(ptr.get()); | |||||
} | |||||
TensorFormat FourBitsAlignedToBytesTensorFormat::deserialize(const Handle*, | |||||
const void* buf, | |||||
size_t size) { | |||||
megdnn_assert(size == sizeof(SerializePack)); | |||||
auto pack = *static_cast<const SerializePack*>(buf); | |||||
return make(pack.align_size_in_bits); | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -128,9 +128,11 @@ public: | |||||
for (size_t i = 0; i < shapes.size(); ++i) { | for (size_t i = 0; i < shapes.size(); ++i) { | ||||
DType dt = (m_dtype.find(i) != m_dtype.end() ? m_dtype[i] | DType dt = (m_dtype.find(i) != m_dtype.end() ? m_dtype[i] | ||||
: dtype::Float32()); | : dtype::Float32()); | ||||
TensorFormat fmt = | |||||
(m_fmt.find(i) != m_fmt.end() ? m_fmt[i] : TensorFormat{}); | |||||
layouts[i] = TensorLayout(shapes[i], dt, fmt); | |||||
if (m_fmt.find(i) == m_fmt.end()) { | |||||
layouts[i] = TensorLayout(shapes[i], dt); | |||||
layouts[i].init_contiguous_stride(); | |||||
} else | |||||
layouts[i] = TensorLayout(shapes[i], dt, m_fmt[i]); | |||||
} | } | ||||
return layouts; | return layouts; | ||||
} | } | ||||
@@ -302,4 +302,50 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) { | |||||
} | } | ||||
} | } | ||||
TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS) { | |||||
TensorLayout layout{{16, 32, 7, 7}, dtype::QuantizedS4{1.2f}}; | |||||
layout.init_contiguous_stride(); | |||||
ASSERT_EQ(layout.stride[0], 1792); | |||||
ASSERT_EQ(layout.stride[1], 56); | |||||
ASSERT_EQ(layout.stride[2], 8); | |||||
ASSERT_EQ(layout.stride[3], 1); | |||||
auto span = layout.span(); | |||||
ASSERT_EQ(0, span.low_elem); | |||||
ASSERT_EQ(28671, span.high_elem); | |||||
ASSERT_EQ(0, span.low_byte); | |||||
ASSERT_EQ(14336, span.high_byte); | |||||
EXPECT_EQ(make_layout({3584, 7}, {8, 1}, dtype::QuantizedS4{1.2f}), | |||||
layout.collapse_contiguous()); | |||||
layout = make_layout({16, 32, 7, 7}, {1792, 56, 8, 1}, | |||||
dtype::QuantizedS4{1.3f}); | |||||
layout.format = FourBitsAlignedToBytesTensorFormat::make(8_z); | |||||
EXPECT_TRUE(layout.is_contiguous()); | |||||
layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}}; | |||||
layout = layout.broadcast({16, 32, 7, 7}); | |||||
EXPECT_EQ(make_layout({16, 32, 49}, {0, 1, 0}, dtype::QuantizedS4{1.2}), | |||||
layout.collapse_contiguous()); | |||||
layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}}; | |||||
layout.init_contiguous_stride(); | |||||
layout = layout.broadcast({16, 32, 7, 7}); | |||||
ASSERT_THROW(layout.span(), MegDNNError); | |||||
} | |||||
TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS_VALID) { | |||||
ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS4{1.2f}, | |||||
DefaultTensorFormat::make()), | |||||
MegDNNError); | |||||
ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS32{1.2f}, | |||||
FourBitsAlignedToBytesTensorFormat::make(8_z)) | |||||
.span(), | |||||
MegDNNError); | |||||
ASSERT_THROW(TensorLayout({16, 32, 7, 7}, dtype::IntB2{}, | |||||
FourBitsAlignedToBytesTensorFormat::make(8_z)) | |||||
.span(), | |||||
MegDNNError); | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |