|
@@ -156,15 +156,42 @@ void megdnn::infer_conv_shape2d(size_t ih, size_t iw, size_t fh, size_t fw, |
|
|
WorkspaceBundle::WorkspaceBundle(void* ptr, SmallVector<size_t> sizes_in_bytes, |
|
|
WorkspaceBundle::WorkspaceBundle(void* ptr, SmallVector<size_t> sizes_in_bytes, |
|
|
size_t align_in_bytes) |
|
|
size_t align_in_bytes) |
|
|
: m_ptr(ptr), |
|
|
: m_ptr(ptr), |
|
|
m_sizes(std::move(sizes_in_bytes)), |
|
|
|
|
|
m_align_in_bytes(align_in_bytes) { |
|
|
m_align_in_bytes(align_in_bytes) { |
|
|
m_aligned_sizes.reserve(m_sizes.size()); |
|
|
m_aligned_sizes.reserve(m_sizes.size()); |
|
|
for (auto size : m_sizes) { |
|
|
|
|
|
|
|
|
m_sizes.push_back(sizes_in_bytes); |
|
|
|
|
|
size_t reduce_size = 0_z; |
|
|
|
|
|
m_reduce_num.push_back(0_z); |
|
|
|
|
|
for (auto size : m_sizes[0]) { |
|
|
auto aligned_size = size; |
|
|
auto aligned_size = size; |
|
|
if (size % m_align_in_bytes != 0) { |
|
|
if (size % m_align_in_bytes != 0) { |
|
|
aligned_size += m_align_in_bytes - size % m_align_in_bytes; |
|
|
aligned_size += m_align_in_bytes - size % m_align_in_bytes; |
|
|
} |
|
|
} |
|
|
m_aligned_sizes.push_back(aligned_size); |
|
|
m_aligned_sizes.push_back(aligned_size); |
|
|
|
|
|
m_reduce_sizes.push_back(reduce_size); |
|
|
|
|
|
reduce_size += aligned_size; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
WorkspaceBundle::WorkspaceBundle( |
|
|
|
|
|
SmallVector<SmallVector<size_t>> vector_sizes_in_bytes, void* ptr, |
|
|
|
|
|
size_t align_in_bytes) |
|
|
|
|
|
: m_ptr(ptr), |
|
|
|
|
|
m_sizes(vector_sizes_in_bytes), |
|
|
|
|
|
m_align_in_bytes(align_in_bytes) { |
|
|
|
|
|
size_t nr_workspace = 0_z; |
|
|
|
|
|
size_t reduce_size = 0_z; |
|
|
|
|
|
for (auto sizes_in_bytes: vector_sizes_in_bytes) { |
|
|
|
|
|
m_reduce_num.push_back(nr_workspace); |
|
|
|
|
|
for (auto size : sizes_in_bytes) { |
|
|
|
|
|
auto aligned_size = size; |
|
|
|
|
|
if (size % m_align_in_bytes != 0) { |
|
|
|
|
|
aligned_size += m_align_in_bytes - size % m_align_in_bytes; |
|
|
|
|
|
} |
|
|
|
|
|
m_aligned_sizes.push_back(aligned_size); |
|
|
|
|
|
m_reduce_sizes.push_back(reduce_size); |
|
|
|
|
|
reduce_size += aligned_size; |
|
|
|
|
|
nr_workspace++; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@@ -172,22 +199,39 @@ void* WorkspaceBundle::ptr() const { |
|
|
return m_ptr; |
|
|
return m_ptr; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void* WorkspaceBundle::get(size_t i) const { |
|
|
|
|
|
|
|
|
void* WorkspaceBundle::get(size_t dim1, size_t dim0) const { |
|
|
|
|
|
megdnn_assert(dim1 < m_sizes.size(), "dim1 is out of range"); |
|
|
|
|
|
megdnn_assert(dim0 < m_sizes[dim1].size(), "dim0 is out of range"); |
|
|
auto addr = reinterpret_cast<uintptr_t>(m_ptr); |
|
|
auto addr = reinterpret_cast<uintptr_t>(m_ptr); |
|
|
if (addr % m_align_in_bytes != 0) |
|
|
if (addr % m_align_in_bytes != 0) |
|
|
addr += m_align_in_bytes - addr % m_align_in_bytes; |
|
|
addr += m_align_in_bytes - addr % m_align_in_bytes; |
|
|
for (size_t j = 0; j < i; ++j) { |
|
|
|
|
|
addr += m_aligned_sizes[j]; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
size_t index = m_reduce_num[dim1] + dim0; |
|
|
|
|
|
addr += m_reduce_sizes[index]; |
|
|
|
|
|
return reinterpret_cast<void*>(addr); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void* WorkspaceBundle::get(size_t dim0) const { |
|
|
|
|
|
megdnn_assert(dim0 < m_aligned_sizes.size(), "dim0 is out of range"); |
|
|
|
|
|
auto addr = reinterpret_cast<uintptr_t>(m_ptr); |
|
|
|
|
|
if (addr % m_align_in_bytes != 0) |
|
|
|
|
|
addr += m_align_in_bytes - addr % m_align_in_bytes; |
|
|
|
|
|
addr += m_reduce_sizes[dim0]; |
|
|
return reinterpret_cast<void*>(addr); |
|
|
return reinterpret_cast<void*>(addr); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
size_t WorkspaceBundle::nr_workspace() const { |
|
|
size_t WorkspaceBundle::nr_workspace() const { |
|
|
return m_sizes.size(); |
|
|
|
|
|
|
|
|
return m_aligned_sizes.size(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
size_t WorkspaceBundle::get_size(size_t dim1, size_t dim0) const { |
|
|
|
|
|
megdnn_assert(dim1 < m_sizes.size(), "dim1 is out of range"); |
|
|
|
|
|
megdnn_assert(dim0 < m_sizes[dim1].size(), "dim0 is out of range"); |
|
|
|
|
|
return m_sizes[dim1][dim0]; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
size_t WorkspaceBundle::get_size(size_t i) const { |
|
|
|
|
|
return m_sizes[i]; |
|
|
|
|
|
|
|
|
size_t WorkspaceBundle::get_size(size_t dim0) const { |
|
|
|
|
|
megdnn_assert(dim0 < m_aligned_sizes.size(), "dim0 is out of range"); |
|
|
|
|
|
return m_sizes[0][dim0]; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void WorkspaceBundle::set(void* ptr) { |
|
|
void WorkspaceBundle::set(void* ptr) { |
|
|