This reverts commitrelease-0.64408bb9e1d
. GitOrigin-RevId:b5b23a8aae
@@ -156,42 +156,15 @@ void megdnn::infer_conv_shape2d(size_t ih, size_t iw, size_t fh, size_t fw, | |||||
WorkspaceBundle::WorkspaceBundle(void* ptr, SmallVector<size_t> sizes_in_bytes, | WorkspaceBundle::WorkspaceBundle(void* ptr, SmallVector<size_t> sizes_in_bytes, | ||||
size_t align_in_bytes) | size_t align_in_bytes) | ||||
: m_ptr(ptr), | : m_ptr(ptr), | ||||
m_sizes(std::move(sizes_in_bytes)), | |||||
m_align_in_bytes(align_in_bytes) { | m_align_in_bytes(align_in_bytes) { | ||||
m_aligned_sizes.reserve(m_sizes.size()); | m_aligned_sizes.reserve(m_sizes.size()); | ||||
m_sizes.push_back(sizes_in_bytes); | |||||
size_t reduce_size = 0_z; | |||||
m_reduce_num.push_back(0_z); | |||||
for (auto size : m_sizes[0]) { | |||||
for (auto size : m_sizes) { | |||||
auto aligned_size = size; | auto aligned_size = size; | ||||
if (size % m_align_in_bytes != 0) { | if (size % m_align_in_bytes != 0) { | ||||
aligned_size += m_align_in_bytes - size % m_align_in_bytes; | aligned_size += m_align_in_bytes - size % m_align_in_bytes; | ||||
} | } | ||||
m_aligned_sizes.push_back(aligned_size); | m_aligned_sizes.push_back(aligned_size); | ||||
m_reduce_sizes.push_back(reduce_size); | |||||
reduce_size += aligned_size; | |||||
} | |||||
} | |||||
WorkspaceBundle::WorkspaceBundle( | |||||
SmallVector<SmallVector<size_t>> vector_sizes_in_bytes, void* ptr, | |||||
size_t align_in_bytes) | |||||
: m_ptr(ptr), | |||||
m_sizes(vector_sizes_in_bytes), | |||||
m_align_in_bytes(align_in_bytes) { | |||||
size_t nr_workspace = 0_z; | |||||
size_t reduce_size = 0_z; | |||||
for (auto sizes_in_bytes: vector_sizes_in_bytes) { | |||||
m_reduce_num.push_back(nr_workspace); | |||||
for (auto size : sizes_in_bytes) { | |||||
auto aligned_size = size; | |||||
if (size % m_align_in_bytes != 0) { | |||||
aligned_size += m_align_in_bytes - size % m_align_in_bytes; | |||||
} | |||||
m_aligned_sizes.push_back(aligned_size); | |||||
m_reduce_sizes.push_back(reduce_size); | |||||
reduce_size += aligned_size; | |||||
nr_workspace++; | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -199,39 +172,22 @@ void* WorkspaceBundle::ptr() const { | |||||
return m_ptr; | return m_ptr; | ||||
} | } | ||||
void* WorkspaceBundle::get(size_t dim1, size_t dim0) const { | |||||
megdnn_assert(dim1 < m_sizes.size(), "dim1 is out of range"); | |||||
megdnn_assert(dim0 < m_sizes[dim1].size(), "dim0 is out of range"); | |||||
void* WorkspaceBundle::get(size_t i) const { | |||||
auto addr = reinterpret_cast<uintptr_t>(m_ptr); | auto addr = reinterpret_cast<uintptr_t>(m_ptr); | ||||
if (addr % m_align_in_bytes != 0) | if (addr % m_align_in_bytes != 0) | ||||
addr += m_align_in_bytes - addr % m_align_in_bytes; | addr += m_align_in_bytes - addr % m_align_in_bytes; | ||||
size_t index = m_reduce_num[dim1] + dim0; | |||||
addr += m_reduce_sizes[index]; | |||||
return reinterpret_cast<void*>(addr); | |||||
} | |||||
void* WorkspaceBundle::get(size_t dim0) const { | |||||
megdnn_assert(dim0 < m_aligned_sizes.size(), "dim0 is out of range"); | |||||
auto addr = reinterpret_cast<uintptr_t>(m_ptr); | |||||
if (addr % m_align_in_bytes != 0) | |||||
addr += m_align_in_bytes - addr % m_align_in_bytes; | |||||
addr += m_reduce_sizes[dim0]; | |||||
for (size_t j = 0; j < i; ++j) { | |||||
addr += m_aligned_sizes[j]; | |||||
} | |||||
return reinterpret_cast<void*>(addr); | return reinterpret_cast<void*>(addr); | ||||
} | } | ||||
size_t WorkspaceBundle::nr_workspace() const { | size_t WorkspaceBundle::nr_workspace() const { | ||||
return m_aligned_sizes.size(); | |||||
} | |||||
size_t WorkspaceBundle::get_size(size_t dim1, size_t dim0) const { | |||||
megdnn_assert(dim1 < m_sizes.size(), "dim1 is out of range"); | |||||
megdnn_assert(dim0 < m_sizes[dim1].size(), "dim0 is out of range"); | |||||
return m_sizes[dim1][dim0]; | |||||
return m_sizes.size(); | |||||
} | } | ||||
size_t WorkspaceBundle::get_size(size_t dim0) const { | |||||
megdnn_assert(dim0 < m_aligned_sizes.size(), "dim0 is out of range"); | |||||
return m_sizes[0][dim0]; | |||||
size_t WorkspaceBundle::get_size(size_t i) const { | |||||
return m_sizes[i]; | |||||
} | } | ||||
void WorkspaceBundle::set(void* ptr) { | void WorkspaceBundle::set(void* ptr) { | ||||
@@ -194,15 +194,8 @@ std::unique_ptr<T> make_unique(Args&&... args) { | |||||
*/ | */ | ||||
class WorkspaceBundle { | class WorkspaceBundle { | ||||
public: | public: | ||||
WorkspaceBundle(void* ptr = nullptr, | |||||
SmallVector<size_t> sizes_in_bytes = {}, | |||||
WorkspaceBundle(void* ptr, SmallVector<size_t> sizes_in_bytes, | |||||
size_t align_in_bytes = 512); | size_t align_in_bytes = 512); | ||||
/** | |||||
* construct 2D workspace buldle | |||||
*/ | |||||
WorkspaceBundle(SmallVector<SmallVector<size_t>> vector_sizes_in_bytes, | |||||
void* ptr, size_t align_in_bytes = 512); | |||||
/** | /** | ||||
* \returns raw workspace ptr. | * \returns raw workspace ptr. | ||||
* | * | ||||
@@ -211,45 +204,26 @@ public: | |||||
*/ | */ | ||||
void* ptr() const; | void* ptr() const; | ||||
/** | /** | ||||
* \returns the 2D [dim1, dim0] workspace ptr (aligned) | |||||
* \returns the i-th workspace ptr (aligned) | |||||
*/ | */ | ||||
void* get(size_t dim1, size_t dim0) const; | |||||
/** | |||||
* \returns the 1D [dim0] workspace ptr (aligned) | |||||
*/ | |||||
void* get(size_t dim0) const; | |||||
void* get(size_t i) const; | |||||
/** | /** | ||||
* \returns total size taking into account paddings to solve alignment | * \returns total size taking into account paddings to solve alignment | ||||
* issue. | * issue. | ||||
*/ | */ | ||||
size_t total_size_in_bytes() const; | size_t total_size_in_bytes() const; | ||||
/** | |||||
* \return the 2D [dim1, dim0] workspace size | |||||
*/ | |||||
size_t get_size(size_t dim1, size_t dim0) const; | |||||
/** | |||||
* \return the 1D [dim0] workspace size | |||||
*/ | |||||
size_t get_size(size_t dim0) const; | |||||
size_t get_size(size_t i) const; | |||||
size_t nr_workspace() const; | size_t nr_workspace() const; | ||||
void set(void* ptr); | void set(void* ptr); | ||||
Workspace get_workspace(size_t dim1, size_t dim0) const { | |||||
return {static_cast<dt_byte*>(get(dim1, dim0)), get_size(dim1, dim0)}; | |||||
} | |||||
Workspace get_workspace(size_t dim0) const { | |||||
return {static_cast<dt_byte*>(get(dim0)), get_size(dim0)}; | |||||
Workspace get_workspace(size_t i) const { | |||||
return {static_cast<dt_byte*>(get(i)), get_size(i)}; | |||||
} | } | ||||
private: | private: | ||||
void* m_ptr; | void* m_ptr; | ||||
SmallVector<SmallVector<size_t>> m_sizes; | |||||
SmallVector<size_t> m_sizes; | |||||
SmallVector<size_t> m_aligned_sizes; | SmallVector<size_t> m_aligned_sizes; | ||||
//! all workspace size prefix sum | |||||
SmallVector<size_t> m_reduce_sizes; | |||||
//! dim1 workspace number prefix sum | |||||
SmallVector<size_t> m_reduce_num; | |||||
size_t m_align_in_bytes; | size_t m_align_in_bytes; | ||||
}; | }; | ||||
@@ -11,7 +11,6 @@ | |||||
#include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
#include "megdnn/tensor_format.h" | #include "megdnn/tensor_format.h" | ||||
#include "src/common/utils.h" | |||||
// clang-format off | // clang-format off | ||||
#include "test/common/utils.h" | #include "test/common/utils.h" | ||||
@@ -279,21 +278,4 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) { | |||||
} | } | ||||
} | } | ||||
TEST(MISC, WORKSPACE_BUNDLE) { | |||||
WorkspaceBundle bundle{ | |||||
{{100, 200}, {435, 234, 143}, {422, 1325, 728}}, nullptr, 64}; | |||||
bundle.set(reinterpret_cast<void*>(82l)); | |||||
ASSERT_EQ(bundle.get(0), reinterpret_cast<void*>(128l)); | |||||
void* dst = reinterpret_cast<void*>(128 + round_up(100, 64)); | |||||
ASSERT_EQ(bundle.get(0, 1), dst); | |||||
dst = reinterpret_cast<void*>(128 + round_up(100, 64) + round_up(200, 64) + | |||||
round_up(435, 64)); | |||||
ASSERT_EQ(bundle.get(1, 1), dst); | |||||
dst = reinterpret_cast<void*>(128l + round_up(100, 64) + round_up(200, 64) + | |||||
round_up(435, 64) + round_up(234, 64) + | |||||
round_up(143, 64) + round_up(422, 64) + | |||||
round_up(1325, 64)); | |||||
ASSERT_EQ(bundle.get(2, 2), dst); | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |