Browse Source

!230 AddGraphWithCopy

From: @HW_KK
Reviewed-by: @ji_chen,@youui,@ji_chen
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
a23a4c2b58
8 changed files with 191 additions and 1 deletions
  1. +27
    -0
      ge/client/ge_api.cc
  2. +72
    -0
      ge/graph/manager/graph_manager.cc
  3. +10
    -0
      ge/graph/manager/graph_manager.h
  4. +18
    -0
      ge/session/inner_session.cc
  5. +2
    -0
      ge/session/inner_session.h
  6. +30
    -0
      ge/session/session_manager.cc
  7. +13
    -1
      ge/session/session_manager.h
  8. +19
    -0
      inc/external/ge/ge_api.h

+ 27
- 0
ge/client/ge_api.cc View File

@@ -260,6 +260,33 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map<s
return ret; return ret;
} }


Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph) {
std::map<AscendString, AscendString> options;
return AddGraphWithCopy(graph_id, graph, options);
}

Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph,
const std::map<AscendString, AscendString> &options) {
GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_);
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "AddGraph failed in Session.");
return FAILED;
}
std::map<std::string, std::string> str_options;
for (auto it = options.begin(); it != options.end(); ++it) {
str_options.insert({it->first.GetString(), it->second.GetString()});
}
GELOGD("Adding graph to session");
Status ret = instance_ptr->SessionManagerObj().AddGraphWithCopy(sessionId_, graph_id, graph, str_options);
if (ret != SUCCESS) {
GELOGE(ret, "AddGraph failed in Session.");
return FAILED;
}
GELOGD("AddGraph finished in Session.");
return ret;
}

Status Session::RemoveGraph(uint32_t graph_id) { Status Session::RemoveGraph(uint32_t graph_id) {
GELOGT(TRACE_INIT, "Session RemoveGraph start"); GELOGT(TRACE_INIT, "Session RemoveGraph start");




+ 72
- 0
ge/graph/manager/graph_manager.cc View File

@@ -354,6 +354,78 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph,
return SUCCESS; return SUCCESS;
} }


Status GraphManager::AddGraphWithCopy(const GraphId &graph_id, const Graph &graph,
const std::map<std::string, std::string> &options,
const OmgContext &omg_context) {
if (HasGraphNode(graph_id)) {
GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, "[GraphManager] graph exists, graph_id = %u.", graph_id);
return GE_GRAPH_GRAPH_ALREADY_EXIST;
}
auto compute_graph = GraphUtils::GetComputeGraph(graph);
if (compute_graph != nullptr) {
compute_graph->SetGraphID(graph_id);
bool graph_has_been_added = false;
if (AttrUtils::GetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, graph_has_been_added)
&& graph_has_been_added) {
GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST,
"[GraphManager] same graph object can not be added again, graph_id = %u.", graph_id);
return GE_GRAPH_GRAPH_ALREADY_EXIST;
}
} else {
GELOGE(FAILED, "compute graph is null");
return FAILED;
}
std::vector<NodePtr> input_nodes;
std::vector<NodePtr> output_nodes;
auto new_compute_graph = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes);
std::string session_graph_id;
if (!AttrUtils::GetStr(*new_compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id) ||
session_graph_id.empty()) {
session_graph_id = "-1_" + to_string(graph_id);
if (!AttrUtils::SetStr(*new_compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) {
GELOGW("Set attribute of compute graph failed.");
}
for (auto &subgraph : new_compute_graph->GetAllSubgraphs()) {
(void)AttrUtils::SetStr(*subgraph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id);
}
GELOGW("Get graph session_graph_id attr failed, set session id to default value: [0]");
}

GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id);
if (graph_node == nullptr) {
GELOGE(FAILED, "GraphNode make shared failed");
return FAILED;
}
std::shared_ptr<Graph> graph_ptr = GraphUtils::CreateGraphPtrFromComputeGraph(new_compute_graph);
if (graph_ptr == nullptr) {
GELOGE(FAILED, "GraphPtr make shared failed");
return FAILED;
}

