Browse Source

Merge branch 'development' of gitee.com:mindspore/graphengine into review_control_passes

pull/401/head
陈叶朦 Gitee 4 years ago
parent
commit
94da88bd04
22 changed files with 461 additions and 33 deletions
  1. +15
    -9
      CMakeLists.txt
  2. +1
    -1
      cmake/external_libs/protobuf_static.cmake
  3. +120
    -1
      ge/client/ge_api.cc
  4. +32
    -3
      ge/graph/manager/graph_manager.cc
  5. +8
    -1
      ge/graph/manager/graph_manager.h
  6. +6
    -1
      ge/graph/optimize/graph_optimize.cc
  7. +2
    -1
      ge/graph/optimize/graph_optimize.h
  8. +11
    -4
      ge/graph/partition/graph_partition.cc
  9. +36
    -0
      ge/graph/passes/merge_pass.cc
  10. +1
    -0
      ge/graph/passes/merge_pass.h
  11. +13
    -2
      ge/graph/passes/multi_batch_clone_pass.cc
  12. +59
    -1
      ge/ir_build/ge_ir_build.cc
  13. +19
    -0
      ge/session/inner_session.cc
  14. +4
    -0
      ge/session/inner_session.h
  15. +20
    -0
      ge/session/session_manager.cc
  16. +3
    -0
      ge/session/session_manager.h
  17. +34
    -0
      inc/external/ge/ge_api.h
  18. +18
    -0
      inc/external/ge/ge_api_error_codes.h
  19. +45
    -5
      inc/external/ge/ge_api_types.h
  20. +12
    -2
      inc/external/ge/ge_ir_build.h
  21. +1
    -1
      metadef
  22. +1
    -1
      parser

+ 15
- 9
CMakeLists.txt View File

@@ -21,6 +21,13 @@ set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64)
set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64)
set(STATIC_ACL_LIB ${ASCEND_ACL_DIR})

set(ASCEND_MS_RUNTIME_PATH ${ASCEND_RUNTIME_DIR} ${ASCEND_ACL_DIR} ${ASCEND_ATC_DIR})
set(ASCEND_MS_DRIVER_PATH ${ASCEND_DRIVER_DIR} ${ASCEND_DRIVER_COMMON_DIR})
set(ATLAS_RUNTIME_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64)
set(ATLAS_ACL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/lib64)
set(ATLAS_ATC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/atc/lib64)
set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR})

option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE)

if (ENABLE_OPEN_SRC)
@@ -129,14 +136,6 @@ if (ENABLE_OPEN_SRC)
#add_subdirectory(metadef/graph)
#add_subdirectory(metadef/register)
elseif (ENABLE_D OR ENABLE_ACL)

set(ASCEND_MS_RUNTIME_PATH ${ASCEND_RUNTIME_DIR} ${ASCEND_ACL_DIR} ${ASCEND_ATC_DIR})
set(ASCEND_MS_DRIVER_PATH ${ASCEND_DRIVER_DIR} ${ASCEND_DRIVER_COMMON_DIR})
set(ATLAS_RUNTIME_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64)
set(ATLAS_ACL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/lib64)
set(ATLAS_ATC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/atc/lib64)
set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR})

# compiling with MindSpore
include(cmake/external_libs/protobuf_static.cmake)
include(cmake/external_libs/protoc.cmake)
@@ -158,11 +157,18 @@ elseif (ENABLE_D OR ENABLE_ACL)

set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef)
add_subdirectory(metadef)
elseif(ENABLE_MS_TESTCASE)
elseif(ENABLE_MS_TESTCASES)
include(cmake/external_libs/protobuf_static.cmake)
include(cmake/external_libs/protoc.cmake)
include(cmake/external_libs/securec.cmake)
include(cmake/FindModule.cmake)
include(cmake/intf_pub_linux.cmake)

# common libraries
find_module(slog libslog.so ${ASCEND_MS_DRIVER_PATH})
find_module(error_manager liberror_manager.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH})
find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH})

set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef)
add_subdirectory(metadef)
else()


+ 1
- 1
cmake/external_libs/protobuf_static.cmake View File

@@ -42,7 +42,7 @@ include(GNUInstallDirs)
add_library(ascend_protobuf_static_lib STATIC IMPORTED)

