diff --git a/ge/graph/build/memory/block_mem_assigner.cc b/ge/graph/build/memory/block_mem_assigner.cc index d59023f8..c00163f8 100755 --- a/ge/graph/build/memory/block_mem_assigner.cc +++ b/ge/graph/build/memory/block_mem_assigner.cc @@ -415,6 +415,15 @@ BlockMemAssigner::~BlockMemAssigner() { } } +void BlockMemAssigner::MarkContinuousAllocedForOneInput(OpDescPtr &node_op_desc) { + // if input size just one, no need to reassign continuous memory + bool is_input_continuous = false; + (void)ge::AttrUtils::GetBool(node_op_desc, ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous); + if (is_input_continuous && (node_op_desc->GetInputsSize() <= 1)) { + (void)ge::AttrUtils::SetBool(node_op_desc, ATTR_NAME_CONTINUOUS_INPUT_ALLOC, true); + } +} + void BlockMemAssigner::GetOutAndWorkSpaceMem(vector &all_memory_size) { vector temp; for (const NodePtr &n : compute_graph_->GetAllNodes()) { @@ -425,12 +434,7 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector &all_memory_size) { atomic_addr_clean_id_ = node_op_desc->GetId(); } - // if input size just one, no need to reassign continuous memory - bool is_input_continuous = false; - (void)ge::AttrUtils::GetBool(node_op_desc, ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous); - if (is_input_continuous && (node_op_desc->GetInputsSize() <= 1)) { - (void)ge::AttrUtils::SetBool(node_op_desc, ATTR_NAME_CONTINUOUS_INPUT_ALLOC, true); - } + MarkContinuousAllocedForOneInput(node_op_desc); for (auto &out_anchor : n->GetAllOutDataAnchors()) { GeTensorDesc output_desc = node_op_desc->GetOutputDesc(out_anchor->GetIdx()); diff --git a/ge/graph/build/memory/block_mem_assigner.h b/ge/graph/build/memory/block_mem_assigner.h index f3d26c1d..c79e695b 100755 --- a/ge/graph/build/memory/block_mem_assigner.h +++ b/ge/graph/build/memory/block_mem_assigner.h @@ -409,6 +409,8 @@ class BlockMemAssigner : public MemAssigner { MemoryBlock *ApplyContinuousMemory(const NodePtr &n, const vector &ranges, const bool is_op_reuse_mem); + void MarkContinuousAllocedForOneInput(OpDescPtr &node_op_desc); + std::unordered_map>> reusable_blocks_; std::map reusable_block_counts_;