Browse Source

refactor(dnn): refactor lowbit tensor format

GitOrigin-RevId: b646dc085b
release-1.5
Megvii Engine Team 4 years ago
parent
commit
3b9b87809d
21 changed files with 497 additions and 127 deletions
  1. +4
    -0
      dnn/include/megdnn/basic_types.h
  2. +21
    -18
      dnn/include/megdnn/tensor_format.h
  3. +3
    -10
      dnn/src/common/basic_types.cpp
  4. +1
    -2
      dnn/src/common/convolution.cpp
  5. +56
    -45
      dnn/src/common/tensor_format.cpp
  6. +3
    -0
      dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
  7. +16
    -7
      dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp
  8. +4
    -0
      dnn/src/cuda/conv_bias/helper.cpp
  9. +6
    -5
      dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp
  10. +3
    -5
      dnn/test/common/test_basic_types.cpp
  11. +9
    -4
      src/core/impl/graph/operator_node.cpp
  12. +16
    -9
      src/core/impl/graph/var_node_mem_mgr.cpp
  13. +20
    -9
      src/core/impl/tensor.cpp
  14. +2
    -1
      src/core/impl/utils/persistent_cache.cpp
  15. +4
    -1
      src/core/include/megbrain/tensor.h
  16. +4
    -2
      src/gopt/impl/inference.cpp
  17. +7
    -2
      src/opr/impl/basic_arith.cpp
  18. +3
    -4
      src/opr/impl/io.cpp
  19. +10
    -2
      src/opr/impl/search_policy/algo_chooser.cpp
  20. +3
    -1
      src/opr/impl/search_policy/profiler.cpp
  21. +302
    -0
      src/opr/test/dnn/convolution.cpp

+ 4
- 0
dnn/include/megdnn/basic_types.h View File

@@ -170,6 +170,7 @@ struct TensorLayout : public TensorShape {


#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
Format(); Format();
Format(DType dtype);


const ImplBase* impl() const { return m_impl; } const ImplBase* impl() const { return m_impl; }


@@ -198,6 +199,9 @@ struct TensorLayout : public TensorShape {
//! whether this is the default tensor format //! whether this is the default tensor format
bool is_default() const; bool is_default() const;


//! whether this is the lowbit aligned to bytes tensor format
bool is_lowbit_aligned() const;

bool operator==(Format rhs) const { return m_impl == rhs.m_impl; } bool operator==(Format rhs) const { return m_impl == rhs.m_impl; }
bool operator!=(Format rhs) const { return m_impl != rhs.m_impl; } bool operator!=(Format rhs) const { return m_impl != rhs.m_impl; }
#endif #endif


+ 21
- 18
dnn/include/megdnn/tensor_format.h View File

@@ -20,7 +20,7 @@ namespace megdnn {
enum class TensorFormat::Type { enum class TensorFormat::Type {
DEFAULT = 0, //!< see DefaultTensorFormat DEFAULT = 0, //!< see DefaultTensorFormat
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat
FOURBITS_ALIGNED_TO_BYTE = 2, //!<
LOWBITS_ALIGNED_TO_BYTE = 2, //!<
}; };


class TensorFormat::ImplBase { class TensorFormat::ImplBase {
@@ -205,21 +205,23 @@ using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>;
/*! /*!
* \brief used for tensors storing lowbit data * \brief used for tensors storing lowbit data
* *
* \p SIZE_NBITS is the size in bits of element of the tensor.
*
* \param m_size_nbits size in bits of elements in the tensor
* \param m_align_size_in_bits aligned size in bits
* \param m_align_size_in_elements aligned size in elements
*/ */
template <size_t SIZE_NBITS_>
class LowbitsTensorFormatBase : public TensorFormat::ImplBase {
static constexpr size_t SIZE_NBITS = SIZE_NBITS_;
size_t m_align_size_in_bits, m_align_size_in_elements;
class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase {
size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements;


protected: //? protected: //?
LowbitsTensorFormatBase(Type type, size_t align_size_in_bits);
LowbitsAlignedTensorFormatBase(Type type, size_t size_nbits,
size_t align_size_in_bits);


virtual ~LowbitsTensorFormatBase() = default;
virtual ~LowbitsAlignedTensorFormatBase() = default;


public: public:
size_t align_size_in_bits() const { return m_align_size_in_bits; } size_t align_size_in_bits() const { return m_align_size_in_bits; }
size_t size_nbits() const { return m_size_nbits; }


std::string to_string() const override; std::string to_string() const override;


@@ -240,10 +242,10 @@ public:
const TensorLayout& layout) const override; const TensorLayout& layout) const override;
protected: protected:
struct SerializePack { struct SerializePack {
uint8_t size_nbits;
uint8_t align_size_in_bits; uint8_t align_size_in_bits;
}; };
}; };
using FourBitsAlignedToBytesTensorFormatBase = LowbitsTensorFormatBase<4>;
} // namespace detail } // namespace detail


/*! /*!
@@ -296,19 +298,20 @@ private:
* \brief Tensor for storing 4bit data that requires stride corresponding to * \brief Tensor for storing 4bit data that requires stride corresponding to
* non-innermost dimension to be aligned to bytes, and pack 2 elems into a byte * non-innermost dimension to be aligned to bytes, and pack 2 elems into a byte
*/ */
class FourBitsAlignedToBytesTensorFormat final
: public detail::FourBitsAlignedToBytesTensorFormatBase {
class LowbitsAlignedToBytesTensorFormat final
: public detail::LowbitsAlignedTensorFormatBase {
public: public:
static constexpr Type TYPE = Type::FOURBITS_ALIGNED_TO_BYTE;
static constexpr Type TYPE = Type::LOWBITS_ALIGNED_TO_BYTE;
static constexpr size_t BYTE_IN_BITS = 8;


static TensorFormat make(size_t align_size_in_bits);
static TensorFormat make(size_t size_nbits);


static TensorFormat deserialize(const Handle* handle, const void* buf, static TensorFormat deserialize(const Handle* handle, const void* buf,
size_t size); size_t size);


static bool is_valid_layout(const TensorLayout& layout) { static bool is_valid_layout(const TensorLayout& layout) {
if (layout.format.type() == TYPE) { if (layout.format.type() == TYPE) {
layout.format.as_impl<FourBitsAlignedToBytesTensorFormat>()
layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>()
.assert_valid(layout); .assert_valid(layout);
return true; return true;
} }
@@ -316,9 +319,9 @@ public:
} }


private: private:
FourBitsAlignedToBytesTensorFormat(size_t align_size_in_bits)
: detail::FourBitsAlignedToBytesTensorFormatBase(
TYPE, align_size_in_bits) {}
LowbitsAlignedToBytesTensorFormat(size_t size_nbits)
: detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits,
BYTE_IN_BITS) {}
}; };
} // namespace megdnn } // namespace megdnn