set_target_properties(ascend_protobuf_static_lib PROPERTIES
IMPORTED_LOCATION ${PROTOBUF_STATIC_PKG_DIR}/lib64/libascend_protobuf.a
IMPORTED_LOCATION ${PROTOBUF_STATIC_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/libascend_protobuf.a
)

add_library(ascend_protobuf_static INTERFACE)


+ 120
- 1
ge/client/ge_api.cc View File

@@ -76,7 +76,7 @@ Status CheckOptionsValid(const std::map<string, string> &options) {
}

// Initialize GE, prepare for execution, call GELib::Initialize
Status GEInitialize(const std::map<string, string> &options) {
Status GEInitializeImpl(const std::map<string, string> &options) {
GELOGT(TRACE_INIT, "GEInitialize start");
// 0.check init status
if (g_ge_initialized) {
@@ -127,6 +127,26 @@ Status GEInitialize(const std::map<string, string> &options) {
return ret;
}

// Initialize GE, prepare for execution, call GELib::Initialize
Status GEInitialize(const std::map<string, string> &options) {
return GEInitializeImpl(options);
}

Status GEInitialize(const std::map<AscendString, AscendString> &options) {
std::map<std::string, std::string> str_options;
for (auto & option : options) {
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) {
GELOGE(FAILED, "GEInitialize options is nullptr.");
return FAILED;
}
std::string key = option.first.GetString();
std::string val = option.second.GetString();
str_options[key] = val;
}
return GEInitializeImpl(str_options);
}


// GE finalize, releasing all resources
Status GEFinalize() {
GELOGT(TRACE_INIT, "GEFinalize start");
@@ -202,6 +222,46 @@ Session::Session(const std::map<string, string> &options) {
GELOGT(TRACE_STOP, "Session Constructor finished");
}

Session::Session(const std::map<AscendString, AscendString> &options) {
GELOGT(TRACE_INIT, "Session Constructor start");
// check init status
sessionId_ = 0;
if (!g_ge_initialized) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GE is not initialized.");
return;
}
// call Initialize
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Session Constructor failed");
return;
}

GELOGT(TRACE_RUNNING, "Creating session");
std::map<std::string, std::string> str_options;
for (auto &option : options) {
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) {
GELOGE(FAILED, "Session options is nullptr.");
return;
}
std::string key = option.first.GetString();
std::string val = option.second.GetString();
str_options[key] = val;
}
uint64_t session_id = 0;
Status ret = instance_ptr->SessionManagerObj().CreateSession(str_options, session_id);
GELOGT(TRACE_RUNNING, "Session id is %lu", session_id);

// check return status, return, update session id if success
if (ret == SUCCESS) {
sessionId_ = session_id;
} else {
GELOGE(ret, "Session constructor failed, session Id not initialized");
return;
}
GELOGT(TRACE_STOP, "Session Constructor finished");
}

// session destructor
Session::~Session() {
GELOGT(TRACE_INIT, "Session Destructor start");
@@ -260,6 +320,34 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map<s
return ret;
}

Status Session::AddGraph(uint32_t graph_id, const Graph &graph,
const std::map<AscendString, AscendString> &options) {
GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_);
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "AddGraph failed in Session.");
return FAILED;
}
GELOGD("Adding graph to session");
std::map<std::string, std::string> str_options;
for (auto &option : options) {
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) {
GELOGE(FAILED, "AddGraph options is nullptr.");
return FAILED;
}
std::string key = option.first.GetString();
std::string val = option.second.GetString();
str_options[key] = val;
}
Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph, str_options);
if (ret != SUCCESS) {
GELOGE(ret, "AddGraph failed in Session.");
return FAILED;
}
GELOGD("AddGraph finished in Session.");
return ret;
}

Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph) {
std::map<AscendString, AscendString> options;
return AddGraphWithCopy(graph_id, graph, options);
@@ -387,6 +475,14 @@ Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc
return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback);
}

Status Session::RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback) {
std::string str_key;
if (key != nullptr) {
str_key = key;
}
return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, str_key, callback);
}

Status Session::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) {
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
@@ -436,6 +532,29 @@ Status Session::GetVariables(const std::vector<std::string> &var_names, std::vec
return SUCCESS;
}

Status Session::GetVariables(const std::vector<AscendString> &var_names, std::vector<Tensor> &var_values) {
auto instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "SessionConstructor failed");
return FAILED;
}
GELOGT(TRACE_RUNNING, "Get Variables");
std::vector<ge::string> str_var_names;
for (auto &var_name : var_names) {
if (var_name.GetString() == nullptr) {
GELOGE(FAILED, "GetVariables name is nullptr.");
return FAILED;
}
str_var_names.emplace_back(var_name.GetString());
}
Status ret = ge::GELib::GetInstance()->SessionManagerObj().GetVariables(sessionId_, str_var_names, var_values);
if (ret != SUCCESS) {
GELOGE(ret, "SessionManager RunGraphAsync failed");
return FAILED;
}
return SUCCESS;
}

