Browse Source

feat(cuda/comp_node): enable to directly query memory status

GitOrigin-RevId: 23c1af9a55
tags/v1.4.0-rc1
Megvii Engine Team 4 years ago
parent
commit
fd61f09540
2 changed files with 45 additions and 0 deletions
  1. +33
    -0
      src/core/impl/comp_node/cuda/comp_node.cpp
  2. +12
    -0
      src/core/include/megbrain/comp_node.h

+ 33
- 0
src/core/impl/comp_node/cuda/comp_node.cpp View File

@@ -172,6 +172,9 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl {
struct StaticData;
static StaticData *sd;
static Spinlock sd_mtx;
#if !MGB_BUILD_SLIM_SERVING
std::mutex m_update_mem;
#endif

//! set to true when m_locator is assigned; set to false if async init
//! failed
@@ -210,7 +213,17 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl {

void* alloc_device(size_t size) override {
activate();
#if MGB_BUILD_SLIM_SERVING
return m_mem_alloc->alloc(size);
#else
void* ptr = m_mem_alloc->alloc(size);
{
MGB_LOCK_GUARD(m_update_mem);
ptr2size[ptr] = size;
m_used_mem += size;
}
return ptr;
#endif
}

void free_device(void *ptr);
@@ -287,8 +300,19 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl {
uint64_t get_uid() override {
return m_uid;
}

#if !MGB_BUILD_SLIM_SERVING
size_t get_used_memory() override {
return m_used_mem;
}
#endif

private:
uint64_t m_uid;
#if !MGB_BUILD_SLIM_SERVING
std::unordered_map<void*, size_t> ptr2size;
size_t m_used_mem = 0;
#endif
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CudaCompNode::CompNodeImpl);

@@ -419,7 +443,16 @@ void CudaCompNodeImpl::free_device(void *ptr) {
return;

activate();
#if !MGB_BUILD_SLIM_SERVING
{
MGB_LOCK_GUARD(m_update_mem);
mgb_assert(ptr2size.find(ptr) != ptr2size.end(), "ptr %p not found!", ptr);
m_used_mem -= ptr2size.at(ptr);
ptr2size.erase(ptr);
}
#endif
m_mem_alloc->free(ptr);

}

void* CudaCompNodeImpl::alloc_host(size_t size) {


+ 12
- 0
src/core/include/megbrain/comp_node.h View File

@@ -351,6 +351,12 @@ class CompNode {
return m_impl->get_mem_status_bytes();
}

#if !MGB_BUILD_SLIM_SERVING
size_t get_used_memory() const {
return m_impl->get_used_memory();
}
#endif

//! change to another stream on the same memory node
CompNode change_stream(int dest_stream) const;

@@ -528,6 +534,12 @@ class CompNode {
virtual MemNode mem_node() = 0;
virtual std::pair<size_t, size_t> get_mem_status_bytes() = 0;

#if !MGB_BUILD_SLIM_SERVING
virtual size_t get_used_memory() {
return 0;
}
#endif

virtual Locator locator() = 0;
virtual Locator locator_logical() = 0;



Loading…
Cancel
Save