@@ -156,15 +156,42 @@ void megdnn::infer_conv_shape2d(size_t ih, size_t iw, size_t fh, size_t fw, | |||
WorkspaceBundle::WorkspaceBundle(void* ptr, SmallVector<size_t> sizes_in_bytes, | |||
size_t align_in_bytes) | |||
: m_ptr(ptr), | |||
m_sizes(std::move(sizes_in_bytes)), | |||
m_align_in_bytes(align_in_bytes) { | |||
m_aligned_sizes.reserve(m_sizes.size()); | |||
for (auto size : m_sizes) { | |||
m_sizes.push_back(sizes_in_bytes); | |||
size_t reduce_size = 0_z; | |||
m_reduce_num.push_back(0_z); | |||
for (auto size : m_sizes[0]) { | |||
auto aligned_size = size; | |||
if (size % m_align_in_bytes != 0) { | |||
aligned_size += m_align_in_bytes - size % m_align_in_bytes; | |||
} | |||
m_aligned_sizes.push_back(aligned_size); | |||
m_reduce_sizes.push_back(reduce_size); | |||
reduce_size += aligned_size; | |||
} | |||
} | |||
WorkspaceBundle::WorkspaceBundle( | |||
SmallVector<SmallVector<size_t>> vector_sizes_in_bytes, void* ptr, | |||
size_t align_in_bytes) | |||
: m_ptr(ptr), | |||
m_sizes(vector_sizes_in_bytes), | |||
m_align_in_bytes(align_in_bytes) { | |||
size_t nr_workspace = 0_z; | |||
size_t reduce_size = 0_z; | |||
for (auto sizes_in_bytes: vector_sizes_in_bytes) { | |||
m_reduce_num.push_back(nr_workspace); | |||
for (auto size : sizes_in_bytes) { | |||
auto aligned_size = size; | |||
if (size % m_align_in_bytes != 0) { | |||
aligned_size += m_align_in_bytes - size % m_align_in_bytes; | |||
} | |||
m_aligned_sizes.push_back(aligned_size); | |||
m_reduce_sizes.push_back(reduce_size); | |||
reduce_size += aligned_size; | |||
nr_workspace++; | |||
} | |||
} | |||
} | |||
@@ -172,22 +199,39 @@ void* WorkspaceBundle::ptr() const { | |||
return m_ptr; | |||
} | |||
void* WorkspaceBundle::get(size_t i) const { | |||
void* WorkspaceBundle::get(size_t dim1, size_t dim0) const { | |||
megdnn_assert(dim1 < m_sizes.size(), "dim1 is out of range"); | |||
megdnn_assert(dim0 < m_sizes[dim1].size(), "dim0 is out of range"); | |||
auto addr = reinterpret_cast<uintptr_t>(m_ptr); | |||
if (addr % m_align_in_bytes != 0) | |||
addr += m_align_in_bytes - addr % m_align_in_bytes; | |||
for (size_t j = 0; j < i; ++j) { | |||
addr += m_aligned_sizes[j]; | |||
} | |||
size_t index = m_reduce_num[dim1] + dim0; | |||
addr += m_reduce_sizes[index]; | |||
return reinterpret_cast<void*>(addr); | |||
} | |||
void* WorkspaceBundle::get(size_t dim0) const { | |||
megdnn_assert(dim0 < m_aligned_sizes.size(), "dim0 is out of range"); | |||
auto addr = reinterpret_cast<uintptr_t>(m_ptr); | |||
if (addr % m_align_in_bytes != 0) | |||
addr += m_align_in_bytes - addr % m_align_in_bytes; | |||
addr += m_reduce_sizes[dim0]; | |||
return reinterpret_cast<void*>(addr); | |||
} | |||
size_t WorkspaceBundle::nr_workspace() const { | |||
return m_sizes.size(); | |||
return m_aligned_sizes.size(); | |||
} | |||
size_t WorkspaceBundle::get_size(size_t dim1, size_t dim0) const { | |||
megdnn_assert(dim1 < m_sizes.size(), "dim1 is out of range"); | |||
megdnn_assert(dim0 < m_sizes[dim1].size(), "dim0 is out of range"); | |||
return m_sizes[dim1][dim0]; | |||
} | |||
size_t WorkspaceBundle::get_size(size_t i) const { | |||
return m_sizes[i]; | |||
size_t WorkspaceBundle::get_size(size_t dim0) const { | |||
megdnn_assert(dim0 < m_aligned_sizes.size(), "dim0 is out of range"); | |||
return m_sizes[0][dim0]; | |||
} | |||
void WorkspaceBundle::set(void* ptr) { | |||
@@ -194,8 +194,15 @@ std::unique_ptr<T> make_unique(Args&&... args) { | |||
*/ | |||
class WorkspaceBundle { | |||
public: | |||
WorkspaceBundle(void* ptr, SmallVector<size_t> sizes_in_bytes, | |||
WorkspaceBundle(void* ptr = nullptr, | |||
SmallVector<size_t> sizes_in_bytes = {}, | |||
size_t align_in_bytes = 512); | |||
/** | |||
* construct 2D workspace buldle | |||
*/ | |||
WorkspaceBundle(SmallVector<SmallVector<size_t>> vector_sizes_in_bytes, | |||
void* ptr, size_t align_in_bytes = 512); | |||
/** | |||
* \returns raw workspace ptr. | |||
* | |||
@@ -204,26 +211,45 @@ public: | |||
*/ | |||
void* ptr() const; | |||
/** | |||
* \returns the i-th workspace ptr (aligned) | |||
* \returns the 2D [dim1, dim0] workspace ptr (aligned) | |||
*/ | |||
void* get(size_t i) const; | |||
void* get(size_t dim1, size_t dim0) const; | |||
/** | |||
* \returns the 1D [dim0] workspace ptr (aligned) | |||
*/ | |||
void* get(size_t dim0) const; | |||
/** | |||
* \returns total size taking into account paddings to solve alignment | |||
* issue. | |||
*/ | |||
size_t total_size_in_bytes() const; | |||
size_t get_size(size_t i) const; | |||
/** | |||
* \return the 2D [dim1, dim0] workspace size | |||
*/ | |||
size_t get_size(size_t dim1, size_t dim0) const; | |||
/** | |||
* \return the 1D [dim0] workspace size | |||
*/ | |||
size_t get_size(size_t dim0) const; | |||
size_t nr_workspace() const; | |||
void set(void* ptr); | |||
Workspace get_workspace(size_t i) const { | |||
return {static_cast<dt_byte*>(get(i)), get_size(i)}; | |||
Workspace get_workspace(size_t dim1, size_t dim0) const { | |||
return {static_cast<dt_byte*>(get(dim1, dim0)), get_size(dim1, dim0)}; | |||
} | |||
Workspace get_workspace(size_t dim0) const { | |||
return {static_cast<dt_byte*>(get(dim0)), get_size(dim0)}; | |||
} | |||
private: | |||
void* m_ptr; | |||
SmallVector<size_t> m_sizes; | |||
SmallVector<SmallVector<size_t>> m_sizes; | |||
SmallVector<size_t> m_aligned_sizes; | |||
//! all workspace size prefix sum | |||
SmallVector<size_t> m_reduce_sizes; | |||
//! dim1 workspace number prefix sum | |||
SmallVector<size_t> m_reduce_num; | |||
size_t m_align_in_bytes; | |||
}; | |||
@@ -11,6 +11,7 @@ | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/tensor_format.h" | |||
#include "src/common/utils.h" | |||
// clang-format off | |||
#include "test/common/utils.h" | |||
@@ -278,4 +279,21 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) { | |||
} | |||
} | |||
TEST(MISC, WORKSPACE_BUNDLE) { | |||
WorkspaceBundle bundle{ | |||
{{100, 200}, {435, 234, 143}, {422, 1325, 728}}, nullptr, 64}; | |||
bundle.set(reinterpret_cast<void*>(82l)); | |||
ASSERT_EQ(bundle.get(0), reinterpret_cast<void*>(128l)); | |||
void* dst = reinterpret_cast<void*>(128 + round_up(100, 64)); | |||
ASSERT_EQ(bundle.get(0, 1), dst); | |||
dst = reinterpret_cast<void*>(128 + round_up(100, 64) + round_up(200, 64) + | |||
round_up(435, 64)); | |||
ASSERT_EQ(bundle.get(1, 1), dst); | |||
dst = reinterpret_cast<void*>(128l + round_up(100, 64) + round_up(200, 64) + | |||
round_up(435, 64) + round_up(234, 64) + | |||
round_up(143, 64) + round_up(422, 64) + | |||
round_up(1325, 64)); | |||
ASSERT_EQ(bundle.get(2, 2), dst); | |||
} | |||
// vim: syntax=cpp.doxygen |