bool Session::IsGraphNeedRebuild(uint32_t graph_id) {
return ge::GELib::GetInstance()->SessionManagerObj().IsGraphNeedRebuild(sessionId_, graph_id);
}


+ 32
- 3
ge/graph/manager/graph_manager.cc View File

@@ -548,7 +548,7 @@ Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_gr
(void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy);
}
std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this,
compute_graph->GetGraphID(), subgraph, session_id, GetThreadLocalContext());
compute_graph->GetGraphID(), subgraph, compute_graph, session_id, GetThreadLocalContext());
if (!f.valid()) {
GELOGE(FAILED, "Future is invalid");
return FAILED;
@@ -563,7 +563,7 @@ Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_gr
(void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy);
}
std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this,
compute_graph->GetGraphID(), subgraph, session_id,
compute_graph->GetGraphID(), subgraph, compute_graph, session_id,
GetThreadLocalContext());
if (!f.valid()) {
GELOGE(FAILED, "Future is invalid");
@@ -1865,12 +1865,30 @@ Status GraphManager::RegisterCallBackFunc(
return SUCCESS;
}

Status GraphManager::RegisterCallBackFunc(
const std::string &key,
const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback) {
std::lock_guard<std::mutex> lock(member_mutex_);
GELOGI("[GraphManager] RegisterCallBackFunc, key=%s.", key.c_str());
callback_map_[key] = callback;
return SUCCESS;
}

Status GraphManager::PushSummaryData2ME(const GraphId &graph_id,
const std::map<std::string, ge::Tensor> &summary_data) {
std::lock_guard<std::mutex> lock(member_mutex_);
GELOGI("[GraphManager] PushSummaryData2ME, dataSize=%zu.", summary_data.size());
auto itr = me_callback_map_.find(kSummary);
if (itr == me_callback_map_.end()) {
auto iter = callback_map_.find(kSummary);
if (iter != callback_map_.end()) {
std::map<AscendString, ge::Tensor> tmp_summary_data;
for (auto &data : summary_data) {
AscendString tmp(data.first.c_str());
tmp_summary_data[tmp] = data.second;
}
return iter->second(graph_id, tmp_summary_data);
}
GELOGE(FAILED, "[GraphManager] PushSummaryData2ME failed, not found summary callback.");
return FAILED;
}
@@ -1882,6 +1900,15 @@ Status GraphManager::PushSaveData2ME(const GraphId &graph_id, const std::map<std
GELOGI("[GraphManager] PushSaveData2ME, dataSize=%zu.", save_data.size());
auto itr = me_callback_map_.find(kSave);
if (itr == me_callback_map_.end()) {
auto iter = callback_map_.find(kSave);
if (iter != callback_map_.end()) {
std::map<AscendString, ge::Tensor> tmp_save_data;
for (auto &data : save_data) {
AscendString tmp(data.first.c_str());
tmp_save_data[tmp] = data.second;
}
return iter->second(graph_id, tmp_save_data);
}
GELOGE(FAILED, "[GraphManager] PushSaveData2ME failed, not found checkpoint callback.");
return FAILED;
}
@@ -2478,7 +2505,8 @@ Status GraphManager::CheckAndReleaseMemory(const GeModelPtr &ge_model, const Gra
}

Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id,
const SubGraphInfoPtr &sub_graph_info_ptr, uint64_t session_id,
const SubGraphInfoPtr &sub_graph_info_ptr,
const ComputeGraphPtr &compute_graph, uint64_t session_id,
const GEThreadLocalContext &ge_context) {
if (sub_graph_info_ptr != nullptr && graph_manager != nullptr) {
GetContext().SetSessionId(session_id);
@@ -2494,6 +2522,7 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager
GE_CHECK_NOTNULL(compute_graph_tmp);
compute_graph_tmp->SetSessionID(session_id);
Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp,
compute_graph,
engine_name);
if (ret != SUCCESS) {
GELOGE(ret, "SubGraph optimize Failed %s", engine_name.c_str());


+ 8
- 1
ge/graph/manager/graph_manager.h View File

@@ -163,6 +163,10 @@ class GraphManager {
const std::string &key,
const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback);

Status RegisterCallBackFunc(
const std::string &key,
const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback);

const bool GetTrainFlag() const { return options_.train_graph_flag; }

bool IsGraphNeedRebuild(uint32_t graph_id);
@@ -214,7 +218,8 @@ class GraphManager {
std::shared_ptr<GraphModelListener> GetModelListener() const { return graph_run_listener_; }

static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, GraphId root_graph_id,
const SubGraphInfoPtr &sub_graph_info_ptr, uint64_t session_id,
const SubGraphInfoPtr &sub_graph_info_ptr,
const ComputeGraphPtr &compute_graph, uint64_t session_id,
const GEThreadLocalContext &ge_context);
Status ParseInputsDims(const std::vector<InputTensorInfo> &input_tensor);
Status DistinguishGetNextAndData(ComputeGraphPtr &graph, vector<NodePtr> &data_nodes,
@@ -390,6 +395,8 @@ class GraphManager {
// summary and checkpoint callback function list for ME, key is summary or checkpoint
std::map<std::string, std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)>> me_callback_map_;

std::map<std::string, std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)>> callback_map_;

bool init_flag_;

GraphManagerOptions options_;


+ 6
- 1
ge/graph/optimize/graph_optimize.cc View File

@@ -76,7 +76,8 @@ void AddNodeInputProperty(ComputeGraphPtr &compute_graph) {
}
}

Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std::string &engine_name) {
Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const ComputeGraphPtr &parent_graph,
const std::string &engine_name) {
if (compute_graph == nullptr) {
GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeSubGraph]: compute_graph is nullptr.");
return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL;
@@ -105,6 +106,10 @@ Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std
for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) {
Status ret = (*iter)->OptimizeFusedGraphAfterGraphSlice(*(compute_graph));
if (ret != SUCCESS) {
auto root_graph = ge::GraphUtils::FindRootGraph(parent_graph);
if (root_graph != nullptr) {
ErrorManager::GetInstance().SaveMstuneCompileFailedMsg(root_graph->GetName());
}
GELOGE(ret, "[OptimizeSubGraph][OptimizeFusedGraphAfterGraphSlice]: graph optimize failed, ret:%d", ret);
return ret;
}


+ 2
- 1
ge/graph/optimize/graph_optimize.h View File

@@ -42,7 +42,8 @@ class GraphOptimize {
~GraphOptimize() = default;

// subgraph optimize
Status OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std::string &engine_name);
Status OptimizeSubGraph(ComputeGraphPtr &compute_graph, const ComputeGraphPtr &parent_graph,
const std::string &engine_name);

// original graph optimize
Status OptimizeOriginalGraph(ComputeGraphPtr &compute_graph);


+ 11
- 4
ge/graph/partition/graph_partition.cc View File

@@ -382,11 +382,18 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr
GELOGW("SetInt anchorIndex failed");)
GE_IF_BOOL_EXEC(!pld_op_desc->SetExtAttr("parentNode", src_node),
GELOGW("SetPldExtAttr parentNode failed");)

OpDescPtr src_node_op_desc = src_node->GetOpDesc();
GE_CHECK_NOTNULL(src_node_op_desc);
GE_IF_BOOL_EXEC(!AttrUtils::SetStr(pld_op_desc, ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME,
src_node_op_desc->GetOpEngineName()), GELOGW("SetStr frontNodeEngineName failed");)
src_node_opdesc->GetOpEngineName()), GELOGW("SetStr frontNodeEngineName failed");)
std::string l2_info_attr;
if (AttrUtils::GetStr(src_node_opdesc, "_task_L2FusionInfo", l2_info_attr)) {
GE_IF_BOOL_EXEC(!AttrUtils::SetStr(pld_op_desc, "_task_L2FusionInfo", l2_info_attr),
GELOGW("SetStr l2_info_attr failed");)
}
int64_t anchor_index_for_lxfusion;
if (AttrUtils::GetInt(src_node_opdesc, "_data_anchor_index_for_lxfusion", anchor_index_for_lxfusion)) {
GE_IF_BOOL_EXEC(!AttrUtils::SetInt(pld_op_desc, "_data_anchor_index_for_lxfusion", anchor_index_for_lxfusion),
GELOGW("SetInt anchor_index_for_lxfusion failed");)
}
// do not care over flow
graph_info_.num_of_pld_end_++;
// replace output_desc of pld with input node's output desc


