2.if broadcast input more than one, and from variable, add memcpy node between them. delete move variable to broadcast input in davinci model runpull/277/head
@@ -425,6 +425,13 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) { | |||
atomic_addr_clean_id_ = node_op_desc->GetId(); | |||
} | |||
// if input size just one, no need to reassign continuous memory | |||
bool is_input_continuous = false; | |||
(void)ge::AttrUtils::GetBool(node_op_desc, ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous); | |||
if (is_input_continuous && (node_op_desc->GetInputSize() <= 1)) { | |||
(void)ge::AttrUtils::SetBool(node_op_desc, ATTR_NAME_CONTINUOUS_INPUT_ALLOC, true); | |||
} | |||
for (auto &out_anchor : n->GetAllOutDataAnchors()) { | |||
GeTensorDesc output_desc = node_op_desc->GetOutputDesc(out_anchor->GetIdx()); | |||
bool reuse_input = false; | |||
@@ -928,6 +935,13 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "input node is null."); | |||
auto node_op_desc = n->GetOpDesc(); | |||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(node_op_desc == nullptr, return nullptr, "node_op_desc is null."); | |||
// if output size just one, no need to reassign continuous memory | |||
if (node_op_desc->GetOutputsSize() == 1) { | |||
zero_memory_list_.emplace_back(n, kOutput, 0); | |||
return nullptr; | |||
} | |||
MemoryBlock *block = nullptr; | |||
int64_t total_size = 0; | |||
int64_t memory_type = RT_MEMORY_HBM; | |||
@@ -1746,9 +1760,8 @@ Status BlockMemAssigner::Assign() { | |||
bool BlockMemAssigner::CheckIsZeroMemNodeType(const string &node_type) const { | |||
return (node_type == VARIABLE) || (node_type == CONSTANT) || (node_type == MULTISHAPE) || | |||
(node_type == HCOMBROADCAST) || (node_type == CONSTANTOP) || | |||
(node_type == ASSIGNADD) || (node_type == ASSIGNSUB) || (node_type == ASSIGN) || (node_type == HVDWAIT) || | |||
(node_type == HVDCALLBACKBROADCAST); | |||
(node_type == CONSTANTOP) || (node_type == ASSIGNADD) || (node_type == ASSIGNSUB) || | |||
(node_type == ASSIGN) || (node_type == HVDWAIT); | |||
} | |||
bool BlockMemAssigner::GetWorkSpaceMemoryType(const NodePtr &node, size_t index, int64_t &memory_type) { | |||
@@ -1993,12 +1993,6 @@ Status DavinciModel::SyncVarData() { | |||
RT_MEMCPY_HOST_TO_DEVICE)); | |||
} | |||
for (auto op_desc : variable_op_list_) { | |||
ret = | |||
VarManager::Instance(session_id_)->SyncVarData(runtime_param_.graph_id, op_desc->GetName(), op_desc, mem_base_); | |||
GE_CHK_BOOL_EXEC(ret == SUCCESS, break, "sync var data ret failed, model id:%u, op name:%s.", model_id_, | |||
op_desc->GetName().c_str()); | |||
} | |||
return ret; | |||
} | |||
@@ -37,6 +37,12 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||
auto op_desc = node->GetOpDesc(); | |||
GE_IF_BOOL_EXEC(op_desc == nullptr, continue); | |||
Status ret = ProcessBroadcastMemcpy(graph, node); | |||
if (ret != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "failed ProcessBroadcastMemcpy."); | |||
return ret; | |||
} | |||
bool node_input_mutable = false; | |||
if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { | |||
continue; | |||
@@ -61,7 +67,7 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||
// Memcpyasync needs to be inserted between constant (/data) and hcomallreduce to avoid constant being cleared. | |||
NodePtr src_node = src_out_anchor->GetOwnerNode(); | |||
std::string src_type = src_node->GetType(); | |||
bool check_src_type = (src_type == CONSTANTOP) || (src_type == DATA) || (src_type == CONSTANT); | |||
bool check_src_type = (src_type == CONSTANTOP) || (src_type == VARIABLE) || (src_type == DATA) || (src_type == CONSTANT); | |||
if (check_src_type) { | |||
Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); | |||
if (ret != SUCCESS) { | |||
@@ -82,6 +88,44 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||
return SUCCESS; | |||
} | |||
// If broadcast input size is bigger than 1, and input from variable, | |||
// cause by broadcast input memory should be continuous, | |||
// another featuremap mem will be allocated for broadcast input. | |||
// In this condition, move data from variable mem to broadcast input featuremap mem will be executed each step. | |||
// In order to avoid move action out of model, use memcpy node instead of move action code. | |||
Status HcclMemcpyPass::ProcessBroadcastMemcpy(const ComputeGraphPtr &graph, const NodePtr node) { | |||
auto op_desc = node->GetOpDesc(); | |||
if (op_desc == nullptr) { | |||
GELOGE(INTERNAL_ERROR, "node has no op_desc, node_name : %s.", node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
if ((node->GetType() == HCOMBROADCAST || node->GetType() == HVDCALLBACKBROADCAST) && op_desc->GetInputSize() > 1) { | |||
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { | |||
if (hccl_in_anchor == nullptr) { | |||
continue; | |||
} | |||
auto src_out_anchor = hccl_in_anchor->GetPeerOutAnchor(); | |||
if (src_out_anchor == nullptr) { | |||
GELOGE(INTERNAL_ERROR, "hcom op input has no peer anchor, node_name:%s", node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
NodePtr src_node = src_out_anchor->GetOwnerNode(); | |||
std::string src_type = src_node->GetType(); | |||
bool check_src_type = (src_type == CONSTANTOP) || (src_type == VARIABLE) || (src_type == DATA) || (src_type == CONSTANT); | |||
if (check_src_type) { | |||
Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); | |||
if (ret != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Failed to modify the connection."); | |||
return ret; | |||
} | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
/// | |||
/// @brief Add MemcpyAsync Node | |||
/// @param [in] ge::ComputeGraphPtr graph | |||
@@ -37,6 +37,8 @@ class HcclMemcpyPass : public GraphPass { | |||
Status ModifyEdgeConnection(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, | |||
const InDataAnchorPtr &hccl_in_anchor); | |||
Status ProcessBroadcastMemcpy(const ComputeGraphPtr &graph, const NodePtr node); | |||
std::unordered_map<std::string, uint32_t> node_num_map_; | |||
}; | |||
} // namespace ge | |||