graph_node->SetGraph(graph_ptr);
graph_node->SetOptions(options);
AddGraphNode(graph_id, graph_node);

AddLocalOmgContext(graph_id, omg_context);
if (!options_.output_datatype.empty()) {
GetLocalOmgContext().output_type = options_.output_datatype;
}

CompilerStages &stages = GetCompilerStages(graph_id);
stages.preparer.SetOptions(options_);
Status status = stages.optimizer.SetOptions(options_);
if (status != SUCCESS) {
GELOGE(status, "Graph optimizer set options failed.");
return status;
}
stages.builder.SetOptions(options_);

var_acc_ctrl_.AddGraph(graph_id, new_compute_graph);

GELOGI("[GraphManager] add graph success, graph_id = %u.", graph_id);
return SUCCESS;
}

Status GraphManager::MergeSubGraph(ComputeGraphPtr &compute_graph, const ge::ComputeGraphPtr &original_compute_graph, Status GraphManager::MergeSubGraph(ComputeGraphPtr &compute_graph, const ge::ComputeGraphPtr &original_compute_graph,
GraphId root_graph_id) { GraphId root_graph_id) {
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();


+ 10
- 0
ge/graph/manager/graph_manager.h View File

@@ -75,6 +75,16 @@ class GraphManager {


/// ///
/// @ingroup ge_graph /// @ingroup ge_graph
/// @brief add a copy graph
/// @param [in] graph_id graph id
/// @param [out] Graph output graph
/// @return Status result of function
///
Status AddGraphWithCopy(const GraphId &graph_id, const Graph &graph,
const std::map<std::string, std::string> &options, const OmgContext &omg_context);

///
/// @ingroup ge_graph
/// @brief remove specific graph /// @brief remove specific graph
/// @param [in] graph_id graph id /// @param [in] graph_id graph id
/// @return Status result of function /// @return Status result of function


+ 18
- 0
ge/session/inner_session.cc View File

@@ -166,6 +166,24 @@ Status InnerSession::AddGraph(uint32_t graph_id, const Graph &graph,
return SUCCESS; return SUCCESS;
} }


Status InnerSession::AddGraphWithCopy(uint32_t graph_id, const Graph &graph,
const std::map<std::string, std::string> &options) {
std::lock_guard<std::mutex> lock(resource_mutex_);
if (!init_flag_) {
GELOGE(GE_SESS_INIT_FAILED, "[InnerSession:%lu] initialize failed.", session_id_);
return GE_SESS_INIT_FAILED;
}
UpdateThreadContext(options);
Status ret = graph_manager_.AddGraphWithCopy(graph_id, graph, options, domi::GetContext());
if (ret != SUCCESS) {
GELOGE(ret, "[InnerSession:%lu] add graph %u failed.", session_id_, graph_id);
return ret;
}

GELOGI("[InnerSession:%lu] add graph success, graph_id=%u.", session_id_, graph_id);
return SUCCESS;
}

Status InnerSession::RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, std::vector<Tensor> &outputs) { Status InnerSession::RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, std::vector<Tensor> &outputs) {
GELOGI("[InnerSession:%lu] run graph on session, graph_id=%u.", session_id_, graph_id); GELOGI("[InnerSession:%lu] run graph on session, graph_id=%u.", session_id_, graph_id);
if (mutex_.try_lock()) { if (mutex_.try_lock()) {


+ 2
- 0
ge/session/inner_session.h View File

@@ -37,6 +37,8 @@ class InnerSession {


Status AddGraph(uint32_t graph_id, const Graph &graph, const std::map<std::string, std::string> &options); Status AddGraph(uint32_t graph_id, const Graph &graph, const std::map<std::string, std::string> &options);


Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph, const std::map<std::string, std::string> &options);

Status RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, std::vector<Tensor> &outputs); Status RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, std::vector<Tensor> &outputs);


Status RemoveGraph(uint32_t graph_id); Status RemoveGraph(uint32_t graph_id);


+ 30
- 0
ge/session/session_manager.cc View File

@@ -170,6 +170,36 @@ Status SessionManager::AddGraph(SessionId session_id, uint32_t graph_id, const G
return innerSession->AddGraph(graph_id, graph, options); return innerSession->AddGraph(graph_id, graph, options);
} }


Status SessionManager::AddGraphWithCopy(SessionId session_id, uint32_t graph_id, const Graph &graph,
const std::map<std::string, std::string> &options) {
if (!init_flag_) {
GELOGE(GE_SESSION_MANAGER_NOT_INIT);
return GE_SESSION_MANAGER_NOT_INIT;
}
SessionPtr innerSession = nullptr;
{
std::lock_guard<std::mutex> lock(mutex_);
std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id);
if (it == session_manager_map_.end()) {
return GE_SESSION_NOT_EXIST;
} else {
innerSession = it->second;
}
auto compute_graph = GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);
std::string session_graph_id = std::to_string(session_id) + "_" + std::to_string(graph_id);
if (!AttrUtils::SetStr(*compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) {
GELOGW("Set graph session_graph_id attr failed.");
} else {
GELOGD("Set graph session_graph_id attr to [%s]", session_graph_id.c_str());
}
for (auto graph : compute_graph->GetAllSubgraphs()) {
AttrUtils::SetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id);
}
}
return innerSession->AddGraphWithCopy(graph_id, graph, options);
}