+ 36
- 0
ge/graph/passes/merge_pass.cc View File

@@ -30,6 +30,11 @@
namespace ge {
const int kValueIndexOutputIndex = 1;

bool IsEmptyTensor(const GeShape &shape) {
const auto &dims = shape.GetDims();
return std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim == 0; });
}

Status MergePass::Run(NodePtr &node) {
GELOGD("MergePass running");
if (node == nullptr) {
@@ -48,6 +53,11 @@ Status MergePass::Run(NodePtr &node) {
return PARAM_INVALID;
}

if (OptimizeEmptyTensorInput(node) != SUCCESS) {
GELOGE(FAILED, "[%s] remove empty_tensor inputs failed.", node->GetName().c_str());
return FAILED;
}

const auto &in_data_nodes = node->GetInDataNodes();
switch (in_data_nodes.size()) {
case 0: {
@@ -197,4 +207,30 @@ bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const {
}
return true;
}

Status MergePass::OptimizeEmptyTensorInput(const NodePtr &node) {
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_data_anchor == nullptr) {
continue;
}
if ((peer_data_anchor->GetOwnerNode() == nullptr) ||
(peer_data_anchor->GetOwnerNode()->GetOpDesc() == nullptr)) {
continue;
}
const auto &op_desc = peer_data_anchor->GetOwnerNode()->GetOpDesc();
if (IsEmptyTensor(op_desc->GetOutputDesc(peer_data_anchor->GetIdx()).GetShape())) {
if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.",
op_desc->GetName().c_str(), peer_data_anchor->GetIdx(),
node->GetName().c_str(), in_data_anchor->GetIdx());
return FAILED;
}
GELOGD("Remove data edge %s:%d->%s:%d",
op_desc->GetName().c_str(), peer_data_anchor->GetIdx(),
node->GetName().c_str(), in_data_anchor->GetIdx());
}
}
return SUCCESS;
}
} // namespace ge

