Browse Source

add assign node in hccl memcpy pass

pull/277/head
wangxiaotian22 4 years ago
parent
commit
62dbaeb8c1
2 changed files with 150 additions and 0 deletions
  1. +141
    -0
      ge/graph/passes/hccl_memcpy_pass.cc
  2. +9
    -0
      ge/graph/passes/hccl_memcpy_pass.h

+ 141
- 0
ge/graph/passes/hccl_memcpy_pass.cc View File

@@ -28,6 +28,8 @@
namespace {
const int32_t kAnchorSize = 1;
const int kAnchorNum = 0;
const int32_t kAnchorAssignRefIndex = 0;
const int32_t kAnchorAssignValueIndex = 1;
const char *const kInputMutable = "_input_mutable";
} // namespace
namespace ge {
@@ -265,6 +267,33 @@ std::string HcclMemcpyPass::CheckDuplicateName(const std::string &node_name) {
///
Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor,
const InDataAnchorPtr &hccl_in_anchor) {
Status ret = InsertIdentityBeforeHccl(graph, src_out_anchor, hccl_in_anchor);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "add identity failed, var_node:%s, hccl_node:%s.",
src_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
return ret;
}

ret = InsertAssignAfterBroadcastIfNeed(graph, src_out_anchor, hccl_in_anchor);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "add assign failed, var_node:%s, hccl_node:%s.",
src_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
return ret;
}
return SUCCESS;
}

///
/// @brief Insert Identity node Between Hccl node and variable
/// @param [in] ComputeGraphPtr graph
/// @param [in] OutDataAnchorPtr src_out_anchor
/// @param [in] InDataAnchorPtr hccl_in_anchor
/// @return status
///
Status HcclMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor,
const InDataAnchorPtr &hccl_in_anchor) {
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor);
@@ -291,8 +320,120 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const
memcpy_node->GetName().c_str());
return FAILED;
}
}

///
/// @brief Insert assign node after broadcast node and variable to refresh variable data
/// @param [in] ComputeGraphPtr graph
/// @param [in] OutDataAnchorPtr var_out_anchor
/// @param [in] InDataAnchorPtr hccl_in_anchor
/// @return status
///
Status HcclMemcpyPass::InsertAssignAfterBroadcastIfNeed(const ComputeGraphPtr &graph,
const OutDataAnchorPtr &var_out_anchor,
const InDataAnchorPtr &hccl_in_anchor) {
if (hccl_in_anchor->GetOwnerNode()->GetType() != HCOMBROADCAST) {
GELOGI("%s not broadcast, no need to insert assign node", hccl_in_anchor->GetOwnerNode()->GetName().c_str());
return SUCCESS;
}

GELOGI("after op %s and op %s need insert assign op.", var_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());

NodePtr assign_node = CreateAssignNode(graph, var_out_anchor);
GE_CHECK_NOTNULL(assign_node);

OutDataAnchorPtr hccl_out_anchor = hccl_in_anchor->GetOwnerNode()->GetOutDataAnchor(hccl_in_anchor->GetIdx());

Status ret = hccl_out_anchor->LinkTo(assign_node->GetInDataAnchor(kAnchorAssignValueIndex));
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", hccl_out_anchor->GetOwnerNode()->GetName().c_str(),
assign_node->GetName().c_str());
return FAILED;
}

ret = var_out_anchor->LinkTo(assign_node->GetInDataAnchor(kAnchorAssignRefIndex));
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", var_out_anchor->GetOwnerNode()->GetName().c_str(),
assign_node->GetName().c_str());
return FAILED;
}

// add control edge between assign node and node after broadcast node
OutControlAnchorPtr assign_out_control_anchor = assign_node->GetOutControlAnchor();

for (auto in_data_anchor : hccl_out_anchor->GetPeerInDataAnchors()) {
ret = assign_out_control_anchor->LinkTo(in_data_anchor->GetOwnerNode()->GetInControlAnchor());
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "The op %s link control anchor %s fail.", assign_out_control_anchor->GetOwnerNode()->GetName().c_str(),
in_data_anchor->GetOwnerNode()->GetName().c_str());
return FAILED;
}
}

for (auto in_control_anchor : hccl_out_anchor->GetOwnerNode()->GetOutControlAnchor()->GetPeerInControlAnchors()) {
ret = assign_out_control_anchor->LinkTo(in_control_anchor);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "The op %s link control anchor %s fail.", assign_out_control_anchor->GetOwnerNode()->GetName().c_str(),
in_control_anchor->GetOwnerNode()->GetName().c_str());
return FAILED;
}
}
return SUCCESS;
}

///
/// @brief create assign Node, add to graph
/// @param [in] ge::ComputeGraphPtr graph
/// @param [in] ge::OutDataAnchorPtr variable node out anchor
/// @return ge::NodePtr
///
NodePtr HcclMemcpyPass::CreateAssignNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) {
GE_IF_BOOL_EXEC(graph == nullptr, return nullptr);
NodePtr pre_node = out_data_anchor->GetOwnerNode();
OpDescPtr pre_op_desc = pre_node->GetOpDesc();
if (pre_op_desc == nullptr) {
GELOGE(INTERNAL_ERROR, "OpDesc of pre node is invalid.");
return nullptr;
}

std::string node_name = pre_node->GetName() + "_" + ASSIGN;
node_name = CheckDuplicateName(node_name);
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), ASSIGN);
if (op_desc == nullptr) {
GELOGE(INTERNAL_ERROR, "Create Assign op: MakeShared op_desc fail.");
return nullptr;
}
GELOGI("Create Assign op:%s.", op_desc->GetName().c_str());

graphStatus ret = op_desc->AddInputDesc("ref", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Create Assign op: add ref input desc fail.");
return nullptr;
}

ret = op_desc->AddInputDesc("value", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Create Assign op: add value input desc fail.");
return nullptr;
}

ret = op_desc->AddOutputDesc("ref", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Create Assign op: add output desc fail.");
return nullptr;
}

NodePtr assign_node = graph->AddNode(op_desc);
if (assign_node == nullptr) {
GELOGE(INTERNAL_ERROR, "Insert Identity node fail.");
return nullptr;
}

return assign_node;
}


///
/// @brief Clear Status, used for subgraph pass
/// @return SUCCESS


+ 9
- 0
ge/graph/passes/hccl_memcpy_pass.h View File

@@ -32,11 +32,20 @@ class HcclMemcpyPass : public GraphPass {
private:
NodePtr CreateIdentityNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor);

NodePtr CreateAssignNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor);

std::string CheckDuplicateName(const std::string &node_name);

Status ModifyEdgeConnection(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor,
const InDataAnchorPtr &hccl_in_anchor);

Status InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor,
const InDataAnchorPtr &hccl_in_anchor);

Status InsertAssignAfterBroadcastIfNeed(const ComputeGraphPtr &graph,
const OutDataAnchorPtr &src_out_anchor,
const InDataAnchorPtr &hccl_in_anchor)

Status ContinuousInputProcess(const ComputeGraphPtr &graph, const NodePtr node);

Status MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node);


Loading…
Cancel
Save