@@ -28,10 +28,9 @@ const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, | |||
kBinSizeUnit8 * kMByteSize, | |||
kBinSizeUnit32 * kMByteSize, | |||
kBinSizeUnit128 * kMByteSize, | |||
kGByteSize, | |||
kBinSizeUnit4 * kGByteSize, | |||
kBinSizeUnit16 * kGByteSize, | |||
kBinSizeUnit26 * kGByteSize}; | |||
kBinSizeUnit256 * kMByteSize, | |||
kBinSizeUnit512 * kMByteSize, | |||
kGByteSize}; | |||
static bool BlockComparator(const Block *left, const Block *right) { | |||
if (left->size != right->size) { | |||
@@ -63,7 +62,10 @@ size_t GetBinIndex(size_t size) { | |||
size_t GetAllocationSize(size_t size) { | |||
size_t index = GetBinIndex(size); | |||
return bin_ranges[index]; | |||
if (bin_ranges[index] >= size) { | |||
return bin_ranges[index]; | |||
} | |||
return kGByteSize * ((size + kGByteSize - 1) / kGByteSize); | |||
} | |||
/// | |||
@@ -36,17 +36,17 @@ namespace ge { | |||
constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes | |||
constexpr size_t kBinSizeUnit4 = 4; | |||
constexpr size_t kBinSizeUnit8 = 8; | |||
constexpr size_t kBinSizeUnit16 = 16; | |||
constexpr size_t kBinSizeUnit26 = 26; | |||
constexpr size_t kBinSizeUnit32 = 32; | |||
constexpr size_t kBinSizeUnit128 = 128; | |||
constexpr size_t kBinSizeUnit256 = 256; | |||
constexpr size_t kBinSizeUnit512 = 512; | |||
constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold | |||
constexpr double kSplitThreshold = 0.5; // split when malloc size <= small block size * kSpliThreshold | |||
constexpr size_t kKByteSize = 1024; | |||
constexpr size_t kMByteSize = 1048576; // 1024 * 1024 | |||
constexpr size_t kGByteSize = 1073741824; // 1024 * 1024 * 1024 | |||
static const uint32_t kNumBins = 8; | |||
static const uint32_t kNumBins = 7; | |||
class MemoryAllocator; | |||
@@ -323,6 +323,8 @@ Status NodeDoneCallback::OnNodeDone() { | |||
node_item.NodeName().c_str()); | |||
} | |||
// release workspace | |||
context_->ReleaseWorkspace(); | |||
// release inputs | |||
for (int i = 0; i < context_->NumInputs(); ++i) { | |||
context_->ReleaseInput(i); | |||
@@ -36,10 +36,6 @@ TaskContext::TaskContext(GraphExecutionContext *execution_context, | |||
TaskContext::~TaskContext() { | |||
GELOGD("[%s] TaskContext destroyed.", node_item_->NodeName().c_str()); | |||
for (auto ws_addr : workspaces_) { | |||
execution_context_->allocator->Deallocate(ws_addr); | |||
} | |||
// release output | |||
for (int i = 0; i < NumOutputs(); ++i) { | |||
auto output_tensor = MutableOutput(i); | |||
@@ -49,6 +45,13 @@ TaskContext::~TaskContext() { | |||
} | |||
} | |||
void TaskContext::ReleaseWorkspace() { | |||
GELOGD("[%s] Start ReleaseWorkspace.", node_item_->NodeName().c_str()); | |||
for (auto ws_addr : workspaces_) { | |||
execution_context_->allocator->Deallocate(ws_addr); | |||
} | |||
} | |||
std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | |||
GraphExecutionContext *execution_context, | |||
SubgraphContext *subgraph_context) { | |||
@@ -56,6 +56,7 @@ class TaskContext { | |||
void ReleaseInputsAndOutputs(); | |||
bool NeedCallback(); | |||
void ReleaseInput(int index); | |||
void ReleaseWorkspace(); | |||
const TensorValue *GetInput(int index) const; | |||
const TensorValue *GetOutput(int index) const; | |||
TensorValue *MutableOutput(int index); | |||
@@ -752,6 +752,7 @@ set(MULTI_PARTS_TEST_FILES | |||
"graph/build/mem_assigner_unittest.cc" | |||
"graph/preprocess/graph_preprocess_unittest.cc" | |||
"graph/manager/hcom_util_unittest.cc" | |||
"graph/manager/graph_caching_allocator_unittest.cc" | |||
"session/omg_omg_unittest.cc" | |||
) | |||
@@ -0,0 +1,76 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include <memory> | |||
#include "graph/anchor.h" | |||
#include "graph/attr_value.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/node_utils.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "omg/omg_inner_types.h" | |||
#define protected public | |||
#define private public | |||
#include "graph/manager/graph_caching_allocator.h" | |||
#include "graph/manager/graph_mem_allocator.h" | |||
#undef protected | |||
#undef private | |||
using namespace std; | |||
using namespace testing; | |||
using namespace ge; | |||
using domi::GetContext; | |||
class UtestGraphCachingAllocatorTest : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() { GetContext().out_nodes_map.clear(); } | |||
}; | |||
TEST_F(UtestGraphCachingAllocatorTest, initialize_success) { | |||
std::vector<rtMemType_t> mem_type; | |||
mem_type.push_back(RT_MEMORY_HBM); | |||
EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); | |||
MemManager::Instance().Finalize(); | |||
} | |||
TEST_F(UtestGraphCachingAllocatorTest, malloc_success) { | |||
std::vector<rtMemType_t> mem_type; | |||
mem_type.push_back(RT_MEMORY_HBM); | |||
EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); | |||
uint8_t *ptr = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kMByteSize); | |||
EXPECT_NE(nullptr, ptr); | |||
MemManager::Instance().Finalize(); | |||
} | |||
TEST_F(UtestGraphCachingAllocatorTest, malloc_statics) { | |||
std::vector<rtMemType_t> mem_type; | |||
mem_type.push_back(RT_MEMORY_HBM); | |||
EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); | |||
uint8_t *ptr = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kMByteSize); | |||
EXPECT_NE(nullptr, ptr); | |||
uint8_t *ptr1 = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kKByteSize); | |||
EXPECT_NE(nullptr, ptr); | |||
EXPECT_EQ(MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Free(ptr), SUCCESS); | |||
EXPECT_EQ(MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Free(ptr1), SUCCESS); | |||
MemManager::Instance().CachingInstance(RT_MEMORY_HBM).FreeCachedBlocks(); | |||
MemManager::Instance().Finalize(); | |||
} |