+ 1
- 0
ge/graph/passes/merge_pass.h View File

@@ -29,6 +29,7 @@ class MergePass : public BaseNodePass {
Status ChangeIndexToConstant(NodePtr &node, int &value_index);
Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc);
bool IsMergeInputNeedOptimized(NodePtr &node) const;
static Status OptimizeEmptyTensorInput(const NodePtr &node);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_MERGE_PASS_H_

+ 13
- 2
ge/graph/passes/multi_batch_clone_pass.cc View File

@@ -610,11 +610,17 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const
///
Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) {
auto func_desc = case_node_->GetOpDesc();
domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr;
auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType());
if (post_func == nullptr) {
GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(),
case_node_->GetType().c_str());
return FAILED;
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType(), parse_func_v2) != SUCCESS ||
parse_func_v2 == nullptr) {
GELOGW("The subgraph new post func v2 for node %s type %s is null", case_node_->GetName().c_str(),
case_node_->GetType().c_str());
return FAILED;
}
}

for (const auto &name : func_desc->GetSubgraphInstanceNames()) {
@@ -629,7 +635,12 @@ Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) {
"Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str());

auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph);
auto ret = post_func(subgraph_name, graph);
Status ret = FAILED;
if (post_func != nullptr) {
ret = post_func(subgraph_name, graph);
} else if (parse_func_v2 != nullptr) {
ret = parse_func_v2(subgraph_name.c_str(), graph);
}
if (ret != SUCCESS) {
GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(),
case_node_->GetName().c_str(), case_node_->GetType().c_str());


+ 59
- 1
ge/ir_build/ge_ir_build.cc View File

@@ -141,7 +141,7 @@ static void LoadOpsProto() {
(void)manager->Initialize(option_tmp);
}

graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options) {
graphStatus aclgrphBuildInitializeImpl(std::map<std::string, std::string> &global_options) {
GELOGD("Enter aclgrphInitialize start!");
// check global options
if (CheckGlobalOptions(global_options) != GRAPH_SUCCESS) {
@@ -164,9 +164,34 @@ graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_opt
}
}
GELOGW("gelib has been initialized!");

std::string path_base = ge::GELib::GetPath();
int ret = ErrorManager::GetInstance().Init(path_base);
if (ret != 0) {
DOMI_LOGE("ErrorManager init fail !");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}

graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options) {
return aclgrphBuildInitializeImpl(global_options);
}

graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &global_options) {
std::map<std::string, std::string> tmp_global_options;
for (auto &option : global_options) {
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) {
GELOGE(GRAPH_FAILED, "AclgrphBuildInitialize option is nullptr.");
return GRAPH_FAILED;
}
std::string key = option.first.GetString();
std::string val = option.second.GetString();
tmp_global_options[key] = val;
}
return aclgrphBuildInitializeImpl(tmp_global_options);
}

void aclgrphBuildFinalize() {
if (ge::GELib::GetInstance() != nullptr && ge::GELib::GetInstance()->InitFlag()) {
(void)ge::GELib::GetInstance()->Finalize();
@@ -453,6 +478,24 @@ graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string
return builder.BuildModel(graph, build_options, model);
}

graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<AscendString, AscendString> &build_options,
ModelBufferData &model) {
GELOGD("Enter aclmdlBuildModel process!");
std::map<std::string, std::string> tmp_build_options;
for (auto &option : build_options) {
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) {
GELOGE(GRAPH_FAILED, "AclgrphBuildInitialize option is nullptr.");
return GRAPH_FAILED;
}
std::string key = option.first.GetString();
std::string val = option.second.GetString();
tmp_build_options[key] = val;
}