+ 3
- 10
dnn/src/common/basic_types.cpp View File

@@ -195,21 +195,14 @@ bool TensorShape::is_empty() const {
/* ===================== TensorLayout ===================== */ /* ===================== TensorLayout ===================== */
TensorLayout::TensorLayout() = default; TensorLayout::TensorLayout() = default;


TensorLayout::TensorLayout(DType dtype_) : dtype{dtype_} {}
TensorLayout::TensorLayout(DType dtype_)
: dtype{dtype_}, format{Format(dtype)} {}


TensorLayout::TensorLayout(DType dtype_, Format format_) TensorLayout::TensorLayout(DType dtype_, Format format_)
: dtype{dtype_}, format{format_} {} : dtype{dtype_}, format{format_} {}


TensorLayout::TensorLayout(const TensorShape& shape, DType dtype) TensorLayout::TensorLayout(const TensorShape& shape, DType dtype)
: TensorShape(shape), dtype{dtype} {
if (dtype.low_bit() == 4_z) {
format = FourBitsAlignedToBytesTensorFormat::make(8_z);
} else {
megdnn_assert(!dtype.is_low_bit(), "Unsupported data type(%s)",
dtype.name());
format = DefaultTensorFormat::make();
}
}
: TensorLayout(shape, dtype, Format(dtype)) {}


TensorLayout::TensorLayout(const TensorShape& shape, DType dtype, TensorLayout::TensorLayout(const TensorShape& shape, DType dtype,
TensorFormat format_) TensorFormat format_)


+ 1
- 2
dnn/src/common/convolution.cpp View File

@@ -722,7 +722,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
megdnn_assert(src.ndim == 5 && megdnn_assert(src.ndim == 5 &&
(filter.ndim == 5 || filter.ndim == 6) && (filter.ndim == 5 || filter.ndim == 6) &&
src[src.ndim - 1] == 64 && src[src.ndim - 1] == 64 &&
filter[filter.ndim - 1] == 4,
filter[filter.ndim - 1] == 64,
"NCHW64 require src and filter's ndim is 5 or 6, and " "NCHW64 require src and filter's ndim is 5 or 6, and "
"last shape is 64 but got src %s, filter %s", "last shape is 64 but got src %s, filter %s",
src.to_string().c_str(), filter.to_string().c_str()); src.to_string().c_str(), filter.to_string().c_str());
@@ -754,7 +754,6 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
src[i + src_or_dst_spatial_start], cflt.dilated_spatial[i], src[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
cflt.stride[i], cflt.padding[i]); cflt.stride[i], cflt.padding[i]);
} }
dst.init_contiguous_stride();
} else if (param().format == Param::Format::NCHW4) { } else if (param().format == Param::Format::NCHW4) {
megdnn_assert(src.ndim == 5, megdnn_assert(src.ndim == 5,
"invalid src ndim for NCHW4, expected=5, got=%zu", "invalid src ndim for NCHW4, expected=5, got=%zu",


+ 56
- 45
dnn/src/common/tensor_format.cpp View File

@@ -35,8 +35,8 @@ TensorFormat TensorFormat::deserialize(const std::string& bin,
case Type::IMAGE2D_PACK4: case Type::IMAGE2D_PACK4:
return Image2DPack4TensorFormat::deserialize( return Image2DPack4TensorFormat::deserialize(
handle, type + 1, bin.size() - sizeof(Type)); handle, type + 1, bin.size() - sizeof(Type));
case Type::FOURBITS_ALIGNED_TO_BYTE:
return FourBitsAlignedToBytesTensorFormat::deserialize(
case Type::LOWBITS_ALIGNED_TO_BYTE:
return LowbitsAlignedToBytesTensorFormat::deserialize(
handle, type + 1, bin.size() - sizeof(Type)); handle, type + 1, bin.size() - sizeof(Type));
default: default:
megdnn_throw("invalid tensor format type in deserialize"); megdnn_throw("invalid tensor format type in deserialize");
@@ -45,6 +45,19 @@ TensorFormat TensorFormat::deserialize(const std::string& bin,


TensorFormat::Format() : m_impl{DefaultTensorFormat::make().m_impl} {} TensorFormat::Format() : m_impl{DefaultTensorFormat::make().m_impl} {}


TensorFormat::Format(DType dtype) {
megdnn_assert(dtype.valid());
if (dtype.is_low_bit()) {
size_t size_nbits = dtype.low_bit();
megdnn_assert(size_nbits == 1 || size_nbits == 2 || size_nbits == 4,
"unsupported lowbits data type(%s, size in bits: %zu)",
dtype.name(), size_nbits);
m_impl = LowbitsAlignedToBytesTensorFormat::make(size_nbits).m_impl;
} else {
m_impl = DefaultTensorFormat::make().m_impl;
}
}

std::string TensorFormat::to_string() const { std::string TensorFormat::to_string() const {
return m_impl->to_string(); return m_impl->to_string();
} }
@@ -69,6 +82,10 @@ bool TensorFormat::is_default() const {
return m_impl == default_tensor_format_obj; return m_impl == default_tensor_format_obj;
} }


bool TensorFormat::is_lowbit_aligned() const {
return type() == TensorFormat::Type::LOWBITS_ALIGNED_TO_BYTE;
}

/* ===================== DefaultFormat ===================== */ /* ===================== DefaultFormat ===================== */
void DefaultTensorFormat::assert_valid(const TensorLayout& layout) const { void DefaultTensorFormat::assert_valid(const TensorLayout& layout) const {
megdnn_assert( megdnn_assert(
@@ -440,27 +457,26 @@ template class Image2DPackedTensorFormatBase<4>;
} // namespace detail } // namespace detail
} // namespace megdnn } // namespace megdnn


/* =============== FourBitsAlignedToBytesTensorFormatBase ============== */
template <size_t SIZE_NBITS>
LowbitsTensorFormatBase<SIZE_NBITS>::LowbitsTensorFormatBase(
Type type, size_t align_size_in_bits)
: ImplBase(type), m_align_size_in_bits(align_size_in_bits) {
megdnn_assert(!(m_align_size_in_bits % SIZE_NBITS),
/* =============== LowbitsAlignedTensorFormatBase ============== */
LowbitsAlignedTensorFormatBase::LowbitsAlignedTensorFormatBase(
Type type, size_t size_nbits, size_t align_size_in_bits)
: ImplBase(type),
m_size_nbits(size_nbits),
m_align_size_in_bits(align_size_in_bits) {
megdnn_assert(!(m_align_size_in_bits % m_size_nbits),
"align size(%zu) must be a multiple of element size(%zu)", "align size(%zu) must be a multiple of element size(%zu)",
m_align_size_in_bits, SIZE_NBITS);
m_align_size_in_elements = m_align_size_in_bits / SIZE_NBITS;
m_align_size_in_bits, m_size_nbits);
m_align_size_in_elements = m_align_size_in_bits / m_size_nbits;
} }


template <size_t SIZE_NBITS>
std::string LowbitsTensorFormatBase<SIZE_NBITS>::to_string() const {
return ssprintf("LOWBITS{%zu,%zu}", SIZE_NBITS, m_align_size_in_bits);
std::string LowbitsAlignedTensorFormatBase::to_string() const {
return ssprintf("LOWBITS{%zu,%zu}", m_size_nbits, m_align_size_in_bits);
} }


template <size_t SIZE_NBITS>
void LowbitsTensorFormatBase<SIZE_NBITS>::assert_valid(
void LowbitsAlignedTensorFormatBase::assert_valid(
const TensorLayout& layout) const { const TensorLayout& layout) const {
megdnn_assert(layout.dtype.valid() && layout.dtype.is_low_bit() && megdnn_assert(layout.dtype.valid() && layout.dtype.is_low_bit() &&
layout.dtype.low_bit() == SIZE_NBITS);
layout.dtype.low_bit() == m_size_nbits);
bool has_dim_unity_stride = false; bool has_dim_unity_stride = false;
for (int i = layout.ndim - 1; i >= 0; --i) { for (int i = layout.ndim - 1; i >= 0; --i) {
if (!has_dim_unity_stride && layout.stride[i] == 1) if (!has_dim_unity_stride && layout.stride[i] == 1)
@@ -469,23 +485,28 @@ void LowbitsTensorFormatBase<SIZE_NBITS>::assert_valid(
layout.stride[i] >= 0 && layout.stride[i] >= 0 &&
(layout.stride[i] % m_align_size_in_elements == 0 || (layout.stride[i] % m_align_size_in_elements == 0 ||
layout.stride[i] == 1), layout.stride[i] == 1),
"bad stride: %zu", layout.stride[i]);
"bad stride:%s, %zu", layout.to_string().c_str(),
layout.stride[i]);
} }
megdnn_assert(has_dim_unity_stride, "innermost dim not contiguous");
/// FIXME
if (layout.ndim == 0) {
printf("%s\n", layout.to_string().c_str());
}
megdnn_assert(layout.ndim == 0 || has_dim_unity_stride,
"innermost dim not contiguous");
} }


template <size_t SIZE_NBITS>
void LowbitsTensorFormatBase<SIZE_NBITS>::serialize_append(
void LowbitsAlignedTensorFormatBase::serialize_append(
std::string& result) const { std::string& result) const {
SerializePack pack; SerializePack pack;
pack.size_nbits = m_size_nbits;
pack.align_size_in_bits = m_align_size_in_bits; pack.align_size_in_bits = m_align_size_in_bits;
megdnn_assert(pack.align_size_in_bits == megdnn_assert(pack.align_size_in_bits ==
m_align_size_in_bits); // detect overflow; m_align_size_in_bits); // detect overflow;
result.append(reinterpret_cast<char*>(&pack), sizeof(pack)); result.append(reinterpret_cast<char*>(&pack), sizeof(pack));
} }


template <size_t SIZE_NBITS>
TensorLayout::Span LowbitsTensorFormatBase<SIZE_NBITS>::span_spec(
TensorLayout::Span LowbitsAlignedTensorFormatBase::span_spec(
const TensorLayout& layout) const { const TensorLayout& layout) const {
assert_valid(layout); assert_valid(layout);
if (layout.ndim == 0) if (layout.ndim == 0)
@@ -507,8 +528,7 @@ TensorLayout::Span LowbitsTensorFormatBase<SIZE_NBITS>::span_spec(
return TensorLayout::Span(0, 0, high_elem, high_byte); return TensorLayout::Span(0, 0, high_elem, high_byte);
} }


template <size_t SIZE_NBITS>
size_t LowbitsTensorFormatBase<SIZE_NBITS>::init_contiguous_stride(
size_t LowbitsAlignedTensorFormatBase::init_contiguous_stride(
TensorLayout& layout) const { TensorLayout& layout) const {
if (!layout.ndim) if (!layout.ndim)
return 0; return 0;
@@ -525,8 +545,7 @@ size_t LowbitsTensorFormatBase<SIZE_NBITS>::init_contiguous_stride(
return accum; return accum;
} }


template <size_t SIZE_NBITS>
bool LowbitsTensorFormatBase<SIZE_NBITS>::is_contiguous_spec(
bool LowbitsAlignedTensorFormatBase::is_contiguous_spec(
const TensorLayout& layout) const { const TensorLayout& layout) const {
assert_valid(layout); assert_valid(layout);
ptrdiff_t expected = 1; ptrdiff_t expected = 1;
@@ -541,8 +560,7 @@ bool LowbitsTensorFormatBase<SIZE_NBITS>::is_contiguous_spec(
return expected != 0; return expected != 0;
} }


template <size_t SIZE_NBITS>
TensorLayout LowbitsTensorFormatBase<SIZE_NBITS>::collapse_contiguous_spec(
TensorLayout LowbitsAlignedTensorFormatBase::collapse_contiguous_spec(
const TensorLayout& layout) const { const TensorLayout& layout) const {
assert_valid(layout); assert_valid(layout);
TensorLayout res{layout}; TensorLayout res{layout};
@@ -572,12 +590,6 @@ TensorLayout LowbitsTensorFormatBase<SIZE_NBITS>::collapse_contiguous_spec(
return res; return res;
} }


namespace megdnn {
namespace detail {
template class LowbitsTensorFormatBase<4>;
} // namespace detail
} // namespace megdnn

/* ===================== Image2DPack4TensorFormat ===================== */ /* ===================== Image2DPack4TensorFormat ===================== */
TensorFormat Image2DPack4TensorFormat::make_raw( TensorFormat Image2DPack4TensorFormat::make_raw(
size_t align_axis, size_t align_size_in_elements, size_t align_axis, size_t align_size_in_elements,
@@ -616,29 +628,28 @@ TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const {
return make_raw(axis, align_size_in_elements(), vendor()); return make_raw(axis, align_size_in_elements(), vendor());
} }


/* ===================== FourBitsAlignedToBytesTensorFormat
/* ===================== LowbitsitsAlignedToBytesTensorFormat
* ===================== */ * ===================== */
TensorFormat FourBitsAlignedToBytesTensorFormat::make(
size_t align_size_in_bits) {
TensorFormat LowbitsAlignedToBytesTensorFormat::make(size_t size_nbits) {
static std::mutex mtx; static std::mutex mtx;
static std::unordered_map< static std::unordered_map<
uint32_t, std::unique_ptr<FourBitsAlignedToBytesTensorFormat>>
uint64_t, std::unique_ptr<LowbitsAlignedToBytesTensorFormat>>
cache; cache;
megdnn_assert(!(align_size_in_bits % 4));
megdnn_assert(!(8 % size_nbits));
MEGDNN_LOCK_GUARD(mtx); MEGDNN_LOCK_GUARD(mtx);
auto&& ptr = cache[static_cast<uint32_t>(align_size_in_bits)];
auto&& ptr = cache[static_cast<uint32_t>(size_nbits)];
if (!ptr) { if (!ptr) {
ptr.reset(new FourBitsAlignedToBytesTensorFormat{align_size_in_bits});
ptr.reset(new LowbitsAlignedToBytesTensorFormat{size_nbits});
} }
return impl_to_tensor_format(ptr.get()); return impl_to_tensor_format(ptr.get());
} }


TensorFormat FourBitsAlignedToBytesTensorFormat::deserialize(const Handle*,
const void* buf,
size_t size) {
TensorFormat LowbitsAlignedToBytesTensorFormat::deserialize(const Handle*,
const void* buf,
size_t size) {
megdnn_assert(size == sizeof(SerializePack)); megdnn_assert(size == sizeof(SerializePack));
auto pack = *static_cast<const SerializePack*>(buf); auto pack = *static_cast<const SerializePack*>(buf);
return make(pack.align_size_in_bits);
return make(pack.size_nbits);
} }


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 3
- 0
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp View File

@@ -24,6 +24,9 @@ using namespace conv_bias;


bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 &&
args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4)
return false;
if (args.src_layout->dtype == args.filter_layout->dtype && if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) { args.src_layout->dtype == dtype::BFloat16()) {
return false; return false;


+ 16
- 7
dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp View File

@@ -103,15 +103,18 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec(
TensorND src_{ws_src, layouts[0]}, filter_{ws_filter, layouts[1]}, TensorND src_{ws_src, layouts[0]}, filter_{ws_filter, layouts[1]},
bias_{args.bias_tensor->raw_ptr, layouts[2]}, z_{ws_z, layouts[3]}, bias_{args.bias_tensor->raw_ptr, layouts[2]}, z_{ws_z, layouts[3]},
dst_{ws_dst, layouts[4]}; dst_{ws_dst, layouts[4]};
ExecArgs args_{args.opr,
auto conv_op = args.opr->handle()->create_operator<ConvBiasForward>();
conv_op->param() = args.opr->param();
using Format = param::ConvBias::Format;
conv_op->param().format = Format::NCHW64;
ExecArgs args_{dynamic_cast<ConvBiasForwardImpl*>(conv_op.get()),
src_, src_,
filter_, filter_,
bias_, bias_,
z_, z_,
dst_, dst_,
ws.get_workspace(3),
args.preprocessed_filter};
m_underlying_algo.exec(args);
ws.get_workspace(3)};
m_underlying_algo.exec(args_);
// reformat dst // reformat dst
nchw642nchw(dst_, {args.dst_tensor->raw_ptr, args.dst_tensor->layout}); nchw642nchw(dst_, {args.dst_tensor->raw_ptr, args.dst_tensor->layout});
} }
@@ -134,6 +137,9 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout(
rst.emplace_back(TensorLayout{}); rst.emplace_back(TensorLayout{});
} }
rst.emplace_back(TensorLayout{{n, co / 64, ho, wo, 64}, dst.dtype}); rst.emplace_back(TensorLayout{{n, co / 64, ho, wo, 64}, dst.dtype});
for (auto& i : rst) {
i.init_contiguous_stride();
}
return rst; return rst;
} }


@@ -145,13 +151,16 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_bundle(
auto layouts = make_underlying_tensor_layout( auto layouts = make_underlying_tensor_layout(
*(args.src_layout), *(args.filter_layout), *(args.bias_layout), *(args.src_layout), *(args.filter_layout), *(args.bias_layout),
*(args.z_layout), *(args.dst_layout)); *(args.z_layout), *(args.dst_layout));
SizeArgs args_{args.opr,
auto conv_op = args.opr->handle()->create_operator<ConvBiasForward>();
conv_op->param() = args.opr->param();
using Format = param::ConvBias::Format;
conv_op->param().format = Format::NCHW64;
SizeArgs args_{dynamic_cast<ConvBiasForwardImpl*>(conv_op.get()),
layouts[0], layouts[0],
layouts[1], layouts[1],
layouts[2], layouts[2],
layouts[3], layouts[3],
layouts[4],
args.preprocessed_filter};
layouts[4]};
size_t ws_size_underlying_algo = size_t ws_size_underlying_algo =
m_underlying_algo.get_workspace_in_bytes(args_); m_underlying_algo.get_workspace_in_bytes(args_);
if (args.z_layout->ndim > 0) { if (args.z_layout->ndim > 0) {


+ 4
- 0
dnn/src/cuda/conv_bias/helper.cpp View File

@@ -136,6 +136,10 @@ void ConvBiasDesc::set_conv(DType data_type, const param::ConvBias& param,
namespace conv_bias { namespace conv_bias {


bool is_cudnn_supported(const BiasForwardSizeArgs& args) { bool is_cudnn_supported(const BiasForwardSizeArgs& args) {
if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 &&
args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4)
return false;

if (args.src_layout->dtype == args.filter_layout->dtype && if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) { args.src_layout->dtype == dtype::BFloat16()) {
return false; return false;


+ 6
- 5
dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp View File

@@ -72,11 +72,11 @@ std::string ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::kernel_key(
auto&& param = args.opr->param(); auto&& param = args.opr->param();
if (args.z_layout->ndim > 0) { if (args.z_layout->ndim > 0) {
kernel_key = kernel_key =
ssprintf("%s_conv_bias_int4_fuse_z_imma_ldg16_%ux%u",
ssprintf("%s_conv_bias_int4_fuse_z_imma8832_ldg16_%ux%u",
current_device_arch_name(), m_tile_nhw, m_tile_oc); current_device_arch_name(), m_tile_nhw, m_tile_oc);
} else { } else {
kernel_key = kernel_key =
ssprintf("%s_conv_bias_int4_imma_ldg16_%ux%u",
ssprintf("%s_conv_bias_int4_imma8832_ldg16_%ux%u",
current_device_arch_name(), m_tile_nhw, m_tile_oc); current_device_arch_name(), m_tile_nhw, m_tile_oc);
} }
if (param.nonlineMode == NonlineMode::H_SWISH) { if (param.nonlineMode == NonlineMode::H_SWISH) {
@@ -170,7 +170,7 @@ void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec(
reorder_imma_filter_bias<4, 64>( reorder_imma_filter_bias<4, 64>(
reinterpret_cast<int8_t*>(filter_ptr), reinterpret_cast<int8_t*>(filter_ptr),
reinterpret_cast<int32_t*>(bias_ptr), reinterpret_cast<int32_t*>(bias_ptr),
args.filter_tensor->compatible_ptr<int8_t>(),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr),
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw, args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw,
stream); stream);
} }
@@ -292,9 +292,10 @@ void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec_preprocess(
param); param);
auto&& stream = cuda_stream(args.opr->handle()); auto&& stream = cuda_stream(args.opr->handle());
reorder_imma_filter_bias<4, 64>( reorder_imma_filter_bias<4, 64>(
args.preprocessed_filter->tensors[0].compatible_ptr<int8_t>(),
reinterpret_cast<int8_t*>(
args.preprocessed_filter->tensors[0].raw_ptr),
args.preprocessed_filter->tensors[1].compatible_ptr<int32_t>(), args.preprocessed_filter->tensors[1].compatible_ptr<int32_t>(),
args.filter_tensor->compatible_ptr<int8_t>(),
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr),
args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw, args.bias_tensor->compatible_ptr<int32_t>(), co, ci, fh, fw,
stream); stream);
} }


+ 3
- 5
dnn/test/common/test_basic_types.cpp View File

@@ -320,7 +320,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS) {


layout = make_layout({16, 32, 7, 7}, {1792, 56, 8, 1}, layout = make_layout({16, 32, 7, 7}, {1792, 56, 8, 1},
dtype::QuantizedS4{1.3f}); dtype::QuantizedS4{1.3f});
layout.format = FourBitsAlignedToBytesTensorFormat::make(8_z);
layout.format = LowbitsAlignedToBytesTensorFormat::make(4_z);
EXPECT_TRUE(layout.is_contiguous()); EXPECT_TRUE(layout.is_contiguous());


layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}}; layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}};
@@ -339,12 +339,10 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS_VALID) {
DefaultTensorFormat::make()), DefaultTensorFormat::make()),
MegDNNError); MegDNNError);
ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS32{1.2f}, ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS32{1.2f},
FourBitsAlignedToBytesTensorFormat::make(8_z))
.span(),
LowbitsAlignedToBytesTensorFormat::make(4_z)),
MegDNNError); MegDNNError);
ASSERT_THROW(TensorLayout({16, 32, 7, 7}, dtype::IntB2{}, ASSERT_THROW(TensorLayout({16, 32, 7, 7}, dtype::IntB2{},
FourBitsAlignedToBytesTensorFormat::make(8_z))
.span(),
LowbitsAlignedToBytesTensorFormat::make(2_z)),
MegDNNError); MegDNNError);
} }