Status SessionManager::RunGraph(SessionId session_id, uint32_t graph_id, const std::vector<Tensor> &inputs, Status SessionManager::RunGraph(SessionId session_id, uint32_t graph_id, const std::vector<Tensor> &inputs,
std::vector<Tensor> &outputs) { std::vector<Tensor> &outputs) {
if (!init_flag_) { if (!init_flag_) {


+ 13
- 1
ge/session/session_manager.h View File

@@ -62,7 +62,7 @@ class SessionManager {


/// ///
/// @ingroup ge_session /// @ingroup ge_session
/// @brief add a graph to the session with specific session id
/// @brief add a graph to the session with specific session id and graphOptions
/// @param [in] session_id session id /// @param [in] session_id session id
/// @param [in] graph_id graph id /// @param [in] graph_id graph id
/// @param [in] graph the graph to add /// @param [in] graph the graph to add
@@ -74,6 +74,18 @@ class SessionManager {


/// ///
/// @ingroup ge_session /// @ingroup ge_session
/// @brief add a copy graph to the session with specific session id and graphOptions
/// @param [in] session_id session id
/// @param [in] graph_id graph id
/// @param [in] graph the graph to add
/// @param [in] options graph level options
/// @return Status result of function
///
Status AddGraphWithCopy(SessionId session_id, uint32_t graph_id, const Graph &graph,
const std::map<std::string, std::string> &options);

///
/// @ingroup ge_session
/// @brief run a graph of the session with specific session id /// @brief run a graph of the session with specific session id
/// @param [in] session_id session id /// @param [in] session_id session id
/// @param [in] graph_id graph id /// @param [in] graph_id graph id


+ 19
- 0
inc/external/ge/ge_api.h View File

@@ -60,6 +60,25 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session {
Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<std::string, std::string> &options); Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<std::string, std::string> &options);


/// ///
/// @ingroup client
/// @brief add a copy graph with a specific graphId
/// @param [in] graphId graph id
/// @param [in] graph the graph
/// @return Status result of function
///
Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph);

///
/// @ingroup client
/// @brief add a copy graph with a specific graphId and graphOptions
/// @param [in] graphId graph id
/// @param [in] graph the graph
/// @param [in] options graph options
/// @return Status result of function
///
Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph, const std::map<AscendString, AscendString> &options);

///
/// @ingroup ge_graph /// @ingroup ge_graph
/// @brief remove a graph of the session with specific session id /// @brief remove a graph of the session with specific session id
/// @param [in] graphId graph id /// @param [in] graphId graph id


Loading…
Cancel
Save