Impl builder;
return builder.BuildModel(graph, tmp_build_options, model);
}

graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model) {
GELOGD("Enter aclmdlSaveModel process!");
if (model.data.get() == nullptr || model.length == 0) {
@@ -463,6 +506,21 @@ graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &m
static_cast<uint32_t>(model.length));
}

graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &model) {
GELOGD("Enter aclmdlSaveModel process!");
if (model.data.get() == nullptr || model.length == 0) {
GELOGE(GRAPH_PARAM_INVALID, "Input model is illegal");
return GRAPH_PARAM_INVALID;
}
if (output_file == nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "Output file is nullptr.");
return GRAPH_PARAM_INVALID;
}
std::string str_output_file = output_file;
return FileSaver::SaveToFile((str_output_file + ".om"), reinterpret_cast<void*>(model.data.get()),
static_cast<uint32_t>(model.length));
}

graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version) {
GELOGD("Enter aclgrphGetIRVersion process!");
GE_CHECK_NOTNULL(major_version);


+ 19
- 0
ge/session/inner_session.cc View File

@@ -254,6 +254,25 @@ Status InnerSession::RegisterCallBackFunc(
return SUCCESS;
}

Status InnerSession::RegisterCallBackFunc(
const std::string &key,
const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback) {
std::lock_guard<std::mutex> lock(resource_mutex_);
if (!init_flag_) {
GELOGE(GE_SESS_INIT_FAILED, "[InnerSession:%lu] initialize failed.", session_id_);
return GE_SESS_INIT_FAILED;
}
UpdateThreadContext(std::map<std::string, std::string>{});
Status ret = graph_manager_.RegisterCallBackFunc(key, callback);
if (ret != SUCCESS) {
GELOGE(ret, "[InnerSession:%lu] register %s callback function failed.", session_id_, key.c_str());
return ret;
}

GELOGI("[InnerSession:%lu] register %s callback function success.", session_id_, key.c_str());
return SUCCESS;
}

Status InnerSession::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) {
UpdateThreadContext(graph_id);
GELOGI("[InnerSession:%lu] build graph on session, graph_id=%u.", session_id_, graph_id);


+ 4
- 0
ge/session/inner_session.h View File

@@ -62,6 +62,10 @@ class InnerSession {
const std::string &key,
const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback);

Status RegisterCallBackFunc(
const std::string &key,
const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback);

const GraphManager &getGraphManagerObj() const;

bool IsGraphNeedRebuild(uint32_t graph_id);


+ 20
- 0
ge/session/session_manager.cc View File

@@ -276,6 +276,26 @@ Status SessionManager::RegisterCallBackFunc(
return innerSession->RegisterCallBackFunc(key, callback);
}

Status SessionManager::RegisterCallBackFunc(
SessionId session_id, const std::string &key,
const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback) {
if (!init_flag_) {
GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");
return GE_SESSION_MANAGER_NOT_INIT;
}
SessionPtr innerSession = nullptr;
{
std::lock_guard<std::mutex> lock(mutex_);
std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id);
if (it == session_manager_map_.end()) {
return GE_SESSION_NOT_EXIST;
} else {
innerSession = it->second;
}
}
return innerSession->RegisterCallBackFunc(key, callback);
}

