diff --git a/src/core/impl/comp_node/cuda/comp_node.cpp b/src/core/impl/comp_node/cuda/comp_node.cpp index 83804687..e89d5670 100644 --- a/src/core/impl/comp_node/cuda/comp_node.cpp +++ b/src/core/impl/comp_node/cuda/comp_node.cpp @@ -172,6 +172,9 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl { struct StaticData; static StaticData *sd; static Spinlock sd_mtx; +#if !MGB_BUILD_SLIM_SERVING + std::mutex m_update_mem; +#endif //! set to true when m_locator is assigned; set to false if async init //! failed @@ -210,7 +213,17 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl { void* alloc_device(size_t size) override { activate(); +#if MGB_BUILD_SLIM_SERVING return m_mem_alloc->alloc(size); +#else + void* ptr = m_mem_alloc->alloc(size); + { + MGB_LOCK_GUARD(m_update_mem); + ptr2size[ptr] = size; + m_used_mem += size; + } + return ptr; +#endif } void free_device(void *ptr); @@ -287,8 +300,19 @@ class CudaCompNode::CompNodeImpl final: public CompNode::Impl { uint64_t get_uid() override { return m_uid; } + +#if !MGB_BUILD_SLIM_SERVING + size_t get_used_memory() override { + return m_used_mem; + } +#endif + private: uint64_t m_uid; +#if !MGB_BUILD_SLIM_SERVING + std::unordered_map ptr2size; + size_t m_used_mem = 0; +#endif }; MGB_DYN_TYPE_OBJ_FINAL_IMPL(CudaCompNode::CompNodeImpl); @@ -419,7 +443,16 @@ void CudaCompNodeImpl::free_device(void *ptr) { return; activate(); +#if !MGB_BUILD_SLIM_SERVING + { + MGB_LOCK_GUARD(m_update_mem); + mgb_assert(ptr2size.find(ptr) != ptr2size.end(), "ptr %p not found!", ptr); + m_used_mem -= ptr2size.at(ptr); + ptr2size.erase(ptr); + } +#endif m_mem_alloc->free(ptr); + } void* CudaCompNodeImpl::alloc_host(size_t size) { diff --git a/src/core/include/megbrain/comp_node.h b/src/core/include/megbrain/comp_node.h index 0b2250a8..99112723 100644 --- a/src/core/include/megbrain/comp_node.h +++ b/src/core/include/megbrain/comp_node.h @@ -351,6 +351,12 @@ class CompNode { return m_impl->get_mem_status_bytes(); } +#if !MGB_BUILD_SLIM_SERVING + size_t get_used_memory() const { + return m_impl->get_used_memory(); + } +#endif + //! change to another stream on the same memory node CompNode change_stream(int dest_stream) const; @@ -528,6 +534,12 @@ class CompNode { virtual MemNode mem_node() = 0; virtual std::pair get_mem_status_bytes() = 0; +#if !MGB_BUILD_SLIM_SERVING + virtual size_t get_used_memory() { + return 0; + } +#endif + virtual Locator locator() = 0; virtual Locator locator_logical() = 0;