+ 9
- 4
src/core/impl/graph/operator_node.cpp View File

@@ -338,21 +338,26 @@ void OperatorNodeBase::init_output_format() {
TensorFormat format, default_; TensorFormat format, default_;
for (auto i : input()) { for (auto i : input()) {
auto cur = i->format(); auto cur = i->format();
if (cur != default_) {
if (!cur.is_default() && !cur.is_lowbit_aligned()) {
if (format == default_) { if (format == default_) {
format = cur; format = cur;
} else { } else {
mgb_assert(format == cur, mgb_assert(format == cur,
"multiple non-default formats in inputs: %s vs %s",
"multiple non-default or non-lowbits aligned "
"formats in inputs: %s vs %s",
format.to_string().c_str(), cur.to_string().c_str()); format.to_string().c_str(), cur.to_string().c_str());
} }
} }
} }
for (auto i : output()) { for (auto i : output()) {
if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
i->format(default_);
mgb_assert(format.is_default());
i->format(TensorFormat(i->dtype()));
} else { } else {
i->format(format);
if (!format.is_default())
i->format(format);
else
i->format(TensorFormat(i->dtype()));
} }
} }
} }


+ 16
- 9
src/core/impl/graph/var_node_mem_mgr.cpp View File

@@ -1063,15 +1063,22 @@ bool VarNodeMemManager::fwd_in2out_readonly(
return false; return false;
} }


