Browse Source

For sc

tags/v1.3.0
zhaozhixuan 4 years ago
parent
commit
005f4a0972
3 changed files with 71 additions and 43 deletions
  1. +42
    -43
      ge/hybrid/model/hybrid_model_builder.cc
  2. +2
    -0
      ge/hybrid/model/hybrid_model_builder.h
  3. +27
    -0
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 42
- 43
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -126,20 +126,6 @@ Status CollectDependenciesForFusedGraph(NodeItem &node_item, std::set<OpDesc *>

return SUCCESS;
}

bool CheckHasHostMem(NodeItem &node_item) {
if (node_item.NodeType() == DATA) {
auto op_desc = node_item.GetOpDesc();
if (op_desc == nullptr) {
return false;
}
auto tensor = op_desc->MutableInputDesc(0);
if (AttrUtils::HasAttr(tensor, ATTR_NAME_VALUE)) {
return true;
}
}
return false;
}
} // namespace
HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model)
: hybrid_model_(hybrid_model), runtime_param_(hybrid_model.root_runtime_param_) {
@@ -298,6 +284,47 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt
return SUCCESS;
}

Status HybridModelBuilder::ParseDependencies(NodeItem &node_item, const std::vector<string> &dependencies,
std::set<NodePtr> &dependent_for_shape_inference) {
auto &ge_node = node_item.node;
for (const auto &input_name : dependencies) {
int input_index = node_item.op_desc->GetInputIndexByName(input_name);
if (input_index < 0) {
GELOGE(INTERNAL_ERROR, "[Get][InputIndex]failed, node:[%s] inputname: %s.",
node_item.NodeName().c_str(), input_name.c_str());
REPORT_CALL_ERROR("E19999", "GetInputIndexByName failed, node:[%s] inputname: %s.",
node_item.NodeName().c_str(), input_name.c_str());
return INTERNAL_ERROR;
}

const auto &in_anchor = ge_node->GetInDataAnchor(input_index);
GE_CHECK_NOTNULL(in_anchor);
const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_anchor);
const auto &src_node = peer_out_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
auto src_node_item = MutableNodeItem(src_node);
GE_CHECK_NOTNULL(src_node_item);
if (src_node_item->NodeType() == DATA) {
auto op_desc = src_node_item->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
auto tensor = op_desc->MutableInputDesc(0);
if (AttrUtils::HasAttr(tensor, ATTR_NAME_VALUE)) {
GELOGD("Skip d2h memcpy, get hostmem from node %s.", src_node_item->NodeName().c_str());
continue;
}
}
src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx());
dependent_for_shape_inference.emplace(src_node);
host_input_value_dependencies_[&node_item].emplace_back(peer_out_anchor->GetIdx(), src_node_item);
GELOGD("[%s] Dependent added from output of [%s:%d]",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str(),
peer_out_anchor->GetIdx());
}
return SUCCESS;
}

Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) {
std::set<NodePtr> dependent_for_shape_inference;
std::set<NodePtr> dependent_for_execution;
@@ -361,35 +388,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
src_node_item->NodeName().c_str());
}

for (const auto &input_name : dependencies) {
int input_index = node_item.op_desc->GetInputIndexByName(input_name);
if (input_index < 0) {
GELOGE(INTERNAL_ERROR, "[Get][InputIndex]failed, node:[%s] inputname: %s.",
node_item.NodeName().c_str(), input_name.c_str());
REPORT_CALL_ERROR("E19999", "GetInputIndexByName failed, node:[%s] inputname: %s.",
node_item.NodeName().c_str(), input_name.c_str());
return INTERNAL_ERROR;
}

const auto &in_anchor = ge_node->GetInDataAnchor(input_index);
GE_CHECK_NOTNULL(in_anchor);
const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_anchor);
const auto &src_node = peer_out_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
auto src_node_item = MutableNodeItem(src_node);
GE_CHECK_NOTNULL(src_node_item);
GE_IF_BOOL_EXEC(CheckHasHostMem(*src_node_item),
GELOGD("Skip d2h memcpy, get hostmem from node %s.", src_node_item->NodeName().c_str());
continue;)
src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx());
dependent_for_shape_inference.emplace(src_node);
host_input_value_dependencies_[&node_item].emplace_back(peer_out_anchor->GetIdx(), src_node_item);
GELOGD("[%s] Dependent added from output of [%s:%d]",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str(),
peer_out_anchor->GetIdx());
}
GE_CHK_STATUS_RET(ParseDependencies(node_item, dependencies, dependent_for_shape_inference));

GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item, dependent_for_shape_inference));
for (const auto &dep_node : dependent_for_shape_inference) {


+ 2
- 0
ge/hybrid/model/hybrid_model_builder.h View File

@@ -66,6 +66,8 @@ class HybridModelBuilder {
Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item);
Status CollectParallelGroups(NodeItem *node_item);
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies);
Status ParseDependencies(NodeItem &node_item, const std::vector<string> &dependencies,
std::set<NodePtr> &dependent_for_shape_inference);
Status ParseDependentForFusedSubgraph(NodeItem &node_item, std::set<ge::NodePtr> &dependencies);
Status ParseDependentByParallelGroup();
Status IndexTaskDefs();


+ 27
- 0
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -670,3 +670,30 @@ TEST_F(UtestGeHybrid, TestParseDependentInputNodesForHccl) {
ASSERT_TRUE(model.GetNodeItem(node)->has_observer);
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1);
}

TEST_F(UtestGeHybrid, TestParseDependencies) {
// make graph
ut::GraphBuilder graph_builder = ut::GraphBuilder("graph");
auto data = graph_builder.AddNode("Data", "Data", 0, 1);
auto netoutput = graph_builder.AddNode("Netoutput", "NetOutput", 1, 0);
graph_builder.AddDataEdge(data, 0, netoutput, 0);
auto graph = graph_builder.GetGraph();

GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(graph);
HybridModel model(root_model);
HybridModelBuilder builder(model);

std::unique_ptr<NodeItem> node_item;
NodeItem::Create(netoutput, node_item);

std::vector<std::string> deps;
deps.push_back("data");
auto op_desc = netoutput->GetOpDesc();
op_desc->input_name_idx_["Data"] = 0;
auto tensor = std::make_shared<GeTensor>();
auto tensor_desc = op_desc->MutableInputDesc(0);
AttrUtils::SetTensor(tensor_desc, "_value", tensor);

std::set<NodePtr> dependent_for_shape_inference;
ASSERT_EQ(builder.ParseDependencies(*node_item, deps, dependent_for_shape_inference), SUCCESS);
}

Loading…
Cancel
Save