Status SessionManager::BuildGraph(SessionId session_id, uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) {
if (!init_flag_) {
GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");


+ 3
- 0
ge/session/session_manager.h View File

@@ -158,6 +158,9 @@ class SessionManager {
Status RegisterCallBackFunc(
SessionId session_id, const std::string &key,
const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback);
Status RegisterCallBackFunc(
SessionId session_id, const std::string &key,
const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback);

bool IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id);



+ 34
- 0
inc/external/ge/ge_api.h View File

@@ -29,16 +29,26 @@
namespace ge {
typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map<std::string, ge::Tensor> &params_list);

namespace session {
typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map<AscendString, ge::Tensor> &params_list);
}

// Initialize GE
ATTRIBUTED_DEPRECATED(Status GEInitialize(const std::map<AscendString, AscendString> &))
Status GEInitialize(const std::map<std::string, std::string> &options);

Status GEInitialize(const std::map<AscendString, AscendString> &options);

// Finalize GE, release all resources
Status GEFinalize();

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session {
public:
ATTRIBUTED_DEPRECATED(Session(const std::map<AscendString, AscendString> &))
explicit Session(const std::map<std::string, std::string> &options);

explicit Session(const std::map<AscendString, AscendString> &options);

~Session();

///
@@ -57,10 +67,21 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session {
/// @param [in] options graph options
/// @return Status result of function
///
ATTRIBUTED_DEPRECATED(Status AddGraph(uint32_t, const Graph &, const std::map<AscendString, AscendString> &))
Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<std::string, std::string> &options);

///
/// @ingroup client
/// @brief add a graph with a specific graphId and graphOptions
/// @param [in] graphId graph id
/// @param [in] graph the graph
/// @param [in] options graph options
/// @return Status result of function
///
Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<AscendString, AscendString> &options);

///
/// @ingroup client
/// @brief add a copy graph with a specific graphId
/// @param [in] graphId graph id
/// @param [in] graph the graph
@@ -124,10 +145,20 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session {
/// @param [out] var_values: variable values
/// @return Status result of function
///
ATTRIBUTED_DEPRECATED(Status GetVariables(const std::vector<std::string> &, std::vector<Tensor> &))
Status GetVariables(const std::vector<std::string> &var_names, std::vector<Tensor> &var_values);

///
/// @ingroup ge_graph
/// @brief get variables in the session with specific session id
/// @param [in] var_names: variable names
/// @param [out] var_values: variable values
/// @return Status result of function
///
Status GetVariables(const std::vector<AscendString> &var_names, std::vector<Tensor> &var_values);

///
/// @ingroup ge_graph
/// @brief register callback func with specific summary or checkpoint by users
/// @param [in] key: func key
/// @param [in] callback: callback specific summary or checkpoint.
@@ -135,8 +166,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session {
/// Please ensure that the implementation of the function is trusted.
/// @return Status result of function
///
ATTRIBUTED_DEPRECATED(Status RegisterCallBackFunc(const char *, const session::pCallBackFunc &))
Status RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback);

Status RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback);

bool IsGraphNeedRebuild(uint32_t graphId);

private:


+ 18
- 0
inc/external/ge/ge_api_error_codes.h View File

@@ -22,6 +22,12 @@
#include "ge_error_codes.h"

namespace ge {
#ifdef __GNUC__
#define ATTRIBUTED_DEPRECATED(replacement) __attribute__((deprecated("Please use " #replacement " instead.")))
#else
#define ATTRIBUTED_DEPRECATED(replacement) __declspec(deprecated("Please use " #replacement " instead."))
#endif

class StatusFactory {
public:
static StatusFactory *Instance() {
@@ -37,6 +43,17 @@ class StatusFactory {
err_desc_[err] = desc;
}

void RegisterErrorNo(uint32_t err, const char *desc) {
if (desc == nullptr) {
return;
}
std::string error_desc = desc;
if (err_desc_.find(err) != err_desc_.end()) {
return;
}
err_desc_[err] = error_desc;
}

std::string GetErrDesc(uint32_t err) {
auto iter_find = err_desc_.find(err);
if (iter_find == err_desc_.end()) {
@@ -56,6 +73,7 @@ class StatusFactory {
class ErrorNoRegisterar {
public:
ErrorNoRegisterar(uint32_t err, const std::string &desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); }
ErrorNoRegisterar(uint32_t err, const char *desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); }
~ErrorNoRegisterar() {}
};



+ 45
- 5
inc/external/ge/ge_api_types.h View File

@@ -65,7 +65,47 @@ const char *const OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION = "ge.exec.isTailingOp
// Option key: memory init
const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize";
const char *const VARIABLE_MEMORY_MAX_SIZE = "ge.variableMemoryMaxSize";

namespace configure_option {
const char *const STREAM_NUM = "ge.streamNum";
const char *const HEAD_STREAM = "ge.headStream";
const char *const PERF_LEVEL = "ge.perfLevel";
const char *const ENCRYPT_MODE = "ge.encryptMode";
const char *const EK_FILE = "ge.ekFile";
const char *const CERT_FILE = "ge.certFile";
const char *const HW_KEY_FILE = "ge.hwKeyFile";
const char *const PRIVATE_KEY_FILE = "ge.privateKeyFile";
const char *const FRAMEWORK_TYPE = "ge.frameworkType";
const char *const CALIBRATION_CONF_FILE = "ge.calibrationConfFile";
const char *const INSERT_OP_FILE = "ge.insertOpFile";
const char *const OUTPUT_NODE_NAME = "ge.outputNodeName";
const char *const COMPRESS_FLAG = "ge.compressFlag";
const char *const PRECISION_MODE = "ge.exec.precision_mode";
const char *const SINGLE_OP_FLAG = "ge.exec.single_op";
const char *const TRAIN_FLAG = "ge.trainFlag";
const char *const RUN_FLAG = "ge.runFlag";
const char *const LOCAL_FMKOP_FLAG = "ge.enabledLocalFmkop";
const char *const TBE_PLUGIN_PATH_FLAG = "ge.TBE_plugin_path";
const char *const DDK_VERSION_FLAG = "ge.DDK_version";
const char *const GE_FE_FLAG = "ge.feFlag";
const char *const STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum";
const char *const OUTPUT_DATATYPE = "ge.outputDatatype";
const char *const OP_SELECT_IMPL_MODE = "ge.opSelectImplmode";
const char *const OPTYPELIST_FOR_IMPLMODE = "ge.optypelistForImplmode";
const char *const HCOM_PARALLEL = "ge.hcomParallel";
const char *const AUTO_TUNE_MODE = "ge.autoTuneMode";
const char *const SOC_VERSION = "ge.socVersion";
const char *const CORE_TYPE = "ge.engineType";
const char *const AICORE_NUM = "ge.aicoreNum";
const char *const L1_FUSION = "ge.l1Fusion";
const char *const BUFFER_OPTIMIZE = "ge.bufferOptimize";
const char *const ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel";
const char *const ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight";
const char *const FUSION_SWITCH_FILE = "ge.fusionSwitchFile";
const char *const SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel";
const char *const ORIGINAL_MODEL_FILE = "ge.originalModelFile";
const char *const INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16";
const char *const OP_DEBUG_LEVEL = "ge.opDebugLevel";
} // namespace configure_option
// Configure stream num by Session constructor options param,
// its value should be int32_t type, default value is "1"
const std::string STREAM_NUM = "ge.streamNum";
@@ -324,8 +364,8 @@ static const char *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c
static const char *const DEBUG_DIR = ge::DEBUG_DIR;
static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR;
static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE;
static const char *const MDL_BANK_PATH_FLAG = ge::MDL_BANK_PATH_FLAG.c_str();
static const char *const OP_BANK_PATH_FLAG = ge::OP_BANK_PATH_FLAG.c_str();
static const char *const MDL_BANK_PATH = ge::MDL_BANK_PATH_FLAG.c_str();
static const char *const OP_BANK_PATH = ge::OP_BANK_PATH_FLAG.c_str();
static const char *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str();

// for interface: aclgrphBuildModel
@@ -347,8 +387,8 @@ const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT,
DEBUG_DIR,
OP_COMPILER_CACHE_DIR,
OP_COMPILER_CACHE_MODE,
MDL_BANK_PATH_FLAG,
OP_BANK_PATH_FLAG};
MDL_BANK_PATH,
OP_BANK_PATH};

// for interface: aclgrphParse
const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT,


+ 12
- 2
inc/external/ge/ge_ir_build.h View File

@@ -44,8 +44,11 @@ struct ModelBufferData {
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
ATTRIBUTED_DEPRECATED(graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &))
graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options);

graphStatus aclgrphBuildInitialize(std::map<AscendString, AscendString> &global_options);

/**
* @ingroup AscendCL
* @brief build model.Notice the model is stored in buffer
@@ -63,9 +66,14 @@ void aclgrphBuildFinalize();
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
ATTRIBUTED_DEPRECATED(graphStatus aclgrphBuildModel(const ge::Graph &, const std::map<AscendString, AscendString> &,
ModelBufferData &))
graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options,
ModelBufferData &model);

graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<AscendString, AscendString> &build_options,
ModelBufferData &model);

/**
* @ingroup AscendCL
* @brief save model buffer to file
@@ -75,8 +83,11 @@ graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
ATTRIBUTED_DEPRECATED(graphStatus aclgrphSaveModel(const char *, const ModelBufferData &))
graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model);

graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &model);

/**
* @ingroup AscendCL
* @brief query IR interface version
@@ -110,6 +121,5 @@ graphStatus aclgrphInferShapeAndType(ge::Graph &graph);
* @retval OtherValues Failure
*/
graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len);
}; // namespace ge

}; // namespace ge
#endif // INC_EXTERNAL_GE_IR_BUILD_H_

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 57e72aac24a35e40799e342fdacca362a66395c4
Subproject commit 0f5ddb10ce79ea2c01b8b9cab5ec3102879610bb

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit bb6424dc6d9252a3ac70650cde2f547761237681
Subproject commit cf60b0c02d1a6e844fcec4202d18a069e9502b0f

Loading…
Cancel
Save