mgb_assert(
src != dest &&
src->comp_node().mem_node() == dest->comp_node().mem_node() &&
dest->m_mem_plan.valid() && src->m_mem_plan.valid() &&
dest->m_mem_plan.layout().eq_shape(sub.layout()) &&
dest->m_mem_plan.layout().dtype.size() == sub.layout().dtype.size()
);
assert_in_mem_opt_phase(
SeqMemOptimizer::Status::ALLOW_FWD_IN2OUT_READONLY);
bool cond_low_bit = dest->m_mem_plan.layout().dtype.is_low_bit() &&
sub.layout().dtype.is_low_bit() &&
dest->m_mem_plan.layout().dtype.low_bit() ==
sub.layout().dtype.low_bit();
bool cond_normal =
!dest->m_mem_plan.layout().dtype.is_low_bit() &&
!sub.layout().dtype.is_low_bit() &&
dest->m_mem_plan.layout().dtype.size() == sub.layout().dtype.size();
MGB_MARK_USED_VAR(cond_low_bit);
MGB_MARK_USED_VAR(cond_normal);
mgb_assert(src != dest &&
src->comp_node().mem_node() == dest->comp_node().mem_node() &&
dest->m_mem_plan.valid() && src->m_mem_plan.valid() &&
dest->m_mem_plan.layout().eq_shape(sub.layout()) &&
(cond_normal || cond_low_bit));
assert_in_mem_opt_phase(SeqMemOptimizer::Status::ALLOW_FWD_IN2OUT_READONLY);


