Browse Source

modify AddGraphWithCopy

tags/v1.3.0
wuweikang 4 years ago
parent
commit
99bfc3132a
3 changed files with 60 additions and 51 deletions
  1. +28
    -51
      ge/graph/manager/graph_manager.cc
  2. +2
    -0
      ge/graph/manager/graph_manager.h
  3. +30
    -0
      tests/ut/ge/graph/manager/graph_manager_unittest.cc

+ 28
- 51
ge/graph/manager/graph_manager.cc View File

@@ -495,7 +495,7 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph,
auto compute_graph = GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);
compute_graph->SetGraphID(graph_id);
(void)AttrUtils::SetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, true);
SetSessionGraphId(compute_graph, graph_id);

if (CreateGraphNode(graph_id, graph, options) != SUCCESS) {
@@ -527,14 +527,7 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph,
return SUCCESS;
}

Status GraphManager::AddGraphWithCopy(const GraphId &graph_id, const Graph &graph,
const std::map<std::string, std::string> &options,
const OmgContext &omg_context) {
if (HasGraphNode(graph_id)) {
REPORT_INNER_ERROR("E19999", "graph_id:%u is exist, check invalid", graph_id);
GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, "[GraphManager] graph exists, graph_id = %u.", graph_id);
return GE_GRAPH_GRAPH_ALREADY_EXIST;
}
Status GraphManager::CheckGraphAdded(const GraphId &graph_id, const Graph &graph) {
auto compute_graph = GraphUtils::GetComputeGraph(graph);
if (compute_graph != nullptr) {
compute_graph->SetGraphID(graph_id);
@@ -553,58 +546,44 @@ Status GraphManager::AddGraphWithCopy(const GraphId &graph_id, const Graph &grap
GELOGE(FAILED, "compute graph is null");
return FAILED;
}
std::vector<NodePtr> input_nodes;
std::vector<NodePtr> output_nodes;
auto new_compute_graph = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes);
std::string session_graph_id;
if (!AttrUtils::GetStr(*new_compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id) ||
session_graph_id.empty()) {
session_graph_id = "-1_" + to_string(graph_id);
if (!AttrUtils::SetStr(*new_compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) {
GELOGW("Set attribute of compute graph failed.");
}
for (auto &subgraph : new_compute_graph->GetAllSubgraphs()) {
(void)AttrUtils::SetStr(*subgraph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id);
}
GELOGD("Get graph session_graph_id attr failed, set session id to default value: [0]");
}
return SUCCESS;
}

GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id);
if (graph_node == nullptr) {
REPORT_CALL_ERROR("E19999", "New GraphNode fail, graph_id:%u",
graph_id);
GELOGE(FAILED, "GraphNode make shared failed");
Status GraphManager::AddGraphWithCopy(const GraphId &graph_id, const Graph &graph,
const std::map<std::string, std::string> &options,
const OmgContext &omg_context) {
if (CheckGraphAdded(graph_id, graph) != SUCCESS) {
GELOGE(FAILED, "AddGraphWithCopy failed.");
return FAILED;
}
std::shared_ptr<Graph> graph_ptr = GraphUtils::CreateGraphPtrFromComputeGraph(new_compute_graph);
if (graph_ptr == nullptr) {
REPORT_CALL_ERROR("E19999", "New Graph fail, graph_id:%u",
graph_id);
GELOGE(FAILED, "GraphPtr make shared failed");
IncreaseGraphCount(graph_id);
// Do add graph
auto compute_graph = GraphUtils::GetComputeGraph(graph);
std::vector<NodePtr> input_nodes;
std::vector<NodePtr> output_nodes;
auto new_compute_graph = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes);
GE_CHECK_NOTNULL(new_compute_graph);
new_compute_graph->SetGraphID(graph_id);
SetSessionGraphId(new_compute_graph, graph_id);
std::shared_ptr<Graph> new_graph_ptr = GraphUtils::CreateGraphPtrFromComputeGraph(new_compute_graph);
if (CreateGraphNode(graph_id, *new_graph_ptr, options) != SUCCESS) {
GELOGE(FAILED, "Failed to create graph_node.");
return FAILED;
}
// update option about tuning graph
ParseOption(options, BUILD_MODE, options_.build_mode);
ParseOption(options, BUILD_STEP, options_.build_step);
ParseOption(options, TUNING_PATH, options_.tuning_path);

graph_node->SetGraph(graph_ptr);
graph_node->SetOptions(options);
AddGraphNode(graph_id, graph_node);

AddLocalOmgContext(graph_id, omg_context);
if (!options_.output_datatype.empty()) {
GetLocalOmgContext().output_type = options_.output_datatype;
}
if (InitDynamicParams(new_compute_graph) != SUCCESS) {
GELOGE(GRAPH_PARAM_INVALID, "Failed to init params when online infer is dynamic.");
return GRAPH_PARAM_INVALID;
}

CompilerStages &stages = GetCompilerStages(graph_id);
stages.preparer.SetOptions(options_);
Status status = stages.optimizer.SetOptions(options_);
if (status != SUCCESS) {
GELOGE(status, "Graph optimizer set options failed.");
return status;
if (SetStagesOptions(graph_id, options_) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Set stage options failed.");
return INTERNAL_ERROR;
}
stages.builder.SetOptions(options_);

var_acc_ctrl_.AddGraph(graph_id, new_compute_graph);
return SUCCESS;
@@ -1080,7 +1059,6 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std:
if (!graph_node->IsAsync()) {
ret = LoadGraph(ge_root_model, graph_node);
} else {
GE_CHECK_NOTNULL(ge_root_model);
ret = LoadGraphAsync(ge_root_model, graph_node);
}
if (ret != SUCCESS) {
@@ -1095,7 +1073,6 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std:
if (!graph_node->IsAsync()) {
ret = LoadGraph(ge_root_model_ptr, graph_node);
} else {
GE_CHECK_NOTNULL(ge_root_model);
ret = LoadGraphAsync(ge_root_model_ptr, graph_node);
}
if (ret != SUCCESS) {


+ 2
- 0
ge/graph/manager/graph_manager.h View File

@@ -413,6 +413,8 @@ class GraphManager {

void SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id);

static Status CheckGraphAdded(const GraphId &graph_id, const Graph &graph);

std::atomic_bool thread_run_flag_;
BlockingQueue<PreRunArgs> prerun_args_q_{};
BlockingQueue<RunArgs> run_args_q_{};


+ 30
- 0
tests/ut/ge/graph/manager/graph_manager_unittest.cc View File

@@ -373,3 +373,33 @@ TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_3) {
Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model);
EXPECT_NE(status, ge::SUCCESS);
}

TEST_F(UtestGraphManagerTest, test_add_graph_with_copy_success) {
GraphId graph_id = 1;
GraphManager graph_manager;
// create graph
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph");
Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);

std::map<std::string, std::string> options;
OmgContext context;
Status status = graph_manager.AddGraphWithCopy(graph_id, graph, options, context);
EXPECT_EQ(status, ge::SUCCESS);
}

TEST_F(UtestGraphManagerTest, test_add_graph_with_copy_fail) {
GraphId graph_id = 1;
GraphManager graph_manager;
// create graph
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph");
Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);

std::map<std::string, std::string> options;
OmgContext context;
Status status = graph_manager.AddGraph(graph_id, graph, options, context);
EXPECT_EQ(status, ge::SUCCESS);
status = graph_manager.RemoveGraph(graph_id);
EXPECT_EQ(status, ge::SUCCESS);
status = graph_manager.AddGraphWithCopy(graph_id, graph, options, context);
EXPECT_NE(status, ge::SUCCESS);
}

Loading…
Cancel
Save