diff --git a/dnn/src/common/utils.cpp b/dnn/src/common/utils.cpp index a11f340a..2dd2ca70 100644 --- a/dnn/src/common/utils.cpp +++ b/dnn/src/common/utils.cpp @@ -156,42 +156,15 @@ void megdnn::infer_conv_shape2d(size_t ih, size_t iw, size_t fh, size_t fw, WorkspaceBundle::WorkspaceBundle(void* ptr, SmallVector sizes_in_bytes, size_t align_in_bytes) : m_ptr(ptr), + m_sizes(std::move(sizes_in_bytes)), m_align_in_bytes(align_in_bytes) { m_aligned_sizes.reserve(m_sizes.size()); - m_sizes.push_back(sizes_in_bytes); - size_t reduce_size = 0_z; - m_reduce_num.push_back(0_z); - for (auto size : m_sizes[0]) { + for (auto size : m_sizes) { auto aligned_size = size; if (size % m_align_in_bytes != 0) { aligned_size += m_align_in_bytes - size % m_align_in_bytes; } m_aligned_sizes.push_back(aligned_size); - m_reduce_sizes.push_back(reduce_size); - reduce_size += aligned_size; - } -} - -WorkspaceBundle::WorkspaceBundle( - SmallVector> vector_sizes_in_bytes, void* ptr, - size_t align_in_bytes) - : m_ptr(ptr), - m_sizes(vector_sizes_in_bytes), - m_align_in_bytes(align_in_bytes) { - size_t nr_workspace = 0_z; - size_t reduce_size = 0_z; - for (auto sizes_in_bytes: vector_sizes_in_bytes) { - m_reduce_num.push_back(nr_workspace); - for (auto size : sizes_in_bytes) { - auto aligned_size = size; - if (size % m_align_in_bytes != 0) { - aligned_size += m_align_in_bytes - size % m_align_in_bytes; - } - m_aligned_sizes.push_back(aligned_size); - m_reduce_sizes.push_back(reduce_size); - reduce_size += aligned_size; - nr_workspace++; - } } } @@ -199,39 +172,22 @@ void* WorkspaceBundle::ptr() const { return m_ptr; } -void* WorkspaceBundle::get(size_t dim1, size_t dim0) const { - megdnn_assert(dim1 < m_sizes.size(), "dim1 is out of range"); - megdnn_assert(dim0 < m_sizes[dim1].size(), "dim0 is out of range"); +void* WorkspaceBundle::get(size_t i) const { auto addr = reinterpret_cast(m_ptr); if (addr % m_align_in_bytes != 0) addr += m_align_in_bytes - addr % m_align_in_bytes; - size_t index = m_reduce_num[dim1] + dim0; - addr += m_reduce_sizes[index]; - return reinterpret_cast(addr); -} - -void* WorkspaceBundle::get(size_t dim0) const { - megdnn_assert(dim0 < m_aligned_sizes.size(), "dim0 is out of range"); - auto addr = reinterpret_cast(m_ptr); - if (addr % m_align_in_bytes != 0) - addr += m_align_in_bytes - addr % m_align_in_bytes; - addr += m_reduce_sizes[dim0]; + for (size_t j = 0; j < i; ++j) { + addr += m_aligned_sizes[j]; + } return reinterpret_cast(addr); } size_t WorkspaceBundle::nr_workspace() const { - return m_aligned_sizes.size(); -} - -size_t WorkspaceBundle::get_size(size_t dim1, size_t dim0) const { - megdnn_assert(dim1 < m_sizes.size(), "dim1 is out of range"); - megdnn_assert(dim0 < m_sizes[dim1].size(), "dim0 is out of range"); - return m_sizes[dim1][dim0]; + return m_sizes.size(); } -size_t WorkspaceBundle::get_size(size_t dim0) const { - megdnn_assert(dim0 < m_aligned_sizes.size(), "dim0 is out of range"); - return m_sizes[0][dim0]; +size_t WorkspaceBundle::get_size(size_t i) const { + return m_sizes[i]; } void WorkspaceBundle::set(void* ptr) { diff --git a/dnn/src/common/utils.h b/dnn/src/common/utils.h index d73a01e7..449c9b04 100644 --- a/dnn/src/common/utils.h +++ b/dnn/src/common/utils.h @@ -194,15 +194,8 @@ std::unique_ptr make_unique(Args&&... args) { */ class WorkspaceBundle { public: - WorkspaceBundle(void* ptr = nullptr, - SmallVector sizes_in_bytes = {}, + WorkspaceBundle(void* ptr, SmallVector sizes_in_bytes, size_t align_in_bytes = 512); - - /** - * construct 2D workspace buldle - */ - WorkspaceBundle(SmallVector> vector_sizes_in_bytes, - void* ptr, size_t align_in_bytes = 512); /** * \returns raw workspace ptr. * @@ -211,45 +204,26 @@ public: */ void* ptr() const; /** - * \returns the 2D [dim1, dim0] workspace ptr (aligned) + * \returns the i-th workspace ptr (aligned) */ - void* get(size_t dim1, size_t dim0) const; - /** - * \returns the 1D [dim0] workspace ptr (aligned) - */ - void* get(size_t dim0) const; + void* get(size_t i) const; /** * \returns total size taking into account paddings to solve alignment * issue. */ size_t total_size_in_bytes() const; - /** - * \return the 2D [dim1, dim0] workspace size - */ - size_t get_size(size_t dim1, size_t dim0) const; - - /** - * \return the 1D [dim0] workspace size - */ - size_t get_size(size_t dim0) const; + size_t get_size(size_t i) const; size_t nr_workspace() const; void set(void* ptr); - Workspace get_workspace(size_t dim1, size_t dim0) const { - return {static_cast(get(dim1, dim0)), get_size(dim1, dim0)}; - } - Workspace get_workspace(size_t dim0) const { - return {static_cast(get(dim0)), get_size(dim0)}; + Workspace get_workspace(size_t i) const { + return {static_cast(get(i)), get_size(i)}; } private: void* m_ptr; - SmallVector> m_sizes; + SmallVector m_sizes; SmallVector m_aligned_sizes; - //! all workspace size prefix sum - SmallVector m_reduce_sizes; - //! dim1 workspace number prefix sum - SmallVector m_reduce_num; size_t m_align_in_bytes; }; diff --git a/dnn/test/common/test_basic_types.cpp b/dnn/test/common/test_basic_types.cpp index 899dc7d2..31a6c2e2 100644 --- a/dnn/test/common/test_basic_types.cpp +++ b/dnn/test/common/test_basic_types.cpp @@ -11,7 +11,6 @@ #include "megdnn/basic_types.h" #include "megdnn/tensor_format.h" -#include "src/common/utils.h" // clang-format off #include "test/common/utils.h" @@ -279,21 +278,4 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) { } } -TEST(MISC, WORKSPACE_BUNDLE) { - WorkspaceBundle bundle{ - {{100, 200}, {435, 234, 143}, {422, 1325, 728}}, nullptr, 64}; - bundle.set(reinterpret_cast(82l)); - ASSERT_EQ(bundle.get(0), reinterpret_cast(128l)); - void* dst = reinterpret_cast(128 + round_up(100, 64)); - ASSERT_EQ(bundle.get(0, 1), dst); - dst = reinterpret_cast(128 + round_up(100, 64) + round_up(200, 64) + - round_up(435, 64)); - ASSERT_EQ(bundle.get(1, 1), dst); - dst = reinterpret_cast(128l + round_up(100, 64) + round_up(200, 64) + - round_up(435, 64) + round_up(234, 64) + - round_up(143, 64) + round_up(422, 64) + - round_up(1325, 64)); - ASSERT_EQ(bundle.get(2, 2), dst); -} - // vim: syntax=cpp.doxygen