if (!m_owner_graph->options().seq_opt.enable_mem_plan_opt) if (!m_owner_graph->options().seq_opt.enable_mem_plan_opt)
return false; return false;


+ 20
- 9
src/core/impl/tensor.cpp View File

@@ -443,8 +443,8 @@ TensorND<TensorStorage>::name


DEF(resize, &)(const TensorShape& shape) { DEF(resize, &)(const TensorShape& shape) {
mgb_assert(m_layout.dtype.valid()); mgb_assert(m_layout.dtype.valid());
auto nr_elems = m_layout.init_contiguous_stride(shape);
m_storage.ensure_size(m_layout.dtype.size(nr_elems));
m_layout = TensorLayout(shape, m_layout.dtype);
m_storage.ensure_size(m_layout.span().dist_byte());
return static_cast<ChainReturnType&>(*this); return static_cast<ChainReturnType&>(*this);
} }


@@ -584,15 +584,19 @@ TensorND<TensorStorage>::copy_from(const TensorND<RStorage> &src) {
m_layout.dtype.assert_is(src.dtype()); m_layout.dtype.assert_is(src.dtype());
else else
m_layout.dtype = src.dtype(); m_layout.dtype = src.dtype();
m_layout.format = {};


size_t size_bytes = dtype().size(
m_layout.init_contiguous_stride(src.shape()));
m_layout = TensorLayout(src.shape(), m_layout.dtype);
size_t size_bytes = m_layout.span().dist_byte();
m_storage.ensure_size(size_bytes); m_storage.ensure_size(size_bytes);
if (!size_bytes) { if (!size_bytes) {
return static_cast<ChainReturnType&>(*this); return static_cast<ChainReturnType&>(*this);
} }
if (src.layout().is_physical_contiguous()) {
// requirement:
// default case, physical contiguous
// lowbit aligned, logical contiguous
if (src.layout().is_physical_contiguous() ||
(src.layout().format.is_lowbit_aligned() &&
src.layout().is_contiguous())) {
if (should_check_overlap(*this, src)) { if (should_check_overlap(*this, src)) {
check_overlapped(m_storage.ptr(), check_overlapped(m_storage.ptr(),
m_storage.ptr() + size_bytes, m_storage.ptr() + size_bytes,
@@ -635,10 +639,17 @@ TensorND<TensorStorage>::copy_from_fixlayout(
src.raw_ptr() + src_span.high_byte); src.raw_ptr() + src_span.high_byte);
} }


bool self_contig = m_layout.is_physical_contiguous(),
src_contig = src.layout().is_physical_contiguous();
bool self_contig = m_layout.is_physical_contiguous() ||
(m_layout.format.is_lowbit_aligned() &&
m_layout.is_contiguous()),
src_contig = src.layout().is_physical_contiguous() ||
(m_layout.format.is_lowbit_aligned() &&
m_layout.is_contiguous());
if (self_contig && src_contig) { if (self_contig && src_contig) {
if (m_layout.format.is_default() && src.layout().format.is_default()) {
if ((m_layout.format.is_default() &&
src.layout().format.is_default()) ||
(m_layout.format.is_lowbit_aligned() &&
src.layout().format.is_lowbit_aligned())) {
mgb_assert(src_span.low_byte == 0 && dst_span.low_byte == 0 && mgb_assert(src_span.low_byte == 0 && dst_span.low_byte == 0 &&
src_span.high_byte == dst_span.high_byte); src_span.high_byte == dst_span.high_byte);
m_storage.copy_from(src.storage(), src_span.high_byte); m_storage.copy_from(src.storage(), src_span.high_byte);


+ 2
- 1
src/core/impl/utils/persistent_cache.cpp View File

@@ -261,7 +261,8 @@ PersistentCache::Blob AlgoChooserProfileCache::Key::build_blob() const {
ret.push_back(';'); ret.push_back(';');
ret.append(ly.dtype.name()); ret.append(ly.dtype.name());
ret.push_back('|'); ret.push_back('|');
mgb_assert(ly.format.is_default(),
mgb_assert(ly.format.is_default() || (ly.format.is_lowbit_aligned() &&
ly.dtype.is_low_bit()),
"currently only default format is supported"); "currently only default format is supported");
} }
if (m_param_size) { if (m_param_size) {


+ 4
- 1
src/core/include/megbrain/tensor.h View File

@@ -68,7 +68,10 @@ class SubTensorSpec {


//! get offset measured in bytes //! get offset measured in bytes
ptrdiff_t offset_byte() const { ptrdiff_t offset_byte() const {
return m_offset_elem * m_layout.dtype.size();
//! for lowbit cases, offset must aligned to bytes
mgb_assert(!m_layout.dtype.is_low_bit() ||
!(m_offset_elem * m_layout.dtype.low_bit() % 8));
return m_layout.dtype.size(m_offset_elem);
} }


/*! /*!


+ 4
- 2
src/gopt/impl/inference.cpp View File

@@ -554,14 +554,16 @@ void ParamFusePass::apply(OptState &state) const {


SymbolVar new_var; SymbolVar new_var;
bool is_default_format = var->format().is_default(); bool is_default_format = var->format().is_default();
if (cg::is_static_var_value(var) && is_default_format) {
bool is_lowbit_aligned = var->format().is_lowbit_aligned();
if (cg::is_static_var_value(var) &&
(is_default_format || is_lowbit_aligned)) {
// use ImmutableTensor for inferable vars // use ImmutableTensor for inferable vars
HostTensorND hv; HostTensorND hv;
hv.copy_from(*inferred_val).sync(); hv.copy_from(*inferred_val).sync();
new_var = opr::ImmutableTensor::make( new_var = opr::ImmutableTensor::make(
*var->owner_graph(), hv, var_namer.name(var)); *var->owner_graph(), hv, var_namer.name(var));
} else { } else {
if (is_default_format) {
if (is_default_format || is_lowbit_aligned) {
new_var = opr::SharedDeviceTensor::make_const( new_var = opr::SharedDeviceTensor::make_const(
*var->owner_graph(), inferred_val, var_namer.name(var)); *var->owner_graph(), inferred_val, var_namer.name(var));
} else { } else {


+ 7
- 2
src/opr/impl/basic_arith.cpp View File

@@ -814,8 +814,13 @@ MGB_IMPL_OPR_GRAD(TypeCvt) {
#endif #endif


void TypeCvt::mem_plan_fwd_in2out_writable() { void TypeCvt::mem_plan_fwd_in2out_writable() {
if (input(0)->dtype().size() == output(0)->dtype().size() &&
input(0)->layout().is_contiguous()) {
bool cond_low_bit =
input(0)->dtype().is_low_bit() && output(0)->dtype().is_low_bit() &&
input(0)->dtype().low_bit() == output(0)->dtype().low_bit();
bool cond_normal = !input(0)->dtype().is_low_bit() &&
!output(0)->dtype().is_low_bit() &&
input(0)->dtype().size() == output(0)->dtype().size();
if ((cond_low_bit || cond_normal) && input(0)->layout().is_contiguous()) {
output(0)->set_fwd_in2out_writable(input(0)); output(0)->set_fwd_in2out_writable(input(0));
} }
} }


+ 3
- 4
src/opr/impl/io.cpp View File

@@ -120,12 +120,11 @@ public:
explicit DevValueExecDep(DeviceTensorStorage val) : m_val{std::move(val)} {} explicit DevValueExecDep(DeviceTensorStorage val) : m_val{std::move(val)} {}
}; };



void intl::DeviceTensorHolder::init_output_format() { void intl::DeviceTensorHolder::init_output_format() {
auto format = get_dev_tensor().format(); auto format = get_dev_tensor().format();
mgb_assert(format.is_default(), "non-default tensor format: %s",
format.to_string().c_str());
// no need to set output foramt since it is initialized as default
mgb_assert(format.is_default() || format.is_lowbit_aligned(),
"invalid tensor format: %s", format.to_string().c_str());
output(0)->format(format);
} }


void intl::DeviceTensorHolder::init_output_mem_plan(bool dynamic) { void intl::DeviceTensorHolder::init_output_mem_plan(bool dynamic) {


+ 10
- 2
src/opr/impl/search_policy/algo_chooser.cpp View File

@@ -638,10 +638,18 @@ AlgoChooser<Opr>::AlgoChooserHelper::profile_single_algo(
param.workspace = get_workspace_size_bytes(policy); param.workspace = get_workspace_size_bytes(policy);
for (int i = 0; i < arity; ++i) { for (int i = 0; i < arity; ++i) {
auto&& src = m_layouts[i]; auto&& src = m_layouts[i];
mgb_assert(src.format.is_default() &&
bool cond_normal = src.format.is_default() &&
(src.dtype.category() == DTypeCategory::FLOAT || (src.dtype.category() == DTypeCategory::FLOAT ||
src.dtype.category() == DTypeCategory::INT || src.dtype.category() == DTypeCategory::INT ||
src.dtype.category() == DTypeCategory::QUANTIZED),
src.dtype.category() == DTypeCategory::QUANTIZED);

bool cond_low_bit = src.dtype.is_low_bit() &&
src.format.is_lowbit_aligned() &&
(src.dtype.category() == DTypeCategory::QUANTIZED ||
src.dtype.category() == DTypeCategory::LOWBIT);
MGB_MARK_USED_VAR(cond_normal);
MGB_MARK_USED_VAR(cond_low_bit);
mgb_assert(cond_normal || cond_low_bit,
"unsupported layout in profiling: %s", "unsupported layout in profiling: %s",
src.to_string().c_str()); src.to_string().c_str());
param.dtypes[i] = src.dtype.enumv(); param.dtypes[i] = src.dtype.enumv();


+ 3
- 1
src/opr/impl/search_policy/profiler.cpp View File

@@ -175,15 +175,17 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
case DTypeTrait<_dt>::enumv: \ case DTypeTrait<_dt>::enumv: \
return _dt(1.0f, static_cast<uint8_t>(0)) return _dt(1.0f, static_cast<uint8_t>(0))
cb(dtype::Quantized8Asymm); cb(dtype::Quantized8Asymm);
cb(dtype::Quantized4Asymm);
#undef cb #undef cb


#define cb(_dt) \ #define cb(_dt) \
case DTypeTrait<_dt>::enumv: \ case DTypeTrait<_dt>::enumv: \
return _dt(1.0f) return _dt(1.0f)
cb(dtype::QuantizedS8); cb(dtype::QuantizedS8);
cb(dtype::QuantizedS16); cb(dtype::QuantizedS16);
cb(dtype::QuantizedS32); cb(dtype::QuantizedS32);
cb(dtype::QuantizedS4);
default: default:
return DType::from_enum(enumv); return DType::from_enum(enumv);
#undef cb #undef cb


+ 302
- 0
src/opr/test/dnn/convolution.cpp View File

@@ -2603,4 +2603,306 @@ TEST_F(TestNoWeightPreprocess, NoPreprocess) {


#endif #endif


namespace {
// FIXME change comp node from "cpu0" to "gpu0"
TEST(TestOprDNN, ConvBiasInt4NCHW) {
auto run = [](size_t N, size_t C, size_t H, size_t W, size_t F, size_t S,
size_t P) {
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();

HostTensorGenerator<dtype::Int8> gen;
auto mkvar = [&gen](const char* name, const TensorShape& shp,
const DType& dtype,
std::shared_ptr<ComputingGraph> graph,
const CompNode& cn) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp, cn))
.rename(name),
dtype);
};
auto mkcvar = [&gen](const char* name, const TensorShape& shp,
const DType& dtype,
std::shared_ptr<ComputingGraph> graph,
const CompNode& cn) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name),
dtype);
};

using Policy = opr::ConvBias::ExecutionPolicy;
using Strategy = Policy::Strategy;
auto x = mkvar("x", {N, C * 4, H, W}, dtype::QuantizedS4(1.19960327f),
graph, cn),
w = mkcvar("w1", {C, C * 4, F, F}, dtype::QuantizedS4(1.19970327f),
graph, cn),
b = mkcvar("b1", {1, C, 1, 1},
dtype::QuantizedS32(1.19960327f * 1.19970327f), graph,
cn);
opr::ConvBias::Param param;
param.format = opr::ConvBias::Param::Format::NCHW;
param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY;
param.stride_h = param.stride_w = S;
param.pad_h = param.pad_w = P;
Policy policy;
policy.strategy = Strategy::PROFILE;

auto y = opr::ConvBias::make(
x, w, b, param, policy,
OperatorNodeConfig{dtype::QuantizedS4(11.9960501f)});
y = opr::TypeCvt::make(y, dtype::Float32());
auto x_f32 = opr::TypeCvt::make(x, dtype::Float32()),
w_f32 = opr::TypeCvt::make(w, dtype::Float32()),
b_f32 = opr::TypeCvt::make(b, dtype::Float32());
auto y_f32 = opr::ConvBias::make(x_f32, w_f32, b_f32, param, policy);
auto y_q4 = opr::TypeCvt::make(y_f32, dtype::QuantizedS4{11.9960501f});
y_q4 = opr::TypeCvt::make(y_q4, dtype::Float32());
HostTensorND host_y, host_y_q4;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_q4, host_y_q4)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_q4, 1e-3);
};
run(2, 64, 14, 14, 3, 2, 1);
run(2, 64, 7, 7, 3, 1, 1);
run(2, 64, 14, 14, 1, 2, 0);
run(2, 64, 7, 7, 1, 1, 0);
}

TEST(TestOprDNN, ConvBiasInt4NCHW64) {
auto nchw2nchw64 = [](SymbolVar x) {
auto y = opr::RelayoutFormat::make(
x, opr::RelayoutFormat::Param::Mode::NCHW_NCHW64);
return y;
};

auto nchw642nchw = [](SymbolVar x) {
auto y = opr::RelayoutFormat::make(
x, opr::RelayoutFormat::Param::Mode::NCHW64_NCHW);
return y;
};

auto run = [&](size_t N, size_t C, size_t H, size_t W, size_t F, size_t S,
size_t P) {
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();

HostTensorGenerator<dtype::Int8> gen;
auto mkvar = [&gen](const char* name, const TensorShape& shp,
const DType& dtype,
std::shared_ptr<ComputingGraph> graph,
const CompNode& cn) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp, cn))
.rename(name),
dtype);
};
auto mkcvar = [&gen](const char* name, const TensorShape& shp,
const DType& dtype,
std::shared_ptr<ComputingGraph> graph,
const CompNode& cn) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name),
dtype);
};

using Policy = opr::ConvBias::ExecutionPolicy;
using Strategy = Policy::Strategy;
auto x = mkvar("x", {N, C / 16, H, W, 64},
dtype::QuantizedS4(1.19960327f), graph, cn),
w = mkcvar("w1", {C, C / 16, F, F, 64},
dtype::QuantizedS4(1.19970327f), graph, cn),
b = mkcvar("b1", {1, C / 64, 1, 1, 64},
dtype::QuantizedS32(1.19960327f * 1.19970327f), graph,
cn);
opr::ConvBias::Param param;
param.format = opr::ConvBias::Param::Format::NCHW64;
param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY;
param.stride_h = param.stride_w = S;
param.pad_h = param.pad_w = P;
Policy policy;
policy.strategy = Strategy::PROFILE;

auto y = opr::ConvBias::make(
x, w, b, param, policy,
OperatorNodeConfig{dtype::QuantizedS4(11.9960501f)});
y = opr::TypeCvt::make(y, dtype::Float32());
x = nchw642nchw(x);
w = nchw642nchw(w);
b = nchw642nchw(b);
auto x_f32 = opr::TypeCvt::make(x, dtype::Float32()),
w_f32 = opr::TypeCvt::make(w, dtype::Float32()),
b_f32 = opr::TypeCvt::make(b, dtype::Float32());
param.format = opr::ConvBias::Param::Format::NCHW;
auto y_f32 = opr::ConvBias::make(x_f32, w_f32, b_f32, param, policy);
auto y_q4 = opr::TypeCvt::make(y_f32, dtype::QuantizedS4{11.9960501f});
y_q4 = opr::TypeCvt::make(y_q4, dtype::Float32());
y_q4 = nchw2nchw64(y_q4);
HostTensorND host_y, host_y_q4;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_q4, host_y_q4)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_q4, 1e-3);
};
run(2, 64, 14, 14, 3, 2, 1);
run(2, 64, 7, 7, 3, 1, 1);
run(2, 64, 14, 14, 1, 2, 0);
run(2, 64, 7, 7, 1, 1, 0);
}

TEST(TestOprDNN, ConvBiasInt4Serialize) {
using namespace serialization;

float inp_scale = 1.20210327f;
float filt_scale = 1.20210406f;
float bias_scale = inp_scale * filt_scale;
DType output_dtype = dtype::QuantizedS4{inp_scale};

HostTensorGenerator<dtype::Int8> gen;
std::shared_ptr<HostTensorND> xv;
auto mkvar = [&gen](const char* name, const DType& dtype,
std::shared_ptr<ComputingGraph> graph,
std::shared_ptr<HostTensorND> val) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, val).rename(name), dtype);
};
auto mkcvar =
[&gen](const char* name, const TensorShape& shp, const DType& dtype,
std::shared_ptr<ComputingGraph> graph, const CompNode& cn) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name),
dtype);
};

auto fname = output_file("ConvBiasInt4Serialize");
HostTensorND y1, y2;
auto dump = [&]() {
opr::ConvBias::Param param;
param.mode = Mode::CONVOLUTION;

auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
xv = gen({1, 64, 56, 56}, cn);
auto x = mkvar("x", dtype::QuantizedS4{inp_scale}, graph, xv);
auto w = mkcvar("w", {256, 64, 1, 1}, dtype::QuantizedS4{filt_scale}, graph, cn);
auto b = mkcvar("b", {1, 256, 1, 1}, dtype::QuantizedS32{bias_scale}, graph, cn);
auto y = opr::ConvBiasForward::make(x, w, b, param, {},
OperatorNodeConfig{output_dtype});
auto w1 = mkcvar("w1", {64, 256, 1, 1}, dtype::QuantizedS4{filt_scale},
graph, cn);
auto b1 = mkcvar("b1", {1, 64, 1, 1}, dtype::QuantizedS32{bias_scale},
graph, cn);
y = opr::ConvBiasForward::make(y, w1, b1, param, {},
OperatorNodeConfig{output_dtype});
y = opr::TypeCvt::make(y, dtype::Float32());
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()));
auto func = graph->compile({make_callback_copy(y, y1)});
func->execute();
func->wait();
auto rst = dumper->dump({y});
ASSERT_EQ(rst.outputs.size(), 1u);
};

auto load = [&]() {
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()));
auto rst = loader->load();
for (const auto& t : rst.tensor_map) {
t.second->copy_from(*xv).sync();
}
auto func = rst.graph->compile(
{make_callback_copy(rst.output_var_list[0], y2)});
func->execute();
func->wait();
ASSERT_EQ(rst.output_var_list.size(), 1u);
EXPECT_EQ(rst.output_var_list[0].dtype(), dtype::Float32());
};

dump();
load();
MGB_ASSERT_TENSOR_NEAR(y1, y2, 1e-3);
}

TEST(TestOprDNN, ConvBiasInt4SerializeWithParamFuse) {
using namespace serialization;

float inp_scale = 1.20210327f;
float filt_scale = 1.20210406f;
float bias_scale = inp_scale * filt_scale;
DType output_dtype = dtype::QuantizedS4{inp_scale};

HostTensorGenerator<dtype::Int8> gen;
std::shared_ptr<HostTensorND> xv;
auto mkvar = [&gen](const char* name, const DType& dtype,
std::shared_ptr<ComputingGraph> graph,
std::shared_ptr<HostTensorND> val) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, val).rename(name), dtype);
};
auto mkcvar =
[&gen](const char* name, const TensorShape& shp, const DType& dtype,
std::shared_ptr<ComputingGraph> graph, const CompNode& cn) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name),
dtype);
};

auto fname = output_file("ConvBiasInt4SerializeWithParamFuse");
HostTensorND y1, y2;
auto dump = [&]() {
opr::ConvBias::Param param;
param.mode = Mode::CONVOLUTION;

auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
xv = gen({1, 64, 56, 56}, cn);
auto x = mkvar("x", dtype::QuantizedS4{inp_scale}, graph, xv);
auto w = mkcvar("w", {256, 64, 1, 1}, dtype::QuantizedS4{filt_scale}, graph, cn);
auto b = mkcvar("b", {1, 256, 1, 1}, dtype::QuantizedS32{bias_scale}, graph, cn);
auto y = opr::ConvBiasForward::make(x, w, b, param, {},
OperatorNodeConfig{output_dtype});
auto w1 = mkcvar("w1", {64, 256, 1, 1}, dtype::QuantizedS4{filt_scale},
graph, cn);
auto b1 = mkcvar("b1", {1, 64, 1, 1}, dtype::QuantizedS32{bias_scale},
graph, cn);
y = opr::ConvBiasForward::make(y, w1, b1, param, {},
OperatorNodeConfig{output_dtype});
y = opr::TypeCvt::make(y, dtype::Float32());
SymbolVar y_param_fused;
unpack_vector(gopt::GraphOptimizer{}
.add_pass<gopt::ParamFusePass>()
.apply({{y}})
.endpoint_vars(),
y_param_fused);
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()));
auto func = graph->compile({make_callback_copy(y_param_fused, y1)});
func->execute();
func->wait();
auto rst = dumper->dump({y_param_fused});
ASSERT_EQ(rst.outputs.size(), 1u);
};

auto load = [&]() {
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()));
auto rst = loader->load();
for (const auto& t : rst.tensor_map) {
t.second->copy_from(*xv).sync();
}
auto func = rst.graph->compile(
{make_callback_copy(rst.output_var_list[0], y2)});
func->execute();
func->wait();
ASSERT_EQ(rst.output_var_list.size(), 1u);
EXPECT_EQ(rst.output_var_list[0].dtype(), dtype::Float32());
};

dump();
load();
MGB_ASSERT_TENSOR_NEAR(y1, y2, 1e-3);
}
} // namespace

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save