From: @dimitri_rose Reviewed-by: @xchu42,@wqtshg Signed-off-by: @wqtshgpull/906/MERGE
@@ -89,6 +89,14 @@ Status VariableMemoryAssigner::AssignVarAttr2Nodes() { | |||||
return ge::SUCCESS; | return ge::SUCCESS; | ||||
} | } | ||||
Status VariableMemoryAssigner::AssignMemory2HasRefAttrNode() { | |||||
Status result = ge::VarMemAssignUtil::AssignMemory2HasRefAttrNode(compute_graph_); | |||||
if (result != ge::SUCCESS) { | |||||
return result; | |||||
} | |||||
return ge::SUCCESS; | |||||
} | |||||
Status GraphMemoryAssigner::AssignMemory() { | Status GraphMemoryAssigner::AssignMemory() { | ||||
ge::HybridMemAssignerPtr mem_assigner(new (std::nothrow) HybridMemAssigner(compute_graph_)); | ge::HybridMemAssignerPtr mem_assigner(new (std::nothrow) HybridMemAssigner(compute_graph_)); | ||||
if (mem_assigner->Assign() != ge::SUCCESS) { | if (mem_assigner->Assign() != ge::SUCCESS) { | ||||
@@ -131,6 +139,19 @@ ge::Status GraphMemoryAssigner::AssignVarAttr2Nodes() { | |||||
return ge::SUCCESS; | return ge::SUCCESS; | ||||
} | } | ||||
ge::Status GraphMemoryAssigner::AssignMemory2HasRefAttrNode() { | |||||
auto variable_assigner = | |||||
std::unique_ptr<ge::VariableMemoryAssigner>(new(std::nothrow) ge::VariableMemoryAssigner(compute_graph_)); | |||||
if (variable_assigner == nullptr) { | |||||
GELOGE(ge::FAILED, "Alloc VariableMemoryAssigner failed."); | |||||
return ge::FAILED; | |||||
} | |||||
if (variable_assigner->AssignMemory2HasRefAttrNode() != ge::SUCCESS) { | |||||
return ge::FAILED; | |||||
} | |||||
return ge::SUCCESS; | |||||
} | |||||
ge::Status GraphMemoryAssigner::CalculateTensorRealSizeAndOutSize(const ge::ConstGeTensorDescPtr &output_desc, | ge::Status GraphMemoryAssigner::CalculateTensorRealSizeAndOutSize(const ge::ConstGeTensorDescPtr &output_desc, | ||||
int64_t dim_index, int64_t &output_mem_size, | int64_t dim_index, int64_t &output_mem_size, | ||||
int64_t &batch_dim_num, int64_t &out_size) { | int64_t &batch_dim_num, int64_t &out_size) { | ||||
@@ -63,6 +63,8 @@ class VariableMemoryAssigner { | |||||
/// | /// | ||||
ge::Status AssignVarAttr2Nodes(); | ge::Status AssignVarAttr2Nodes(); | ||||
ge::Status AssignMemory2HasRefAttrNode(); | |||||
private: | private: | ||||
ge::ComputeGraphPtr compute_graph_; | ge::ComputeGraphPtr compute_graph_; | ||||
}; | }; | ||||
@@ -99,6 +101,8 @@ class GraphMemoryAssigner { | |||||
ge::Status ReAssignMemory(bool is_loop_graph, size_t &mem_offset); | ge::Status ReAssignMemory(bool is_loop_graph, size_t &mem_offset); | ||||
ge::Status AssignMemory2HasRefAttrNode(); | |||||
ge::Status AssignZeroCopyMemory(size_t &mem_offset, size_t &zero_mem_copy_size); | ge::Status AssignZeroCopyMemory(size_t &mem_offset, size_t &zero_mem_copy_size); | ||||
ge::Status SetInputOffset(); | ge::Status SetInputOffset(); | ||||
@@ -40,6 +40,11 @@ Status MemoryAssigner::AssignMemory(bool is_loop_graph, size_t &mem_offset, size | |||||
return ge::FAILED; | return ge::FAILED; | ||||
} | } | ||||
if (graph_mem_assigner.AssignMemory2HasRefAttrNode() != ge::SUCCESS) { | |||||
GELOGE(ge::FAILED, "Assign reference memory failed!"); | |||||
return ge::FAILED; | |||||
} | |||||
// Assign memory for reference | // Assign memory for reference | ||||
if (graph_mem_assigner.AssignReferenceMemory() != ge::SUCCESS) { | if (graph_mem_assigner.AssignReferenceMemory() != ge::SUCCESS) { | ||||
GELOGE(ge::FAILED, "Assign reference memory failed!"); | GELOGE(ge::FAILED, "Assign reference memory failed!"); | ||||
@@ -34,7 +34,6 @@ using std::vector; | |||||
namespace ge { | namespace ge { | ||||
Status VarMemAssignUtil::AssignVarMemory(ge::ComputeGraphPtr &compute_graph) { | Status VarMemAssignUtil::AssignVarMemory(ge::ComputeGraphPtr &compute_graph) { | ||||
GE_CHK_STATUS_RET(AssignMemory2VariableNode(compute_graph)); | GE_CHK_STATUS_RET(AssignMemory2VariableNode(compute_graph)); | ||||
GE_CHK_STATUS_RET(AssignMemory2HasRefAttrNode(compute_graph)); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -45,16 +45,9 @@ NpuMemoryAllocator *NpuMemoryAllocator::GetAllocator() { | |||||
NpuMemoryAllocator::NpuMemoryAllocator(uint32_t device_id) : device_id_(device_id) {} | NpuMemoryAllocator::NpuMemoryAllocator(uint32_t device_id) : device_id_(device_id) {} | ||||
void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { | void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { | ||||
void *try_reuse_addr = nullptr; | |||||
size_t allocate_size = size; | size_t allocate_size = size; | ||||
MemStorageType mem_type = HBM; | MemStorageType mem_type = HBM; | ||||
if (attr != nullptr) { | if (attr != nullptr) { | ||||
try_reuse_addr = attr->try_reuse_addr_; | |||||
if (attr->padding_ != 0) { | |||||
// padding up to multiple of attr->padding, and add extra attr->padding_ | |||||
allocate_size = (size + 2 * attr->padding_ - 1) / attr->padding_ * attr->padding_; | |||||
GELOGD("Padding size %ld by %d. final size = %zu.", size, attr->padding_, allocate_size); | |||||
} | |||||
mem_type = attr->mem_type_; | mem_type = attr->mem_type_; | ||||
} | } | ||||
@@ -69,6 +62,18 @@ void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { | |||||
} else if (mem_type == HOST_DDR) { | } else if (mem_type == HOST_DDR) { | ||||
buffer = malloc(allocate_size); | buffer = malloc(allocate_size); | ||||
} else { | } else { | ||||
void *try_reuse_addr = nullptr; | |||||
int padding = kDefaultPadding; | |||||
if (attr != nullptr) { | |||||
try_reuse_addr = attr->try_reuse_addr_; | |||||
if (attr->padding_ > 0) { | |||||
padding = attr->padding_; | |||||
} | |||||
// padding up to multiple of padding, and add extra padding | |||||
allocate_size = (size + 2 * padding - 1) / padding * padding; | |||||
GELOGD("Padding size %ld by %d. final size = %zu.", size, padding, allocate_size); | |||||
} | |||||
buffer = MemManager::Instance() | buffer = MemManager::Instance() | ||||
.CachingInstance(RT_MEMORY_HBM) | .CachingInstance(RT_MEMORY_HBM) | ||||
.Malloc(allocate_size, reinterpret_cast<uint8_t *>(try_reuse_addr), device_id_); | .Malloc(allocate_size, reinterpret_cast<uint8_t *>(try_reuse_addr), device_id_); | ||||