@@ -38,6 +38,17 @@ class Handle { | |||
CAMBRICON = 12, | |||
}; | |||
//! Device vendor | |||
enum class HandleVendorType : uint32_t { | |||
NOT_SPEC = 0, | |||
MALI = 1, | |||
ADRENO = 2, | |||
CUDA = 3, | |||
INTEL = 4, | |||
POWERVR = 5, | |||
AMD = 6, | |||
}; | |||
protected: | |||
Handle(megcoreComputingHandle_t computing_handle, HandleType type); | |||
@@ -130,6 +141,9 @@ class Handle { | |||
//! get alignment in bytes for rows of image 2D tensor format | |||
virtual size_t image2d_pitch_alignment() const; | |||
//! get vendor type | |||
virtual HandleVendorType vendor_type() const; | |||
HandleType type() const { | |||
return m_handle_type; | |||
} | |||
@@ -12,6 +12,7 @@ | |||
#pragma once | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/handle.h" | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megdnn { | |||
@@ -43,7 +44,7 @@ public: | |||
protected: | |||
ImplBase(Type type) : m_type{type} {} | |||
~ImplBase() = default; | |||
virtual ~ImplBase() = default; | |||
static TensorFormat impl_to_tensor_format(ImplBase* impl) { return {impl}; } | |||
@@ -93,8 +94,8 @@ namespace detail { | |||
* | |||
* \p align_axis is the axis to be aligned, also the first axis of image width. | |||
* More precisely speaking, `stride[align_axis-1] * dtype.size()` must divide \p | |||
* align_size_in_byte. Axes from 0 to align_axis-1 would be considered as the | |||
* height of the image, and other axes are the width. | |||
* align_size_in_elements. Axes from 0 to align_axis-1 would be considered as | |||
* the height of the image, and other axes are the width. | |||
* | |||
* Empty tensors and negative strides are not allowed. Only contiguous or | |||
* broadcasted cases are allowed. | |||
@@ -103,41 +104,32 @@ namespace detail { | |||
* considered as contiguous. | |||
*/ | |||
class Image2DTensorFormatBase : public TensorFormat::ImplBase { | |||
size_t m_align_axis, m_align_size_in_byte_log2; | |||
size_t m_align_axis, m_align_size_in_elements_log2; | |||
protected: | |||
Image2DTensorFormatBase(Type type, size_t align_axis, | |||
size_t align_size_in_byte); | |||
~Image2DTensorFormatBase() = default; | |||
size_t align_size_in_elements); | |||
virtual ~Image2DTensorFormatBase() = default; | |||
public: | |||
/*! | |||
* \brief get alignment requirement in bytes | |||
* \brief get alignment requirement in elements | |||
* \param div_log2 the result would be divided by `(1 << div_log2)` | |||
*/ | |||
size_t align_size_in_byte(size_t div_log2 = 0) const { | |||
return 1 << (m_align_size_in_byte_log2 > div_log2 | |||
? m_align_size_in_byte_log2 - div_log2 | |||
size_t align_size_in_elements(size_t div_log2 = 0) const { | |||
return 1 << (m_align_size_in_elements_log2 > div_log2 | |||
? m_align_size_in_elements_log2 - div_log2 | |||
: 0); | |||
} | |||
size_t align_axis() const { return m_align_axis; } | |||
size_t init_contiguous_stride(TensorLayout& layout) const override; | |||
bool is_contiguous_spec(const TensorLayout& layout) const override; | |||
TensorLayout collapse_contiguous_spec( | |||
const TensorLayout& layout) const override; | |||
//! span for image must include the padding at the last row | |||
TensorLayout::Span span_spec(const TensorLayout& layout) const override; | |||
size_t align_size_in_elements_log2() const { | |||
return m_align_size_in_elements_log2; | |||
} | |||
std::string to_string() const override; | |||
//! raise exception if preconditions violated | |||
virtual void assert_valid(const TensorLayout& layout) const; | |||
//! modify the align axis and return a new TensorFormat | |||
virtual TensorFormat change_axis(size_t axis) const = 0; | |||
@@ -147,9 +139,6 @@ public: | |||
//! number of rows | |||
size_t image_height(const TensorLayout& layout) const; | |||
//! delta of addresses of consecutive rows (in bytes) | |||
size_t image_row_pitch(const TensorLayout& layout) const; | |||
void serialize_append(std::string& result) const override; | |||
protected: | |||
struct SerializePack { | |||
@@ -159,9 +148,27 @@ protected: | |||
template <size_t PIXEL_SIZE> | |||
class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase { | |||
Handle::HandleVendorType m_vendor_type = Handle::HandleVendorType::NOT_SPEC; | |||
/*! | |||
* \brief get fix alignment requirement in bytes, consider m_vendor_type, | |||
* for example on MALI, CL_DEVICE_IMAGE_PITCH_ALIGNMENT means image_width | |||
* align COUNT, but mdl needs align size in byte, which equal to | |||
* (image_width algin count) * sizeof(data_type) * pixel_size | |||
*/ | |||
size_t image_pitch_alignment_in_bytes(size_t align_size_in_elements, | |||
const TensorLayout& layout) const; | |||
protected: | |||
using Image2DTensorFormatBase::Image2DTensorFormatBase; | |||
~Image2DPackedTensorFormatBase() = default; | |||
Image2DPackedTensorFormatBase(Type type, size_t align_axis, | |||
size_t align_size_in_elements, | |||
Handle::HandleVendorType vendor_type) | |||
: detail::Image2DTensorFormatBase(type, align_axis, | |||
align_size_in_elements), | |||
m_vendor_type(vendor_type) {} | |||
virtual ~Image2DPackedTensorFormatBase() = default; | |||
Handle::HandleVendorType vendor() const { return m_vendor_type; } | |||
public: | |||
/*! | |||
@@ -173,7 +180,20 @@ public: | |||
*/ | |||
size_t image_width(const TensorLayout& layout) const; | |||
void assert_valid(const TensorLayout& layout) const override; | |||
//! raise exception if preconditions violated | |||
void assert_valid(const TensorLayout& layout) const; | |||
size_t image_row_pitch(const TensorLayout& layout) const; | |||
//! span for image must include the padding at the last row | |||
TensorLayout::Span span_spec(const TensorLayout& layout) const override; | |||
size_t init_contiguous_stride(TensorLayout& layout) const override; | |||
bool is_contiguous_spec(const TensorLayout& layout) const override; | |||
TensorLayout collapse_contiguous_spec( | |||
const TensorLayout& layout) const override; | |||
}; | |||
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | |||
} // namespace detail | |||
@@ -190,7 +210,10 @@ public: | |||
static constexpr Type TYPE = Type::IMAGE2D_PACK4; | |||
//! for internal usage or test purposes | |||
static TensorFormat make_raw(size_t align_axis, size_t align_size_in_byte); | |||
static TensorFormat make_raw(size_t align_axis, | |||
size_t align_size_in_elements, | |||
Handle::HandleVendorType vendor_type = | |||
Handle::HandleVendorType::NOT_SPEC); | |||
static TensorFormat make(size_t align_axis, const Handle* handle); | |||
@@ -215,9 +238,10 @@ public: | |||
TensorFormat change_axis(size_t axis) const override; | |||
private: | |||
Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_byte) | |||
: detail::Image2DPack4TensorFormatBase(TYPE, align_axis, | |||
align_size_in_byte) {} | |||
Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_elements, | |||
Handle::HandleVendorType vendor_type) | |||
: detail::Image2DPack4TensorFormatBase( | |||
TYPE, align_axis, align_size_in_elements, vendor_type) {} | |||
}; | |||
} // namespace megdnn | |||
@@ -147,6 +147,10 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle, | |||
megdnn_throw("image2d tensor format not supported on this handle"); | |||
} | |||
megdnn::HandleImplHelper::HandleVendorType Handle::vendor_type() const { | |||
return HandleVendorType::NOT_SPEC; | |||
} | |||
bool Handle::check_cross_dev_copy_constraint(const TensorLayout& src) { | |||
return src.is_contiguous(); | |||
} | |||
@@ -236,6 +236,7 @@ void RelayoutFormat::deduce_layout(const TensorLayout& src, TensorLayout& dst) { | |||
void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||
size_t align = handle()->image2d_pitch_alignment(); | |||
auto vendor_type = handle()->vendor_type(); | |||
using Param = param::RelayoutFormat; | |||
#define CHECK_SRC(_expect) \ | |||
megdnn_assert(src == _expect, "invalid src format: expect=%s got=%s", \ | |||
@@ -251,7 +252,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||
break; | |||
case Param::Mode::NHWC_NHWCD4I: | |||
CHECK_SRC(DefaultTensorFormat::make()); | |||
dst = Image2DPack4TensorFormat::make_raw(2, align); | |||
dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type); | |||
break; | |||
case Param::Mode::NCHW_NHWCD4: | |||
CHECK_SRC(DefaultTensorFormat::make()); | |||
@@ -263,10 +264,10 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||
break; | |||
case Param::Mode::NCHW_NHWCD4I: | |||
CHECK_SRC(DefaultTensorFormat::make()); | |||
dst = Image2DPack4TensorFormat::make_raw(2, align); | |||
dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type); | |||
break; | |||
case Param::Mode::NHWCD4I_NCHW: | |||
CHECK_SRC(Image2DPack4TensorFormat::make_raw(2, align)); | |||
CHECK_SRC(Image2DPack4TensorFormat::make_raw(2, align, vendor_type)); | |||
dst = DefaultTensorFormat::make(); | |||
break; | |||
case Param::Mode::NHWCD4_NCHW: | |||
@@ -280,7 +281,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||
case Param::Mode::INTER_WEIGHT_DENSEI: | |||
case Param::Mode::INTER_WEIGHT_DENSEI_DOT: | |||
CHECK_SRC(DefaultTensorFormat::make()); | |||
dst = Image2DPack4TensorFormat::make_raw(3, align); | |||
dst = Image2DPack4TensorFormat::make_raw(3, align, vendor_type); | |||
break; | |||
case Param::Mode::INTER_WEIGHT_GROUP: | |||
CHECK_SRC(DefaultTensorFormat::make()); | |||
@@ -289,7 +290,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||
case Param::Mode::INTER_WEIGHT_GROUPI: | |||
case Param::Mode::INTER_WEIGHT_GROUPI_DOT: | |||
CHECK_SRC(DefaultTensorFormat::make()); | |||
dst = Image2DPack4TensorFormat::make_raw(4, align); | |||
dst = Image2DPack4TensorFormat::make_raw(4, align, vendor_type); | |||
break; | |||
case Param::Mode::INTER_WEIGHT_CHAN: | |||
CHECK_SRC(DefaultTensorFormat::make()); | |||
@@ -297,7 +298,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||
break; | |||
case Param::Mode::INTER_WEIGHT_CHANI: | |||
CHECK_SRC(DefaultTensorFormat::make()); | |||
dst = Image2DPack4TensorFormat::make_raw(1, align); | |||
dst = Image2DPack4TensorFormat::make_raw(1, align, vendor_type); | |||
break; | |||
case Param::Mode::NCHW4_CHWN4: | |||
CHECK_SRC(DefaultTensorFormat::make()); | |||
@@ -185,23 +185,134 @@ TensorFormat DefaultTensorFormat::make() { | |||
/* ===================== Image2DTensorFormatBase ===================== */ | |||
Image2DTensorFormatBase::Image2DTensorFormatBase(Type type, size_t align_axis, | |||
size_t align_size_in_byte) | |||
: ImplBase(type) { | |||
megdnn_assert(align_size_in_byte && align_axis); | |||
m_align_axis = align_axis; | |||
m_align_size_in_byte_log2 = __builtin_ctz(align_size_in_byte); | |||
megdnn_assert((1u << m_align_size_in_byte_log2) == align_size_in_byte, | |||
"align size not power of 2: %zu", align_size_in_byte); | |||
size_t align_size_in_elements) | |||
: ImplBase(type), m_align_axis(align_axis) { | |||
megdnn_assert(align_size_in_elements && align_axis); | |||
m_align_size_in_elements_log2 = __builtin_ctz(align_size_in_elements); | |||
megdnn_assert( | |||
(1u << m_align_size_in_elements_log2) == align_size_in_elements, | |||
"align size not power of 2: %zu", align_size_in_elements); | |||
} | |||
size_t Image2DTensorFormatBase::init_contiguous_stride( | |||
void Image2DTensorFormatBase::serialize_append(std::string& result) const { | |||
SerializePack pack; | |||
pack.align_axis = m_align_axis; | |||
megdnn_assert(pack.align_axis == m_align_axis); // detect overflow | |||
result.append(reinterpret_cast<char*>(&pack), sizeof(pack)); | |||
} | |||
size_t Image2DTensorFormatBase::image_height(const TensorLayout& layout) const { | |||
size_t accum = 1; | |||
for (int i = m_align_axis - 1; i >= 0; --i) { | |||
if (layout.stride[i] == 0) { | |||
// this dimension is broadcasted | |||
} else { | |||
accum *= layout.shape[i]; | |||
} | |||
} | |||
return accum; | |||
} | |||
size_t Image2DTensorFormatBase::image_width_elems( | |||
const TensorLayout& layout) const { | |||
size_t high_elem = 0; | |||
for (size_t i = m_align_axis; i < layout.ndim; ++i) { | |||
high_elem += (layout.shape[i] - 1) * layout.stride[i]; | |||
} | |||
return high_elem + 1; | |||
} | |||
std::string Image2DTensorFormatBase::to_string() const { | |||
return ssprintf("I2D{%zu,%d}", m_align_axis, | |||
1 << m_align_size_in_elements_log2); | |||
} | |||
/* ===================== Image2DPackedTensorFormatBase ===================== */ | |||
template <size_t PIXEL_SIZE> | |||
size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_width( | |||
const TensorLayout& layout) const { | |||
auto ret = image_width_elems(layout); | |||
megdnn_assert(ret % PIXEL_SIZE == 0); | |||
return ret / PIXEL_SIZE; | |||
} | |||
template <size_t PIXEL_SIZE> | |||
void Image2DPackedTensorFormatBase<PIXEL_SIZE>::assert_valid( | |||
const TensorLayout& layout) const { | |||
auto m_align_axis = align_axis(); | |||
megdnn_assert(!(layout.shape[layout.ndim - 1] % PIXEL_SIZE), | |||
"bad shape: %zu", layout.shape[layout.ndim - 1]); | |||
megdnn_assert(layout.dtype.valid() && layout.ndim > m_align_axis); | |||
ptrdiff_t first_non_zero_stride = 0; | |||
for (int i = layout.ndim - 1; i >= 0; --i) { | |||
megdnn_assert(layout.shape[i] && layout.stride[i] >= 0); | |||
if (i < static_cast<int>(m_align_axis) && !first_non_zero_stride) { | |||
first_non_zero_stride = layout.stride[i]; | |||
} | |||
} | |||
size_t mask = | |||
image_pitch_alignment_in_bytes( | |||
align_size_in_elements(layout.dtype.size_log()), layout) - | |||
1; | |||
megdnn_assert(!(first_non_zero_stride & mask), | |||
"first stride is %d, but alignment is %zu", | |||
static_cast<int>(first_non_zero_stride), mask + 1); | |||
} | |||
template <size_t PIXEL_SIZE> | |||
size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_row_pitch( | |||
const TensorLayout& layout) const { | |||
for (int i = align_axis() - 1; i >= 0; --i) { | |||
// find a non-broadcast axis | |||
if (auto s = layout.stride[i]) { | |||
return layout.dtype.size(s); | |||
} | |||
} | |||
// use width for all broadcasted case | |||
size_t alignment_in_bytes_log2 = align_size_in_elements_log2(); | |||
if (m_vendor_type == Handle::HandleVendorType::MALI) { | |||
alignment_in_bytes_log2 += | |||
__builtin_ctz(layout.dtype.size() * PIXEL_SIZE); | |||
} | |||
return get_aligned_power2<size_t>( | |||
layout.dtype.size(image_width_elems(layout)), | |||
1 << alignment_in_bytes_log2); | |||
} | |||
template <size_t PIXEL_SIZE> | |||
size_t | |||
Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_pitch_alignment_in_bytes( | |||
size_t align_size_in_elements, const TensorLayout& layout) const { | |||
return m_vendor_type == Handle::HandleVendorType::MALI | |||
? (align_size_in_elements * layout.dtype.size() * PIXEL_SIZE) | |||
: align_size_in_elements; | |||
} | |||
template <size_t PIXEL_SIZE> | |||
TensorLayout::Span Image2DPackedTensorFormatBase<PIXEL_SIZE>::span_spec( | |||
const TensorLayout& layout) const { | |||
assert_valid(layout); | |||
size_t size = image_height(layout) * image_row_pitch(layout); | |||
auto mask = (1 << layout.dtype.size_log()) - 1; | |||
megdnn_assert(!(size & mask), "unaligned size: %zu", size); | |||
return {0, 0, size >> layout.dtype.size_log(), size}; | |||
} | |||
template <size_t PIXEL_SIZE> | |||
size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::init_contiguous_stride( | |||
TensorLayout& layout) const { | |||
auto m_align_axis = align_axis(); | |||
if (!layout.ndim) | |||
return 0; | |||
megdnn_assert(layout.dtype.valid() && layout.ndim > m_align_axis, | |||
"dtype=%s ndim=%zu align=%zu", layout.dtype.name(), | |||
layout.ndim, m_align_axis); | |||
size_t align_size = align_size_in_byte(layout.dtype.size_log()); | |||
size_t align_size = image_pitch_alignment_in_bytes( | |||
align_size_in_elements(layout.dtype.size_log()), layout); | |||
size_t accum = 1; | |||
SafeMultiplies<size_t> mul; | |||
for (size_t i = layout.ndim; i; --i) { | |||
@@ -216,12 +327,15 @@ size_t Image2DTensorFormatBase::init_contiguous_stride( | |||
return accum; | |||
}; | |||
bool Image2DTensorFormatBase::is_contiguous_spec( | |||
template <size_t PIXEL_SIZE> | |||
bool Image2DPackedTensorFormatBase<PIXEL_SIZE>::is_contiguous_spec( | |||
const TensorLayout& layout) const { | |||
megdnn_assert(layout.dtype.valid()); | |||
size_t align_size = align_size_in_byte(layout.dtype.size_log()); | |||
size_t align_size = image_pitch_alignment_in_bytes( | |||
align_size_in_elements(layout.dtype.size_log()), layout); | |||
ptrdiff_t expected = 1; | |||
int height_axis = static_cast<int>(m_align_axis - 1); | |||
int height_axis = static_cast<int>(align_axis() - 1); | |||
for (int i = layout.ndim - 1; i >= 0; --i) { | |||
if (i == height_axis) { | |||
expected = megdnn::get_aligned_power2<size_t>(expected, align_size); | |||
@@ -235,7 +349,12 @@ bool Image2DTensorFormatBase::is_contiguous_spec( | |||
return false; | |||
} | |||
size_t mask = align_size_in_byte(layout.dtype.size_log()) - 1; | |||
size_t mask = | |||
image_pitch_alignment_in_bytes( | |||
align_size_in_elements(layout.dtype.size_log()), | |||
layout) - | |||
1; | |||
megdnn_assert(s > expected && !(s & mask), | |||
"invalid row pitch: %d; layout: %s", | |||
static_cast<int>(s), layout.to_string().c_str()); | |||
@@ -250,11 +369,12 @@ bool Image2DTensorFormatBase::is_contiguous_spec( | |||
return expected != 0; | |||
} | |||
TensorLayout Image2DTensorFormatBase::collapse_contiguous_spec( | |||
template <size_t PIXEL_SIZE> | |||
TensorLayout Image2DPackedTensorFormatBase<PIXEL_SIZE>::collapse_contiguous_spec( | |||
const TensorLayout& layout) const { | |||
assert_valid(layout); | |||
TensorLayout res{layout}; | |||
int new_axis = m_align_axis; | |||
int new_axis = align_axis(); | |||
// remove all dims with shape 1 | |||
for (int i = static_cast<int>(res.ndim) - 1; i >= 0 && res.ndim >= 3; --i) { | |||
if (i == new_axis && static_cast<int>(res.ndim) == new_axis + 1) { | |||
@@ -302,95 +422,6 @@ TensorLayout Image2DTensorFormatBase::collapse_contiguous_spec( | |||
return res; | |||
} | |||
TensorLayout::Span Image2DTensorFormatBase::span_spec( | |||
const TensorLayout& layout) const { | |||
assert_valid(layout); | |||
size_t size = image_height(layout) * image_row_pitch(layout); | |||
auto mask = (1 << layout.dtype.size_log()) - 1; | |||
megdnn_assert(!(size & mask), "unaligned size: %zu", size); | |||
return {0, 0, size >> layout.dtype.size_log(), size}; | |||
} | |||
void Image2DTensorFormatBase::serialize_append(std::string& result) const { | |||
SerializePack pack; | |||
pack.align_axis = m_align_axis; | |||
megdnn_assert(pack.align_axis == m_align_axis); // detect overflow | |||
result.append(reinterpret_cast<char*>(&pack), sizeof(pack)); | |||
} | |||
size_t Image2DTensorFormatBase::image_height(const TensorLayout& layout) const { | |||
size_t accum = 1; | |||
for (int i = m_align_axis - 1; i >= 0; --i) { | |||
if (layout.stride[i] == 0) { | |||
// this dimension is broadcasted | |||
} else { | |||
accum *= layout.shape[i]; | |||
} | |||
} | |||
return accum; | |||
} | |||
size_t Image2DTensorFormatBase::image_row_pitch( | |||
const TensorLayout& layout) const { | |||
for (int i = m_align_axis - 1; i >= 0; --i) { | |||
// find a non-broadcast axis | |||
if (auto s = layout.stride[i]) { | |||
return layout.dtype.size(s); | |||
} | |||
} | |||
// use width for all broadcasted case | |||
return get_aligned_power2<size_t>( | |||
layout.dtype.size(image_width_elems(layout)), | |||
1 << m_align_size_in_byte_log2); | |||
} | |||
void Image2DTensorFormatBase::assert_valid(const TensorLayout& layout) const { | |||
megdnn_assert(layout.dtype.valid() && layout.ndim > m_align_axis); | |||
ptrdiff_t first_non_zero_stride = 0; | |||
for (int i = layout.ndim - 1; i >= 0; --i) { | |||
megdnn_assert(layout.shape[i] && layout.stride[i] >= 0); | |||
if (i < static_cast<int>(m_align_axis) && !first_non_zero_stride) { | |||
first_non_zero_stride = layout.stride[i]; | |||
} | |||
} | |||
size_t mask = align_size_in_byte(layout.dtype.size_log()) - 1; | |||
megdnn_assert(!(first_non_zero_stride & mask), | |||
"first stride is %d, but alignment is %zu", | |||
static_cast<int>(first_non_zero_stride), mask + 1); | |||
} | |||
size_t Image2DTensorFormatBase::image_width_elems( | |||
const TensorLayout& layout) const { | |||
size_t high_elem = 0; | |||
for (size_t i = m_align_axis; i < layout.ndim; ++i) { | |||
high_elem += (layout.shape[i] - 1) * layout.stride[i]; | |||
} | |||
return high_elem + 1; | |||
} | |||
std::string Image2DTensorFormatBase::to_string() const { | |||
return ssprintf("I2D{%zu,%d}", m_align_axis, | |||
1 << m_align_size_in_byte_log2); | |||
} | |||
/* ===================== Image2DPackedTensorFormatBase ===================== */ | |||
template <size_t PIXEL_SIZE> | |||
size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_width( | |||
const TensorLayout& layout) const { | |||
auto ret = image_width_elems(layout); | |||
megdnn_assert(ret % PIXEL_SIZE == 0); | |||
return ret / PIXEL_SIZE; | |||
} | |||
template <size_t PIXEL_SIZE> | |||
void Image2DPackedTensorFormatBase<PIXEL_SIZE>::assert_valid( | |||
const TensorLayout& layout) const { | |||
Image2DTensorFormatBase::assert_valid(layout); | |||
megdnn_assert(!(layout.shape[layout.ndim - 1] % PIXEL_SIZE), | |||
"bad shape: %zu", layout.shape[layout.ndim - 1]); | |||
} | |||
namespace megdnn { | |||
namespace detail { | |||
template class Image2DPackedTensorFormatBase<4>; | |||
@@ -398,26 +429,29 @@ template class Image2DPackedTensorFormatBase<4>; | |||
} // namespace megdnn | |||
/* ===================== Image2DPack4TensorFormat ===================== */ | |||
TensorFormat Image2DPack4TensorFormat::make_raw(size_t align_axis, | |||
size_t align_size_in_byte) { | |||
TensorFormat Image2DPack4TensorFormat::make_raw( | |||
size_t align_axis, size_t align_size_in_elements, | |||
Handle::HandleVendorType vendor_type) { | |||
static std::mutex mtx; | |||
static std::unordered_map<uint64_t, | |||
std::unique_ptr<Image2DPack4TensorFormat>> | |||
cache; | |||
megdnn_assert(std::max(align_axis, align_size_in_byte) <= | |||
megdnn_assert(std::max(align_axis, align_size_in_elements) <= | |||
std::numeric_limits<uint32_t>::max()); | |||
MEGDNN_LOCK_GUARD(mtx); | |||
auto&& ptr = cache[(static_cast<uint64_t>(align_axis) << 32) | | |||
align_size_in_byte]; | |||
align_size_in_elements]; | |||
if (!ptr) { | |||
ptr.reset(new Image2DPack4TensorFormat{align_axis, align_size_in_byte}); | |||
ptr.reset(new Image2DPack4TensorFormat{ | |||
align_axis, align_size_in_elements, vendor_type}); | |||
} | |||
return impl_to_tensor_format(ptr.get()); | |||
} | |||
TensorFormat Image2DPack4TensorFormat::make(size_t align_axis, | |||
const Handle* handle) { | |||
return make_raw(align_axis, handle->image2d_pitch_alignment()); | |||
return make_raw(align_axis, handle->image2d_pitch_alignment(), | |||
handle->vendor_type()); | |||
} | |||
TensorFormat Image2DPack4TensorFormat::deserialize(const Handle* handle, | |||
@@ -429,7 +463,7 @@ TensorFormat Image2DPack4TensorFormat::deserialize(const Handle* handle, | |||
} | |||
TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const { | |||
return make_raw(axis, align_size_in_byte()); | |||
return make_raw(axis, align_size_in_elements(), vendor()); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -123,6 +123,10 @@ size_t HandleImpl::image2d_pitch_alignment() const { | |||
return align; | |||
} | |||
HandleImpl::HandleVendorType HandleImpl::vendor_type() const { | |||
return HandleVendorType::CUDA; | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn | |||
@@ -123,6 +123,7 @@ class HandleImpl: public HandleImplHelper { | |||
TypeCvt* typecvt_opr() { return get_helper_opr<TypeCvt, 0>(this); } | |||
size_t image2d_pitch_alignment() const override; | |||
HandleVendorType vendor_type() const override; | |||
private: | |||
bool m_is_tegra_k1; | |||
int m_device_id; | |||
@@ -118,6 +118,10 @@ size_t HandleImpl::image2d_pitch_alignment() const { | |||
return g_image2d_pitch_alignment; | |||
} | |||
HandleImpl::HandleVendorType HandleImpl::vendor_type() const { | |||
return HandleVendorType::NOT_SPEC; | |||
} | |||
size_t HandleImpl::exchange_image2d_pitch_alignment(size_t alignment) { | |||
auto ret = g_image2d_pitch_alignment; | |||
g_image2d_pitch_alignment = alignment; | |||
@@ -169,6 +169,7 @@ public: | |||
* \param alignment the new alignment value to set | |||
*/ | |||
static size_t exchange_image2d_pitch_alignment(size_t alignment); | |||
HandleVendorType vendor_type() const override; | |||
}; | |||
} // namespace naive | |||
@@ -175,6 +175,30 @@ namespace { | |||
} | |||
} | |||
TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_WITH_VENDOR_MALI) { | |||
TensorFormat fmt = Image2DPack4TensorFormat::make_raw( | |||
1, 512, Handle::HandleVendorType::MALI); | |||
TensorLayout layout{{5, 3, 8}, dtype::Float32{}, fmt}; | |||
ASSERT_EQ(layout.stride[2], 1); | |||
ASSERT_EQ(layout.stride[1], 8); | |||
ASSERT_EQ(layout.stride[0], 2048); | |||
ASSERT_EQ(8192u, image_row_pitch(layout)); | |||
ASSERT_EQ(6u, image_width(layout)); | |||
ASSERT_EQ(5u, image_height(layout)); | |||
fmt = Image2DPack4TensorFormat::make_raw(1, 512, | |||
Handle::HandleVendorType::MALI); | |||
TensorLayout layout_s{{5, 3, 8}, dtype::Float16{}, fmt}; | |||
ASSERT_EQ(layout_s.stride[2], 1); | |||
ASSERT_EQ(layout_s.stride[1], 8); | |||
ASSERT_EQ(layout_s.stride[0], 2048); | |||
ASSERT_EQ(4096u, image_row_pitch(layout_s)); | |||
ASSERT_EQ(6u, image_width(layout_s)); | |||
ASSERT_EQ(5u, image_height(layout_s)); | |||
} | |||
TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT) { | |||
TensorFormat fmt = Image2DPack4TensorFormat::make_raw(1, 1024); | |||
ASSERT_FALSE(fmt.is_default()); | |||
@@ -233,7 +257,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT) { | |||
auto&& impl = contig.format.as_impl<Image2DPack4TensorFormat>(); | |||
ASSERT_EQ(make_layout({1, 8}, {32, 1}, layout.dtype), contig); | |||
ASSERT_EQ(1u, impl.align_axis()); | |||
ASSERT_EQ(64u, impl.align_size_in_byte()); | |||
ASSERT_EQ(64u, impl.align_size_in_elements()); | |||
} | |||
} | |||
@@ -258,7 +282,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_H) { | |||
auto&& impl = contig.format.as_impl<Image2DPack4TensorFormat>(); | |||
ASSERT_EQ(make_layout({v0, 8}, {32, 1}, layout.dtype), contig); | |||
ASSERT_EQ(1u, impl.align_axis()); | |||
ASSERT_EQ(64u, impl.align_size_in_byte()); | |||
ASSERT_EQ(64u, impl.align_size_in_elements()); | |||
} | |||
} | |||
@@ -274,7 +298,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) { | |||
layout.dtype), | |||
contig); | |||
ASSERT_EQ(1u, impl.align_axis()); | |||
ASSERT_EQ(64u, impl.align_size_in_byte()); | |||
ASSERT_EQ(64u, impl.align_size_in_elements()); | |||
} | |||
} | |||