From 653b91452103c74105e58c6036fd08852e88ff91 Mon Sep 17 00:00:00 2001 From: TangQunzhang Date: Fri, 14 May 2021 14:14:36 +0800 Subject: [PATCH] Support session scope memory --- ge/graph/manager/graph_mem_manager.cc | 2 ++ ge/graph/manager/session_scope_mem_allocator.cc | 4 +--- ge/graph/manager/session_scope_mem_allocator.h | 1 + .../manager/session_scope_mem_allocator_unittest.cc | 19 +++++++++++++++++++ 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/ge/graph/manager/graph_mem_manager.cc b/ge/graph/manager/graph_mem_manager.cc index 8d300dc2..21eaf302 100644 --- a/ge/graph/manager/graph_mem_manager.cc +++ b/ge/graph/manager/graph_mem_manager.cc @@ -65,6 +65,7 @@ Status MemManager::Initialize(const std::vector &memory_type) { return ret; } init_ = true; + memory_type_ = memory_type; return SUCCESS; } @@ -90,6 +91,7 @@ void MemManager::Finalize() noexcept { FinalizeAllocatorMap(host_allocator_map_); FinalizeAllocatorMap(memory_allocator_map_); init_ = false; + memory_type_.clear(); } MemoryAllocator &MemManager::MemInstance(rtMemType_t memory_type) { diff --git a/ge/graph/manager/session_scope_mem_allocator.cc b/ge/graph/manager/session_scope_mem_allocator.cc index 8eb01445..aedc2e92 100644 --- a/ge/graph/manager/session_scope_mem_allocator.cc +++ b/ge/graph/manager/session_scope_mem_allocator.cc @@ -65,9 +65,7 @@ Status SessionScopeMemAllocator::Free(uint64_t session_id, uint32_t device_id) { std::lock_guard lock(mutex_); auto it = allocated_memory_.find(session_id); if (it == allocated_memory_.end()) { - REPORT_INNER_ERROR("E19999", "Param memory not allocated before, session_id:%lu device_id:%u, check invalid", - session_id, device_id); - GELOGE(PARAM_INVALID, "Invalid session_id"); + GELOGW("Invalid session_id"); return ge::PARAM_INVALID; } allocated_memory_.erase(it); diff --git a/ge/graph/manager/session_scope_mem_allocator.h b/ge/graph/manager/session_scope_mem_allocator.h index 5aea9554..3dbf3cb0 100644 --- a/ge/graph/manager/session_scope_mem_allocator.h +++ b/ge/graph/manager/session_scope_mem_allocator.h @@ -53,6 +53,7 @@ class SessionScopeMemoryInfo { } size = other.size; ptr = other.ptr; + return *this; }; private: diff --git a/tests/ut/ge/graph/manager/session_scope_mem_allocator_unittest.cc b/tests/ut/ge/graph/manager/session_scope_mem_allocator_unittest.cc index 4a336af9..87af585a 100644 --- a/tests/ut/ge/graph/manager/session_scope_mem_allocator_unittest.cc +++ b/tests/ut/ge/graph/manager/session_scope_mem_allocator_unittest.cc @@ -73,3 +73,22 @@ TEST_F(UtestSessionScopeMemAllocator, free_success) { EXPECT_NE(SUCCESS, MemManager::Instance().SessionScopeMemInstance(RT_MEMORY_HBM).Free(0)); MemManager::Instance().Finalize(); } + +TEST_F(UtestSessionScopeMemAllocator, free_success_session) { + std::vector mem_type; + mem_type.push_back(RT_MEMORY_HBM); + mem_type.push_back(RT_MEMORY_P2P_DDR); + EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); + uint8_t *ptr = MemManager::Instance().SessionScopeMemInstance(RT_MEMORY_HBM).Malloc(100, 0); + EXPECT_NE(nullptr, ptr); + ptr = MemManager::Instance().SessionScopeMemInstance(RT_MEMORY_HBM).Malloc(100, 0); + EXPECT_NE(nullptr, ptr); + for (auto memory_type : MemManager::Instance().GetAllMemoryType()) { + if (RT_MEMORY_P2P_DDR == memory_type) { + EXPECT_NE(MemManager::Instance().SessionScopeMemInstance(memory_type).Free(0), SUCCESS); + } else { + EXPECT_EQ(MemManager::Instance().SessionScopeMemInstance(memory_type).Free(0), SUCCESS); + } + } + MemManager::Instance().Finalize(); +}