From c5e9c669075d7b0dadba25d542055f7bcdcd796e Mon Sep 17 00:00:00 2001 From: zhengyuanhua Date: Thu, 12 Nov 2020 20:04:41 +0800 Subject: [PATCH] external interface modify: string change to ascendstring --- ge/client/ge_api.cc | 121 +++++++++++++++++++++++++++++- ge/common/CMakeLists.txt | 2 + ge/common/ge_common.mk | 12 ++- ge/executor/CMakeLists.txt | 1 + ge/executor/module.mk | 8 +- ge/graph/manager/graph_manager.cc | 27 +++++++ ge/graph/manager/graph_manager.h | 6 ++ ge/graph/passes/data_pass.cc | 16 +++- ge/graph/passes/multi_batch_clone_pass.cc | 15 +++- ge/ir_build/ge_ir_build.cc | 53 ++++++++++++- ge/offline/CMakeLists.txt | 1 + ge/offline/module.mk | 2 +- ge/session/inner_session.cc | 19 +++++ ge/session/inner_session.h | 4 + ge/session/session_manager.cc | 20 +++++ ge/session/session_manager.h | 3 + inc/external/ge/ge_api.h | 34 +++++++++ inc/external/ge/ge_api_error_codes.h | 26 +++++++ inc/external/ge/ge_api_types.h | 42 ++++++++++- inc/external/ge/ge_ir_build.h | 11 +++ 20 files changed, 406 insertions(+), 17 deletions(-) diff --git a/ge/client/ge_api.cc b/ge/client/ge_api.cc index 522985fa..844cacca 100644 --- a/ge/client/ge_api.cc +++ b/ge/client/ge_api.cc @@ -76,7 +76,7 @@ Status CheckOptionsValid(const std::map &options) { } // Initialize GE, prepare for execution, call GELib::Initialize -Status GEInitialize(const std::map &options) { +Status GEInitializeImpl(const std::map &options) { GELOGT(TRACE_INIT, "GEInitialize start"); // 0.check init status if (g_ge_initialized) { @@ -127,6 +127,26 @@ Status GEInitialize(const std::map &options) { return ret; } +// Initialize GE, prepare for execution, call GELib::Initialize +Status GEInitialize(const std::map &options) { + return GEInitializeImpl(options); +} + +Status GEInitialize(const std::map &options) { + std::map str_options; + for (auto & option : options) { + if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { + GELOGE(FAILED, "GEInitialize options is nullptr."); + return FAILED; + } + std::string key = option.first.GetString(); + std::string val = option.second.GetString(); + str_options[key] = val; + } + return GEInitializeImpl(str_options); +} + + // GE finalize, releasing all resources Status GEFinalize() { GELOGT(TRACE_INIT, "GEFinalize start"); @@ -202,6 +222,46 @@ Session::Session(const std::map &options) { GELOGT(TRACE_STOP, "Session Constructor finished"); } +Session::Session(const std::map &options) { + GELOGT(TRACE_INIT, "Session Constructor start"); + // check init status + sessionId_ = 0; + if (!g_ge_initialized) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED); + return; + } + // call Initialize + std::shared_ptr instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Session Constructor failed"); + return; + } + + GELOGT(TRACE_RUNNING, "Creating session"); + std::map str_options; + for (auto &option : options) { + if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { + GELOGE(FAILED, "Session options is nullptr."); + return; + } + std::string key = option.first.GetString(); + std::string val = option.second.GetString(); + str_options[key] = val; + } + uint64_t session_id = 0; + Status ret = instance_ptr->SessionManagerObj().CreateSession(str_options, session_id); + GELOGT(TRACE_RUNNING, "Session id is %lu", session_id); + + // check return status, return, update session id if success + if (ret == SUCCESS) { + sessionId_ = session_id; + } else { + GELOGE(ret, "Session constructor failed, session Id not initialized"); + return; + } + GELOGT(TRACE_STOP, "Session Constructor finished"); +} + // session destructor Session::~Session() { GELOGT(TRACE_INIT, "Session Destructor start"); @@ -260,6 +320,34 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map &options) { + GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_); + std::shared_ptr instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "AddGraph failed in Session."); + return FAILED; + } + GELOGD("Adding graph to session"); + std::map str_options; + for (auto &option : options) { + if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { + GELOGE(FAILED, "AddGraph options is nullptr."); + return FAILED; + } + std::string key = option.first.GetString(); + std::string val = option.second.GetString(); + str_options[key] = val; + } + Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph, str_options); + if (ret != SUCCESS) { + GELOGE(ret, "AddGraph failed in Session."); + return FAILED; + } + GELOGD("AddGraph finished in Session."); + return ret; +} + Status Session::RemoveGraph(uint32_t graph_id) { GELOGT(TRACE_INIT, "Session RemoveGraph start"); @@ -360,6 +448,14 @@ Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); } +Status Session::RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback) { + std::string str_key; + if (key != nullptr) { + str_key = key; + } + return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, str_key, callback); +} + Status Session::BuildGraph(uint32_t graph_id, const std::vector &inputs) { std::shared_ptr instance_ptr = ge::GELib::GetInstance(); if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { @@ -409,6 +505,29 @@ Status Session::GetVariables(const std::vector &var_names, std::vec return SUCCESS; } +Status Session::GetVariables(const std::vector &var_names, std::vector &var_values) { + auto instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "SessionConstructor failed"); + return FAILED; + } + GELOGT(TRACE_RUNNING, "Get Variables"); + std::vector str_var_names; + for (auto &var_name : var_names) { + if (var_name.GetString() == nullptr) { + GELOGE(FAILED, "GetVariables name is nullptr."); + return FAILED; + } + str_var_names.emplace_back(var_name.GetString()); + } + Status ret = ge::GELib::GetInstance()->SessionManagerObj().GetVariables(sessionId_, str_var_names, var_values); + if (ret != SUCCESS) { + GELOGE(ret, "SessionManager RunGraphAsync failed"); + return FAILED; + } + return SUCCESS; +} + bool Session::IsGraphNeedRebuild(uint32_t graph_id) { return ge::GELib::GetInstance()->SessionManagerObj().IsGraphNeedRebuild(sessionId_, graph_id); } diff --git a/ge/common/CMakeLists.txt b/ge/common/CMakeLists.txt index f068baff..54cdd130 100755 --- a/ge/common/CMakeLists.txt +++ b/ge/common/CMakeLists.txt @@ -77,6 +77,7 @@ target_compile_options(ge_common PRIVATE -fvisibility=hidden -O2 -Werror + -wno-deprecated-declarations ) target_include_directories(ge_common PRIVATE @@ -131,6 +132,7 @@ target_compile_options(ge_common_static PRIVATE -fvisibility=hidden -O2 -Werror + -wno-deprecated-declarations ) target_include_directories(ge_common_static PRIVATE diff --git a/ge/common/ge_common.mk b/ge/common/ge_common.mk index 45ee1057..a462daa4 100755 --- a/ge/common/ge_common.mk +++ b/ge/common/ge_common.mk @@ -81,8 +81,9 @@ include $(CLEAR_VARS) LOCAL_MODULE := libge_common -LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP +LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP -Wno-deprecated-declarations LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 -Dgoogle=ascend_private + ifeq ($(DEBUG), 1) LOCAL_CFLAGS += -g -O0 else @@ -122,8 +123,9 @@ include $(CLEAR_VARS) LOCAL_MODULE := libge_common -LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP +LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP -Wno-deprecated-declarations LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 -Dgoogle=ascend_private + ifeq ($(DEBUG), 1) LOCAL_CFLAGS += -g -O0 else @@ -168,8 +170,9 @@ include $(CLEAR_VARS) LOCAL_MODULE := libge_common -LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP +LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP -Wno-deprecated-declarations LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 -Dgoogle=ascend_private + ifeq ($(DEBUG), 1) LOCAL_CFLAGS += -g -O0 endif @@ -210,8 +213,9 @@ include $(CLEAR_VARS) LOCAL_MODULE := libge_common -LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP +LOCAL_CFLAGS += -Werror -DFMK_SUPPORT_DUMP -Wno-deprecated-declarations LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 -Dgoogle=ascend_private + ifeq ($(DEBUG), 1) LOCAL_CFLAGS += -g -O0 endif diff --git a/ge/executor/CMakeLists.txt b/ge/executor/CMakeLists.txt index bba1927f..f30b2a51 100755 --- a/ge/executor/CMakeLists.txt +++ b/ge/executor/CMakeLists.txt @@ -80,6 +80,7 @@ add_library(ge_executor STATIC ${SRC_LIST} ${PROTO_HDRS}) target_compile_options(ge_executor PRIVATE -Werror -O2 + -Wno-deprecated-declarations ) target_compile_definitions(ge_executor PRIVATE diff --git a/ge/executor/module.mk b/ge/executor/module.mk index eaa611d2..dd352c14 100755 --- a/ge/executor/module.mk +++ b/ge/executor/module.mk @@ -100,7 +100,7 @@ local_ge_executor_ldflags := -lrt -ldl \ include $(CLEAR_VARS) LOCAL_MODULE := libge_executor -LOCAL_CFLAGS += -Werror +LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 -DDAVINCI_SUPPORT_PROFILING -Dgoogle=ascend_private LOCAL_SRC_FILES := $(local_ge_executor_src_files) @@ -126,7 +126,7 @@ include $(BUILD_SHARED_LIBRARY) include $(CLEAR_VARS) LOCAL_MODULE := libge_executor -LOCAL_CFLAGS += -Werror +LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DDAVINCI_SUPPORT_PROFILING -Dgoogle=ascend_private ifeq ($(DEBUG), 1) LOCAL_CFLAGS += -g -O0 @@ -162,7 +162,7 @@ include $(BUILD_HOST_SHARED_LIBRARY) include $(CLEAR_VARS) LOCAL_MODULE := libge_executor -LOCAL_CFLAGS += -Werror +LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DDAVINCI_SUPPORT_PROFILING -Dgoogle=ascend_private ifeq ($(DEBUG), 1) LOCAL_CFLAGS += -g -O0 @@ -195,7 +195,7 @@ include $(BUILD_HOST_STATIC_LIBRARY) include $(CLEAR_VARS) LOCAL_MODULE := libge_executor -LOCAL_CFLAGS += -Werror +LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DDAVINCI_SUPPORT_PROFILING -Dgoogle=ascend_private ifeq ($(DEBUG), 1) LOCAL_CFLAGS += -g -O0 diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 282cd7a6..e698371b 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -1730,12 +1730,30 @@ Status GraphManager::RegisterCallBackFunc( return SUCCESS; } +Status GraphManager::RegisterCallBackFunc( + const std::string &key, + const std::function &)> &callback) { + std::lock_guard lock(member_mutex_); + GELOGI("[GraphManager] RegisterCallBackFunc, key=%s.", key.c_str()); + callback_map_[key] = callback; + return SUCCESS; +} + Status GraphManager::PushSummaryData2ME(const GraphId &graph_id, const std::map &summary_data) { std::lock_guard lock(member_mutex_); GELOGI("[GraphManager] PushSummaryData2ME, dataSize=%zu.", summary_data.size()); auto itr = me_callback_map_.find(kSummary); if (itr == me_callback_map_.end()) { + auto iter = callback_map_.find(kSummary); + if (iter != callback_map_.end()) { + std::map tmp_summary_data; + for (auto &data : summary_data) { + AscendString tmp(data.first.c_str()); + tmp_summary_data[tmp] = data.second; + } + return iter->second(graph_id, tmp_summary_data); + } GELOGE(FAILED, "[GraphManager] PushSummaryData2ME failed, not found summary callback."); return FAILED; } @@ -1747,6 +1765,15 @@ Status GraphManager::PushSaveData2ME(const GraphId &graph_id, const std::map tmp_save_data; + for (auto &data : save_data) { + AscendString tmp(data.first.c_str()); + tmp_save_data[tmp] = data.second; + } + return iter->second(graph_id, tmp_save_data); + } GELOGE(FAILED, "[GraphManager] PushSaveData2ME failed, not found checkpoint callback."); return FAILED; } diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index fc3601af..d48a2c0f 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -152,6 +152,10 @@ class GraphManager { const std::string &key, const std::function &)> &callback); + Status RegisterCallBackFunc( + const std::string &key, + const std::function &)> &callback); + const bool GetTrainFlag() const { return options_.train_graph_flag; } bool IsGraphNeedRebuild(uint32_t graph_id); @@ -373,6 +377,8 @@ class GraphManager { // summary and checkpoint callback function list for ME, key is summary or checkpoint std::map &)>> me_callback_map_; + std::map &)>> callback_map_; + bool init_flag_; GraphManagerOptions options_; diff --git a/ge/graph/passes/data_pass.cc b/ge/graph/passes/data_pass.cc index 38688848..9ec1a729 100644 --- a/ge/graph/passes/data_pass.cc +++ b/ge/graph/passes/data_pass.cc @@ -62,16 +62,26 @@ Status DataPass::Run(ComputeGraphPtr compute_graph) { node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); } - + domi::ParseSubgraphFuncV1 parse_subgraph = nullptr; auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(parent_node->GetType()); if (post_func == nullptr) { GELOGW("The subgraph post func for node %s type %s is null.", parent_node->GetName().c_str(), parent_node->GetType().c_str()); - return SUCCESS; + if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(parent_node->GetType(), parse_subgraph) != SUCCESS || + parse_subgraph == nullptr) { + GELOGW("The subgraph new post func for node[%s] type [%s] is null", + parent_node->GetName().c_str(), parent_node->GetType().c_str()); + return SUCCESS; + } } auto graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); - auto ret = post_func(subgraph_name, graph); + Status ret = FAILED; + if (post_func != nullptr) { + ret = post_func(subgraph_name, graph); + } else if (parse_subgraph != nullptr) { + ret = parse_subgraph(subgraph_name.c_str(), graph); + } if (ret != SUCCESS) { GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(), parent_node->GetName().c_str(), parent_node->GetType().c_str()); diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index 732844e5..36aa8a5c 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -610,11 +610,17 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const /// Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { auto func_desc = case_node_->GetOpDesc(); + domi::ParseSubgraphFuncV1 parse_subgraph = nullptr; auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType()); if (post_func == nullptr) { GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(), case_node_->GetType().c_str()); - return FAILED; + if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType(), parse_subgraph) != SUCCESS || + parse_subgraph == nullptr) { + GELOGW("The subgraph new post func for node[%s] type [%s] is null", case_node_->GetName().c_str(), + case_node_->GetType().c_str()); + return FAILED; + } } for (const auto &name : func_desc->GetSubgraphInstanceNames()) { @@ -629,7 +635,12 @@ Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { "Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str()); auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph); - auto ret = post_func(subgraph_name, graph); + Status ret = FAILED; + if (post_func != nullptr) { + ret = post_func(subgraph_name, graph); + } else if (parse_subgraph != nullptr) { + ret = parse_subgraph(subgraph_name.c_str(), graph); + } if (ret != SUCCESS) { GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(), case_node_->GetName().c_str(), case_node_->GetType().c_str()); diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index b845bd75..03f919ea 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -109,7 +109,7 @@ static graphStatus CheckGlobalOptions(std::map &global return GRAPH_SUCCESS; } -graphStatus aclgrphBuildInitialize(std::map global_options) { +graphStatus aclgrphBuildInitializeImpl(std::map &global_options) { GELOGD("Enter aclgrphInitialize start!"); // check global options if (CheckGlobalOptions(global_options) != GRAPH_SUCCESS) { @@ -132,6 +132,24 @@ graphStatus aclgrphBuildInitialize(std::map global_opt return GRAPH_SUCCESS; } +graphStatus aclgrphBuildInitialize(std::map global_options) { + return aclgrphBuildInitializeImpl(global_options); +} + +graphStatus aclgrphBuildInitialize(std::map &global_options) { + std::map tmp_global_options; + for (auto &option : global_options) { + if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { + GELOGE(GRAPH_FAILED, "AclgrphBuildInitialize option is nullptr."); + return GRAPH_FAILED; + } + std::string key = option.first.GetString(); + std::string val = option.second.GetString(); + tmp_global_options[key] = val; + } + return aclgrphBuildInitializeImpl(tmp_global_options); +} + void aclgrphBuildFinalize() { if (ge::GELib::GetInstance() != nullptr && ge::GELib::GetInstance()->InitFlag()) { (void)ge::GELib::GetInstance()->Finalize(); @@ -417,6 +435,24 @@ graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map &build_options, + ModelBufferData &model) { + GELOGD("Enter aclmdlBuildModel process!"); + std::map tmp_build_options; + for (auto &option : build_options) { + if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { + GELOGE(GRAPH_FAILED, "AclgrphBuildInitialize option is nullptr."); + return GRAPH_FAILED; + } + std::string key = option.first.GetString(); + std::string val = option.second.GetString(); + tmp_build_options[key] = val; + } + + Impl builder; + return builder.BuildModel(graph, tmp_build_options, model); +} + graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model) { GELOGD("Enter aclmdlSaveModel process!"); if (model.data.get() == nullptr || model.length == 0) { @@ -427,6 +463,21 @@ graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &m static_cast(model.length)); } +graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &model) { + GELOGD("Enter aclmdlSaveModel process!"); + if (model.data.get() == nullptr || model.length == 0) { + GELOGE(GRAPH_PARAM_INVALID, "Input model is illegal"); + return GRAPH_PARAM_INVALID; + } + if (output_file == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "Output file is nullptr."); + return GRAPH_PARAM_INVALID; + } + std::string str_output_file = output_file; + return FileSaver::SaveToFile((str_output_file + ".om"), reinterpret_cast(model.data.get()), + static_cast(model.length)); +} + graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version) { GELOGD("Enter aclgrphGetIRVersion process!"); GE_CHECK_NOTNULL(major_version); diff --git a/ge/offline/CMakeLists.txt b/ge/offline/CMakeLists.txt index e6233634..e8e91327 100644 --- a/ge/offline/CMakeLists.txt +++ b/ge/offline/CMakeLists.txt @@ -20,6 +20,7 @@ add_executable(atc ${SRC_LIST} ${PROTO_HDRS}) target_compile_options(atc PRIVATE -Werror -O2 + -Wno-deprecated-declarations ) target_compile_definitions(atc PRIVATE diff --git a/ge/offline/module.mk b/ge/offline/module.mk index d84734d0..7d205fc3 100755 --- a/ge/offline/module.mk +++ b/ge/offline/module.mk @@ -5,7 +5,7 @@ include $(CLEAR_VARS) LOCAL_MODULE := atc -LOCAL_CFLAGS += -Werror +LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dgoogle=ascend_private LOCAL_SRC_FILES := \ diff --git a/ge/session/inner_session.cc b/ge/session/inner_session.cc index aa825a4b..7b74af76 100755 --- a/ge/session/inner_session.cc +++ b/ge/session/inner_session.cc @@ -236,6 +236,25 @@ Status InnerSession::RegisterCallBackFunc( return SUCCESS; } +Status InnerSession::RegisterCallBackFunc( + const std::string &key, + const std::function &)> &callback) { + std::lock_guard lock(resource_mutex_); + if (!init_flag_) { + GELOGE(GE_SESS_INIT_FAILED, "[InnerSession:%lu] initialize failed.", session_id_); + return GE_SESS_INIT_FAILED; + } + UpdateThreadContext(std::map{}); + Status ret = graph_manager_.RegisterCallBackFunc(key, callback); + if (ret != SUCCESS) { + GELOGE(ret, "[InnerSession:%lu] register %s callback function failed.", session_id_, key.c_str()); + return ret; + } + + GELOGI("[InnerSession:%lu] register %s callback function success.", session_id_, key.c_str()); + return SUCCESS; +} + Status InnerSession::BuildGraph(uint32_t graph_id, const std::vector &inputs) { UpdateThreadContext(graph_id); GELOGI("[InnerSession:%lu] build graph on session, graph_id=%u.", session_id_, graph_id); diff --git a/ge/session/inner_session.h b/ge/session/inner_session.h index 25f5c307..a887e303 100644 --- a/ge/session/inner_session.h +++ b/ge/session/inner_session.h @@ -60,6 +60,10 @@ class InnerSession { const std::string &key, const std::function &)> &callback); + Status RegisterCallBackFunc( + const std::string &key, + const std::function &)> &callback); + const GraphManager &getGraphManagerObj() const; bool IsGraphNeedRebuild(uint32_t graph_id); diff --git a/ge/session/session_manager.cc b/ge/session/session_manager.cc index 6f8c9432..d9613ec3 100755 --- a/ge/session/session_manager.cc +++ b/ge/session/session_manager.cc @@ -246,6 +246,26 @@ Status SessionManager::RegisterCallBackFunc( return innerSession->RegisterCallBackFunc(key, callback); } +Status SessionManager::RegisterCallBackFunc( + SessionId session_id, const std::string &key, + const std::function &)> &callback) { + if (!init_flag_) { + GELOGE(GE_SESSION_MANAGER_NOT_INIT); + return GE_SESSION_MANAGER_NOT_INIT; + } + SessionPtr innerSession = nullptr; + { + std::lock_guard lock(mutex_); + std::map::iterator it = session_manager_map_.find(session_id); + if (it == session_manager_map_.end()) { + return GE_SESSION_NOT_EXIST; + } else { + innerSession = it->second; + } + } + return innerSession->RegisterCallBackFunc(key, callback); +} + Status SessionManager::BuildGraph(SessionId session_id, uint32_t graph_id, const std::vector &inputs) { if (!init_flag_) { GELOGE(GE_SESSION_MANAGER_NOT_INIT); diff --git a/ge/session/session_manager.h b/ge/session/session_manager.h index 88864f61..cb825ccd 100644 --- a/ge/session/session_manager.h +++ b/ge/session/session_manager.h @@ -146,6 +146,9 @@ class SessionManager { Status RegisterCallBackFunc( SessionId session_id, const std::string &key, const std::function &)> &callback); + Status RegisterCallBackFunc( + SessionId session_id, const std::string &key, + const std::function &)> &callback); bool IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id); diff --git a/inc/external/ge/ge_api.h b/inc/external/ge/ge_api.h index b4b9bb2a..dd11e42e 100644 --- a/inc/external/ge/ge_api.h +++ b/inc/external/ge/ge_api.h @@ -29,16 +29,26 @@ namespace ge { typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map ¶ms_list); +namespace session { +typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map ¶ms_list); +} + // Initialize GE +ATTRIBUTED_DEPRECATED(Status GEInitialize(const std::map &)) Status GEInitialize(const std::map &options); +Status GEInitialize(const std::map &options); + // Finalize GE, release all resources Status GEFinalize(); class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { public: + ATTRIBUTED_DEPRECATED(Session(const std::map &)) explicit Session(const std::map &options); + explicit Session(const std::map &options); + ~Session(); /// @@ -57,9 +67,20 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { /// @param [in] options graph options /// @return Status result of function /// + ATTRIBUTED_DEPRECATED(Status AddGraph(uint32_t, const Graph &, const std::map &)) Status AddGraph(uint32_t graphId, const Graph &graph, const std::map &options); /// + /// @ingroup client + /// @brief add a graph with a specific graphId and graphOptions + /// @param [in] graphId graph id + /// @param [in] graph the graph + /// @param [in] options graph options + /// @return Status result of function + /// + Status AddGraph(uint32_t graphId, const Graph &graph, const std::map &options); + + /// /// @ingroup ge_graph /// @brief remove a graph of the session with specific session id /// @param [in] graphId graph id @@ -105,10 +126,20 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { /// @param [out] var_values: variable values /// @return Status result of function /// + ATTRIBUTED_DEPRECATED(Status GetVariables(const std::vector &, std::vector &)) Status GetVariables(const std::vector &var_names, std::vector &var_values); /// /// @ingroup ge_graph + /// @brief get variables in the session with specific session id + /// @param [in] var_names: variable names + /// @param [out] var_values: variable values + /// @return Status result of function + /// + Status GetVariables(const std::vector &var_names, std::vector &var_values); + + /// + /// @ingroup ge_graph /// @brief register callback func with specific summary or checkpoint by users /// @param [in] key: func key /// @param [in] callback: callback specific summary or checkpoint. @@ -116,8 +147,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { /// Please ensure that the implementation of the function is trusted. /// @return Status result of function /// + ATTRIBUTED_DEPRECATED( Status RegisterCallBackFunc(const char *, const session::pCallBackFunc &)) Status RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback); + Status RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback); + bool IsGraphNeedRebuild(uint32_t graphId); private: diff --git a/inc/external/ge/ge_api_error_codes.h b/inc/external/ge/ge_api_error_codes.h index 7b045d54..251ceffc 100644 --- a/inc/external/ge/ge_api_error_codes.h +++ b/inc/external/ge/ge_api_error_codes.h @@ -19,8 +19,15 @@ #include #include +#include "graph/ascend_string.h" namespace ge { +#ifdef __GNUC__ +#define ATTRIBUTED_DEPRECATED(replacement) __attribute__((deprecated("Please use " #replacement " instead."))) +#else +#define ATTRIBUTED_DEPRECATED(replacement) __declspec(deprecated("Please use " #replacement " instead.")) +#endif + class StatusFactory { public: static StatusFactory *Instance() { @@ -36,6 +43,17 @@ class StatusFactory { err_desc_[err] = desc; } + void RegisterErrorNo(uint32_t err, const char *desc) { + if (desc == nullptr) { + return; + } + std::string error_desc = desc; + if (err_desc_.find(err) != err_desc_.end()) { + return; + } + err_desc_[err] = error_desc; + } + std::string GetErrDesc(uint32_t err) { auto iter_find = err_desc_.find(err); if (iter_find == err_desc_.end()) { @@ -44,6 +62,13 @@ class StatusFactory { return iter_find->second; } + void GetErrDesc(uint32_t err, ge::AscendString &err_desc) { + auto iter_find = err_desc_.find(err); + if (iter_find != err_desc_.end()) { + err_desc = ge::AscendString((iter_find->second).c_str()); + } + } + protected: StatusFactory() {} ~StatusFactory() {} @@ -55,6 +80,7 @@ class StatusFactory { class ErrorNoRegisterar { public: ErrorNoRegisterar(uint32_t err, const std::string &desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); } + ErrorNoRegisterar(uint32_t err, const char *desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); } ~ErrorNoRegisterar() {} }; diff --git a/inc/external/ge/ge_api_types.h b/inc/external/ge/ge_api_types.h index 37e2dccf..dca02450 100644 --- a/inc/external/ge/ge_api_types.h +++ b/inc/external/ge/ge_api_types.h @@ -65,7 +65,47 @@ const char *const OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION = "ge.exec.isTailingOp // Option key: memory init const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; const char *const VARIABLE_MEMORY_MAX_SIZE = "ge.variableMemoryMaxSize"; - +namespace configure_option { +const char *const STREAM_NUM = "ge.streamNum"; +const char *const HEAD_STREAM = "ge.headStream"; +const char *const PERF_LEVEL = "ge.perfLevel"; +const char *const ENCRYPT_MODE = "ge.encryptMode"; +const char *const EK_FILE = "ge.ekFile"; +const char *const CERT_FILE = "ge.certFile"; +const char *const HW_KEY_FILE = "ge.hwKeyFile"; +const char *const PRIVATE_KEY_FILE = "ge.privateKeyFile"; +const char *const FRAMEWORK_TYPE = "ge.frameworkType"; +const char *const CALIBRATION_CONF_FILE = "ge.calibrationConfFile"; +const char *const INSERT_OP_FILE = "ge.insertOpFile"; +const char *const OUTPUT_NODE_NAME = "ge.outputNodeName"; +const char *const COMPRESS_FLAG = "ge.compressFlag"; +const char *const PRECISION_MODE = "ge.exec.precision_mode"; +const char *const SINGLE_OP_FLAG = "ge.exec.single_op"; +const char *const TRAIN_FLAG = "ge.trainFlag"; +const char *const RUN_FLAG = "ge.runFlag"; +const char *const LOCAL_FMKOP_FLAG = "ge.enabledLocalFmkop"; +const char *const TBE_PLUGIN_PATH_FLAG = "ge.TBE_plugin_path"; +const char *const DDK_VERSION_FLAG = "ge.DDK_version"; +const char *const GE_FE_FLAG = "ge.feFlag"; +const char *const STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; +const char *const OUTPUT_DATATYPE = "ge.outputDatatype"; +const char *const OP_SELECT_IMPL_MODE = "ge.opSelectImplmode"; +const char *const OPTYPELIST_FOR_IMPLMODE = "ge.optypelistForImplmode"; +const char *const HCOM_PARALLEL = "ge.hcomParallel"; +const char *const AUTO_TUNE_MODE = "ge.autoTuneMode"; +const char *const SOC_VERSION = "ge.socVersion"; +const char *const CORE_TYPE = "ge.engineType"; +const char *const AICORE_NUM = "ge.aicoreNum"; +const char *const L1_FUSION = "ge.l1Fusion"; +const char *const BUFFER_OPTIMIZE = "ge.bufferOptimize"; +const char *const ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; +const char *const ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight"; +const char *const FUSION_SWITCH_FILE = "ge.fusionSwitchFile"; +const char *const SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; +const char *const ORIGINAL_MODEL_FILE = "ge.originalModelFile"; +const char *const INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; +const char *const OP_DEBUG_LEVEL = "ge.opDebugLevel"; +} // namespace configure_option // Configure stream num by Session constructor options param, // its value should be int32_t type, default value is "1" const std::string STREAM_NUM = "ge.streamNum"; diff --git a/inc/external/ge/ge_ir_build.h b/inc/external/ge/ge_ir_build.h index e6401093..7038f6fb 100644 --- a/inc/external/ge/ge_ir_build.h +++ b/inc/external/ge/ge_ir_build.h @@ -45,8 +45,11 @@ struct ModelBufferData * @retval GRAPH_SUCCESS The function is successfully executed. * @retval OtherValues Failure */ +ATTRIBUTED_DEPRECATED(graphStatus aclgrphBuildInitialize(std::map &)) graphStatus aclgrphBuildInitialize(std::map global_options); +graphStatus aclgrphBuildInitialize(std::map &global_options); + /** * @ingroup AscendCL * @brief build model.Notice the model is stored in buffer @@ -64,8 +67,13 @@ void aclgrphBuildFinalize(); * @retval GRAPH_SUCCESS The function is successfully executed. * @retval OtherValues Failure */ +ATTRIBUTED_DEPRECATED(graphStatus aclgrphBuildModel(const ge::Graph &, const std::map &, + ModelBufferData&)) graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map &build_options, ModelBufferData& model); +graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map &build_options, + ModelBufferData& model); + /** * @ingroup AscendCL * @brief save model buffer to file @@ -75,8 +83,11 @@ graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map