@@ -76,7 +76,7 @@ Status CheckOptionsValid(const std::map<string, string> &options) { | |||||
} | } | ||||
// Initialize GE, prepare for execution, call GELib::Initialize | // Initialize GE, prepare for execution, call GELib::Initialize | ||||
Status GEInitializeImpl(const std::map<string, string> &options) { | |||||
Status GEInitialize(const std::map<string, string> &options) { | |||||
GELOGT(TRACE_INIT, "GEInitialize start"); | GELOGT(TRACE_INIT, "GEInitialize start"); | ||||
// 0.check init status | // 0.check init status | ||||
if (g_ge_initialized) { | if (g_ge_initialized) { | ||||
@@ -127,26 +127,6 @@ Status GEInitializeImpl(const std::map<string, string> &options) { | |||||
return ret; | return ret; | ||||
} | } | ||||
// Initialize GE, prepare for execution, call GELib::Initialize | |||||
Status GEInitialize(const std::map<string, string> &options) { | |||||
return GEInitializeImpl(options); | |||||
} | |||||
Status GEInitialize(const std::map<AscendString, AscendString> &options) { | |||||
std::map<std::string, std::string> str_options; | |||||
for (auto & option : options) { | |||||
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||||
GELOGE(FAILED, "GEInitialize options is nullptr."); | |||||
return FAILED; | |||||
} | |||||
std::string key = option.first.GetString(); | |||||
std::string val = option.second.GetString(); | |||||
str_options[key] = val; | |||||
} | |||||
return GEInitializeImpl(str_options); | |||||
} | |||||
// GE finalize, releasing all resources | // GE finalize, releasing all resources | ||||
Status GEFinalize() { | Status GEFinalize() { | ||||
GELOGT(TRACE_INIT, "GEFinalize start"); | GELOGT(TRACE_INIT, "GEFinalize start"); | ||||
@@ -222,46 +202,6 @@ Session::Session(const std::map<string, string> &options) { | |||||
GELOGT(TRACE_STOP, "Session Constructor finished"); | GELOGT(TRACE_STOP, "Session Constructor finished"); | ||||
} | } | ||||
Session::Session(const std::map<AscendString, AscendString> &options) { | |||||
GELOGT(TRACE_INIT, "Session Constructor start"); | |||||
// check init status | |||||
sessionId_ = 0; | |||||
if (!g_ge_initialized) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GE is not initialized."); | |||||
return; | |||||
} | |||||
// call Initialize | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Session Constructor failed"); | |||||
return; | |||||
} | |||||
GELOGT(TRACE_RUNNING, "Creating session"); | |||||
std::map<std::string, std::string> str_options; | |||||
for (auto &option : options) { | |||||
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||||
GELOGE(FAILED, "Session options is nullptr."); | |||||
return; | |||||
} | |||||
std::string key = option.first.GetString(); | |||||
std::string val = option.second.GetString(); | |||||
str_options[key] = val; | |||||
} | |||||
uint64_t session_id = 0; | |||||
Status ret = instance_ptr->SessionManagerObj().CreateSession(str_options, session_id); | |||||
GELOGT(TRACE_RUNNING, "Session id is %lu", session_id); | |||||
// check return status, return, update session id if success | |||||
if (ret == SUCCESS) { | |||||
sessionId_ = session_id; | |||||
} else { | |||||
GELOGE(ret, "Session constructor failed, session Id not initialized"); | |||||
return; | |||||
} | |||||
GELOGT(TRACE_STOP, "Session Constructor finished"); | |||||
} | |||||
// session destructor | // session destructor | ||||
Session::~Session() { | Session::~Session() { | ||||
GELOGT(TRACE_INIT, "Session Destructor start"); | GELOGT(TRACE_INIT, "Session Destructor start"); | ||||
@@ -320,34 +260,6 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map<s | |||||
return ret; | return ret; | ||||
} | } | ||||
Status Session::AddGraph(uint32_t graph_id, const Graph &graph, | |||||
const std::map<AscendString, AscendString> &options) { | |||||
GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_); | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "AddGraph failed in Session."); | |||||
return FAILED; | |||||
} | |||||
GELOGD("Adding graph to session"); | |||||
std::map<std::string, std::string> str_options; | |||||
for (auto &option : options) { | |||||
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||||
GELOGE(FAILED, "AddGraph options is nullptr."); | |||||
return FAILED; | |||||
} | |||||
std::string key = option.first.GetString(); | |||||
std::string val = option.second.GetString(); | |||||
str_options[key] = val; | |||||
} | |||||
Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph, str_options); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "AddGraph failed in Session."); | |||||
return FAILED; | |||||
} | |||||
GELOGD("AddGraph finished in Session."); | |||||
return ret; | |||||
} | |||||
Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph) { | Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph) { | ||||
std::map<AscendString, AscendString> options; | std::map<AscendString, AscendString> options; | ||||
return AddGraphWithCopy(graph_id, graph, options); | return AddGraphWithCopy(graph_id, graph, options); | ||||
@@ -475,14 +387,6 @@ Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc | |||||
return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | ||||
} | } | ||||
Status Session::RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback) { | |||||
std::string str_key; | |||||
if (key != nullptr) { | |||||
str_key = key; | |||||
} | |||||
return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, str_key, callback); | |||||
} | |||||
Status Session::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | Status Session::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | ||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | ||||
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | ||||
@@ -532,29 +436,6 @@ Status Session::GetVariables(const std::vector<std::string> &var_names, std::vec | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status Session::GetVariables(const std::vector<AscendString> &var_names, std::vector<Tensor> &var_values) { | |||||
auto instance_ptr = ge::GELib::GetInstance(); | |||||
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "SessionConstructor failed"); | |||||
return FAILED; | |||||
} | |||||
GELOGT(TRACE_RUNNING, "Get Variables"); | |||||
std::vector<ge::string> str_var_names; | |||||
for (auto &var_name : var_names) { | |||||
if (var_name.GetString() == nullptr) { | |||||
GELOGE(FAILED, "GetVariables name is nullptr."); | |||||
return FAILED; | |||||
} | |||||
str_var_names.emplace_back(var_name.GetString()); | |||||
} | |||||
Status ret = ge::GELib::GetInstance()->SessionManagerObj().GetVariables(sessionId_, str_var_names, var_values); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "SessionManager RunGraphAsync failed"); | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
bool Session::IsGraphNeedRebuild(uint32_t graph_id) { | bool Session::IsGraphNeedRebuild(uint32_t graph_id) { | ||||
return ge::GELib::GetInstance()->SessionManagerObj().IsGraphNeedRebuild(sessionId_, graph_id); | return ge::GELib::GetInstance()->SessionManagerObj().IsGraphNeedRebuild(sessionId_, graph_id); | ||||
} | } | ||||
@@ -1870,30 +1870,12 @@ Status GraphManager::RegisterCallBackFunc( | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphManager::RegisterCallBackFunc( | |||||
const std::string &key, | |||||
const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback) { | |||||
std::lock_guard<std::mutex> lock(member_mutex_); | |||||
GELOGI("[GraphManager] RegisterCallBackFunc, key=%s.", key.c_str()); | |||||
callback_map_[key] = callback; | |||||
return SUCCESS; | |||||
} | |||||
Status GraphManager::PushSummaryData2ME(const GraphId &graph_id, | Status GraphManager::PushSummaryData2ME(const GraphId &graph_id, | ||||
const std::map<std::string, ge::Tensor> &summary_data) { | const std::map<std::string, ge::Tensor> &summary_data) { | ||||
std::lock_guard<std::mutex> lock(member_mutex_); | std::lock_guard<std::mutex> lock(member_mutex_); | ||||
GELOGI("[GraphManager] PushSummaryData2ME, dataSize=%zu.", summary_data.size()); | GELOGI("[GraphManager] PushSummaryData2ME, dataSize=%zu.", summary_data.size()); | ||||
auto itr = me_callback_map_.find(kSummary); | auto itr = me_callback_map_.find(kSummary); | ||||
if (itr == me_callback_map_.end()) { | if (itr == me_callback_map_.end()) { | ||||
auto iter = callback_map_.find(kSummary); | |||||
if (iter != callback_map_.end()) { | |||||
std::map<AscendString, ge::Tensor> tmp_summary_data; | |||||
for (auto &data : summary_data) { | |||||
AscendString tmp(data.first.c_str()); | |||||
tmp_summary_data[tmp] = data.second; | |||||
} | |||||
return iter->second(graph_id, tmp_summary_data); | |||||
} | |||||
GELOGE(FAILED, "[GraphManager] PushSummaryData2ME failed, not found summary callback."); | GELOGE(FAILED, "[GraphManager] PushSummaryData2ME failed, not found summary callback."); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -1905,15 +1887,6 @@ Status GraphManager::PushSaveData2ME(const GraphId &graph_id, const std::map<std | |||||
GELOGI("[GraphManager] PushSaveData2ME, dataSize=%zu.", save_data.size()); | GELOGI("[GraphManager] PushSaveData2ME, dataSize=%zu.", save_data.size()); | ||||
auto itr = me_callback_map_.find(kSave); | auto itr = me_callback_map_.find(kSave); | ||||
if (itr == me_callback_map_.end()) { | if (itr == me_callback_map_.end()) { | ||||
auto iter = callback_map_.find(kSave); | |||||
if (iter != callback_map_.end()) { | |||||
std::map<AscendString, ge::Tensor> tmp_save_data; | |||||
for (auto &data : save_data) { | |||||
AscendString tmp(data.first.c_str()); | |||||
tmp_save_data[tmp] = data.second; | |||||
} | |||||
return iter->second(graph_id, tmp_save_data); | |||||
} | |||||
GELOGE(FAILED, "[GraphManager] PushSaveData2ME failed, not found checkpoint callback."); | GELOGE(FAILED, "[GraphManager] PushSaveData2ME failed, not found checkpoint callback."); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -163,10 +163,6 @@ class GraphManager { | |||||
const std::string &key, | const std::string &key, | ||||
const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | ||||
Status RegisterCallBackFunc( | |||||
const std::string &key, | |||||
const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback); | |||||
const bool GetTrainFlag() const { return options_.train_graph_flag; } | const bool GetTrainFlag() const { return options_.train_graph_flag; } | ||||
bool IsGraphNeedRebuild(uint32_t graph_id); | bool IsGraphNeedRebuild(uint32_t graph_id); | ||||
@@ -394,8 +390,6 @@ class GraphManager { | |||||
// summary and checkpoint callback function list for ME, key is summary or checkpoint | // summary and checkpoint callback function list for ME, key is summary or checkpoint | ||||
std::map<std::string, std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)>> me_callback_map_; | std::map<std::string, std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)>> me_callback_map_; | ||||
std::map<std::string, std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)>> callback_map_; | |||||
bool init_flag_; | bool init_flag_; | ||||
GraphManagerOptions options_; | GraphManagerOptions options_; | ||||
@@ -610,17 +610,11 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const | |||||
/// | /// | ||||
Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { | Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { | ||||
auto func_desc = case_node_->GetOpDesc(); | auto func_desc = case_node_->GetOpDesc(); | ||||
domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr; | |||||
auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType()); | auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType()); | ||||
if (post_func == nullptr) { | if (post_func == nullptr) { | ||||
GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(), | GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(), | ||||
case_node_->GetType().c_str()); | case_node_->GetType().c_str()); | ||||
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType(), parse_func_v2) != SUCCESS || | |||||
parse_func_v2 == nullptr) { | |||||
GELOGW("The subgraph new post func v2 for node %s type %s is null", case_node_->GetName().c_str(), | |||||
case_node_->GetType().c_str()); | |||||
return FAILED; | |||||
} | |||||
return FAILED; | |||||
} | } | ||||
for (const auto &name : func_desc->GetSubgraphInstanceNames()) { | for (const auto &name : func_desc->GetSubgraphInstanceNames()) { | ||||
@@ -635,12 +629,7 @@ Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { | |||||
"Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str()); | "Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str()); | ||||
auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph); | auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph); | ||||
Status ret = FAILED; | |||||
if (post_func != nullptr) { | |||||
ret = post_func(subgraph_name, graph); | |||||
} else if (parse_func_v2 != nullptr) { | |||||
ret = parse_func_v2(subgraph_name.c_str(), graph); | |||||
} | |||||
auto ret = post_func(subgraph_name, graph); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(), | GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(), | ||||
case_node_->GetName().c_str(), case_node_->GetType().c_str()); | case_node_->GetName().c_str(), case_node_->GetType().c_str()); | ||||
@@ -141,7 +141,7 @@ static void LoadOpsProto() { | |||||
(void)manager->Initialize(option_tmp); | (void)manager->Initialize(option_tmp); | ||||
} | } | ||||
graphStatus aclgrphBuildInitializeImpl(std::map<std::string, std::string> &global_options) { | |||||
graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options) { | |||||
GELOGD("Enter aclgrphInitialize start!"); | GELOGD("Enter aclgrphInitialize start!"); | ||||
// check global options | // check global options | ||||
if (CheckGlobalOptions(global_options) != GRAPH_SUCCESS) { | if (CheckGlobalOptions(global_options) != GRAPH_SUCCESS) { | ||||
@@ -167,24 +167,6 @@ graphStatus aclgrphBuildInitializeImpl(std::map<std::string, std::string> &globa | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options) { | |||||
return aclgrphBuildInitializeImpl(global_options); | |||||
} | |||||
graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &global_options) { | |||||
std::map<std::string, std::string> tmp_global_options; | |||||
for (auto &option : global_options) { | |||||
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "AclgrphBuildInitialize option is nullptr."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
std::string key = option.first.GetString(); | |||||
std::string val = option.second.GetString(); | |||||
tmp_global_options[key] = val; | |||||
} | |||||
return aclgrphBuildInitializeImpl(tmp_global_options); | |||||
} | |||||
void aclgrphBuildFinalize() { | void aclgrphBuildFinalize() { | ||||
if (ge::GELib::GetInstance() != nullptr && ge::GELib::GetInstance()->InitFlag()) { | if (ge::GELib::GetInstance() != nullptr && ge::GELib::GetInstance()->InitFlag()) { | ||||
(void)ge::GELib::GetInstance()->Finalize(); | (void)ge::GELib::GetInstance()->Finalize(); | ||||
@@ -471,24 +453,6 @@ graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string | |||||
return builder.BuildModel(graph, build_options, model); | return builder.BuildModel(graph, build_options, model); | ||||
} | } | ||||
graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<AscendString, AscendString> &build_options, | |||||
ModelBufferData &model) { | |||||
GELOGD("Enter aclmdlBuildModel process!"); | |||||
std::map<std::string, std::string> tmp_build_options; | |||||
for (auto &option : build_options) { | |||||
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "AclgrphBuildInitialize option is nullptr."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
std::string key = option.first.GetString(); | |||||
std::string val = option.second.GetString(); | |||||
tmp_build_options[key] = val; | |||||
} | |||||
Impl builder; | |||||
return builder.BuildModel(graph, tmp_build_options, model); | |||||
} | |||||
graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model) { | graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model) { | ||||
GELOGD("Enter aclmdlSaveModel process!"); | GELOGD("Enter aclmdlSaveModel process!"); | ||||
if (model.data.get() == nullptr || model.length == 0) { | if (model.data.get() == nullptr || model.length == 0) { | ||||
@@ -499,21 +463,6 @@ graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &m | |||||
static_cast<uint32_t>(model.length)); | static_cast<uint32_t>(model.length)); | ||||
} | } | ||||
graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &model) { | |||||
GELOGD("Enter aclmdlSaveModel process!"); | |||||
if (model.data.get() == nullptr || model.length == 0) { | |||||
GELOGE(GRAPH_PARAM_INVALID, "Input model is illegal"); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
if (output_file == nullptr) { | |||||
GELOGE(GRAPH_PARAM_INVALID, "Output file is nullptr."); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
std::string str_output_file = output_file; | |||||
return FileSaver::SaveToFile((str_output_file + ".om"), reinterpret_cast<void*>(model.data.get()), | |||||
static_cast<uint32_t>(model.length)); | |||||
} | |||||
graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version) { | graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version) { | ||||
GELOGD("Enter aclgrphGetIRVersion process!"); | GELOGD("Enter aclgrphGetIRVersion process!"); | ||||
GE_CHECK_NOTNULL(major_version); | GE_CHECK_NOTNULL(major_version); | ||||
@@ -254,25 +254,6 @@ Status InnerSession::RegisterCallBackFunc( | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status InnerSession::RegisterCallBackFunc( | |||||
const std::string &key, | |||||
const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback) { | |||||
std::lock_guard<std::mutex> lock(resource_mutex_); | |||||
if (!init_flag_) { | |||||
GELOGE(GE_SESS_INIT_FAILED, "[InnerSession:%lu] initialize failed.", session_id_); | |||||
return GE_SESS_INIT_FAILED; | |||||
} | |||||
UpdateThreadContext(std::map<std::string, std::string>{}); | |||||
Status ret = graph_manager_.RegisterCallBackFunc(key, callback); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[InnerSession:%lu] register %s callback function failed.", session_id_, key.c_str()); | |||||
return ret; | |||||
} | |||||
GELOGI("[InnerSession:%lu] register %s callback function success.", session_id_, key.c_str()); | |||||
return SUCCESS; | |||||
} | |||||
Status InnerSession::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | Status InnerSession::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | ||||
UpdateThreadContext(graph_id); | UpdateThreadContext(graph_id); | ||||
GELOGI("[InnerSession:%lu] build graph on session, graph_id=%u.", session_id_, graph_id); | GELOGI("[InnerSession:%lu] build graph on session, graph_id=%u.", session_id_, graph_id); | ||||
@@ -62,10 +62,6 @@ class InnerSession { | |||||
const std::string &key, | const std::string &key, | ||||
const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | ||||
Status RegisterCallBackFunc( | |||||
const std::string &key, | |||||
const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback); | |||||
const GraphManager &getGraphManagerObj() const; | const GraphManager &getGraphManagerObj() const; | ||||
bool IsGraphNeedRebuild(uint32_t graph_id); | bool IsGraphNeedRebuild(uint32_t graph_id); | ||||
@@ -276,26 +276,6 @@ Status SessionManager::RegisterCallBackFunc( | |||||
return innerSession->RegisterCallBackFunc(key, callback); | return innerSession->RegisterCallBackFunc(key, callback); | ||||
} | } | ||||
Status SessionManager::RegisterCallBackFunc( | |||||
SessionId session_id, const std::string &key, | |||||
const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback) { | |||||
if (!init_flag_) { | |||||
GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized."); | |||||
return GE_SESSION_MANAGER_NOT_INIT; | |||||
} | |||||
SessionPtr innerSession = nullptr; | |||||
{ | |||||
std::lock_guard<std::mutex> lock(mutex_); | |||||
std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id); | |||||
if (it == session_manager_map_.end()) { | |||||
return GE_SESSION_NOT_EXIST; | |||||
} else { | |||||
innerSession = it->second; | |||||
} | |||||
} | |||||
return innerSession->RegisterCallBackFunc(key, callback); | |||||
} | |||||
Status SessionManager::BuildGraph(SessionId session_id, uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | Status SessionManager::BuildGraph(SessionId session_id, uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) { | ||||
if (!init_flag_) { | if (!init_flag_) { | ||||
GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized."); | GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized."); | ||||
@@ -158,9 +158,6 @@ class SessionManager { | |||||
Status RegisterCallBackFunc( | Status RegisterCallBackFunc( | ||||
SessionId session_id, const std::string &key, | SessionId session_id, const std::string &key, | ||||
const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback); | ||||
Status RegisterCallBackFunc( | |||||
SessionId session_id, const std::string &key, | |||||
const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback); | |||||
bool IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id); | bool IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id); | ||||
@@ -29,26 +29,16 @@ | |||||
namespace ge { | namespace ge { | ||||
typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map<std::string, ge::Tensor> ¶ms_list); | typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map<std::string, ge::Tensor> ¶ms_list); | ||||
namespace session { | |||||
typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map<AscendString, ge::Tensor> ¶ms_list); | |||||
} | |||||
// Initialize GE | // Initialize GE | ||||
ATTRIBUTED_DEPRECATED(Status GEInitialize(const std::map<AscendString, AscendString> &)) | |||||
Status GEInitialize(const std::map<std::string, std::string> &options); | Status GEInitialize(const std::map<std::string, std::string> &options); | ||||
Status GEInitialize(const std::map<AscendString, AscendString> &options); | |||||
// Finalize GE, release all resources | // Finalize GE, release all resources | ||||
Status GEFinalize(); | Status GEFinalize(); | ||||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | ||||
public: | public: | ||||
ATTRIBUTED_DEPRECATED(Session(const std::map<AscendString, AscendString> &)) | |||||
explicit Session(const std::map<std::string, std::string> &options); | explicit Session(const std::map<std::string, std::string> &options); | ||||
explicit Session(const std::map<AscendString, AscendString> &options); | |||||
~Session(); | ~Session(); | ||||
/// | /// | ||||
@@ -67,21 +57,10 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||||
/// @param [in] options graph options | /// @param [in] options graph options | ||||
/// @return Status result of function | /// @return Status result of function | ||||
/// | /// | ||||
ATTRIBUTED_DEPRECATED(Status AddGraph(uint32_t, const Graph &, const std::map<AscendString, AscendString> &)) | |||||
Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<std::string, std::string> &options); | Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<std::string, std::string> &options); | ||||
/// | /// | ||||
/// @ingroup client | /// @ingroup client | ||||
/// @brief add a graph with a specific graphId and graphOptions | |||||
/// @param [in] graphId graph id | |||||
/// @param [in] graph the graph | |||||
/// @param [in] options graph options | |||||
/// @return Status result of function | |||||
/// | |||||
Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<AscendString, AscendString> &options); | |||||
/// | |||||
/// @ingroup client | |||||
/// @brief add a copy graph with a specific graphId | /// @brief add a copy graph with a specific graphId | ||||
/// @param [in] graphId graph id | /// @param [in] graphId graph id | ||||
/// @param [in] graph the graph | /// @param [in] graph the graph | ||||
@@ -145,20 +124,10 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||||
/// @param [out] var_values: variable values | /// @param [out] var_values: variable values | ||||
/// @return Status result of function | /// @return Status result of function | ||||
/// | /// | ||||
ATTRIBUTED_DEPRECATED(Status GetVariables(const std::vector<std::string> &, std::vector<Tensor> &)) | |||||
Status GetVariables(const std::vector<std::string> &var_names, std::vector<Tensor> &var_values); | Status GetVariables(const std::vector<std::string> &var_names, std::vector<Tensor> &var_values); | ||||
/// | /// | ||||
/// @ingroup ge_graph | /// @ingroup ge_graph | ||||
/// @brief get variables in the session with specific session id | |||||
/// @param [in] var_names: variable names | |||||
/// @param [out] var_values: variable values | |||||
/// @return Status result of function | |||||
/// | |||||
Status GetVariables(const std::vector<AscendString> &var_names, std::vector<Tensor> &var_values); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief register callback func with specific summary or checkpoint by users | /// @brief register callback func with specific summary or checkpoint by users | ||||
/// @param [in] key: func key | /// @param [in] key: func key | ||||
/// @param [in] callback: callback specific summary or checkpoint. | /// @param [in] callback: callback specific summary or checkpoint. | ||||
@@ -166,11 +135,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||||
/// Please ensure that the implementation of the function is trusted. | /// Please ensure that the implementation of the function is trusted. | ||||
/// @return Status result of function | /// @return Status result of function | ||||
/// | /// | ||||
ATTRIBUTED_DEPRECATED(Status RegisterCallBackFunc(const char *, const session::pCallBackFunc &)) | |||||
Status RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback); | Status RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback); | ||||
Status RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback); | |||||
bool IsGraphNeedRebuild(uint32_t graphId); | bool IsGraphNeedRebuild(uint32_t graphId); | ||||
private: | private: | ||||
@@ -20,15 +20,8 @@ | |||||
#include <map> | #include <map> | ||||
#include <string> | #include <string> | ||||
#include "ge_error_codes.h" | #include "ge_error_codes.h" | ||||
#include "graph/ascend_string.h" | |||||
namespace ge { | namespace ge { | ||||
#ifdef __GNUC__ | |||||
#define ATTRIBUTED_DEPRECATED(replacement) __attribute__((deprecated("Please use " #replacement " instead."))) | |||||
#else | |||||
#define ATTRIBUTED_DEPRECATED(replacement) __declspec(deprecated("Please use " #replacement " instead.")) | |||||
#endif | |||||
class StatusFactory { | class StatusFactory { | ||||
public: | public: | ||||
static StatusFactory *Instance() { | static StatusFactory *Instance() { | ||||
@@ -44,17 +37,6 @@ class StatusFactory { | |||||
err_desc_[err] = desc; | err_desc_[err] = desc; | ||||
} | } | ||||
void RegisterErrorNo(uint32_t err, const char *desc) { | |||||
if (desc == nullptr) { | |||||
return; | |||||
} | |||||
std::string error_desc = desc; | |||||
if (err_desc_.find(err) != err_desc_.end()) { | |||||
return; | |||||
} | |||||
err_desc_[err] = error_desc; | |||||
} | |||||
std::string GetErrDesc(uint32_t err) { | std::string GetErrDesc(uint32_t err) { | ||||
auto iter_find = err_desc_.find(err); | auto iter_find = err_desc_.find(err); | ||||
if (iter_find == err_desc_.end()) { | if (iter_find == err_desc_.end()) { | ||||
@@ -63,13 +45,6 @@ class StatusFactory { | |||||
return iter_find->second; | return iter_find->second; | ||||
} | } | ||||
void GetErrDesc(uint32_t err, AscendString &err_desc) { | |||||
auto iter_find = err_desc_.find(err); | |||||
if (iter_find != err_desc_.end()) { | |||||
err_desc = AscendString((iter_find->second).c_str()); | |||||
} | |||||
} | |||||
protected: | protected: | ||||
StatusFactory() {} | StatusFactory() {} | ||||
~StatusFactory() {} | ~StatusFactory() {} | ||||
@@ -81,7 +56,6 @@ class StatusFactory { | |||||
class ErrorNoRegisterar { | class ErrorNoRegisterar { | ||||
public: | public: | ||||
ErrorNoRegisterar(uint32_t err, const std::string &desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); } | ErrorNoRegisterar(uint32_t err, const std::string &desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); } | ||||
ErrorNoRegisterar(uint32_t err, const char *desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); } | |||||
~ErrorNoRegisterar() {} | ~ErrorNoRegisterar() {} | ||||
}; | }; | ||||
@@ -65,47 +65,7 @@ const char *const OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION = "ge.exec.isTailingOp | |||||
// Option key: memory init | // Option key: memory init | ||||
const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; | const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; | ||||
const char *const VARIABLE_MEMORY_MAX_SIZE = "ge.variableMemoryMaxSize"; | const char *const VARIABLE_MEMORY_MAX_SIZE = "ge.variableMemoryMaxSize"; | ||||
namespace configure_option { | |||||
const char *const STREAM_NUM = "ge.streamNum"; | |||||
const char *const HEAD_STREAM = "ge.headStream"; | |||||
const char *const PERF_LEVEL = "ge.perfLevel"; | |||||
const char *const ENCRYPT_MODE = "ge.encryptMode"; | |||||
const char *const EK_FILE = "ge.ekFile"; | |||||
const char *const CERT_FILE = "ge.certFile"; | |||||
const char *const HW_KEY_FILE = "ge.hwKeyFile"; | |||||
const char *const PRIVATE_KEY_FILE = "ge.privateKeyFile"; | |||||
const char *const FRAMEWORK_TYPE = "ge.frameworkType"; | |||||
const char *const CALIBRATION_CONF_FILE = "ge.calibrationConfFile"; | |||||
const char *const INSERT_OP_FILE = "ge.insertOpFile"; | |||||
const char *const OUTPUT_NODE_NAME = "ge.outputNodeName"; | |||||
const char *const COMPRESS_FLAG = "ge.compressFlag"; | |||||
const char *const PRECISION_MODE = "ge.exec.precision_mode"; | |||||
const char *const SINGLE_OP_FLAG = "ge.exec.single_op"; | |||||
const char *const TRAIN_FLAG = "ge.trainFlag"; | |||||
const char *const RUN_FLAG = "ge.runFlag"; | |||||
const char *const LOCAL_FMKOP_FLAG = "ge.enabledLocalFmkop"; | |||||
const char *const TBE_PLUGIN_PATH_FLAG = "ge.TBE_plugin_path"; | |||||
const char *const DDK_VERSION_FLAG = "ge.DDK_version"; | |||||
const char *const GE_FE_FLAG = "ge.feFlag"; | |||||
const char *const STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; | |||||
const char *const OUTPUT_DATATYPE = "ge.outputDatatype"; | |||||
const char *const OP_SELECT_IMPL_MODE = "ge.opSelectImplmode"; | |||||
const char *const OPTYPELIST_FOR_IMPLMODE = "ge.optypelistForImplmode"; | |||||
const char *const HCOM_PARALLEL = "ge.hcomParallel"; | |||||
const char *const AUTO_TUNE_MODE = "ge.autoTuneMode"; | |||||
const char *const SOC_VERSION = "ge.socVersion"; | |||||
const char *const CORE_TYPE = "ge.engineType"; | |||||
const char *const AICORE_NUM = "ge.aicoreNum"; | |||||
const char *const L1_FUSION = "ge.l1Fusion"; | |||||
const char *const BUFFER_OPTIMIZE = "ge.bufferOptimize"; | |||||
const char *const ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; | |||||
const char *const ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight"; | |||||
const char *const FUSION_SWITCH_FILE = "ge.fusionSwitchFile"; | |||||
const char *const SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; | |||||
const char *const ORIGINAL_MODEL_FILE = "ge.originalModelFile"; | |||||
const char *const INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; | |||||
const char *const OP_DEBUG_LEVEL = "ge.opDebugLevel"; | |||||
} // namespace configure_option | |||||
// Configure stream num by Session constructor options param, | // Configure stream num by Session constructor options param, | ||||
// its value should be int32_t type, default value is "1" | // its value should be int32_t type, default value is "1" | ||||
const std::string STREAM_NUM = "ge.streamNum"; | const std::string STREAM_NUM = "ge.streamNum"; | ||||
@@ -44,11 +44,8 @@ struct ModelBufferData { | |||||
* @retval GRAPH_SUCCESS The function is successfully executed. | * @retval GRAPH_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ATTRIBUTED_DEPRECATED(graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &)) | |||||
graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options); | graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options); | ||||
graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &global_options); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
* @brief build model.Notice the model is stored in buffer | * @brief build model.Notice the model is stored in buffer | ||||
@@ -66,14 +63,9 @@ void aclgrphBuildFinalize(); | |||||
* @retval GRAPH_SUCCESS The function is successfully executed. | * @retval GRAPH_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ATTRIBUTED_DEPRECATED(graphStatus aclgrphBuildModel(const ge::Graph &, const std::map<AscendString, AscendString> &, | |||||
ModelBufferData &)) | |||||
graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options, | graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options, | ||||
ModelBufferData &model); | ModelBufferData &model); | ||||
graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<AscendString, AscendString> &build_options, | |||||
ModelBufferData &model); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
* @brief save model buffer to file | * @brief save model buffer to file | ||||
@@ -83,11 +75,8 @@ graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<AscendStrin | |||||
* @retval GRAPH_SUCCESS The function is successfully executed. | * @retval GRAPH_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ATTRIBUTED_DEPRECATED(graphStatus aclgrphSaveModel(const char *, const ModelBufferData &)) | |||||
graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model); | graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model); | ||||
graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &model); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
* @brief query IR interface version | * @brief query IR interface version | ||||
@@ -121,5 +110,6 @@ graphStatus aclgrphInferShapeAndType(ge::Graph &graph); | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len); | graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len); | ||||
}; // namespace ge | |||||
}; // namespace ge | |||||
#endif // INC_EXTERNAL_GE_IR_BUILD_H_ | #endif // INC_EXTERNAL_GE_IR_BUILD_H_ |