@@ -170,6 +170,7 @@ struct TensorLayout : public TensorShape { | |||||
#if MEGDNN_CC_HOST | #if MEGDNN_CC_HOST | ||||
Format(); | Format(); | ||||
Format(DType dtype); | |||||
const ImplBase* impl() const { return m_impl; } | const ImplBase* impl() const { return m_impl; } | ||||
@@ -198,6 +199,9 @@ struct TensorLayout : public TensorShape { | |||||
//! whether this is the default tensor format | //! whether this is the default tensor format | ||||
bool is_default() const; | bool is_default() const; | ||||
//! whether this is the lowbit aligned to bytes tensor format | |||||
bool is_lowbit_aligned() const; | |||||
bool operator==(Format rhs) const { return m_impl == rhs.m_impl; } | bool operator==(Format rhs) const { return m_impl == rhs.m_impl; } | ||||
bool operator!=(Format rhs) const { return m_impl != rhs.m_impl; } | bool operator!=(Format rhs) const { return m_impl != rhs.m_impl; } | ||||
#endif | #endif | ||||
@@ -20,7 +20,7 @@ namespace megdnn { | |||||
enum class TensorFormat::Type { | enum class TensorFormat::Type { | ||||
DEFAULT = 0, //!< see DefaultTensorFormat | DEFAULT = 0, //!< see DefaultTensorFormat | ||||
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat | IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat | ||||
FOURBITS_ALIGNED_TO_BYTE = 2, //!< | |||||
LOWBITS_ALIGNED_TO_BYTE = 2, //!< | |||||
}; | }; | ||||
class TensorFormat::ImplBase { | class TensorFormat::ImplBase { | ||||
@@ -205,21 +205,23 @@ using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | |||||
/*! | /*! | ||||
* \brief used for tensors storing lowbit data | * \brief used for tensors storing lowbit data | ||||
* | * | ||||
* \p SIZE_NBITS is the size in bits of element of the tensor. | |||||
* | |||||
* \param m_size_nbits size in bits of elements in the tensor | |||||
* \param m_align_size_in_bits aligned size in bits | |||||
* \param m_align_size_in_elements aligned size in elements | |||||
*/ | */ | ||||
template <size_t SIZE_NBITS_> | |||||
class LowbitsTensorFormatBase : public TensorFormat::ImplBase { | |||||
static constexpr size_t SIZE_NBITS = SIZE_NBITS_; | |||||
size_t m_align_size_in_bits, m_align_size_in_elements; | |||||
class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase { | |||||
size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements; | |||||
protected: //? | protected: //? | ||||
LowbitsTensorFormatBase(Type type, size_t align_size_in_bits); | |||||
LowbitsAlignedTensorFormatBase(Type type, size_t size_nbits, | |||||
size_t align_size_in_bits); | |||||
virtual ~LowbitsTensorFormatBase() = default; | |||||
virtual ~LowbitsAlignedTensorFormatBase() = default; | |||||
public: | public: | ||||
size_t align_size_in_bits() const { return m_align_size_in_bits; } | size_t align_size_in_bits() const { return m_align_size_in_bits; } | ||||
size_t size_nbits() const { return m_size_nbits; } | |||||
std::string to_string() const override; | std::string to_string() const override; | ||||
@@ -240,10 +242,10 @@ public: | |||||
const TensorLayout& layout) const override; | const TensorLayout& layout) const override; | ||||
protected: | protected: | ||||
struct SerializePack { | struct SerializePack { | ||||
uint8_t size_nbits; | |||||
uint8_t align_size_in_bits; | uint8_t align_size_in_bits; | ||||
}; | }; | ||||
}; | }; | ||||
using FourBitsAlignedToBytesTensorFormatBase = LowbitsTensorFormatBase<4>; | |||||
} // namespace detail | } // namespace detail | ||||
/*! | /*! | ||||
@@ -296,19 +298,20 @@ private: | |||||
* \brief Tensor for storing 4bit data that requires stride corresponding to | * \brief Tensor for storing 4bit data that requires stride corresponding to | ||||
* non-innermost dimension to be aligned to bytes, and pack 2 elems into a byte | * non-innermost dimension to be aligned to bytes, and pack 2 elems into a byte | ||||
*/ | */ | ||||
class FourBitsAlignedToBytesTensorFormat final | |||||
: public detail::FourBitsAlignedToBytesTensorFormatBase { | |||||
class LowbitsAlignedToBytesTensorFormat final | |||||
: public detail::LowbitsAlignedTensorFormatBase { | |||||
public: | public: | ||||
static constexpr Type TYPE = Type::FOURBITS_ALIGNED_TO_BYTE; | |||||
static constexpr Type TYPE = Type::LOWBITS_ALIGNED_TO_BYTE; | |||||
static constexpr size_t BYTE_IN_BITS = 8; | |||||
static TensorFormat make(size_t align_size_in_bits); | |||||
static TensorFormat make(size_t size_nbits); | |||||
static TensorFormat deserialize(const Handle* handle, const void* buf, | static TensorFormat deserialize(const Handle* handle, const void* buf, | ||||
size_t size); | size_t size); | ||||
static bool is_valid_layout(const TensorLayout& layout) { | static bool is_valid_layout(const TensorLayout& layout) { | ||||
if (layout.format.type() == TYPE) { | if (layout.format.type() == TYPE) { | ||||
layout.format.as_impl<FourBitsAlignedToBytesTensorFormat>() | |||||
layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>() | |||||
.assert_valid(layout); | .assert_valid(layout); | ||||
return true; | return true; | ||||
} | } | ||||
@@ -316,9 +319,9 @@ public: | |||||
} | } | ||||
private: | private: | ||||
FourBitsAlignedToBytesTensorFormat(size_t align_size_in_bits) | |||||
: detail::FourBitsAlignedToBytesTensorFormatBase( | |||||
TYPE, align_size_in_bits) {} | |||||
LowbitsAlignedToBytesTensorFormat(size_t size_nbits) | |||||
: detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, | |||||
BYTE_IN_BITS) {} | |||||
}; | }; | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -195,21 +195,14 @@ bool TensorShape::is_empty() const { | |||||
/* ===================== TensorLayout ===================== */ | /* ===================== TensorLayout ===================== */ | ||||
TensorLayout::TensorLayout() = default; | TensorLayout::TensorLayout() = default; | ||||
TensorLayout::TensorLayout(DType dtype_) : dtype{dtype_} {} | |||||
TensorLayout::TensorLayout(DType dtype_) | |||||
: dtype{dtype_}, format{Format(dtype)} {} | |||||
TensorLayout::TensorLayout(DType dtype_, Format format_) | TensorLayout::TensorLayout(DType dtype_, Format format_) | ||||
: dtype{dtype_}, format{format_} {} | : dtype{dtype_}, format{format_} {} | ||||
TensorLayout::TensorLayout(const TensorShape& shape, DType dtype) | TensorLayout::TensorLayout(const TensorShape& shape, DType dtype) | ||||
: TensorShape(shape), dtype{dtype} { | |||||
if (dtype.low_bit() == 4_z) { | |||||
format = FourBitsAlignedToBytesTensorFormat::make(8_z); | |||||
} else { | |||||
megdnn_assert(!dtype.is_low_bit(), "Unsupported data type(%s)", | |||||
dtype.name()); | |||||
format = DefaultTensorFormat::make(); | |||||
} | |||||
} | |||||
: TensorLayout(shape, dtype, Format(dtype)) {} | |||||
TensorLayout::TensorLayout(const TensorShape& shape, DType dtype, | TensorLayout::TensorLayout(const TensorShape& shape, DType dtype, | ||||
TensorFormat format_) | TensorFormat format_) | ||||
@@ -722,7 +722,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
megdnn_assert(src.ndim == 5 && | megdnn_assert(src.ndim == 5 && | ||||
(filter.ndim == 5 || filter.ndim == 6) && | (filter.ndim == 5 || filter.ndim == 6) && | ||||
src[src.ndim - 1] == 64 && | src[src.ndim - 1] == 64 && | ||||
filter[filter.ndim - 1] == 4, | |||||
filter[filter.ndim - 1] == 64, | |||||
"NCHW64 require src and filter's ndim is 5 or 6, and " | "NCHW64 require src and filter's ndim is 5 or 6, and " | ||||
"last shape is 64 but got src %s, filter %s", | "last shape is 64 but got src %s, filter %s", | ||||
src.to_string().c_str(), filter.to_string().c_str()); | src.to_string().c_str(), filter.to_string().c_str()); | ||||
@@ -754,7 +754,6 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
src[i + src_or_dst_spatial_start], cflt.dilated_spatial[i], | src[i + src_or_dst_spatial_start], cflt.dilated_spatial[i], | ||||
cflt.stride[i], cflt.padding[i]); | cflt.stride[i], cflt.padding[i]); | ||||
} | } | ||||
dst.init_contiguous_stride(); | |||||
} else if (param().format == Param::Format::NCHW4) { | } else if (param().format == Param::Format::NCHW4) { | ||||
megdnn_assert(src.ndim == 5, | megdnn_assert(src.ndim == 5, | ||||
"invalid src ndim for NCHW4, expected=5, got=%zu", | "invalid src ndim for NCHW4, expected=5, got=%zu", | ||||
@@ -35,8 +35,8 @@ TensorFormat TensorFormat::deserialize(const std::string& bin, | |||||
case Type::IMAGE2D_PACK4: | case Type::IMAGE2D_PACK4: | ||||
return Image2DPack4TensorFormat::deserialize( | return Image2DPack4TensorFormat::deserialize( | ||||
handle, type + 1, bin.size() - sizeof(Type)); | handle, type + 1, bin.size() - sizeof(Type)); | ||||
case Type::FOURBITS_ALIGNED_TO_BYTE: | |||||
return FourBitsAlignedToBytesTensorFormat::deserialize( | |||||
case Type::LOWBITS_ALIGNED_TO_BYTE: | |||||
return LowbitsAlignedToBytesTensorFormat::deserialize( | |||||
handle, type + 1, bin.size() - sizeof(Type)); | handle, type + 1, bin.size() - sizeof(Type)); | ||||
default: | default: | ||||
megdnn_throw("invalid tensor format type in deserialize"); | megdnn_throw("invalid tensor format type in deserialize"); | ||||
@@ -45,6 +45,19 @@ TensorFormat TensorFormat::deserialize(const std::string& bin, | |||||
TensorFormat::Format() : m_impl{DefaultTensorFormat::make().m_impl} {} | TensorFormat::Format() : m_impl{DefaultTensorFormat::make().m_impl} {} | ||||
TensorFormat::Format(DType dtype) { | |||||
megdnn_assert(dtype.valid()); | |||||
if (dtype.is_low_bit()) { | |||||
size_t size_nbits = dtype.low_bit(); | |||||
megdnn_assert(size_nbits == 1 || size_nbits == 2 || size_nbits == 4, | |||||
"unsupported lowbits data type(%s, size in bits: %zu)", | |||||
dtype.name(), size_nbits); | |||||
m_impl = LowbitsAlignedToBytesTensorFormat::make(size_nbits).m_impl; | |||||
} else { | |||||
m_impl = DefaultTensorFormat::make().m_impl; | |||||
} | |||||
} | |||||
std::string TensorFormat::to_string() const { | std::string TensorFormat::to_string() const { | ||||
return m_impl->to_string(); | return m_impl->to_string(); | ||||
} | } | ||||
@@ -69,6 +82,10 @@ bool TensorFormat::is_default() const { | |||||
return m_impl == default_tensor_format_obj; | return m_impl == default_tensor_format_obj; | ||||
} | } | ||||
bool TensorFormat::is_lowbit_aligned() const { | |||||
return type() == TensorFormat::Type::LOWBITS_ALIGNED_TO_BYTE; | |||||
} | |||||
/* ===================== DefaultFormat ===================== */ | /* ===================== DefaultFormat ===================== */ | ||||
void DefaultTensorFormat::assert_valid(const TensorLayout& layout) const { | void DefaultTensorFormat::assert_valid(const TensorLayout& layout) const { | ||||
megdnn_assert( | megdnn_assert( | ||||
@@ -440,27 +457,26 @@ template class Image2DPackedTensorFormatBase<4>; | |||||
} // namespace detail | } // namespace detail | ||||
} // namespace megdnn | } // namespace megdnn | ||||
/* =============== FourBitsAlignedToBytesTensorFormatBase ============== */ | |||||
template <size_t SIZE_NBITS> | |||||
LowbitsTensorFormatBase<SIZE_NBITS>::LowbitsTensorFormatBase( | |||||
Type type, size_t align_size_in_bits) | |||||
: ImplBase(type), m_align_size_in_bits(align_size_in_bits) { | |||||
megdnn_assert(!(m_align_size_in_bits % SIZE_NBITS), | |||||
/* =============== LowbitsAlignedTensorFormatBase ============== */ | |||||
LowbitsAlignedTensorFormatBase::LowbitsAlignedTensorFormatBase( | |||||
Type type, size_t size_nbits, size_t align_size_in_bits) | |||||
: ImplBase(type), | |||||
m_size_nbits(size_nbits), | |||||
m_align_size_in_bits(align_size_in_bits) { | |||||
megdnn_assert(!(m_align_size_in_bits % m_size_nbits), | |||||
"align size(%zu) must be a multiple of element size(%zu)", | "align size(%zu) must be a multiple of element size(%zu)", | ||||
m_align_size_in_bits, SIZE_NBITS); | |||||
m_align_size_in_elements = m_align_size_in_bits / SIZE_NBITS; | |||||
m_align_size_in_bits, m_size_nbits); | |||||
m_align_size_in_elements = m_align_size_in_bits / m_size_nbits; | |||||
} | } | ||||
template <size_t SIZE_NBITS> | |||||
std::string LowbitsTensorFormatBase<SIZE_NBITS>::to_string() const { | |||||
return ssprintf("LOWBITS{%zu,%zu}", SIZE_NBITS, m_align_size_in_bits); | |||||
std::string LowbitsAlignedTensorFormatBase::to_string() const { | |||||
return ssprintf("LOWBITS{%zu,%zu}", m_size_nbits, m_align_size_in_bits); | |||||
} | } | ||||
template <size_t SIZE_NBITS> | |||||
void LowbitsTensorFormatBase<SIZE_NBITS>::assert_valid( | |||||
void LowbitsAlignedTensorFormatBase::assert_valid( | |||||
const TensorLayout& layout) const { | const TensorLayout& layout) const { | ||||
megdnn_assert(layout.dtype.valid() && layout.dtype.is_low_bit() && | megdnn_assert(layout.dtype.valid() && layout.dtype.is_low_bit() && | ||||
layout.dtype.low_bit() == SIZE_NBITS); | |||||
layout.dtype.low_bit() == m_size_nbits); | |||||
bool has_dim_unity_stride = false; | bool has_dim_unity_stride = false; | ||||
for (int i = layout.ndim - 1; i >= 0; --i) { | for (int i = layout.ndim - 1; i >= 0; --i) { | ||||
if (!has_dim_unity_stride && layout.stride[i] == 1) | if (!has_dim_unity_stride && layout.stride[i] == 1) | ||||
@@ -469,23 +485,28 @@ void LowbitsTensorFormatBase<SIZE_NBITS>::assert_valid( | |||||
layout.stride[i] >= 0 && | layout.stride[i] >= 0 && | ||||
(layout.stride[i] % m_align_size_in_elements == 0 || | (layout.stride[i] % m_align_size_in_elements == 0 || | ||||
layout.stride[i] == 1), | layout.stride[i] == 1), | ||||
"bad stride: %zu", layout.stride[i]); | |||||
"bad stride:%s, %zu", layout.to_string().c_str(), | |||||
layout.stride[i]); | |||||
} | } | ||||
megdnn_assert(has_dim_unity_stride, "innermost dim not contiguous"); | |||||
/// FIXME | |||||
if (layout.ndim == 0) { | |||||
printf("%s\n", layout.to_string().c_str()); | |||||
} | |||||
megdnn_assert(layout.ndim == 0 || has_dim_unity_stride, | |||||
"innermost dim not contiguous"); | |||||
} | } | ||||
template <size_t SIZE_NBITS> | |||||
void LowbitsTensorFormatBase<SIZE_NBITS>::serialize_append( | |||||
void LowbitsAlignedTensorFormatBase::serialize_append( | |||||
std::string& result) const { | std::string& result) const { | ||||
SerializePack pack; | SerializePack pack; | ||||
pack.size_nbits = m_size_nbits; | |||||
pack.align_size_in_bits = m_align_size_in_bits; | pack.align_size_in_bits = m_align_size_in_bits; | ||||
megdnn_assert(pack.align_size_in_bits == | megdnn_assert(pack.align_size_in_bits == | ||||
m_align_size_in_bits); // detect overflow; | m_align_size_in_bits); // detect overflow; | ||||
result.append(reinterpret_cast<char*>(&pack), sizeof(pack)); | result.append(reinterpret_cast<char*>(&pack), sizeof(pack)); | ||||
} | } | ||||
template <size_t SIZE_NBITS> | |||||
TensorLayout::Span LowbitsTensorFormatBase<SIZE_NBITS>::span_spec( | |||||
TensorLayout::Span LowbitsAlignedTensorFormatBase::span_spec( | |||||
const TensorLayout& layout) const { | const TensorLayout& layout) const { | ||||
assert_valid(layout); | assert_valid(layout); | ||||
if (layout.ndim == 0) | if (layout.ndim == 0) | ||||
@@ -507,8 +528,7 @@ TensorLayout::Span LowbitsTensorFormatBase<SIZE_NBITS>::span_spec( | |||||
return TensorLayout::Span(0, 0, high_elem, high_byte); | return TensorLayout::Span(0, 0, high_elem, high_byte); | ||||
} | } | ||||
template <size_t SIZE_NBITS> | |||||
size_t LowbitsTensorFormatBase<SIZE_NBITS>::init_contiguous_stride( | |||||
size_t LowbitsAlignedTensorFormatBase::init_contiguous_stride( | |||||
TensorLayout& layout) const { | TensorLayout& layout) const { | ||||
if (!layout.ndim) | if (!layout.ndim) | ||||
return 0; | return 0; | ||||
@@ -525,8 +545,7 @@ size_t LowbitsTensorFormatBase<SIZE_NBITS>::init_contiguous_stride( | |||||
return accum; | return accum; | ||||
} | } | ||||
template <size_t SIZE_NBITS> | |||||
bool LowbitsTensorFormatBase<SIZE_NBITS>::is_contiguous_spec( | |||||
bool LowbitsAlignedTensorFormatBase::is_contiguous_spec( | |||||
const TensorLayout& layout) const { | const TensorLayout& layout) const { | ||||
assert_valid(layout); | assert_valid(layout); | ||||
ptrdiff_t expected = 1; | ptrdiff_t expected = 1; | ||||
@@ -541,8 +560,7 @@ bool LowbitsTensorFormatBase<SIZE_NBITS>::is_contiguous_spec( | |||||
return expected != 0; | return expected != 0; | ||||
} | } | ||||
template <size_t SIZE_NBITS> | |||||
TensorLayout LowbitsTensorFormatBase<SIZE_NBITS>::collapse_contiguous_spec( | |||||
TensorLayout LowbitsAlignedTensorFormatBase::collapse_contiguous_spec( | |||||
const TensorLayout& layout) const { | const TensorLayout& layout) const { | ||||
assert_valid(layout); | assert_valid(layout); | ||||
TensorLayout res{layout}; | TensorLayout res{layout}; | ||||
@@ -572,12 +590,6 @@ TensorLayout LowbitsTensorFormatBase<SIZE_NBITS>::collapse_contiguous_spec( | |||||
return res; | return res; | ||||
} | } | ||||
namespace megdnn { | |||||
namespace detail { | |||||
template class LowbitsTensorFormatBase<4>; | |||||
} // namespace detail | |||||
} // namespace megdnn | |||||
/* ===================== Image2DPack4TensorFormat ===================== */ | /* ===================== Image2DPack4TensorFormat ===================== */ | ||||
TensorFormat Image2DPack4TensorFormat::make_raw( | TensorFormat Image2DPack4TensorFormat::make_raw( | ||||
size_t align_axis, size_t align_size_in_elements, | size_t align_axis, size_t align_size_in_elements, | ||||
@@ -616,29 +628,28 @@ TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const { | |||||
return make_raw(axis, align_size_in_elements(), vendor()); | return make_raw(axis, align_size_in_elements(), vendor()); | ||||
} | } | ||||
/* ===================== FourBitsAlignedToBytesTensorFormat | |||||
/* ===================== LowbitsitsAlignedToBytesTensorFormat | |||||
* ===================== */ | * ===================== */ | ||||
TensorFormat FourBitsAlignedToBytesTensorFormat::make( | |||||
size_t align_size_in_bits) { | |||||
TensorFormat LowbitsAlignedToBytesTensorFormat::make(size_t size_nbits) { | |||||
static std::mutex mtx; | static std::mutex mtx; | ||||
static std::unordered_map< | static std::unordered_map< | ||||
uint32_t, std::unique_ptr<FourBitsAlignedToBytesTensorFormat>> | |||||
uint64_t, std::unique_ptr<LowbitsAlignedToBytesTensorFormat>> | |||||
cache; | cache; | ||||
megdnn_assert(!(align_size_in_bits % 4)); | |||||
megdnn_assert(!(8 % size_nbits)); | |||||
MEGDNN_LOCK_GUARD(mtx); | MEGDNN_LOCK_GUARD(mtx); | ||||
auto&& ptr = cache[static_cast<uint32_t>(align_size_in_bits)]; | |||||
auto&& ptr = cache[static_cast<uint32_t>(size_nbits)]; | |||||
if (!ptr) { | if (!ptr) { | ||||
ptr.reset(new FourBitsAlignedToBytesTensorFormat{align_size_in_bits}); | |||||
ptr.reset(new LowbitsAlignedToBytesTensorFormat{size_nbits}); | |||||
} | } | ||||
return impl_to_tensor_format(ptr.get()); | return impl_to_tensor_format(ptr.get()); | ||||
} | } | ||||
TensorFormat FourBitsAlignedToBytesTensorFormat::deserialize(const Handle*, | |||||
const void* buf, | |||||
size_t size) { | |||||
TensorFormat LowbitsAlignedToBytesTensorFormat::deserialize(const Handle*, | |||||
const void* buf, | |||||
size_t size) { | |||||
megdnn_assert(size == sizeof(SerializePack)); | megdnn_assert(size == sizeof(SerializePack)); | ||||
auto pack = *static_cast<const SerializePack*>(buf); | auto pack = *static_cast<const SerializePack*>(buf); | ||||
return make(pack.align_size_in_bits); | |||||
return make(pack.size_nbits); | |||||
} | } | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -24,6 +24,9 @@ using namespace conv_bias; | |||||
bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | ||||
const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) | |||||
return false; | |||||
if (args.src_layout->dtype == args.filter_layout->dtype && | if (args.src_layout->dtype == args.filter_layout->dtype && | ||||
args.src_layout->dtype == dtype::BFloat16()) { | args.src_layout->dtype == dtype::BFloat16()) { | ||||
return false; | return false; | ||||
@@ -103,15 +103,18 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec( | |||||
TensorND src_{ws_src, layouts[0]}, filter_{ws_filter, layouts[1]}, | TensorND src_{ws_src, layouts[0]}, filter_{ws_filter, layouts[1]}, | ||||
bias_{args.bias_tensor->raw_ptr, layouts[2]}, z_{ws_z, layouts[3]}, | bias_{args.bias_tensor->raw_ptr, layouts[2]}, z_{ws_z, layouts[3]}, | ||||
dst_{ws_dst, layouts[4]}; | dst_{ws_dst, layouts[4]}; | ||||
ExecArgs args_{args.opr, | |||||
auto conv_op = args.opr->handle()->create_operator<ConvBiasForward>(); | |||||
conv_op->param() = args.opr->param(); | |||||
using Format = param::ConvBias::Format; | |||||
conv_op->param().format = Format::NCHW64; | |||||
ExecArgs args_{dynamic_cast<ConvBiasForwardImpl*>(conv_op.get()), | |||||
src_, | src_, | ||||
filter_, | filter_, | ||||
bias_, | bias_, | ||||
z_, | z_, | ||||
dst_, | dst_, | ||||
ws.get_workspace(3), | |||||
args.preprocessed_filter}; | |||||
m_underlying_algo.exec(args); | |||||
ws.get_workspace(3)}; | |||||
m_underlying_algo.exec(args_); | |||||
// reformat dst | // reformat dst | ||||
nchw642nchw(dst_, {args.dst_tensor->raw_ptr, args.dst_tensor->layout}); | nchw642nchw(dst_, {args.dst_tensor->raw_ptr, args.dst_tensor->layout}); | ||||
} | } | ||||
@@ -134,6 +137,9 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout( | |||||
rst.emplace_back(TensorLayout{}); | rst.emplace_back(TensorLayout{}); | ||||
} | } | ||||
rst.emplace_back(TensorLayout{{n, co / 64, ho, wo, 64}, dst.dtype}); | rst.emplace_back(TensorLayout{{n, co / 64, ho, wo, 64}, dst.dtype}); | ||||
for (auto& i : rst) { | |||||
i.init_contiguous_stride(); | |||||
} | |||||
return rst; | return rst; | ||||
} | } | ||||
@@ -145,13 +151,16 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_bundle( | |||||
auto layouts = make_underlying_tensor_layout( | auto layouts = make_underlying_tensor_layout( | ||||
*(args.src_layout), *(args.filter_layout), *(args.bias_layout), | *(args.src_layout), *(args.filter_layout), *(args.bias_layout), | ||||
*(args.z_layout), *(args.dst_layout)); | *(args.z_layout), *(args.dst_layout)); | ||||
SizeArgs args_{args.opr, | |||||
auto conv_op = args.opr->handle()->create_operator<ConvBiasForward>(); | |||||
conv_op->param() = args.opr->param(); | |||||
using Format = param::ConvBias::Format; | |||||
conv_op->param().format = Format::NCHW64; | |||||
SizeArgs args_{dynamic_cast<ConvBiasForwardImpl*>(conv_op.get()), | |||||
layouts[0], | layouts[0], | ||||
layouts[1], | layouts[1], | ||||
layouts[2], | layouts[2], | ||||
layouts[3], | layouts[3], | ||||
layouts[4], | |||||
args.preprocessed_filter}; | |||||
layouts[4]}; | |||||
size_t ws_size_underlying_algo = | size_t ws_size_underlying_algo = | ||||
m_underlying_algo.get_workspace_in_bytes(args_); | m_underlying_algo.get_workspace_in_bytes(args_); | ||||
if (args.z_layout->ndim > 0) { | if (args.z_layout->ndim > 0) { | ||||
@@ -136,6 +136,10 @@ void ConvBiasDesc::set_conv(DType data_type, const param::ConvBias& param, | |||||
namespace conv_bias { | namespace conv_bias { | ||||
bool is_cudnn_supported(const BiasForwardSizeArgs& args) { | bool is_cudnn_supported(const BiasForwardSizeArgs& args) { | ||||
if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) | |||||
return false; | |||||
if (args.src_layout->dtype == args.filter_layout->dtype && | if (args.src_layout->dtype == args.filter_layout->dtype && | ||||
args.src_layout->dtype == dtype::BFloat16()) { | args.src_layout->dtype == dtype::BFloat16()) { | ||||
return false; | return false; | ||||
@@ -72,11 +72,11 @@ std::string ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::kernel_key( | |||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
if (args.z_layout->ndim > 0) { | if (args.z_layout->ndim > 0) { | ||||
kernel_key = | kernel_key = | ||||
ssprintf("%s_conv_bias_int4_fuse_z_imma_ldg16_%ux%u", | |||||
ssprintf("%s_conv_bias_int4_fuse_z_imma8832_ldg16_%ux%u", | |||||
current_device_arch_name(), m_tile_nhw, m_tile_oc); | current_device_arch_name(), m_tile_nhw, m_tile_oc); | ||||
} else { | } else { | ||||
kernel_key = | kernel_key = | ||||
ssprintf("%s_conv_bias_int4_imma_ldg16_%ux%u", | |||||
ssprintf("%s_conv_bias_int4_imma8832_ldg16_%ux%u", | |||||
current_device_arch_name(), m_tile_nhw, m_tile_oc); | current_device_arch_name(), m_tile_nhw, m_tile_oc); | ||||
} | } | ||||
if (param.nonlineMode == NonlineMode::H_SWISH) { | if (param.nonlineMode == NonlineMode::H_SWISH) { | ||||
@@ -170,7 +170,7 @@ void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec( | |||||
reorder_imma_filter_bias<4, 64>( | reorder_imma_filter_bias<4, 64>( | ||||
reinterpret_cast<int8_t*>(filter_ptr), | reinterpret_cast<int8_t*>(filter_ptr), | ||||
reinterpret_cast<int32_t*>(bias_ptr), | reinterpret_cast<int32_t*>(bias_ptr), | ||||
args.filter_tensor->compatible_ptr<int8_t>(), | |||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), | |||||
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw, | args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw, | ||||
stream); | stream); | ||||
} | } | ||||
@@ -292,9 +292,10 @@ void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec_preprocess( | |||||
param); | param); | ||||
auto&& stream = cuda_stream(args.opr->handle()); | auto&& stream = cuda_stream(args.opr->handle()); | ||||
reorder_imma_filter_bias<4, 64>( | reorder_imma_filter_bias<4, 64>( | ||||
args.preprocessed_filter->tensors[0].compatible_ptr<int8_t>(), | |||||
reinterpret_cast<int8_t*>( | |||||
args.preprocessed_filter->tensors[0].raw_ptr), | |||||
args.preprocessed_filter->tensors[1].compatible_ptr<int32_t>(), | args.preprocessed_filter->tensors[1].compatible_ptr<int32_t>(), | ||||
args.filter_tensor->compatible_ptr<int8_t>(), | |||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), | |||||
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw, | args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw, | ||||
stream); | stream); | ||||
} | } | ||||
@@ -320,7 +320,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS) { | |||||
layout = make_layout({16, 32, 7, 7}, {1792, 56, 8, 1}, | layout = make_layout({16, 32, 7, 7}, {1792, 56, 8, 1}, | ||||
dtype::QuantizedS4{1.3f}); | dtype::QuantizedS4{1.3f}); | ||||
layout.format = FourBitsAlignedToBytesTensorFormat::make(8_z); | |||||
layout.format = LowbitsAlignedToBytesTensorFormat::make(4_z); | |||||
EXPECT_TRUE(layout.is_contiguous()); | EXPECT_TRUE(layout.is_contiguous()); | ||||
layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}}; | layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}}; | ||||
@@ -339,12 +339,10 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS_VALID) { | |||||
DefaultTensorFormat::make()), | DefaultTensorFormat::make()), | ||||
MegDNNError); | MegDNNError); | ||||
ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS32{1.2f}, | ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS32{1.2f}, | ||||
FourBitsAlignedToBytesTensorFormat::make(8_z)) | |||||
.span(), | |||||
LowbitsAlignedToBytesTensorFormat::make(4_z)), | |||||
MegDNNError); | MegDNNError); | ||||
ASSERT_THROW(TensorLayout({16, 32, 7, 7}, dtype::IntB2{}, | ASSERT_THROW(TensorLayout({16, 32, 7, 7}, dtype::IntB2{}, | ||||
FourBitsAlignedToBytesTensorFormat::make(8_z)) | |||||
.span(), | |||||
LowbitsAlignedToBytesTensorFormat::make(2_z)), | |||||
MegDNNError); | MegDNNError); | ||||
} | } | ||||
@@ -338,21 +338,26 @@ void OperatorNodeBase::init_output_format() { | |||||
TensorFormat format, default_; | TensorFormat format, default_; | ||||
for (auto i : input()) { | for (auto i : input()) { | ||||
auto cur = i->format(); | auto cur = i->format(); | ||||
if (cur != default_) { | |||||
if (!cur.is_default() && !cur.is_lowbit_aligned()) { | |||||
if (format == default_) { | if (format == default_) { | ||||
format = cur; | format = cur; | ||||
} else { | } else { | ||||
mgb_assert(format == cur, | mgb_assert(format == cur, | ||||
"multiple non-default formats in inputs: %s vs %s", | |||||
"multiple non-default or non-lowbits aligned " | |||||
"formats in inputs: %s vs %s", | |||||
format.to_string().c_str(), cur.to_string().c_str()); | format.to_string().c_str(), cur.to_string().c_str()); | ||||
} | } | ||||
} | } | ||||
} | } | ||||
for (auto i : output()) { | for (auto i : output()) { | ||||
if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | ||||
i->format(default_); | |||||
mgb_assert(format.is_default()); | |||||
i->format(TensorFormat(i->dtype())); | |||||
} else { | } else { | ||||
i->format(format); | |||||
if (!format.is_default()) | |||||
i->format(format); | |||||
else | |||||
i->format(TensorFormat(i->dtype())); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -1063,15 +1063,22 @@ bool VarNodeMemManager::fwd_in2out_readonly( | |||||
return false; | return false; | ||||
} | } | ||||
mgb_assert( | |||||
src != dest && | |||||
src->comp_node().mem_node() == dest->comp_node().mem_node() && | |||||
dest->m_mem_plan.valid() && src->m_mem_plan.valid() && | |||||
dest->m_mem_plan.layout().eq_shape(sub.layout()) && | |||||
dest->m_mem_plan.layout().dtype.size() == sub.layout().dtype.size() | |||||
); | |||||
assert_in_mem_opt_phase( | |||||
SeqMemOptimizer::Status::ALLOW_FWD_IN2OUT_READONLY); | |||||
bool cond_low_bit = dest->m_mem_plan.layout().dtype.is_low_bit() && | |||||
sub.layout().dtype.is_low_bit() && | |||||
dest->m_mem_plan.layout().dtype.low_bit() == | |||||
sub.layout().dtype.low_bit(); | |||||
bool cond_normal = | |||||
!dest->m_mem_plan.layout().dtype.is_low_bit() && | |||||
!sub.layout().dtype.is_low_bit() && | |||||
dest->m_mem_plan.layout().dtype.size() == sub.layout().dtype.size(); | |||||
MGB_MARK_USED_VAR(cond_low_bit); | |||||
MGB_MARK_USED_VAR(cond_normal); | |||||
mgb_assert(src != dest && | |||||
src->comp_node().mem_node() == dest->comp_node().mem_node() && | |||||
dest->m_mem_plan.valid() && src->m_mem_plan.valid() && | |||||
dest->m_mem_plan.layout().eq_shape(sub.layout()) && | |||||
(cond_normal || cond_low_bit)); | |||||
assert_in_mem_opt_phase(SeqMemOptimizer::Status::ALLOW_FWD_IN2OUT_READONLY); | |||||
if (!m_owner_graph->options().seq_opt.enable_mem_plan_opt) | if (!m_owner_graph->options().seq_opt.enable_mem_plan_opt) | ||||
return false; | return false; | ||||
@@ -443,8 +443,8 @@ TensorND<TensorStorage>::name | |||||
DEF(resize, &)(const TensorShape& shape) { | DEF(resize, &)(const TensorShape& shape) { | ||||
mgb_assert(m_layout.dtype.valid()); | mgb_assert(m_layout.dtype.valid()); | ||||
auto nr_elems = m_layout.init_contiguous_stride(shape); | |||||
m_storage.ensure_size(m_layout.dtype.size(nr_elems)); | |||||
m_layout = TensorLayout(shape, m_layout.dtype); | |||||
m_storage.ensure_size(m_layout.span().dist_byte()); | |||||
return static_cast<ChainReturnType&>(*this); | return static_cast<ChainReturnType&>(*this); | ||||
} | } | ||||
@@ -584,15 +584,19 @@ TensorND<TensorStorage>::copy_from(const TensorND<RStorage> &src) { | |||||
m_layout.dtype.assert_is(src.dtype()); | m_layout.dtype.assert_is(src.dtype()); | ||||
else | else | ||||
m_layout.dtype = src.dtype(); | m_layout.dtype = src.dtype(); | ||||
m_layout.format = {}; | |||||
size_t size_bytes = dtype().size( | |||||
m_layout.init_contiguous_stride(src.shape())); | |||||
m_layout = TensorLayout(src.shape(), m_layout.dtype); | |||||
size_t size_bytes = m_layout.span().dist_byte(); | |||||
m_storage.ensure_size(size_bytes); | m_storage.ensure_size(size_bytes); | ||||
if (!size_bytes) { | if (!size_bytes) { | ||||
return static_cast<ChainReturnType&>(*this); | return static_cast<ChainReturnType&>(*this); | ||||
} | } | ||||
if (src.layout().is_physical_contiguous()) { | |||||
// requirement: | |||||
// default case, physical contiguous | |||||
// lowbit aligned, logical contiguous | |||||
if (src.layout().is_physical_contiguous() || | |||||
(src.layout().format.is_lowbit_aligned() && | |||||
src.layout().is_contiguous())) { | |||||
if (should_check_overlap(*this, src)) { | if (should_check_overlap(*this, src)) { | ||||
check_overlapped(m_storage.ptr(), | check_overlapped(m_storage.ptr(), | ||||
m_storage.ptr() + size_bytes, | m_storage.ptr() + size_bytes, | ||||
@@ -635,10 +639,17 @@ TensorND<TensorStorage>::copy_from_fixlayout( | |||||
src.raw_ptr() + src_span.high_byte); | src.raw_ptr() + src_span.high_byte); | ||||
} | } | ||||
bool self_contig = m_layout.is_physical_contiguous(), | |||||
src_contig = src.layout().is_physical_contiguous(); | |||||
bool self_contig = m_layout.is_physical_contiguous() || | |||||
(m_layout.format.is_lowbit_aligned() && | |||||
m_layout.is_contiguous()), | |||||
src_contig = src.layout().is_physical_contiguous() || | |||||
(m_layout.format.is_lowbit_aligned() && | |||||
m_layout.is_contiguous()); | |||||
if (self_contig && src_contig) { | if (self_contig && src_contig) { | ||||
if (m_layout.format.is_default() && src.layout().format.is_default()) { | |||||
if ((m_layout.format.is_default() && | |||||
src.layout().format.is_default()) || | |||||
(m_layout.format.is_lowbit_aligned() && | |||||
src.layout().format.is_lowbit_aligned())) { | |||||
mgb_assert(src_span.low_byte == 0 && dst_span.low_byte == 0 && | mgb_assert(src_span.low_byte == 0 && dst_span.low_byte == 0 && | ||||
src_span.high_byte == dst_span.high_byte); | src_span.high_byte == dst_span.high_byte); | ||||
m_storage.copy_from(src.storage(), src_span.high_byte); | m_storage.copy_from(src.storage(), src_span.high_byte); | ||||
@@ -261,7 +261,8 @@ PersistentCache::Blob AlgoChooserProfileCache::Key::build_blob() const { | |||||
ret.push_back(';'); | ret.push_back(';'); | ||||
ret.append(ly.dtype.name()); | ret.append(ly.dtype.name()); | ||||
ret.push_back('|'); | ret.push_back('|'); | ||||
mgb_assert(ly.format.is_default(), | |||||
mgb_assert(ly.format.is_default() || (ly.format.is_lowbit_aligned() && | |||||
ly.dtype.is_low_bit()), | |||||
"currently only default format is supported"); | "currently only default format is supported"); | ||||
} | } | ||||
if (m_param_size) { | if (m_param_size) { | ||||
@@ -68,7 +68,10 @@ class SubTensorSpec { | |||||
//! get offset measured in bytes | //! get offset measured in bytes | ||||
ptrdiff_t offset_byte() const { | ptrdiff_t offset_byte() const { | ||||
return m_offset_elem * m_layout.dtype.size(); | |||||
//! for lowbit cases, offset must aligned to bytes | |||||
mgb_assert(!m_layout.dtype.is_low_bit() || | |||||
!(m_offset_elem * m_layout.dtype.low_bit() % 8)); | |||||
return m_layout.dtype.size(m_offset_elem); | |||||
} | } | ||||
/*! | /*! | ||||
@@ -554,14 +554,16 @@ void ParamFusePass::apply(OptState &state) const { | |||||
SymbolVar new_var; | SymbolVar new_var; | ||||
bool is_default_format = var->format().is_default(); | bool is_default_format = var->format().is_default(); | ||||
if (cg::is_static_var_value(var) && is_default_format) { | |||||
bool is_lowbit_aligned = var->format().is_lowbit_aligned(); | |||||
if (cg::is_static_var_value(var) && | |||||
(is_default_format || is_lowbit_aligned)) { | |||||
// use ImmutableTensor for inferable vars | // use ImmutableTensor for inferable vars | ||||
HostTensorND hv; | HostTensorND hv; | ||||
hv.copy_from(*inferred_val).sync(); | hv.copy_from(*inferred_val).sync(); | ||||
new_var = opr::ImmutableTensor::make( | new_var = opr::ImmutableTensor::make( | ||||
*var->owner_graph(), hv, var_namer.name(var)); | *var->owner_graph(), hv, var_namer.name(var)); | ||||
} else { | } else { | ||||
if (is_default_format) { | |||||
if (is_default_format || is_lowbit_aligned) { | |||||
new_var = opr::SharedDeviceTensor::make_const( | new_var = opr::SharedDeviceTensor::make_const( | ||||
*var->owner_graph(), inferred_val, var_namer.name(var)); | *var->owner_graph(), inferred_val, var_namer.name(var)); | ||||
} else { | } else { | ||||
@@ -814,8 +814,13 @@ MGB_IMPL_OPR_GRAD(TypeCvt) { | |||||
#endif | #endif | ||||
void TypeCvt::mem_plan_fwd_in2out_writable() { | void TypeCvt::mem_plan_fwd_in2out_writable() { | ||||
if (input(0)->dtype().size() == output(0)->dtype().size() && | |||||
input(0)->layout().is_contiguous()) { | |||||
bool cond_low_bit = | |||||
input(0)->dtype().is_low_bit() && output(0)->dtype().is_low_bit() && | |||||
input(0)->dtype().low_bit() == output(0)->dtype().low_bit(); | |||||
bool cond_normal = !input(0)->dtype().is_low_bit() && | |||||
!output(0)->dtype().is_low_bit() && | |||||
input(0)->dtype().size() == output(0)->dtype().size(); | |||||
if ((cond_low_bit || cond_normal) && input(0)->layout().is_contiguous()) { | |||||
output(0)->set_fwd_in2out_writable(input(0)); | output(0)->set_fwd_in2out_writable(input(0)); | ||||
} | } | ||||
} | } | ||||
@@ -120,12 +120,11 @@ public: | |||||
explicit DevValueExecDep(DeviceTensorStorage val) : m_val{std::move(val)} {} | explicit DevValueExecDep(DeviceTensorStorage val) : m_val{std::move(val)} {} | ||||
}; | }; | ||||
void intl::DeviceTensorHolder::init_output_format() { | void intl::DeviceTensorHolder::init_output_format() { | ||||
auto format = get_dev_tensor().format(); | auto format = get_dev_tensor().format(); | ||||
mgb_assert(format.is_default(), "non-default tensor format: %s", | |||||
format.to_string().c_str()); | |||||
// no need to set output foramt since it is initialized as default | |||||
mgb_assert(format.is_default() || format.is_lowbit_aligned(), | |||||
"invalid tensor format: %s", format.to_string().c_str()); | |||||
output(0)->format(format); | |||||
} | } | ||||
void intl::DeviceTensorHolder::init_output_mem_plan(bool dynamic) { | void intl::DeviceTensorHolder::init_output_mem_plan(bool dynamic) { | ||||
@@ -638,10 +638,18 @@ AlgoChooser<Opr>::AlgoChooserHelper::profile_single_algo( | |||||
param.workspace = get_workspace_size_bytes(policy); | param.workspace = get_workspace_size_bytes(policy); | ||||
for (int i = 0; i < arity; ++i) { | for (int i = 0; i < arity; ++i) { | ||||
auto&& src = m_layouts[i]; | auto&& src = m_layouts[i]; | ||||
mgb_assert(src.format.is_default() && | |||||
bool cond_normal = src.format.is_default() && | |||||
(src.dtype.category() == DTypeCategory::FLOAT || | (src.dtype.category() == DTypeCategory::FLOAT || | ||||
src.dtype.category() == DTypeCategory::INT || | src.dtype.category() == DTypeCategory::INT || | ||||
src.dtype.category() == DTypeCategory::QUANTIZED), | |||||
src.dtype.category() == DTypeCategory::QUANTIZED); | |||||
bool cond_low_bit = src.dtype.is_low_bit() && | |||||
src.format.is_lowbit_aligned() && | |||||
(src.dtype.category() == DTypeCategory::QUANTIZED || | |||||
src.dtype.category() == DTypeCategory::LOWBIT); | |||||
MGB_MARK_USED_VAR(cond_normal); | |||||
MGB_MARK_USED_VAR(cond_low_bit); | |||||
mgb_assert(cond_normal || cond_low_bit, | |||||
"unsupported layout in profiling: %s", | "unsupported layout in profiling: %s", | ||||
src.to_string().c_str()); | src.to_string().c_str()); | ||||
param.dtypes[i] = src.dtype.enumv(); | param.dtypes[i] = src.dtype.enumv(); | ||||
@@ -175,15 +175,17 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl( | |||||
case DTypeTrait<_dt>::enumv: \ | case DTypeTrait<_dt>::enumv: \ | ||||
return _dt(1.0f, static_cast<uint8_t>(0)) | return _dt(1.0f, static_cast<uint8_t>(0)) | ||||
cb(dtype::Quantized8Asymm); | cb(dtype::Quantized8Asymm); | ||||
cb(dtype::Quantized4Asymm); | |||||
#undef cb | #undef cb | ||||
#define cb(_dt) \ | #define cb(_dt) \ | ||||
case DTypeTrait<_dt>::enumv: \ | case DTypeTrait<_dt>::enumv: \ | ||||
return _dt(1.0f) | return _dt(1.0f) | ||||
cb(dtype::QuantizedS8); | cb(dtype::QuantizedS8); | ||||
cb(dtype::QuantizedS16); | cb(dtype::QuantizedS16); | ||||
cb(dtype::QuantizedS32); | cb(dtype::QuantizedS32); | ||||
cb(dtype::QuantizedS4); | |||||
default: | default: | ||||
return DType::from_enum(enumv); | return DType::from_enum(enumv); | ||||
#undef cb | #undef cb | ||||
@@ -2603,4 +2603,306 @@ TEST_F(TestNoWeightPreprocess, NoPreprocess) { | |||||
#endif | #endif | ||||
namespace { | |||||
// FIXME change comp node from "cpu0" to "gpu0" | |||||
TEST(TestOprDNN, ConvBiasInt4NCHW) { | |||||
auto run = [](size_t N, size_t C, size_t H, size_t W, size_t F, size_t S, | |||||
size_t P) { | |||||
auto cn = CompNode::load("cpu0"); | |||||
auto graph = ComputingGraph::make(); | |||||
HostTensorGenerator<dtype::Int8> gen; | |||||
auto mkvar = [&gen](const char* name, const TensorShape& shp, | |||||
const DType& dtype, | |||||
std::shared_ptr<ComputingGraph> graph, | |||||
const CompNode& cn) { | |||||
return opr::TypeCvt::make( | |||||
opr::Host2DeviceCopy::make(*graph, gen(shp, cn)) | |||||
.rename(name), | |||||
dtype); | |||||
}; | |||||
auto mkcvar = [&gen](const char* name, const TensorShape& shp, | |||||
const DType& dtype, | |||||
std::shared_ptr<ComputingGraph> graph, | |||||
const CompNode& cn) { | |||||
return opr::TypeCvt::make( | |||||
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name), | |||||
dtype); | |||||
}; | |||||
using Policy = opr::ConvBias::ExecutionPolicy; | |||||
using Strategy = Policy::Strategy; | |||||
auto x = mkvar("x", {N, C * 4, H, W}, dtype::QuantizedS4(1.19960327f), | |||||
graph, cn), | |||||
w = mkcvar("w1", {C, C * 4, F, F}, dtype::QuantizedS4(1.19970327f), | |||||
graph, cn), | |||||
b = mkcvar("b1", {1, C, 1, 1}, | |||||
dtype::QuantizedS32(1.19960327f * 1.19970327f), graph, | |||||
cn); | |||||
opr::ConvBias::Param param; | |||||
param.format = opr::ConvBias::Param::Format::NCHW; | |||||
param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; | |||||
param.stride_h = param.stride_w = S; | |||||
param.pad_h = param.pad_w = P; | |||||
Policy policy; | |||||
policy.strategy = Strategy::PROFILE; | |||||
auto y = opr::ConvBias::make( | |||||
x, w, b, param, policy, | |||||
OperatorNodeConfig{dtype::QuantizedS4(11.9960501f)}); | |||||
y = opr::TypeCvt::make(y, dtype::Float32()); | |||||
auto x_f32 = opr::TypeCvt::make(x, dtype::Float32()), | |||||
w_f32 = opr::TypeCvt::make(w, dtype::Float32()), | |||||
b_f32 = opr::TypeCvt::make(b, dtype::Float32()); | |||||
auto y_f32 = opr::ConvBias::make(x_f32, w_f32, b_f32, param, policy); | |||||
auto y_q4 = opr::TypeCvt::make(y_f32, dtype::QuantizedS4{11.9960501f}); | |||||
y_q4 = opr::TypeCvt::make(y_q4, dtype::Float32()); | |||||
HostTensorND host_y, host_y_q4; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | |||||
make_callback_copy(y_q4, host_y_q4)}); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_q4, 1e-3); | |||||
}; | |||||
run(2, 64, 14, 14, 3, 2, 1); | |||||
run(2, 64, 7, 7, 3, 1, 1); | |||||
run(2, 64, 14, 14, 1, 2, 0); | |||||
run(2, 64, 7, 7, 1, 1, 0); | |||||
} | |||||
TEST(TestOprDNN, ConvBiasInt4NCHW64) { | |||||
auto nchw2nchw64 = [](SymbolVar x) { | |||||
auto y = opr::RelayoutFormat::make( | |||||
x, opr::RelayoutFormat::Param::Mode::NCHW_NCHW64); | |||||
return y; | |||||
}; | |||||
auto nchw642nchw = [](SymbolVar x) { | |||||
auto y = opr::RelayoutFormat::make( | |||||
x, opr::RelayoutFormat::Param::Mode::NCHW64_NCHW); | |||||
return y; | |||||
}; | |||||
auto run = [&](size_t N, size_t C, size_t H, size_t W, size_t F, size_t S, | |||||
size_t P) { | |||||
auto cn = CompNode::load("cpu0"); | |||||
auto graph = ComputingGraph::make(); | |||||
HostTensorGenerator<dtype::Int8> gen; | |||||
auto mkvar = [&gen](const char* name, const TensorShape& shp, | |||||
const DType& dtype, | |||||
std::shared_ptr<ComputingGraph> graph, | |||||
const CompNode& cn) { | |||||
return opr::TypeCvt::make( | |||||
opr::Host2DeviceCopy::make(*graph, gen(shp, cn)) | |||||
.rename(name), | |||||
dtype); | |||||
}; | |||||
auto mkcvar = [&gen](const char* name, const TensorShape& shp, | |||||
const DType& dtype, | |||||
std::shared_ptr<ComputingGraph> graph, | |||||
const CompNode& cn) { | |||||
return opr::TypeCvt::make( | |||||
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name), | |||||
dtype); | |||||
}; | |||||
using Policy = opr::ConvBias::ExecutionPolicy; | |||||
using Strategy = Policy::Strategy; | |||||
auto x = mkvar("x", {N, C / 16, H, W, 64}, | |||||
dtype::QuantizedS4(1.19960327f), graph, cn), | |||||
w = mkcvar("w1", {C, C / 16, F, F, 64}, | |||||
dtype::QuantizedS4(1.19970327f), graph, cn), | |||||
b = mkcvar("b1", {1, C / 64, 1, 1, 64}, | |||||
dtype::QuantizedS32(1.19960327f * 1.19970327f), graph, | |||||
cn); | |||||
opr::ConvBias::Param param; | |||||
param.format = opr::ConvBias::Param::Format::NCHW64; | |||||
param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; | |||||
param.stride_h = param.stride_w = S; | |||||
param.pad_h = param.pad_w = P; | |||||
Policy policy; | |||||
policy.strategy = Strategy::PROFILE; | |||||
auto y = opr::ConvBias::make( | |||||
x, w, b, param, policy, | |||||
OperatorNodeConfig{dtype::QuantizedS4(11.9960501f)}); | |||||
y = opr::TypeCvt::make(y, dtype::Float32()); | |||||
x = nchw642nchw(x); | |||||
w = nchw642nchw(w); | |||||
b = nchw642nchw(b); | |||||
auto x_f32 = opr::TypeCvt::make(x, dtype::Float32()), | |||||
w_f32 = opr::TypeCvt::make(w, dtype::Float32()), | |||||
b_f32 = opr::TypeCvt::make(b, dtype::Float32()); | |||||
param.format = opr::ConvBias::Param::Format::NCHW; | |||||
auto y_f32 = opr::ConvBias::make(x_f32, w_f32, b_f32, param, policy); | |||||
auto y_q4 = opr::TypeCvt::make(y_f32, dtype::QuantizedS4{11.9960501f}); | |||||
y_q4 = opr::TypeCvt::make(y_q4, dtype::Float32()); | |||||
y_q4 = nchw2nchw64(y_q4); | |||||
HostTensorND host_y, host_y_q4; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | |||||
make_callback_copy(y_q4, host_y_q4)}); | |||||
func->execute(); | |||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_q4, 1e-3); | |||||
}; | |||||
run(2, 64, 14, 14, 3, 2, 1); | |||||
run(2, 64, 7, 7, 3, 1, 1); | |||||
run(2, 64, 14, 14, 1, 2, 0); | |||||
run(2, 64, 7, 7, 1, 1, 0); | |||||
} | |||||
TEST(TestOprDNN, ConvBiasInt4Serialize) { | |||||
using namespace serialization; | |||||
float inp_scale = 1.20210327f; | |||||
float filt_scale = 1.20210406f; | |||||
float bias_scale = inp_scale * filt_scale; | |||||
DType output_dtype = dtype::QuantizedS4{inp_scale}; | |||||
HostTensorGenerator<dtype::Int8> gen; | |||||
std::shared_ptr<HostTensorND> xv; | |||||
auto mkvar = [&gen](const char* name, const DType& dtype, | |||||
std::shared_ptr<ComputingGraph> graph, | |||||
std::shared_ptr<HostTensorND> val) { | |||||
return opr::TypeCvt::make( | |||||
opr::Host2DeviceCopy::make(*graph, val).rename(name), dtype); | |||||
}; | |||||
auto mkcvar = | |||||
[&gen](const char* name, const TensorShape& shp, const DType& dtype, | |||||
std::shared_ptr<ComputingGraph> graph, const CompNode& cn) { | |||||
return opr::TypeCvt::make( | |||||
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name), | |||||
dtype); | |||||
}; | |||||
auto fname = output_file("ConvBiasInt4Serialize"); | |||||
HostTensorND y1, y2; | |||||
auto dump = [&]() { | |||||
opr::ConvBias::Param param; | |||||
param.mode = Mode::CONVOLUTION; | |||||
auto cn = CompNode::load("cpu0"); | |||||
auto graph = ComputingGraph::make(); | |||||
xv = gen({1, 64, 56, 56}, cn); | |||||
auto x = mkvar("x", dtype::QuantizedS4{inp_scale}, graph, xv); | |||||
auto w = mkcvar("w", {256, 64, 1, 1}, dtype::QuantizedS4{filt_scale}, graph, cn); | |||||
auto b = mkcvar("b", {1, 256, 1, 1}, dtype::QuantizedS32{bias_scale}, graph, cn); | |||||
auto y = opr::ConvBiasForward::make(x, w, b, param, {}, | |||||
OperatorNodeConfig{output_dtype}); | |||||
auto w1 = mkcvar("w1", {64, 256, 1, 1}, dtype::QuantizedS4{filt_scale}, | |||||
graph, cn); | |||||
auto b1 = mkcvar("b1", {1, 64, 1, 1}, dtype::QuantizedS32{bias_scale}, | |||||
graph, cn); | |||||
y = opr::ConvBiasForward::make(y, w1, b1, param, {}, | |||||
OperatorNodeConfig{output_dtype}); | |||||
y = opr::TypeCvt::make(y, dtype::Float32()); | |||||
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str())); | |||||
auto func = graph->compile({make_callback_copy(y, y1)}); | |||||
func->execute(); | |||||
func->wait(); | |||||
auto rst = dumper->dump({y}); | |||||
ASSERT_EQ(rst.outputs.size(), 1u); | |||||
}; | |||||
auto load = [&]() { | |||||
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str())); | |||||
auto rst = loader->load(); | |||||
for (const auto& t : rst.tensor_map) { | |||||
t.second->copy_from(*xv).sync(); | |||||
} | |||||
auto func = rst.graph->compile( | |||||
{make_callback_copy(rst.output_var_list[0], y2)}); | |||||
func->execute(); | |||||
func->wait(); | |||||
ASSERT_EQ(rst.output_var_list.size(), 1u); | |||||
EXPECT_EQ(rst.output_var_list[0].dtype(), dtype::Float32()); | |||||
}; | |||||
dump(); | |||||
load(); | |||||
MGB_ASSERT_TENSOR_NEAR(y1, y2, 1e-3); | |||||
} | |||||
TEST(TestOprDNN, ConvBiasInt4SerializeWithParamFuse) { | |||||
using namespace serialization; | |||||
float inp_scale = 1.20210327f; | |||||
float filt_scale = 1.20210406f; | |||||
float bias_scale = inp_scale * filt_scale; | |||||
DType output_dtype = dtype::QuantizedS4{inp_scale}; | |||||
HostTensorGenerator<dtype::Int8> gen; | |||||
std::shared_ptr<HostTensorND> xv; | |||||
auto mkvar = [&gen](const char* name, const DType& dtype, | |||||
std::shared_ptr<ComputingGraph> graph, | |||||
std::shared_ptr<HostTensorND> val) { | |||||
return opr::TypeCvt::make( | |||||
opr::Host2DeviceCopy::make(*graph, val).rename(name), dtype); | |||||
}; | |||||
auto mkcvar = | |||||
[&gen](const char* name, const TensorShape& shp, const DType& dtype, | |||||
std::shared_ptr<ComputingGraph> graph, const CompNode& cn) { | |||||
return opr::TypeCvt::make( | |||||
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name), | |||||
dtype); | |||||
}; | |||||
auto fname = output_file("ConvBiasInt4SerializeWithParamFuse"); | |||||
HostTensorND y1, y2; | |||||
auto dump = [&]() { | |||||
opr::ConvBias::Param param; | |||||
param.mode = Mode::CONVOLUTION; | |||||
auto cn = CompNode::load("cpu0"); | |||||
auto graph = ComputingGraph::make(); | |||||
xv = gen({1, 64, 56, 56}, cn); | |||||
auto x = mkvar("x", dtype::QuantizedS4{inp_scale}, graph, xv); | |||||
auto w = mkcvar("w", {256, 64, 1, 1}, dtype::QuantizedS4{filt_scale}, graph, cn); | |||||
auto b = mkcvar("b", {1, 256, 1, 1}, dtype::QuantizedS32{bias_scale}, graph, cn); | |||||
auto y = opr::ConvBiasForward::make(x, w, b, param, {}, | |||||
OperatorNodeConfig{output_dtype}); | |||||
auto w1 = mkcvar("w1", {64, 256, 1, 1}, dtype::QuantizedS4{filt_scale}, | |||||
graph, cn); | |||||
auto b1 = mkcvar("b1", {1, 64, 1, 1}, dtype::QuantizedS32{bias_scale}, | |||||
graph, cn); | |||||
y = opr::ConvBiasForward::make(y, w1, b1, param, {}, | |||||
OperatorNodeConfig{output_dtype}); | |||||
y = opr::TypeCvt::make(y, dtype::Float32()); | |||||
SymbolVar y_param_fused; | |||||
unpack_vector(gopt::GraphOptimizer{} | |||||
.add_pass<gopt::ParamFusePass>() | |||||
.apply({{y}}) | |||||
.endpoint_vars(), | |||||
y_param_fused); | |||||
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str())); | |||||
auto func = graph->compile({make_callback_copy(y_param_fused, y1)}); | |||||
func->execute(); | |||||
func->wait(); | |||||
auto rst = dumper->dump({y_param_fused}); | |||||
ASSERT_EQ(rst.outputs.size(), 1u); | |||||
}; | |||||
auto load = [&]() { | |||||
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str())); | |||||
auto rst = loader->load(); | |||||
for (const auto& t : rst.tensor_map) { | |||||
t.second->copy_from(*xv).sync(); | |||||
} | |||||
auto func = rst.graph->compile( | |||||
{make_callback_copy(rst.output_var_list[0], y2)}); | |||||
func->execute(); | |||||
func->wait(); | |||||
ASSERT_EQ(rst.output_var_list.size(), 1u); | |||||
EXPECT_EQ(rst.output_var_list[0].dtype(), dtype::Float32()); | |||||
}; | |||||
dump(); | |||||
load(); | |||||
MGB_ASSERT_TENSOR_NEAR(y1, y2, 1e-3); | |||||
} | |||||
} // namespace | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |