diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index f98297d8..1d2bdeea 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -133,6 +133,7 @@ set(EXECUTOR_SRC_LIST "graph/load/model_manager/task_info/event_record_task_info.cc" "graph/load/model_manager/task_info/event_wait_task_info.cc" "graph/load/model_manager/task_info/ffts_task_info.cc" + "graph/load/model_manager/task_info/ffts_plus_task_info.cc" "graph/load/model_manager/task_info/fusion_start_task_info.cc" "graph/load/model_manager/task_info/fusion_stop_task_info.cc" #"graph/load/model_manager/task_info/hccl_task_info.cc" # Just for runner. diff --git a/ge/common/ge/plugin_manager.h b/ge/common/ge/plugin_manager.h index 0869704f..9c075ea2 100755 --- a/ge/common/ge/plugin_manager.h +++ b/ge/common/ge/plugin_manager.h @@ -120,6 +120,22 @@ class PluginManager { } return SUCCESS; } + + template + void OptionalInvokeAll(const string &func_name, T1 arg1, T2 arg2) { + for (const auto &handle : handles_) { + // If the funcName is existed, signature of realFn can be casted to any type + auto real_fn = (void (*)(T1, T2))mmDlsym(handle.second, const_cast(func_name.c_str())); + if (real_fn == nullptr) { + GELOGI("func %s not exist in so %s", handle.first.c_str(), func_name.c_str()); + continue; + } else { + GELOGI("func %s exists in so %s", handle.first.c_str(), func_name.c_str()); + real_fn(arg1, arg2); + } + } + } + template Status InvokeAll(const string &func_name, T1 arg) { for (const auto &handle : handles_) { diff --git a/ge/engine_manager/dnnengine_manager.cc b/ge/engine_manager/dnnengine_manager.cc index 36f11828..a9bd17c4 100644 --- a/ge/engine_manager/dnnengine_manager.cc +++ b/ge/engine_manager/dnnengine_manager.cc @@ -17,20 +17,15 @@ #include "engine_manager/dnnengine_manager.h" #include -#include #include -#include #include "framework/common/debug/log.h" #include "common/ge/ge_util.h" -#include "common/util/error_manager/error_manager.h" -#include "framework/common/debug/ge_log.h" #include "analyzer/analyzer.h" #include "graph/ge_context.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" #include "init/gelib.h" -#include "framework/common/types.h" namespace { const char *const kSchedulerUnits = "schedule_units"; @@ -40,7 +35,7 @@ const char *const kExAttrs = "ex_attrs"; const char *const kIndependent = "independent"; const char *const kSkipAssignStream = "skip_assign_stream"; const char *const kCalEngines = "cal_engines"; -const char *const kAttch = "attach"; +const char *const kAttach = "attach"; const char *const kVectorCore = "VectorCore"; const char *const kVectorEngine = "VectorEngine"; const char *const kAIcoreEngine = "AIcoreEngine"; @@ -51,6 +46,9 @@ const char *const kHostCpuOpKernelLibName = "DNN_VM_HOST_CPU_OP_STORE"; namespace ge { namespace { const std::set kNotCpuOp = {DATA, CONSTANT, CONSTANTOP, VARIABLE, NETOUTPUT}; +const char *const kGetDNNEngineObjs = "GetDNNEngineObjs"; +const char *const kInvalidCompositeEngineName = "InvalidCompositeEngineName"; +constexpr uint32_t kMaxRecursiveDepth = 10; bool ExecOnHostCpu(const OpDescPtr &op_desc) { bool is_host_cpu_op = (kNotCpuOp.find(op_desc->GetType()) == kNotCpuOp.end()); @@ -64,6 +62,11 @@ DNNEngineManager::~DNNEngineManager() { schedulers_.clear(); } +DNNEngineManager &DNNEngineManager::GetInstance() { + static DNNEngineManager instance; + return instance; +} + Status DNNEngineManager::Initialize(const std::map &options) { // Multiple initializations are not supported if (init_flag_) { @@ -72,22 +75,21 @@ Status DNNEngineManager::Initialize(const std::map &op } // Load engine so - std::string so_path = "plugin/nnengine/"; + std::string plugin_so_path = "plugin/nnengine/"; std::string path = PluginManager::GetPath(); - path.append(so_path); - std::string so_api_func = "GetDNNEngineObjs"; - std::vector so_func{so_api_func}; - Status status = plugin_mgr_.Load(path, so_func); + std::string engine_plugin_path = path + plugin_so_path; + std::vector so_func{kGetDNNEngineObjs}; + Status status = plugin_mgr_.Load(engine_plugin_path, so_func); if (status != SUCCESS) { GELOGE(status, "[Load][EngineSo]Failed, lib path %s", path.c_str()); - REPORT_CALL_ERROR("E19999", "Load engine so failed, lib path %s", path.c_str()); + REPORT_CALL_ERROR("E19999", "Load engine so failed, lib path %s", engine_plugin_path.c_str()); return status; } - status = plugin_mgr_.InvokeAll &>(so_api_func, engines_map_); + status = plugin_mgr_.InvokeAll &>(kGetDNNEngineObjs, engines_map_); if (status != SUCCESS) { - GELOGE(status, "[Get][DNNEngineObjs]Failed, so_api_func %s", so_api_func.c_str()); - REPORT_CALL_ERROR("E19999", "Get DNNEngineObjs failed, so_api_func %s", so_api_func.c_str()); + GELOGE(status, "[Get][DNNEngineObjs]Failed, so_api_func %s", kGetDNNEngineObjs); + REPORT_CALL_ERROR("E19999", "Get DNNEngineObjs failed, so_api_func %s", kGetDNNEngineObjs); return status; } @@ -117,8 +119,8 @@ Status DNNEngineManager::Initialize(const std::map &op if ((attrs.mem_type.size()) != 1 || (attrs.mem_type[0] != GE_ENGINE_ATTR_MEM_TYPE_HBM)) { GELOGE(GE_ENG_MEMTYPE_ERROR, "[Check][Param]Engine %s in aicore, but the memory type is " "not HBM, mem_type_size %lu", (iter->first).c_str(), attrs.mem_type.size()); - REPORT_INNER_ERROR("E19999", "Engine %s in aicore, but the memory type is not HBM, " - "mem_type_size %lu", (iter->first).c_str(), attrs.mem_type.size()); + REPORT_INNER_ERROR("E19999", "Engine %s in aicore, but the memory type is not HBM, mem_type_size %lu", + (iter->first).c_str(), attrs.mem_type.size()); return GE_ENG_MEMTYPE_ERROR; } } @@ -161,6 +163,7 @@ Status DNNEngineManager::Finalize() { } init_flag_ = false; engines_map_.clear(); + atomic_2_composite_.clear(); return SUCCESS; } @@ -183,7 +186,7 @@ bool DNNEngineManager::IsEngineRegistered(const std::string &name) { return false; } -void DNNEngineManager::InitPerformanceStaistic() { +void DNNEngineManager::InitPerformanceStatistic() { std::lock_guard lock(mutex_); checksupport_cost_.clear(); } @@ -201,15 +204,8 @@ std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) { auto op_desc = node_ptr->GetOpDesc(); GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GE_CLI_GE_NOT_INITIALIZED, "DNNEngineManager: op_desc is nullptr"); return ""); - // Use the OpsKernelManager in GELib to get the opInfos for this opCode - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Get][DNNEngineName]Failed, gelib not init before"); - REPORT_INNER_ERROR("E19999", "Get DNNEngineName failed, gelib not init before"); - return ""; - } - OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj(); - std::vector op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); + // Use the OpsKernelManager to get the opInfos for this opCode + std::vector op_infos = OpsKernelManager::GetInstance().GetOpsKernelInfo(op_desc->GetType()); if (op_infos.empty()) { GELOGI("DNNEngineManager: Can not get op info by op type %s", op_desc->GetType().c_str()); return ""; @@ -221,47 +217,43 @@ std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) { std::string exclude_core_Type = (ge_core_type == kVectorCore) ? kAIcoreEngine : kVectorEngine; GELOGD("engine type will exclude: %s", exclude_core_Type.c_str()); - auto root_graph = ge::GraphUtils::FindRootGraph(node_ptr->GetOwnerComputeGraph()); std::map unsupported_reasons; for (const auto &it : op_infos) { if (it.engine == exclude_core_Type) { continue; } - auto &kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores(); - auto &kernel_name = it.opKernelLib; - auto kernel_info_store = kernel_map.find(kernel_name); - if (kernel_info_store != kernel_map.end()) { - std::string unsupported_reason; - // It will be replaced by engine' checksupport - uint64_t start_time = GetCurrentTimestamp(); - if (kernel_info_store->second->CheckSupported(node_ptr, unsupported_reason)) { - checksupport_cost_[kernel_name] += GetCurrentTimestamp() - start_time; - op_desc->SetOpEngineName(it.engine); - op_desc->SetOpKernelLibName(kernel_name); - // set attrs for taking information when load txt to graph object - if (it.flagAsync) { - GELOGD("Set aicpu blocking op:%s attribute(is_blocking_op):true", op_desc->GetName().c_str()); - (void)AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); - } - (void) AttrUtils::SetStr(op_desc, ATTR_NAME_ENGINE_NAME_FOR_LX, it.engine); - (void) AttrUtils::SetStr(op_desc, ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX, kernel_name); - GELOGD("DNNEngineManager:Set OpKernelLibName %s and engine name %s to op_desc %s", kernel_name.c_str(), - it.engine.c_str(), op_desc->GetName().c_str()); - return it.engine; - } else { - checksupport_cost_[kernel_name] += GetCurrentTimestamp() - start_time; - unsupported_reasons.emplace(kernel_name, unsupported_reason); - GELOGI("DNNEngineManager:Check support failed, kernel_name is %s, op type is %s, op name is %s", - kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str()); - if (!op_desc->HasAttr("_is_ge_op")) { - ErrorManager::GetInstance().ATCReportErrMessage("W11001", {"opname"}, {op_desc->GetName()}); - } + const auto &kernel_name = it.opKernelLib; + auto kernel_info_store = OpsKernelManager::GetInstance().GetOpsKernelInfoStore(kernel_name); + if (kernel_info_store == nullptr) { + GELOGW("DNNEngineManager:Can not find any supported ops kernel info store by kernel_name %s, op type is %s, " + "op name is %s", kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str()); + return ""; + } + std::string unsupported_reason; + // It will be replaced by engine's check support + uint64_t start_time = GetCurrentTimestamp(); + if (kernel_info_store->CheckSupported(node_ptr, unsupported_reason)) { + checksupport_cost_[kernel_name] += GetCurrentTimestamp() - start_time; + op_desc->SetOpEngineName(it.engine); + op_desc->SetOpKernelLibName(kernel_name); + // set attrs for taking information when load txt to graph object + if (it.flagAsync) { + GELOGD("Set aicpu blocking op:%s attribute(is_blocking_op):true", op_desc->GetName().c_str()); + (void)AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); } + (void) AttrUtils::SetStr(op_desc, ATTR_NAME_ENGINE_NAME_FOR_LX, it.engine); + (void) AttrUtils::SetStr(op_desc, ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX, kernel_name); + GELOGD("DNNEngineManager:Set kernel_lib %s, atomic engine %s, to node %s", kernel_name.c_str(), it.engine.c_str(), + op_desc->GetName().c_str()); + return it.engine; } else { - GELOGW( - "DNNEngineManager:Can not find any supported ops kernel info store by kernel_name %s," - "op type is %s, op name is %s", - kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str()); + checksupport_cost_[kernel_name] += GetCurrentTimestamp() - start_time; + unsupported_reasons.emplace(kernel_name, unsupported_reason); + GELOGI("DNNEngineManager:Check support failed, kernel_name is %s, op type is %s, op name is %s", + kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str()); + if (!op_desc->HasAttr("_is_ge_op")) { + ErrorManager::GetInstance().ATCReportErrMessage("W11001", {"opname"}, {op_desc->GetName()}); + } } } @@ -276,6 +268,7 @@ std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) { op_desc->GetType().c_str(), it.first.c_str(), it.second.c_str()); } + auto root_graph = ge::GraphUtils::FindRootGraph(node_ptr->GetOwnerComputeGraph()); analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), analyzer::CHECKSUPPORT, node_ptr, reason}; // do not change original process @@ -289,6 +282,184 @@ std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) { return ""; } +std::string DNNEngineManager::GetCompositeEngineName(const ge::NodePtr &node_ptr, uint32_t recursive_depth) { + // op_desc of node should not be null + const auto &op_desc = node_ptr->GetOpDesc(); + if (recursive_depth > kMaxRecursiveDepth) { + REPORT_INNER_ERROR("E19999", "Get CompositeEngineName will be terminated because too many nesting levels(%u) of " + "subgraphs, last node is %s", recursive_depth, op_desc->GetName().c_str()); + GELOGE(PARAM_INVALID, + "[Check][Param] Get CompositeEngineName will be terminated because too many nesting levels(%u) of subgraphs, " + "last node is %s", recursive_depth, op_desc->GetName().c_str()); + return ""; + } + + if (OpsKernelManager::GetInstance().GetCompositeEngines().empty() || + OpsKernelManager::GetInstance().GetCompositeEngineKernelLibNames().empty()) { + return ""; + } + + // composite engine name exist + std::string composite_engine_name; + (void)AttrUtils::GetStr(op_desc, ATTR_NAME_COMPOSITE_ENGINE_NAME, composite_engine_name); + std::string composite_engine_kernel_lib_name; + (void)AttrUtils::GetStr(op_desc, ATTR_NAME_COMPOSITE_ENGINE_KERNEL_LIB_NAME, composite_engine_kernel_lib_name); + if (!composite_engine_name.empty() && !composite_engine_kernel_lib_name.empty()) { + return composite_engine_name; + } + + // normal node without subgraph + if (op_desc->GetSubgraphInstanceNames().empty()) { + return GetCompositeEngine(node_ptr); + } + return GetCompositeEngine(node_ptr, recursive_depth); +} + +std::string DNNEngineManager::GetCompositeEngine(const NodePtr &node) { + // op_desc of node should not be null + const auto &op_desc = node->GetOpDesc(); + auto atomic_engine_name = op_desc->GetOpEngineName().empty() ? GetDNNEngineName(node) : op_desc->GetOpEngineName(); + bool gelocal_follow_flag = false; + if (IsStreamAssignSkip(atomic_engine_name)) { + bool in_diff_flag = false; + std::string in_composite_engine_name = kInvalidCompositeEngineName; + for (const auto &in_node : node->GetInAllNodes()) { + std::string tmp_composite_engine_name; + (void)AttrUtils::GetStr(in_node->GetOpDesc(), ATTR_NAME_COMPOSITE_ENGINE_NAME, tmp_composite_engine_name); + if (in_composite_engine_name == kInvalidCompositeEngineName) { + in_composite_engine_name = tmp_composite_engine_name; + } else if (in_composite_engine_name != tmp_composite_engine_name) { + in_diff_flag = true; + break; + } + } + if (!in_diff_flag && + (in_composite_engine_name != kInvalidCompositeEngineName) && + !in_composite_engine_name.empty()) { + gelocal_follow_flag = true; + } + } + std::string composite_engine_name; + if (!gelocal_follow_flag) { + composite_engine_name = GetCompositeEngineName(atomic_engine_name); + } + const auto &composite_engine_kernel_lib_name = GetCompositeEngineKernelLibName(composite_engine_name); + if (composite_engine_name.empty() || composite_engine_kernel_lib_name.empty()) { + (void)op_desc->DelAttr(ATTR_NAME_COMPOSITE_ENGINE_NAME); + (void)op_desc->DelAttr(ATTR_NAME_COMPOSITE_ENGINE_KERNEL_LIB_NAME); + } else { + GELOGI("Assign composite engine %s, kernel lib name %s for node %s.", composite_engine_name.c_str(), + composite_engine_kernel_lib_name.c_str(), op_desc->GetName().c_str()); + (void)AttrUtils::SetStr(op_desc, ATTR_NAME_COMPOSITE_ENGINE_NAME, composite_engine_name); + (void)AttrUtils::SetStr(op_desc, ATTR_NAME_COMPOSITE_ENGINE_KERNEL_LIB_NAME, composite_engine_kernel_lib_name); + } + return composite_engine_name; +} + +std::string DNNEngineManager::GetCompositeEngine(const NodePtr &func_node, uint32_t recursive_depth) { + // op_desc of node should not be null + const auto &op_desc = func_node->GetOpDesc(); + bool graph_diff_composite_engine_flag = false; + std::string graph_composite_engine_name = kInvalidCompositeEngineName; + std::vector subgraphs; + if (NodeUtils::GetDirectSubgraphs(func_node, subgraphs) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Get subgraphs of node %s failed", op_desc->GetName().c_str()); + GELOGE(FAILED, "[Check][Param] Get subgraphs of node %s failed", op_desc->GetName().c_str()); + return ""; + } + for (const auto &subgraph : subgraphs) { + std::string cur_graph_composite_engine_name = GetCompositeEngine(subgraph, recursive_depth); + if (graph_composite_engine_name == kInvalidCompositeEngineName) { + graph_composite_engine_name = cur_graph_composite_engine_name; + } else if (graph_composite_engine_name != cur_graph_composite_engine_name) { + graph_diff_composite_engine_flag = true; + break; + } + } + + std::string composite_engine_name; + std::string composite_engine_kernel_lib_name = GetCompositeEngineKernelLibName(graph_composite_engine_name); + if (!graph_diff_composite_engine_flag && + (graph_composite_engine_name != kInvalidCompositeEngineName) && + !graph_composite_engine_name.empty() && + !composite_engine_kernel_lib_name.empty()) { + composite_engine_name = graph_composite_engine_name; + GELOGI("Assign composite engine %s, kernel lib name %s for node %s.", composite_engine_name.c_str(), + composite_engine_kernel_lib_name.c_str(), op_desc->GetName().c_str()); + (void)AttrUtils::SetStr(op_desc, ATTR_NAME_COMPOSITE_ENGINE_NAME, composite_engine_name); + (void)AttrUtils::SetStr(op_desc, ATTR_NAME_COMPOSITE_ENGINE_KERNEL_LIB_NAME, composite_engine_kernel_lib_name); + } else { + (void)op_desc->DelAttr(ATTR_NAME_COMPOSITE_ENGINE_NAME); + (void)op_desc->DelAttr(ATTR_NAME_COMPOSITE_ENGINE_KERNEL_LIB_NAME); + } + + return composite_engine_name; +} + +std::string DNNEngineManager::GetCompositeEngine(const ComputeGraphPtr &subgraph, uint32_t recursive_depth) { + std::string graph_composite_engine_name; + (void)AttrUtils::GetStr(subgraph, ATTR_NAME_COMPOSITE_ENGINE_NAME, graph_composite_engine_name); + // if subgraph has been assigned + if (!graph_composite_engine_name.empty()) { + return graph_composite_engine_name; + } + + bool node_diff_composite_engine_flag = false; + std::string node_composite_engine_name = kInvalidCompositeEngineName; + uint32_t assigned_node_num = 0; + for (const auto &cur_node : subgraph->GetDirectNode()) { + if (IsNoTask(cur_node)) { + continue; + } + assigned_node_num++; + std::string cur_node_composite_engine_name = GetCompositeEngineName(cur_node, recursive_depth + 1); + if (node_composite_engine_name == kInvalidCompositeEngineName) { + node_composite_engine_name = cur_node_composite_engine_name; + } else if (node_composite_engine_name != cur_node_composite_engine_name) { + node_diff_composite_engine_flag = true; + break; + } + } + if (assigned_node_num == 0) { + GELOGD("all nodes in subgraph %s belongs to ge_local engine", subgraph->GetName().c_str()); + return ""; + } + if (!node_diff_composite_engine_flag && + (node_composite_engine_name != kInvalidCompositeEngineName) && + !node_composite_engine_name.empty()) { + GELOGI("Assign composite engine %s for subgraph %s.", node_composite_engine_name.c_str(), subgraph->GetName().c_str()); + (void)AttrUtils::SetStr(subgraph, ATTR_NAME_COMPOSITE_ENGINE_NAME, node_composite_engine_name); + graph_composite_engine_name = node_composite_engine_name; + } + else { + (void)subgraph->DelAttr(ATTR_NAME_COMPOSITE_ENGINE_NAME); + } + + return graph_composite_engine_name; +} + +std::string DNNEngineManager::GetCompositeEngineName(const string &atomic_engine_name) { + if (atomic_2_composite_.empty()) { + InitAtomicCompositeMapping(); + } + const auto &iter = atomic_2_composite_.find(atomic_engine_name); + if (iter == atomic_2_composite_.end()) { + GELOGW("Composite engine which contains atomic engine %s is not registered", atomic_engine_name.c_str()); + return ""; + } + return iter->second; +} + +std::string DNNEngineManager::GetCompositeEngineKernelLibName(const string &composite_engine_name) const { + const auto &composite_engine_2_kernel_lib_name = OpsKernelManager::GetInstance().GetCompositeEngineKernelLibNames(); + const auto &iter = composite_engine_2_kernel_lib_name.find(composite_engine_name); + if (iter == composite_engine_2_kernel_lib_name.end()) { + GELOGW("Kernel lib name of composite engine %s is not registered", composite_engine_name.c_str()); + return ""; + } + return iter->second; +} + std::string DNNEngineManager::GetHostCpuEngineName(const std::vector &op_infos, const OpDescPtr &op_desc) const { for (const auto &it : op_infos) { @@ -422,8 +593,8 @@ Status DNNEngineManager::ParserEngineMessage(const json engines_json, const std: engine_conf_ptr->independent = engines_elems[kIndependent]; } - if (engines_elems.find(kAttch) != engines_elems.end()) { - engine_conf_ptr->attach = engines_elems[kAttch]; + if (engines_elems.find(kAttach) != engines_elems.end()) { + engine_conf_ptr->attach = engines_elems[kAttach]; } if (engines_elems.find(kSkipAssignStream) != engines_elems.end()) { @@ -531,4 +702,59 @@ Status DNNEngineManager::CheckJsonFile() { GELOGD("Check json file success"); return SUCCESS; } + +void DNNEngineManager::InitAtomicCompositeMapping() { + for (const auto &item : OpsKernelManager::GetInstance().GetCompositeEngines()) { + const auto &composite_engine = GetEngine(item.first); + if ((composite_engine == nullptr) || composite_engine->IsAtomic()) { + GELOGW("Composite engine %s is not registered", item.first.c_str()); + } + for (const auto &atomic_engine_name : item.second) { + const auto &atomic_engine = GetEngine(atomic_engine_name); + if ((atomic_engine == nullptr) || !atomic_engine->IsAtomic()) { + GELOGW("Atomic engine %s is not registered", atomic_engine_name.c_str()); + continue; + } + auto iter = atomic_2_composite_.find(atomic_engine_name); + if (iter != atomic_2_composite_.end()) { + GELOGW("Atomic engine %s has been contained in composite engine %s, and will be overwritten by engine %s", + atomic_engine_name.c_str(), iter->second.c_str(), item.first.c_str()); + } + atomic_2_composite_[atomic_engine_name] = item.first; + } + } +} + +bool DNNEngineManager::IsNoTask(const NodePtr &node) { + const auto &op_desc = node->GetOpDesc(); + // op_desc of node should not be null + if (op_desc->HasAttr(ATTR_NAME_NOTASK)) { + return true; + } + return IsStreamAssignSkip(node) && op_desc->GetSubgraphInstanceNames().empty(); +} + +bool DNNEngineManager::IsStreamAssignSkip(const NodePtr &node) { + const auto &op_desc = node->GetOpDesc(); + // op_desc of node should not be null + const auto &engine_name = op_desc->GetOpEngineName().empty() ? GetDNNEngineName(node) : op_desc->GetOpEngineName(); + return IsStreamAssignSkip(engine_name); +} + +bool DNNEngineManager::IsStreamAssignSkip(const string &engine_name) { + // Only one scheduler has been supported by now + for (const auto &scheduler : schedulers_) { + const auto &iter = scheduler.second.cal_engines.find(engine_name); + if (iter == scheduler.second.cal_engines.end()) { + GELOGW("No engine found within name %s", engine_name.c_str()); + continue; + } + if (iter->second == nullptr) { + GELOGW("engine configuration of engine %s is null", engine_name.c_str()); + continue; + } + return iter->second->skip_assign_stream; + } + return false; +} } // namespace ge diff --git a/ge/engine_manager/dnnengine_manager.h b/ge/engine_manager/dnnengine_manager.h index 42da3596..379fedc5 100755 --- a/ge/engine_manager/dnnengine_manager.h +++ b/ge/engine_manager/dnnengine_manager.h @@ -60,13 +60,21 @@ using DNNEnginePtr = std::shared_ptr; class DNNEngineManager { public: friend class GELib; + static DNNEngineManager &GetInstance(); std::shared_ptr GetEngine(const std::string &name) const; + const std::map &GetAllEngines() const { return engines_map_; } bool IsEngineRegistered(const std::string &name); // If can't find appropriate engine name, return "", report error string GetDNNEngineName(const ge::NodePtr &node_ptr); + string GetCompositeEngineName(const ge::NodePtr &node_ptr, uint32_t recursive_depth = 1); + string GetCompositeEngineName(const string &atomic_engine_name); + string GetCompositeEngineKernelLibName(const string &composite_engine_name) const; const map &GetSchedulers() const; const map &GetCheckSupportCost() const; - void InitPerformanceStaistic(); + void InitPerformanceStatistic(); + bool IsNoTask(const NodePtr &node); + bool IsStreamAssignSkip(const NodePtr &node); + bool IsStreamAssignSkip(const string &engine_name); private: DNNEngineManager(); @@ -79,11 +87,19 @@ class DNNEngineManager { map &engines); Status CheckJsonFile(); std::string GetHostCpuEngineName(const std::vector &op_infos, const OpDescPtr &op_desc) const; + + void InitAtomicCompositeMapping(); + std::string GetCompositeEngine(const NodePtr &node); + std::string GetCompositeEngine(const NodePtr &func_node, uint32_t recursive_depth); + std::string GetCompositeEngine(const ComputeGraphPtr &subgraph, uint32_t recursive_depth); + PluginManager plugin_mgr_; std::map engines_map_; std::map engines_attrs_map_; std::map schedulers_; std::map checksupport_cost_; + // {atomic_engine, composite_engine} + std::map atomic_2_composite_{}; bool init_flag_; mutable std::mutex mutex_; }; diff --git a/ge/engine_manager/engine_conf.json b/ge/engine_manager/engine_conf.json index ad43c9ab..9986c28c 100755 --- a/ge/engine_manager/engine_conf.json +++ b/ge/engine_manager/engine_conf.json @@ -61,6 +61,13 @@ "independent": false, "skip_assign_stream": false, "attach": true + }, + { + "id": "ffts_plus", + "name": "FFTS+", + "independent": false, + "skip_assign_stream": true, + "attach": true } ] } diff --git a/ge/ge_runtime/task/label_goto_task.cc b/ge/ge_runtime/task/label_goto_task.cc index 7cb6d556..f0b15f85 100644 --- a/ge/ge_runtime/task/label_goto_task.cc +++ b/ge/ge_runtime/task/label_goto_task.cc @@ -72,7 +72,7 @@ bool LabelGotoTask::Distribute() { return false; } - rt_ret = rtLabelListCpy(reinterpret_cast(label_list.data()), label_list.size(), label_info_, label_info_size); + rt_ret = rtLabelListCpy(const_cast(label_list.data()), label_list.size(), label_info_, label_info_size); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); return false; diff --git a/ge/graph/build/label_allocator.cc b/ge/graph/build/label_allocator.cc index f2329769..9408d5fa 100644 --- a/ge/graph/build/label_allocator.cc +++ b/ge/graph/build/label_allocator.cc @@ -18,7 +18,6 @@ #include "framework/common/types.h" #include "framework/common/util.h" -#include "framework/common/ge_inner_error_codes.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/label/label_maker.h" @@ -85,8 +84,9 @@ bool LabelAllocator::CollectFunctionalNode(ComputeGraphPtr &graph, std::setGetOpDesc() != nullptr && func_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { - GELOGD("Graph[%s] is ffts subgraph, skip label allocator.", graph->GetName().c_str()); + if (func_node->GetOpDesc() != nullptr && (func_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH) || + func_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH))) { + GELOGD("Graph[%s] is ffts/ffts+ subgraph, skip label allocator.", graph->GetName().c_str()); return true; } diff --git a/ge/graph/build/stream_allocator.cc b/ge/graph/build/stream_allocator.cc index 987a77f7..b33a1c54 100644 --- a/ge/graph/build/stream_allocator.cc +++ b/ge/graph/build/stream_allocator.cc @@ -17,18 +17,12 @@ #include "graph/build/stream_allocator.h" #include #include -#include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/fmk_error_codes.h" -#include "framework/common/types.h" #include "graph/build/logical_stream_allocator.h" #include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" #include "graph/utils/graph_utils.h" #include "init/gelib.h" -#include "framework/common/string_util.h" -#include "common/util/error_manager/error_manager.h" using std::map; using std::set; @@ -433,7 +427,8 @@ Status StreamAllocator::SetActiveStreamsForSubgraphs() { // Insert the send/recv event id to the graph Status StreamAllocator::InsertSyncEvents() { auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) { - return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH); + return !(node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH) || + node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH)); }; for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag(), nullptr, ffts_filter)) { @@ -536,7 +531,9 @@ Status StreamAllocator::InsertEventsForSubgraph() { for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) { GE_CHECK_NOTNULL(subgraph); const auto parent_node = subgraph->GetParentNode(); - if (parent_node != nullptr && parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { + if (parent_node != nullptr && (parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH) || + parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH) || + parent_node->GetOpDesc()->HasAttr(ATTR_NAME_THREAD_SCOPE_ID))) { GELOGD("Skip ffts subgraph, parent node is %s.", parent_node->GetName().c_str()); continue; } diff --git a/ge/graph/build/task_generator.cc b/ge/graph/build/task_generator.cc index abb409c4..50ba37ad 100755 --- a/ge/graph/build/task_generator.cc +++ b/ge/graph/build/task_generator.cc @@ -356,7 +356,8 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra GE_MAKE_GUARD(release, callback); auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) { - return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH); + return !(node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH) || + node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH)); }; for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag(), nullptr, ffts_filter)) { OpDescPtr op_desc = node->GetOpDesc(); @@ -371,7 +372,6 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra continue); GE_CHK_STATUS_RET(UpdateOpIsVarAttr(op_desc, graph->GetSessionID())); - string op_kernel_lib_name = op_desc->GetOpKernelLibName(); // For fusion ddb pass, task def must be continuous. // Part2: Call auto fusion_task_info = @@ -384,13 +384,10 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str()); continue; } - GE_CHK_BOOL_EXEC_INFO(!op_kernel_lib_name.empty(), continue, - "Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); - auto kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); - GE_CHECK_NOTNULL(kernel_info_store); GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", name.c_str(), type.c_str()); - if (node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { + if (node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH) || + node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH)) { GE_CHK_STATUS_RET(UpdateAnchorStatusForFfts(node), "[Call][UpdateAnchorStatusForFfts] node:%s(%s) failed", name.c_str(), type.c_str()); } @@ -406,10 +403,56 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra "[Set][KnownShapeStream] node[name:%s(%s), id:%ld] stream id is invalid.", name.c_str(), type.c_str(), op_id); } + std::string op_kernel_lib_name = op_desc->GetOpKernelLibName(); + GE_CHK_BOOL_EXEC_INFO(!op_kernel_lib_name.empty(), continue, + "Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); + auto kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); + GE_CHECK_NOTNULL(kernel_info_store);; + if (op_desc->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH)) { + (void)AttrUtils::GetStr(op_desc, ATTR_NAME_COMPOSITE_ENGINE_KERNEL_LIB_NAME, op_kernel_lib_name); + } GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task.", op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id); GE_TIMESTAMP_RESTART(GenerateTask); - auto ret = OpsKernelBuilderManager::Instance().GenerateTask(*node, run_context, task_def_list); + auto ret = SUCCESS; + if (op_desc->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH)) { + std::vector subgraphs; + if (NodeUtils::GetDirectSubgraphs(node, subgraphs) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Get subgraphs of node %s failed", op_desc->GetName().c_str()); + GELOGE(FAILED, "[Check][Param] Get subgraphs of node %s failed", op_desc->GetName().c_str()); + return FAILED; + } + for (const auto &subgraph : subgraphs) { + for (const auto &tmp_node : subgraph->GetAllNodes()) { + bool notask = false; + (void)AttrUtils::GetBool(tmp_node->GetOpDesc(), ATTR_NAME_NOTASK, notask); + if (notask) { + GELOGI("Node[name:%s, type:%s] does not need to generate context.", + tmp_node->GetName().c_str(), tmp_node->GetType().c_str()); + continue; + } + std::string atomic_op_kernel_lib_name = op_desc->GetOpKernelLibName(); + GE_CHK_BOOL_EXEC_INFO(!atomic_op_kernel_lib_name.empty(), continue, + "Node[name:%s, type:%s] does not need to generate task.", + tmp_node->GetName().c_str(), tmp_node->GetType().c_str()); + GE_CHECK_NOTNULL(ops_kernel_manager.GetOpsKernelInfoStore(atomic_op_kernel_lib_name)); + GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] context.", + atomic_op_kernel_lib_name.c_str(), tmp_node->GetName().c_str(), tmp_node->GetType().c_str(), + tmp_node->GetOpDesc()->GetId(), tmp_node->GetOpDesc()->GetStreamId()); + ret = OpsKernelBuilderManager::Instance().GenerateTask(*tmp_node, run_context, task_def_list); + if (ret != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Call OpsKernelBuilderManager GenerateTask fail for op:%s(%s)", + tmp_node->GetName().c_str(), tmp_node->GetType().c_str()); + GELOGE(ret, "[Generate][Task] fail for op:%s(%s)", tmp_node->GetName().c_str(), + tmp_node->GetType().c_str()); + return ret; + } + } + } + ret = OpsKernelBuilderManager::Instance().GenerateTask(*node, run_context, task_def_list, false); + } else { + ret = OpsKernelBuilderManager::Instance().GenerateTask(*node, run_context, task_def_list); + } GE_TIMESTAMP_ADD(GenerateTask); if (ret != SUCCESS) { REPORT_CALL_ERROR("E19999", "Call OpsKernelBuilderManager GenerateTask fail for op:%s(%s)", diff --git a/ge/graph/load/model_manager/davinci_model.cc b/ge/graph/load/model_manager/davinci_model.cc index 495ec28e..f5ae32b8 100755 --- a/ge/graph/load/model_manager/davinci_model.cc +++ b/ge/graph/load/model_manager/davinci_model.cc @@ -100,9 +100,6 @@ const uint32_t kEndOfSequenceNew = 507005; const int32_t kModelAbortNormal = 0x0704000e; const int32_t kModelAbortNormalNew = 507024; const uint32_t kInteval = 2; -const uint32_t kFftsTbeHandleElementSize = 2; -const uint32_t kNonTailBlock = 0; -const uint32_t kTailBlock = 1; const char *const kModelName = "model_name"; const char *const kModeleId = "model_id"; const char *const kLoadStartTime = "load_start_time"; @@ -132,6 +129,10 @@ const char *const kStubFuncName = "_register_stub_func"; const uint32_t kStringHeadElems = 2; const uint32_t kPlacementHostData = 0; const size_t kAlignment = 64; +const uint32_t kAutoThreadMode = 1; +const std::vector kMixAttrPrefix = { "_mix_aic", "_mix_aiv" }; +const std::string kMixCoreType = "MIX_AIC_AIV"; +const std::string kAutoAttrPrefix = "_thread_"; inline bool IsDataOp(const std::string &node_type) { return (node_type == DATA_TYPE) || (node_type == AIPP_DATA_TYPE) || (node_type == ANN_DATA_TYPE); @@ -951,8 +952,10 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { GE_TIMESTAMP_RESTART(InitTbeHandle); if (IsTbeTask(op_desc)) { - Status status = - op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID) ? InitTbeHandleWithFfts(op_desc) : InitTbeHandle(op_desc); + uint32_t thread_mode = 0; + (void)AttrUtils::GetInt(op_desc, ATTR_NAME_THREAD_MODE, thread_mode); + bool is_auto_mode = op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID) && (thread_mode == kAutoThreadMode); + Status status = (is_auto_mode) ? InitTbeHandleInAutoMode(op_desc) : InitTbeHandle(op_desc); if (status != SUCCESS) { GELOGE(status, "[Init][TbeHandle] failed. op:%s", op_desc->GetName().c_str()); return status; @@ -3734,52 +3737,47 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file!", op_desc->GetName().c_str()); return INTERNAL_ERROR; } - GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file, tbe_kernel, false), "Function register of bin file: %s failed", + GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file, tbe_kernel, UINT32_MAX), "Function register of bin file: %s failed", bin_file.c_str()); return SUCCESS; } -Status DavinciModel::InitTbeHandleWithFfts(const OpDescPtr &op_desc) { +Status DavinciModel::InitTbeHandleInAutoMode(const OpDescPtr &op_desc) { std::vector tbe_kernel; tbe_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_NAME_THREAD_TBE_KERNEL, tbe_kernel); - GELOGD("Kernel bin ptr vec size is %zu.", tbe_kernel.size()); - if (tbe_kernel.size() != kFftsTbeHandleElementSize) { - REPORT_INNER_ERROR("E19999", "Get tbe_kernel for op:%s(%s) fail, model_id:%u", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); - GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file, size is %zu when ffts", - op_desc->GetName().c_str(), tbe_kernel.size()); + std::vector bin_file_keys; + (void)AttrUtils::GetListStr(op_desc, kStubFuncName, bin_file_keys); + if (tbe_kernel.size() != bin_file_keys.size()) { + REPORT_INNER_ERROR("E19999", "[%s] number of bin_file != number of file_name, bin_file_num=%zu, file_name_num=%zu", + op_desc->GetName().c_str(), tbe_kernel.size(), bin_file_keys.size()); + GELOGE(INTERNAL_ERROR, + "[Check][Param] [%s] number of bin_file != number of file_name, bin_file_num=%zu, file_name_num=%zu", + op_desc->GetName().c_str(), tbe_kernel.size(), bin_file_keys.size()); return INTERNAL_ERROR; } - if (tbe_kernel[0] == nullptr || tbe_kernel[1] == nullptr) { - REPORT_INNER_ERROR("E19999", "Tbe kernel for op:%s is nullptr.", op_desc->GetName().c_str()); - GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: tvm bin file of %s is nullptr when ffts.", op_desc->GetName().c_str()); + if (tbe_kernel.empty()) { + REPORT_INNER_ERROR("E19999", "[%s] tbe kernel is empty", op_desc->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "[Check][Param] [%s] tbe kernel is empty", op_desc->GetName().c_str()); return INTERNAL_ERROR; } - vector bin_file_keys; - (void)AttrUtils::GetListStr(op_desc, kStubFuncName, bin_file_keys); - if (bin_file_keys.size() != kFftsTbeHandleElementSize) { - REPORT_INNER_ERROR("E19999", "Get bin_file for op:%s(%s) fail.", op_desc->GetName().c_str(), - op_desc->GetType().c_str()); - GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find bin file keys, size is %zu when ffts", - op_desc->GetName().c_str(), bin_file_keys.size()); - return INTERNAL_ERROR; + size_t num = tbe_kernel.size(); + GELOGD("Kernel bin num is %zu", num); + for (size_t i = 0; i < num; i++) { + if (tbe_kernel[i] == nullptr) { + REPORT_INNER_ERROR("E19999", "Tbe kernel for op:%s is nullptr.", op_desc->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: tvm bin file of %s is nullptr when ffts.", op_desc->GetName().c_str()); + return INTERNAL_ERROR; + } + GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[i], tbe_kernel[i], i), + "Function register of No. %zu bin file %s failed.", i, bin_file_keys[i].c_str()); } - GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kNonTailBlock], tbe_kernel[kNonTailBlock], true, - kNonTailBlock), - "Function register of first bin file %s failed.", bin_file_keys[kNonTailBlock].c_str()); - GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kTailBlock], tbe_kernel[kTailBlock], true, kTailBlock), - "Function register of second bin file %s failed.", bin_file_keys[kTailBlock].c_str()); return SUCCESS; } Status DavinciModel::FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel, - bool is_ffts, size_t thread_index) { - if (thread_index > 1) { - GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Thread index: %zu should less than 1.", thread_index); - return INTERNAL_ERROR; - } + uint32_t thread_index) { const char *bin_file_key; - if (is_ffts) { + if (thread_index != UINT32_MAX) { bin_file_key = GetRegisterStub(bin_file, ""); GELOGI("Node:%s inherit func name:%s directly.", op_desc->GetName().c_str(), bin_file_key); } else { @@ -3788,55 +3786,96 @@ Status DavinciModel::FunctionRegister(const OpDescPtr &op_desc, string &bin_file bin_file_key = GetRegisterStub(bin_file, session_graph_model_id); // from set, always valid. } - TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); std::lock_guard lock(tvm_bin_mutex_); if (rtQueryFunctionRegistered(bin_file_key) != RT_ERROR_NONE) { - void *bin_handle = nullptr; - if (!kernel_store.FindTBEHandle(bin_file_key, bin_handle)) { - GELOGD("TBE: can't find the kernel_name[%s] in HandleMap", bin_file_key); - - rtDevBinary_t binary; - GE_CHK_STATUS_RET(InitBinaryMagic(op_desc, is_ffts, thread_index, binary), "Init binary magic of %s failed.", - op_desc->GetName().c_str()); - binary.version = 0; - binary.data = tbe_kernel->GetBinData(); - binary.length = tbe_kernel->GetBinDataSize(); - GELOGD("TBE: binary.length: %lu", binary.length); - GE_CHK_RT_RET(rtDevBinaryRegister(&binary, &bin_handle)); - - GE_CHK_STATUS_RET(InitMetaData(op_desc, is_ffts, thread_index, bin_handle), "Init tvm meta data of %s failed.", - op_desc->GetName().c_str()); - kernel_store.StoreTBEHandle(bin_file_key, bin_handle, tbe_kernel); + if (thread_index != UINT32_MAX) { + GE_CHK_STATUS_RET(KernelRegister(op_desc, thread_index, bin_file_key, kAutoAttrPrefix, tbe_kernel), + "Kernel register for auto mode failed, node:%s, thread_index:%u, bin_file:%s, prefix:%s", + op_desc->GetName().c_str(), thread_index, bin_file_key, kAutoAttrPrefix.c_str()); } else { - GELOGI("TBE: find the kernel_name[%s] in HandleMap", bin_file_key); - kernel_store.ReferTBEHandle(bin_file_key); + std::string core_type; + (void)AttrUtils::GetStr(op_desc, ATTR_NAME_CUBE_VECTOR_CORE_TYPE, core_type); + if (core_type == kMixCoreType) { + for (const auto &prefix : kMixAttrPrefix) { + GE_CHK_STATUS_RET(KernelRegister(op_desc, thread_index, bin_file_key, prefix, tbe_kernel), + "Kernel register for mix mode failed, node:%s, thread_index:%u, bin_file:%s, prefix:%s", + op_desc->GetName().c_str(), thread_index, bin_file_key, prefix.c_str()); + } + } else { + GE_CHK_STATUS_RET(KernelRegister(op_desc, thread_index, bin_file_key, "", tbe_kernel), + "Kernel register for normal mode failed, node:%s, thread_index:%u, bin_file:%s", + op_desc->GetName().c_str(), thread_index, bin_file_key); + } } - std::string kernel_name; - GE_CHK_STATUS_RET(InitKernelName(op_desc, is_ffts, thread_index, kernel_name), "Init kernel name of %s failed.", - op_desc->GetName().c_str()); - GE_CHK_RT_RET(rtFunctionRegister(bin_handle, bin_file_key, bin_file_key, kernel_name.c_str(), 0)); - used_tbe_handle_map_[bin_file_key] = 1; // Init used num to 1. - return SUCCESS; } // Kernel registed, Increase used num in store. StoreTbeHandle(bin_file_key); return SUCCESS; } -Status DavinciModel::InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, - rtDevBinary_t &binary) { +Status DavinciModel::GetAddrAndPrefCnt(const std::string &kernel_name, void *&addr, uint32_t &pref_cnt) { + const auto &iter = addr_and_pref_cnt_.find(kernel_name); + if (iter == addr_and_pref_cnt_.end()) { + REPORT_INNER_ERROR("E19999", "Get addr and pref cnt failed, kernel_name:%s", kernel_name.c_str()); + GELOGE(INTERNAL_ERROR, "[Check][Param] Get addr and pref cnt failed, kernel_name:%s", kernel_name.c_str()); + return INTERNAL_ERROR; + } + addr = iter->second.first; + pref_cnt = iter->second.second; + return SUCCESS; +} + +Status DavinciModel::KernelRegister(const OpDescPtr &op_desc, uint32_t thread_index, + const char *bin_file_key, const std::string &attr_prefix, + OpKernelBinPtr &tbe_kernel) { + TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); + void *bin_handle = nullptr; + if (!kernel_store.FindTBEHandle(bin_file_key, bin_handle)) { + GELOGD("TBE: can't find the kernel_name[%s] in HandleMap", bin_file_key); + rtDevBinary_t binary; + GE_CHK_STATUS_RET(InitBinaryMagic(op_desc, thread_index, binary, attr_prefix), "Init binary magic of %s failed.", + op_desc->GetName().c_str()); + binary.version = 0; + binary.data = tbe_kernel->GetBinData(); + binary.length = tbe_kernel->GetBinDataSize(); + GELOGD("TBE: binary.length: %lu", binary.length); + GE_CHK_RT_RET(rtDevBinaryRegister(&binary, &bin_handle)); + + GE_CHK_STATUS_RET(InitMetaData(op_desc, thread_index, bin_handle, attr_prefix), "Init tvm meta data of %s failed.", + op_desc->GetName().c_str()); + kernel_store.StoreTBEHandle(bin_file_key, bin_handle, tbe_kernel); + } else { + GELOGI("TBE: find the kernel_name[%s] in HandleMap", bin_file_key); + kernel_store.ReferTBEHandle(bin_file_key); + } + std::string kernel_name; + GE_CHK_STATUS_RET(InitKernelName(op_desc, thread_index, kernel_name, attr_prefix), "Init kernel name of %s failed.", + op_desc->GetName().c_str()); + GE_CHK_RT_RET(rtFunctionRegister(bin_handle, bin_file_key, bin_file_key, kernel_name.c_str(), 0)); + void *addr; + uint32_t prefetch_cnt; + GE_CHK_RT_RET(rtGetAddrAndPrefCntWithHandle(bin_handle, kernel_name.c_str(), &addr, &prefetch_cnt)); + GELOGI("Get addr 0x%lx, pref_cnt %u for kernel_name %s", reinterpret_cast(addr), prefetch_cnt, + kernel_name.c_str()); + addr_and_pref_cnt_[kernel_name] = { addr, prefetch_cnt }; + used_tbe_handle_map_[bin_file_key] = 1; // Init used num to 1. + return SUCCESS; +} + +Status DavinciModel::InitBinaryMagic(const OpDescPtr &op_desc, uint32_t thread_index, rtDevBinary_t &binary, + const std::string &prefix) { string json_string; - const string &tvm_magic = is_ffts ? TVM_ATTR_NAME_THREAD_MAGIC : TVM_ATTR_NAME_MAGIC; + const string &tvm_magic = prefix + TVM_ATTR_NAME_MAGIC; const static std::map binary_magics = { {"RT_DEV_BINARY_MAGIC_ELF_AICPU", RT_DEV_BINARY_MAGIC_ELF_AICPU}, {"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF}, {"RT_DEV_BINARY_MAGIC_ELF_AIVEC", RT_DEV_BINARY_MAGIC_ELF_AIVEC}, {"RT_DEV_BINARY_MAGIC_ELF_AICUBE", RT_DEV_BINARY_MAGIC_ELF_AICUBE} }; - if (is_ffts) { + if (thread_index != UINT32_MAX) { vector json_list; (void)AttrUtils::GetListStr(op_desc, tvm_magic, json_list); - if (json_list.size() != kFftsTbeHandleElementSize) { + if (json_list.size() <= thread_index) { GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Attr is %s, thread index is %zu, json list size is %zu.", tvm_magic.c_str(), thread_index, json_list.size()); return INTERNAL_ERROR; @@ -3859,13 +3898,14 @@ Status DavinciModel::InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, siz return SUCCESS; } -Status DavinciModel::InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle) { +Status DavinciModel::InitMetaData(const OpDescPtr &op_desc, uint32_t thread_index, void *bin_handle, + const std::string &prefix) { string meta_data; - const string &tvm_metadata = is_ffts ? TVM_ATTR_NAME_THREAD_METADATA : TVM_ATTR_NAME_METADATA; - if (is_ffts) { + const string &tvm_metadata = prefix + TVM_ATTR_NAME_METADATA; + if (thread_index != UINT32_MAX) { vector meta_data_list; (void)AttrUtils::GetListStr(op_desc, tvm_metadata, meta_data_list); - if (meta_data_list.size() != kFftsTbeHandleElementSize) { + if (meta_data_list.size() <= thread_index) { GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, meta data list size is %zu.", tvm_metadata.c_str(), thread_index, meta_data_list.size()); return INTERNAL_ERROR; @@ -3881,8 +3921,9 @@ Status DavinciModel::InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t return SUCCESS; } -Status DavinciModel::InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name) { - if (is_ffts) { +Status DavinciModel::InitKernelName(const OpDescPtr &op_desc, uint32_t thread_index, string &kernel_name, + const std::string &prefix) { + if (thread_index != UINT32_MAX) { // delete prefix, eg: *sgt_graph_nodes*/loss_scale/gradient/fp32_vals/Mean_grad/Tile vector kernel_name_list; auto pos = op_desc->GetName().find("/"); @@ -3892,14 +3933,14 @@ Status DavinciModel::InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size } string attr_kernel_name = op_desc->GetName().substr(pos + 1) + "_thread_kernelname"; (void)AttrUtils::GetListStr(op_desc, attr_kernel_name, kernel_name_list); - if (kernel_name_list.size() != kFftsTbeHandleElementSize) { + if (kernel_name_list.size() <= thread_index) { GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, kernel name list size is %zu.", attr_kernel_name.c_str(), thread_index, kernel_name_list.size()); return INTERNAL_ERROR; } kernel_name = kernel_name_list[thread_index]; } else { - string attr_kernel_name = op_desc->GetName() + "_kernelname"; + string attr_kernel_name = prefix + op_desc->GetName() + "_kernelname"; (void)AttrUtils::GetStr(op_desc, attr_kernel_name, kernel_name); } return SUCCESS; diff --git a/ge/graph/load/model_manager/davinci_model.h b/ge/graph/load/model_manager/davinci_model.h index 76b0beef..935b2547 100755 --- a/ge/graph/load/model_manager/davinci_model.h +++ b/ge/graph/load/model_manager/davinci_model.h @@ -586,6 +586,8 @@ class DavinciModel { Status GetEventByStream(const rtStream_t &stream, rtEvent_t &rt_event); Status GetEventIdForBlockingAicpuOp(const OpDescPtr &op_desc, rtStream_t stream, uint32_t &event_id); + Status GetAddrAndPrefCnt(const std::string &kernel_name, void *&addr, uint32_t &pref_cnt); + private: // memory address of weights uint8_t *weights_mem_base_; @@ -772,12 +774,17 @@ class DavinciModel { /// @return Status /// Status InitTbeHandle(const OpDescPtr &op_desc); - Status InitTbeHandleWithFfts(const OpDescPtr &op_desc); - Status FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel, bool is_ffts, - size_t thread_index = 0); - Status InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, rtDevBinary_t &binary); - Status InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle); - Status InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name); + Status InitTbeHandleInAutoMode(const OpDescPtr &op_desc); + Status FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel, + uint32_t thread_index = 0); + Status KernelRegister(const OpDescPtr &op_desc, uint32_t thread_index, const char *bin_file_key, + const std::string &attr_prefix, OpKernelBinPtr &tbe_kernel); + Status InitBinaryMagic(const OpDescPtr &op_desc, uint32_t thread_index, rtDevBinary_t &binary, + const std::string &prefix = ""); + Status InitMetaData(const OpDescPtr &op_desc, uint32_t thread_index, void *bin_handle, + const std::string &prefix = ""); + Status InitKernelName(const OpDescPtr &op_desc, uint32_t thread_index, string &kernel_name, + const std::string &prefix = ""); void StoreTbeHandle(const string &handle_key); void CleanTbeHandle(); @@ -1025,6 +1032,8 @@ class DavinciModel { map used_tbe_handle_map_; + std::map> addr_and_pref_cnt_; + // for profiling task and graph info vector task_desc_info_; diff --git a/ge/graph/load/model_manager/task_info/ffts_plus_task_info.cc b/ge/graph/load/model_manager/task_info/ffts_plus_task_info.cc new file mode 100644 index 00000000..6834c863 --- /dev/null +++ b/ge/graph/load/model_manager/task_info/ffts_plus_task_info.cc @@ -0,0 +1,1078 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/load/model_manager/task_info/ffts_plus_task_info.h" +#include "graph/load/model_manager/davinci_model.h" + +namespace { +constexpr int kSrcSlotNum = 4; +constexpr int kWriteValueNum = 4; +constexpr int kUserDataNum = 9; + +constexpr int kManualIndex = 0; +constexpr int kManualAicAivCtxPcNum = 1; +constexpr int kAutoNonTailIndex = 0; +constexpr int kAutoTailIndex = 1; +constexpr int kAutoAicAivCtxPcNum = 2; +constexpr int kManualAicCtxIndex = 0; +constexpr int kManualAivCtxIndex = 1; +constexpr int kManualMixAicAivCtxPcNum = 2; +constexpr int kAutoNonTailAicCtxIndex = 0; +constexpr int kAutoTailAicCtxIndex = 1; +constexpr int kAutoNonTailAivCtxIndex = 2; +constexpr int kAutoTailAivCtxIndex = 3; +constexpr int kAutoMixAicAivCtxPcNum = 4; + +constexpr uint32_t k1BitMask = 0x00000001; // 1 bit , 0000,0001 +constexpr uint32_t k2BitsMask = 0x00000003; // 2 bits, 0000,0011 +constexpr uint32_t k3BitsMask = 0x00000007; // 3 bits, 0000,0111 +constexpr uint32_t k4BitsMask = 0x0000000F; // 4 bits, 0000,1111 +constexpr uint32_t k5BitsMask = 0x0000001F; // 5 bits, 0001,1111 +constexpr uint32_t k6BitsMask = 0x0000003F; // 6 bits, 0011,1111 +constexpr uint32_t k7BitsMask = 0x0000007F; // 7 bits, 0111,1111 +constexpr uint32_t k8BitsMask = 0x000000FF; // 8 bits, 1111,1111 + +constexpr uint32_t k12BitsMask = 0x00000FFF; // 12 bits, 0000,1111,1111,1111 +constexpr uint32_t k16BitsMask = 0x0000FFFF; // 16 bits, 1111,1111,1111,1111 + +constexpr uint32_t k17BitsMask = 0x0001FFFF; // 17 bits, 0000,0000,0000,0001,1111,1111,1111,1111 +constexpr uint32_t k32BitsMask = 0xFFFFFFFF; // 32 bits, 1111,1111,1111,1111,1111,1111,1111,1111 +} +namespace ge { +FftsPlusTaskInfo::~FftsPlusTaskInfo() { + GE_FREE_RT_LOG(args_); +} + +Status FftsPlusTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + GELOGI("Init FftsPlusTaskInfo Start"); + GE_CHECK_NOTNULL(davinci_model); + davinci_model_ = davinci_model; + GE_CHK_STATUS_RET_NOLOG(SetStream(task_def.stream_id(), davinci_model_->GetStreamList())); + + const domi::FftsPlusTaskDef &ffts_plus_task_def = task_def.ffts_plus_task(); + OpDescPtr op_desc = davinci_model_->GetOpByIndex(ffts_plus_task_def.op_index()); + GE_CHECK_NOTNULL(op_desc); + + args_size_ = sizeof(void *) * ffts_plus_task_def.addr_size(); + if (args_size_ > 0) { + GE_CHK_RT_RET(rtMalloc(&args_, args_size_, RT_MEMORY_HBM)); + } + + std::vector sqe_buffer(sizeof(rtFftsPlusSqe_t)); + auto ffts_plus_sqe = reinterpret_cast(sqe_buffer.data()); + InitFftsPlusSqe(ffts_plus_task_def.ffts_plus_sqe(), ffts_plus_sqe); + ffts_plus_task_info_.fftsPlusSqe = ffts_plus_sqe; + + int ctx_num = ffts_plus_task_def.ffts_plus_ctx_size(); + ffts_plus_task_info_.descBufLen = sizeof(rtFftsPlusComCtx_t) * ctx_num; + std::vector ctx_buffer(ffts_plus_task_info_.descBufLen); + auto ctx = reinterpret_cast(ctx_buffer.data()); + GELOGI("Init ctx begin, node %s, args_size=%d, ctx_num=%d", op_desc->GetName().c_str(), args_size_, ctx_num); + GE_CHK_STATUS_RET_NOLOG(InitFftsPlusCtx(ffts_plus_task_def, ctx_num, ctx)); + ffts_plus_task_info_.descBuf = reinterpret_cast(ctx_buffer.data()); + + if (!io_addrs_.empty()) { + GE_CHECK_NOTNULL(args_); + auto data_size = sizeof(void *) * io_addrs_.size(); + if (args_size_ < data_size) { + REPORT_INNER_ERROR("E19999", "addr_size %d of FftsPlusTaskInfo is less than number of task_addr %zu", + ffts_plus_task_def.addr_size(), io_addrs_.size()); + GELOGE(FAILED, "[Check][Param] addr_size %d of FftsPlusTaskInfo is less than number of task_addr %zu", + ffts_plus_task_def.addr_size(), io_addrs_.size()); + return FAILED; + } + GELOGI("Memcpy io addrs to 0x%lx, addr_size=%d, len=%d", + reinterpret_cast(args_), args_size_, data_size); + GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs_.data(), data_size, RT_MEMCPY_HOST_TO_DEVICE)); + } + GELOGI("Init FftsPlusTaskInfo success, node: %s", op_desc->GetName().c_str()); + + return SUCCESS; +} + +void FftsPlusTaskInfo::InitFftsPlusSqe(const domi::FftsPlusSqeDef &sqe_def, rtFftsPlusSqe_t *&sqe) { + InitFftsPlusSqeHeader(sqe_def.sqe_header(), sqe->sqeHeader); + + sqe->pmg = static_cast(sqe_def.pmg() & k2BitsMask); + sqe->ns = static_cast(sqe_def.ns() & k1BitMask); + sqe->partId = static_cast(sqe_def.part_id() & k8BitsMask); + sqe->qos = static_cast(sqe_def.qos() & k4BitsMask); + + sqe->totalContextNum = static_cast(sqe_def.total_context_num()); + sqe->readyContextNum = static_cast(sqe_def.ready_context_num()); + sqe->preloadContextNum = static_cast(sqe_def.preload_context_num()); + + sqe->dsplitUnit = static_cast(sqe_def.dsplit_unit() & k3BitsMask); + sqe->prefetchOstNum = static_cast(sqe_def.prefetch_ost_num() & k5BitsMask); + sqe->cmaintOstNum = static_cast(sqe_def.cmaint_ost_num() & k5BitsMask); + + sqe->aicPrefetchLower = static_cast(sqe_def.aic_prefetch_lower() & k5BitsMask); + sqe->aicPrefetchUpper = static_cast(sqe_def.aic_prefetch_upper() & k5BitsMask); + sqe->aivPrefetchLower = static_cast(sqe_def.aiv_prefetch_lower() & k5BitsMask); + sqe->aivPrefetchUpper = static_cast(sqe_def.aiv_prefetch_upper() & k5BitsMask); +} + +void FftsPlusTaskInfo::InitFftsPlusSqeHeader(const domi::StarsSqeHeaderDef &sqe_header_def, + rtStarsSqeHeader_t &sqe_header) { + sqe_header.l1Lock = static_cast(sqe_header_def.l1_lock()); + sqe_header.l1Unlock = static_cast(sqe_header_def.l1_unlock()); + sqe_header.blockDim = static_cast(sqe_header_def.block_dim()); +} + +Status FftsPlusTaskInfo::InitFftsPlusCtx(const domi::FftsPlusTaskDef &task_def, int ctx_num, void *&ctx) { + for (auto i = 0; i < ctx_num; i++) { + const domi::FftsPlusCtxDef &ctx_def = task_def.ffts_plus_ctx(i); + const auto &ctx_type = static_cast(ctx_def.context_type()); + GELOGI("Init ctx %d in FftsPlusTask, context_type=%u", i, ctx_type); + auto cur_ctx = reinterpret_cast(ctx) + sizeof(rtFftsPlusComCtx_t) * i; + switch (ctx_type) { + case RT_CTX_TYPE_AICORE: + case RT_CTX_TYPE_AIV: { + auto aic_aiv_ctx = reinterpret_cast(cur_ctx); + GE_CHK_STATUS_RET(InitAicAivCtx(ctx_def.aic_aiv_ctx(), aic_aiv_ctx), "Init AicAivCtx failed, ctx_index=%d", i); + break; + } + case RT_CTX_TYPE_NOTIFY_WAIT: + case RT_CTX_TYPE_NOTIFY_RECORD: { + auto notify_ctx = reinterpret_cast(cur_ctx); + GE_CHK_STATUS_RET(InitNotifyCtx(ctx_def.notify_ctx(), notify_ctx), "Init NotifyCtx failed, ctx_index=%d", i); + break; + } + case RT_CTX_TYPE_WRITE_VALUE: { + auto write_value_ctx = reinterpret_cast(cur_ctx); + GE_CHK_STATUS_RET(InitWriteValueCtx(ctx_def.write_value_ctx(), write_value_ctx), + "Init WriteValueCtx failed, ctx_index=%d", i); + break; + } + case RT_CTX_TYPE_MIX_AIC: + case RT_CTX_TYPE_MIX_AIV: { + auto mix_aic_aiv_ctx = reinterpret_cast(cur_ctx); + GE_CHK_STATUS_RET(InitMixAicAivCtx(ctx_def.mix_aic_aiv_ctx(), mix_aic_aiv_ctx), + "Init MixAicAivCtx failed, ctx_index=%d", i); + break; + } + case RT_CTX_TYPE_SDMA: { + auto sdma_ctx = reinterpret_cast(cur_ctx); + GE_CHK_STATUS_RET(InitSdmaCtx(ctx_def.sdma_ctx(), sdma_ctx), "Init SdmaCtx failed, ctx_index=%d", i); + break; + } + case RT_CTX_TYPE_FLUSH_DATA: + case RT_CTX_TYPE_INVALIDATE_DATA: + case RT_CTX_TYPE_WRITEBACK_DATA: { + auto data_ctx = reinterpret_cast(cur_ctx); + GE_CHK_STATUS_RET(InitDataCtx(ctx_def.data_ctx(), data_ctx), "Init DataCtx failed, ctx_index=%d", i); + break; + } + case RT_CTX_TYPE_AICPU: { + auto aicpu_ctx = reinterpret_cast(cur_ctx); + GE_CHK_STATUS_RET(InitAicpuCtx(ctx_def.aicpu_ctx(), aicpu_ctx), "Init AicpuCtx failed, ctx_index=%d", i); + break; + } + case RT_CTX_TYPE_COND_SWITCH: { + auto cond_switch_ctx = reinterpret_cast(cur_ctx); + GE_CHK_STATUS_RET(InitCondSwitchCtx(ctx_def.cond_switch_ctx(), cond_switch_ctx), + "Init CondSwitchCtx failed, ctx_index=%d", i); + break; + } + case RT_CTX_TYPE_CASE_SWITCH: { + if (ctx_def.has_case_switch_ctx() == ctx_def.has_case_default_ctx()) { + REPORT_INNER_ERROR("E19999", "case_switch_ctx %s and case_default_ctx %s when software ctx type is case, ctx_index=%d", + ctx_def.has_case_switch_ctx() ? "exist" : "not exist", + ctx_def.has_case_default_ctx() ? "exist" : "not exist", i); + GELOGE(FAILED, "[Check][Ctx] case_switch_ctx %s and case_default_ctx %s when software ctx type is case, ctx_index=%d", + ctx_def.has_case_switch_ctx() ? "exist" : "not exist", + ctx_def.has_case_default_ctx() ? "exist" : "not exist", i); + return FAILED; + } + if (ctx_def.has_case_switch_ctx()) { + auto case_switch_ctx = reinterpret_cast(cur_ctx); + GE_CHK_STATUS_RET(InitCaseSwitchCtx(ctx_def.case_switch_ctx(), case_switch_ctx), + "Init CaseSwitchCtx failed, ctx_index=%d", i); + } + if (ctx_def.has_case_default_ctx()) { + auto case_default_ctx = reinterpret_cast(cur_ctx); + GE_CHK_STATUS_RET(InitCaseDefaultCtx(ctx_def.case_default_ctx(), case_default_ctx), + "Init CaseDefaultCtx failed, ctx_index=%d", i); + } + break; + } + case RT_CTX_TYPE_AT_START: { + auto at_start_ctx = reinterpret_cast(cur_ctx); + GE_CHK_STATUS_RET(InitAtStartCtx(ctx_def.at_start_ctx(), at_start_ctx), + "Init AtStartCtx failed, ctx_index=%d", i); + break; + } + case RT_CTX_TYPE_AT_END: { + auto at_end_ctx = reinterpret_cast(cur_ctx); + GE_CHK_STATUS_RET(InitAtEndCtx(ctx_def.at_end_ctx(), at_end_ctx), "Init AtEndCtx failed, ctx_index=%d", i); + break; + } + case RT_CTX_TYPE_LABEL: { + auto label_ctx = reinterpret_cast(cur_ctx); + GE_CHK_STATUS_RET(InitLabelCtx(ctx_def.label_ctx(), label_ctx), "Init LabelCtx failed, ctx_index=%d", i); + break; + } + default: + REPORT_INNER_ERROR("E19999", "Unsupported ctx type %u", ctx_type); + GELOGE(FAILED, "[Check][CtxType] Unsupported ctx type %u", ctx_type); + return FAILED; + } + } + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitAicAivCtx(const domi::FftsPlusAicAivCtxDef &ctx_def, rtFftsPlusAicAivCtx_t *&ctx) { + ctx->successorNum = static_cast(ctx_def.successor_num()); + ctx->aten = static_cast(ctx_def.aten() & k1BitMask); + ctx->predCntInit = static_cast(ctx_def.pred_cnt_init()); + ctx->predCnt = static_cast(ctx_def.pred_cnt()); + + if (ctx_def.successor_list_size() > RT_CTX_SUCCESSOR_NUM) { + REPORT_INNER_ERROR("E19999", "Size of successor_list in FftsPlusAicAivCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + GELOGE(FAILED, "[Check][Param] Size of successor_list in FftsPlusAicAivCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.successor_list_size(); i++) { + ctx->successorList[i] = static_cast(ctx_def.successor_list(i)); + } + + ctx->stat = static_cast(ctx_def.stat() & k1BitMask); + ctx->schem = static_cast(ctx_def.schem() & k2BitsMask); + ctx->atm = static_cast(ctx_def.atm() & k1BitMask); + ctx->prefetchEnableBitmap = static_cast(ctx_def.atm() & k4BitsMask); + ctx->prefetchOnceBitmap = static_cast(ctx_def.atm() & k4BitsMask); + ctx->prefetchConfig = static_cast(ctx_def.prefetch_config()); + + ctx->threadId = static_cast(ctx_def.thread_id()); + ctx->threadDim = static_cast(ctx_def.thread_dim()); + + ctx->nonTailBlockdim = static_cast(ctx_def.non_tail_block_dim()); + ctx->tailBlockdim = static_cast(ctx_def.tail_block_dim()); + + uint64_t task_param_ptr_base = reinterpret_cast(args_) + sizeof(void *) * io_addrs_.size(); + GELOGD("FftsPlusAicAivCtxDef: task param addr is %lu.", task_param_ptr_base); + ctx->taskParamPtrBaseL = static_cast(task_param_ptr_base & k32BitsMask); + ctx->taskParamPtrBaseH = static_cast((task_param_ptr_base >> 32) & k16BitsMask); + ctx->taskParamPtrOffset = static_cast(ctx_def.task_param_ptr_offset()); + + if (ctx->atm == 0) { + GE_CHK_STATUS_RET(InitManualAicAivCtx(ctx_def, ctx), "Init AicAivCtx in manual mode failed"); + } else { + GE_CHK_STATUS_RET(InitAutoAicAivCtx(ctx_def, ctx), "Init AicAivCtx in auto mode failed"); + } + + if (ctx_def.src_slot_size() > kSrcSlotNum) { + REPORT_INNER_ERROR("E19999", "Size of src_slot in FftsPlusAicAivCtxDef should not > %d, but %d exactly", + kSrcSlotNum, ctx_def.src_slot_size()); + GELOGE(FAILED, "[Check][Param] Size of src_slot in FftsPlusAicAivCtxDef should not > %d, but %d exactly", + kSrcSlotNum, ctx_def.src_slot_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.src_slot_size(); i++) { + ctx->srcSlot[i] = static_cast(ctx_def.src_slot(i)); + } + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitManualAicAivCtx(const domi::FftsPlusAicAivCtxDef &ctx_def, rtFftsPlusAicAivCtx_t *&ctx) { + const auto &rts_param = davinci_model_->GetRuntimeParam(); + for (auto i = 0; i < ctx_def.task_addr_size(); ++i) { + uintptr_t logic_addr = ctx_def.task_addr(i); + uint8_t *io_addr = nullptr; + if (ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress] failed, logic addr is 0x%lx.", logic_addr); + return INTERNAL_ERROR; + } + GELOGD("index %d, task addr is %ld, logic addr is 0x%lx, io addr is 0x%lx", + i, ctx_def.task_addr(i), logic_addr, reinterpret_cast(io_addr)); + io_addrs_.emplace_back(io_addr); + } + // PcL for low 32 bits of pc, PcH for high 16 bits of pc + if (ctx_def.kernel_name_size() != kManualAicAivCtxPcNum) { + REPORT_INNER_ERROR("E19999", "Size of kernel_name in FftsPlusAicAivCtxDef should be %d, but %d exactly", + kManualAicAivCtxPcNum, ctx_def.kernel_name_size()); + GELOGE(FAILED, "[Check][Param] Size of kernel_name in FftsPlusAicAivCtxDef should be %d, but %d exactly", + kManualAicAivCtxPcNum, ctx_def.kernel_name_size()); + return FAILED; + } + uint32_t i_cache_prefetch_cnt; + void *task_start_pc = nullptr; + GE_CHK_STATUS_RET(davinci_model_->GetAddrAndPrefCnt(ctx_def.kernel_name(kManualIndex), task_start_pc, + i_cache_prefetch_cnt), + "Get addr and pref cnt failed, kernel_name=%s", ctx_def.kernel_name(kManualIndex).c_str()); + ctx->nonTailTaskStartPcL = static_cast(reinterpret_cast(task_start_pc) & k32BitsMask); + ctx->nonTailTaskStartPcH = static_cast((reinterpret_cast(task_start_pc) >> 32) & k16BitsMask); + ctx->tailTaskStartPcL = static_cast(reinterpret_cast(task_start_pc) & k32BitsMask); + ctx->tailTaskStartPcH = static_cast((reinterpret_cast(task_start_pc) >> 32) & k16BitsMask); + ctx->icachePrefetchCnt = static_cast(i_cache_prefetch_cnt & k5BitsMask); + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitAutoAicAivCtx(const domi::FftsPlusAicAivCtxDef &ctx_def, rtFftsPlusAicAivCtx_t *&ctx) { + const auto &rts_param = davinci_model_->GetRuntimeParam(); + for (auto i = 0; i < ctx->threadDim - 1; i++) { + GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, ctx_def, i, ctx_def.task_addr_offset_size())); + } + GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, ctx_def, ctx->threadDim - 1, ctx_def.input_output_count())); + for (auto k = 0; k < ctx_def.task_addr_size() - ctx_def.task_addr_offset_size(); ++k) { + auto logic_addr = reinterpret_cast(ctx_def.task_addr(ctx_def.task_addr_offset_size() + k)); + uint8_t *io_addr = nullptr; + if (ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress] failed, logic addr is 0x%lx.", logic_addr); + return INTERNAL_ERROR; + } + io_addrs_.emplace_back(io_addr); + } + // PcL for low 32 bits of pc, PcH for high 16 bits of pc + if (ctx_def.kernel_name_size() != kAutoAicAivCtxPcNum) { + REPORT_INNER_ERROR("E19999", "Size of kernel_name in FftsPlusAicAivCtxDef should be %d, but %d exactly", + kAutoAicAivCtxPcNum, ctx_def.kernel_name_size()); + GELOGE(FAILED, "[Check][Param] Size of kernel_name in FftsPlusAicAivCtxDef should be %d, but %d exactly", + kAutoAicAivCtxPcNum, ctx_def.kernel_name_size()); + return FAILED; + } + uint32_t non_tail_i_cache_prefetch_cnt; + void *non_tail_task_start_pc = nullptr; + GE_CHK_STATUS_RET(davinci_model_->GetAddrAndPrefCnt(ctx_def.kernel_name(kAutoNonTailIndex), non_tail_task_start_pc, + non_tail_i_cache_prefetch_cnt), + "Get addr and pref cnt failed, kernel_name=%s", ctx_def.kernel_name(kAutoNonTailIndex).c_str()); + ctx->nonTailTaskStartPcL = static_cast(reinterpret_cast(non_tail_task_start_pc) & k32BitsMask); + ctx->nonTailTaskStartPcH = static_cast((reinterpret_cast(non_tail_task_start_pc) >> 32) & + k16BitsMask); + uint32_t tail_i_cache_prefetch_cnt; + void *tail_task_start_pc = nullptr; + GE_CHK_STATUS_RET(davinci_model_->GetAddrAndPrefCnt(ctx_def.kernel_name(kAutoTailIndex), tail_task_start_pc, + tail_i_cache_prefetch_cnt), + "Get addr and pref cnt failed, kernel_name=%s", ctx_def.kernel_name(kAutoTailIndex).c_str()); + ctx->tailTaskStartPcL = static_cast(reinterpret_cast(tail_task_start_pc) & k32BitsMask); + ctx->tailTaskStartPcH = static_cast((reinterpret_cast(tail_task_start_pc) >> 32) & k16BitsMask); + uint32_t i_cache_prefetch_cnt = std::min(non_tail_i_cache_prefetch_cnt, tail_i_cache_prefetch_cnt); + ctx->icachePrefetchCnt = static_cast(i_cache_prefetch_cnt & k5BitsMask); + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitNotifyCtx(const domi::FftsPlusNotifyCtxDef &ctx_def, rtFftsPlusNotifyCtx_t *&ctx) { + ctx->successorNum = static_cast(ctx_def.successor_num()); + ctx->aten = static_cast(ctx_def.aten() & k1BitMask); + ctx->predCntInit = static_cast(ctx_def.pred_cnt_init()); + ctx->predCnt = static_cast(ctx_def.pred_cnt()); + + if (ctx_def.successor_list_size() > RT_CTX_SUCCESSOR_NUM) { + REPORT_INNER_ERROR("E19999", "Size of successor_list in FftsPlusNotifyCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + GELOGE(FAILED, "[Check][Param] Size of successor_list in FftsPlusNotifyCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.successor_list_size(); i++) { + ctx->successorList[i] = static_cast(ctx_def.successor_list(i)); + } + + ctx->atm = static_cast(ctx_def.atm() & k1BitMask); + ctx->threadId = static_cast(ctx_def.thread_id()); + ctx->threadDim = static_cast(ctx_def.thread_dim()); + ctx->notifyIdBase = static_cast(ctx_def.notify_id_base()); + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitWriteValueCtx(const domi::FftsPlusWriteValueCtxDef &ctx_def, + rtFftsPlusWriteValueCtx_t *&ctx) { + ctx->successorNum = static_cast(ctx_def.successor_num()); + ctx->aten = static_cast(ctx_def.aten() & k1BitMask); + ctx->predCntInit = static_cast(ctx_def.pred_cnt_init()); + ctx->predCnt = static_cast(ctx_def.pred_cnt()); + + if (ctx_def.successor_list_size() > RT_CTX_SUCCESSOR_NUM) { + REPORT_INNER_ERROR("E19999", "Size of successor_list in FftsPlusWriteValueCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + GELOGE(FAILED, "[Check][Param] Size of successor_list in FftsPlusWriteValueCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.successor_list_size(); i++) { + ctx->successorList[i] = static_cast(ctx_def.successor_list(i)); + } + + ctx->atm = static_cast(ctx_def.atm() & k1BitMask); + ctx->threadId = static_cast(ctx_def.thread_id()); + ctx->threadDim = static_cast(ctx_def.thread_dim()); + + ctx->awSize = static_cast(ctx_def.aw_size() & k3BitsMask); + ctx->snoop = static_cast(ctx_def.snoop() & k1BitMask); + ctx->awCache = static_cast(ctx_def.aw_cache() & k4BitsMask); + ctx->awProt = static_cast(ctx_def.aw_prot() & k3BitsMask); + ctx->va = static_cast(ctx_def.va() & k1BitMask); + + const auto &rts_param = davinci_model_->GetRuntimeParam(); + uint8_t *write_addr_base = nullptr; + if (ModelUtils::GetRtAddress(rts_param, ctx_def.write_addr_base(), write_addr_base) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress] failed, logic write addr base is 0x%lx.", ctx_def.write_addr_base()); + return INTERNAL_ERROR; + } + ctx->writeAddressBaseL = static_cast(reinterpret_cast(write_addr_base) & k32BitsMask); + ctx->writeAddressBaseH = static_cast((reinterpret_cast(write_addr_base) >> 32) & k17BitsMask); + ctx->writeAddressOffset = ctx_def.write_addr_offset(); + + if (ctx_def.write_value_size() > kWriteValueNum) { + REPORT_INNER_ERROR("E19999", "Size of write_value in FftsPlusWriteValueCtxDef should not > %d, but %d exactly", + kWriteValueNum, ctx_def.write_value_size()); + GELOGE(FAILED, "[Check][Param] Size of write_value in FftsPlusWriteValueCtxDef should not > %d, but %d exactly", + kWriteValueNum, ctx_def.write_value_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.write_value_size(); i++) { + ctx->writeValue[i] = static_cast(ctx_def.write_value(i)); + } + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitMixAicAivCtx(const domi::FftsPlusMixAicAivCtxDef &ctx_def, + rtFftsPlusMixAicAivCtx_t *&ctx) { + ctx->successorNum = static_cast(ctx_def.successor_num()); + ctx->aten = static_cast(ctx_def.aten() & k1BitMask); + ctx->predCntInit = static_cast(ctx_def.pred_cnt_init()); + ctx->predCnt = static_cast(ctx_def.pred_cnt()); + + if (ctx_def.successor_list_size() > RT_CTX_SUCCESSOR_NUM) { + REPORT_INNER_ERROR("E19999", "Size of successor_list in FftsPlusMixAicAivCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + GELOGE(FAILED, "[Check][Param] Size of successor_list in FftsPlusMixAicAivCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.successor_list_size(); i++) { + ctx->successorList[i] = static_cast(ctx_def.successor_list(i)); + } + + ctx->stat = static_cast(ctx_def.stat() & k1BitMask); + ctx->schem = static_cast(ctx_def.schem() & k2BitsMask); + ctx->atm = static_cast(ctx_def.atm() & k1BitMask); + ctx->prefetchEnableBitmap = static_cast(ctx_def.prefetch_enable_bitmap() & k4BitsMask); + ctx->prefetchOnceBitmap = static_cast(ctx_def.prefetch_once_bitmap() & k4BitsMask); + ctx->prefetchConfig = static_cast(ctx_def.prefetch_config()); + + ctx->threadId = static_cast(ctx_def.thread_id()); + ctx->threadDim = static_cast(ctx_def.thread_dim()); + + ctx->nonTailBlockRatioN = static_cast(ctx_def.non_tail_block_ratio_n()); + ctx->tailBlockRatioN = static_cast(ctx_def.tail_block_ratio_n()); + + ctx->nonTailBlockdim = static_cast(ctx_def.non_tail_block_dim()); + ctx->tailBlockdim = static_cast(ctx_def.tail_block_dim()); + + uint64_t task_param_ptr_base = reinterpret_cast(args_) + sizeof(void *) * io_addrs_.size(); + GELOGD("FftsPlusMixAicAivCtxDef: task param addr is %lu.", task_param_ptr_base); + ctx->aicTaskParamPtrL = static_cast(task_param_ptr_base & k32BitsMask); + ctx->aicTaskParamPtrH = static_cast((task_param_ptr_base >> 32) & k16BitsMask); + ctx->aivTaskParamPtrL = static_cast(task_param_ptr_base & k32BitsMask); + ctx->aivTaskParamPtrH = static_cast((task_param_ptr_base >> 32) & k16BitsMask); + ctx->aicTaskParamPtrOffset = static_cast(ctx_def.aic_task_param_ptr_offset()); + ctx->aivTaskParamPtrOffset = static_cast(ctx_def.aiv_task_param_ptr_offset()); + + if (ctx->atm == 0) { + GE_CHK_STATUS_RET(InitManualMixAicAivCtx(ctx_def, ctx), "Init MixAicAivCtx in manual mode failed"); + } else { + GE_CHK_STATUS_RET(InitAutoMixAicAivCtx(ctx_def, ctx), "Init MixAicAivCtx in auto mode failed"); + } + + if (ctx_def.src_slot_size() > kSrcSlotNum) { + REPORT_INNER_ERROR("E19999", "Size of src_slot in FftsPlusMixAicAivCtxDef should not > %d, but %d exactly", + kSrcSlotNum, ctx_def.src_slot_size()); + GELOGE(FAILED, "[Check][Param] Size of src_slot in FftsPlusMixAicAivCtxDef should not > %d, but %d exactly", + kSrcSlotNum, ctx_def.src_slot_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.src_slot_size(); i++) { + ctx->srcSlot[i] = static_cast(ctx_def.src_slot(i)); + } + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitManualMixAicAivCtx(const domi::FftsPlusMixAicAivCtxDef &ctx_def, + rtFftsPlusMixAicAivCtx_t *&ctx) { + const auto &rts_param = davinci_model_->GetRuntimeParam(); + for (auto i = 0; i < ctx_def.task_addr_size(); ++i) { + uintptr_t logic_addr = ctx_def.task_addr(i); + uint8_t *io_addr = nullptr; + if (ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress] failed, logic addr is 0x%lx.", logic_addr); + return INTERNAL_ERROR; + } + GELOGD("index %d, task addr is %ld, logic addr is 0x%lx, io addr is 0x%lx", + i, ctx_def.task_addr(i), logic_addr, reinterpret_cast(io_addr)); + io_addrs_.emplace_back(io_addr); + } + // PcL for low 32 bits of pc, PcH for high 16 bits of pc + if (ctx_def.kernel_name_size() != kManualMixAicAivCtxPcNum) { + REPORT_INNER_ERROR("E19999", "Size of kernel_name in FftsPlusMixAicAivCtxDef should be %d, but %d exactly", + kManualMixAicAivCtxPcNum, ctx_def.kernel_name_size()); + GELOGE(FAILED, "[Check][Param] Size of kernel_name in FftsPlusMixAicAivCtxDef should be %d, but %d exactly", + kManualMixAicAivCtxPcNum, ctx_def.kernel_name_size()); + return FAILED; + } + uint32_t aic_i_cache_prefetch_cnt; + void *aic_task_start_pc = nullptr; + GE_CHK_STATUS_RET(davinci_model_->GetAddrAndPrefCnt(ctx_def.kernel_name(kManualAicCtxIndex), + aic_task_start_pc, aic_i_cache_prefetch_cnt), + "Get addr and pref cnt failed, kernel_name=%s", ctx_def.kernel_name(kManualAicCtxIndex).c_str()); + ctx->nonTailAicTaskStartPcL = static_cast(reinterpret_cast(aic_task_start_pc) & k32BitsMask); + ctx->nonTailAicTaskStartPcH = static_cast((reinterpret_cast(aic_task_start_pc) >> 32) & + k16BitsMask); + ctx->tailAicTaskStartPcL = static_cast(reinterpret_cast(aic_task_start_pc) & k32BitsMask); + ctx->tailAicTaskStartPcH = static_cast((reinterpret_cast(aic_task_start_pc) >> 32) & + k16BitsMask); + ctx->aicIcachePrefetchCnt = static_cast(aic_i_cache_prefetch_cnt & k5BitsMask); + + uint32_t aiv_i_cache_prefetch_cnt; + void *aiv_task_start_pc = nullptr; + GE_CHK_STATUS_RET(davinci_model_->GetAddrAndPrefCnt(ctx_def.kernel_name(kManualAivCtxIndex), + aiv_task_start_pc, aiv_i_cache_prefetch_cnt), + "Get addr and pref cnt failed, kernel_name=%s", ctx_def.kernel_name(kManualAivCtxIndex).c_str()); + ctx->nonTailAivTaskStartPcL = static_cast(reinterpret_cast(aiv_task_start_pc) & k32BitsMask); + ctx->nonTailAivTaskStartPcH = static_cast((reinterpret_cast(aiv_task_start_pc) >> 32) & + k16BitsMask); + ctx->tailAivTaskStartPcL = static_cast(reinterpret_cast(aiv_task_start_pc) & k32BitsMask); + ctx->tailAivTaskStartPcH = static_cast((reinterpret_cast(aiv_task_start_pc) >> 32) & + k16BitsMask); + ctx->aivIcachePrefetchCnt = static_cast(aiv_i_cache_prefetch_cnt & k5BitsMask); + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitAutoMixAicAivCtx(const domi::FftsPlusMixAicAivCtxDef &ctx_def, + rtFftsPlusMixAicAivCtx_t *&ctx) { + const auto &rts_param = davinci_model_->GetRuntimeParam(); + for (auto i = 0; i < ctx->threadDim - 1; i++) { + GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, ctx_def, i, ctx_def.task_addr_offset_size())); + } + GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, ctx_def, ctx->threadDim - 1, ctx_def.input_output_count())); + int last_thread_workspace_size = ctx_def.task_addr_size() - ctx_def.task_addr_offset_size(); + for (auto k = 0; k < last_thread_workspace_size; ++k) { + uintptr_t logic_addr = ctx_def.task_addr(ctx_def.task_addr_offset_size() + k); + uint8_t *io_addr = nullptr; + if (ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress] failed, logic addr is 0x%lx.", logic_addr); + return INTERNAL_ERROR; + } + io_addrs_.emplace_back(io_addr); + } + // PcL for low 32 bits of pc, PcH for high 16 bits of pc + if (ctx_def.kernel_name_size() != kAutoMixAicAivCtxPcNum) { + REPORT_INNER_ERROR("E19999", "Size of kernel_name in FftsPlusMixAicAivCtxDef should be %d, but %d exactly", + kAutoMixAicAivCtxPcNum, ctx_def.kernel_name_size()); + GELOGE(FAILED, "[Check][Param] Size of kernel_name in FftsPlusMixAicAivCtxDef should be %d, but %d exactly", + kAutoMixAicAivCtxPcNum, ctx_def.kernel_name_size()); + return FAILED; + } + uint32_t non_tail_aic_i_cache_prefetch_cnt; + void *non_tail_aic_task_start_pc = nullptr; + GE_CHK_STATUS_RET(davinci_model_->GetAddrAndPrefCnt(ctx_def.kernel_name(kAutoNonTailAicCtxIndex), + non_tail_aic_task_start_pc, non_tail_aic_i_cache_prefetch_cnt), + "Get addr and pref cnt failed, kernel_name=%s", ctx_def.kernel_name(kAutoNonTailAicCtxIndex).c_str()); + ctx->nonTailAicTaskStartPcL = + static_cast(reinterpret_cast(non_tail_aic_task_start_pc) & k32BitsMask); + ctx->nonTailAicTaskStartPcH = + static_cast((reinterpret_cast(non_tail_aic_task_start_pc) >> 32) & k16BitsMask); + uint32_t tail_aic_i_cache_prefetch_cnt; + void *tail_aic_task_start_pc = nullptr; + GE_CHK_STATUS_RET(davinci_model_->GetAddrAndPrefCnt(ctx_def.kernel_name(kAutoTailAicCtxIndex), + tail_aic_task_start_pc, tail_aic_i_cache_prefetch_cnt), + "Get addr and pref cnt failed, kernel_name=%s", ctx_def.kernel_name(kAutoTailAicCtxIndex).c_str()); + ctx->tailAicTaskStartPcL = static_cast(reinterpret_cast(tail_aic_task_start_pc) & k32BitsMask); + ctx->tailAicTaskStartPcH = + static_cast((reinterpret_cast(tail_aic_task_start_pc) >> 32) & k16BitsMask); + uint32_t aic_i_cache_prefetch_cnt = std::min(non_tail_aic_i_cache_prefetch_cnt, tail_aic_i_cache_prefetch_cnt); + ctx->aicIcachePrefetchCnt = static_cast(aic_i_cache_prefetch_cnt & k5BitsMask); + + uint32_t non_tail_aiv_i_cache_prefetch_cnt; + void *non_tail_aiv_task_start_pc = nullptr; + GE_CHK_STATUS_RET(davinci_model_->GetAddrAndPrefCnt(ctx_def.kernel_name(kAutoNonTailAivCtxIndex), + non_tail_aiv_task_start_pc, non_tail_aiv_i_cache_prefetch_cnt), + "Get addr and pref cnt failed, kernel_name=%s", ctx_def.kernel_name(kAutoNonTailAivCtxIndex).c_str()); + ctx->nonTailAivTaskStartPcL = + static_cast(reinterpret_cast(non_tail_aiv_task_start_pc) & k32BitsMask); + ctx->nonTailAivTaskStartPcH = + static_cast((reinterpret_cast(non_tail_aiv_task_start_pc) >> 32) & k16BitsMask); + uint32_t tail_aiv_i_cache_prefetch_cnt; + void *tail_aiv_task_start_pc = nullptr; + GE_CHK_STATUS_RET(davinci_model_->GetAddrAndPrefCnt(ctx_def.kernel_name(kAutoTailAivCtxIndex), + tail_aiv_task_start_pc, tail_aiv_i_cache_prefetch_cnt), + "Get addr and pref cnt failed, kernel_name=%s", ctx_def.kernel_name(kAutoTailAivCtxIndex).c_str()); + ctx->tailAivTaskStartPcL = static_cast(reinterpret_cast(tail_aiv_task_start_pc) & k32BitsMask); + ctx->tailAivTaskStartPcH = + static_cast((reinterpret_cast(tail_aiv_task_start_pc) >> 32) & k16BitsMask); + uint32_t aiv_i_cache_prefetch_cnt = std::min(non_tail_aiv_i_cache_prefetch_cnt, tail_aiv_i_cache_prefetch_cnt); + ctx->aivIcachePrefetchCnt = static_cast(aiv_i_cache_prefetch_cnt & k5BitsMask); + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitSdmaCtx(const domi::FftsPlusSdmaCtxDef &ctx_def, rtFftsPlusSdmaCtx_t *&ctx) { + ctx->successorNum = static_cast(ctx_def.successor_num()); + ctx->aten = static_cast(ctx_def.aten() & k1BitMask); + ctx->predCntInit = static_cast(ctx_def.pred_cnt_init()); + ctx->predCnt = static_cast(ctx_def.pred_cnt()); + + if (ctx_def.successor_list_size() > RT_CTX_SUCCESSOR_NUM) { + REPORT_INNER_ERROR("E19999", "Size of successor_list in FftsPlusSdmaCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + GELOGE(FAILED, "[Check][Param] Size of successor_list in FftsPlusSdmaCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.successor_list_size(); i++) { + ctx->successorList[i] = static_cast(ctx_def.successor_list(i)); + } + + ctx->sat = static_cast(ctx_def.sat() & k1BitMask); + ctx->atm = static_cast(ctx_def.atm() & k1BitMask); + + ctx->threadId = static_cast(ctx_def.thread_id()); + ctx->threadDim = static_cast(ctx_def.thread_dim()); + + ctx->sdmaSqeHeader = ctx_def.sdma_sqe_header(); + + ctx->sourceStreamId = static_cast(ctx_def.src_stream_id()); + ctx->sourceSubstreamId = static_cast(ctx_def.src_sub_stream_id()); + + ctx->destinationStreamId = static_cast(ctx_def.dst_stream_id()); + ctx->destinationSubstreamId = static_cast(ctx_def.dst_sub_stream_id()); + + const auto &rts_param = davinci_model_->GetRuntimeParam(); + uint8_t *src_addr_base = nullptr; + if (ModelUtils::GetRtAddress(rts_param, ctx_def.src_addr_base(), src_addr_base) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress] failed, logic src addr is 0x%lx.", ctx_def.src_addr_base()); + return INTERNAL_ERROR; + } + ctx->sourceAddressBaseL = static_cast(reinterpret_cast(src_addr_base) & k32BitsMask); + ctx->sourceAddressBaseH = static_cast(reinterpret_cast(src_addr_base) >> 32); + ctx->sourceAddressOffset = ctx_def.src_addr_offset(); + + uint8_t *dst_addr_base = nullptr; + if (ModelUtils::GetRtAddress(rts_param, ctx_def.dst_addr_base(), dst_addr_base) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress] failed, logic dst addr is 0x%lx.", ctx_def.dst_addr_base()); + return INTERNAL_ERROR; + } + ctx->destinationAddressBaseL = static_cast(reinterpret_cast(dst_addr_base) & k32BitsMask); + ctx->destinationAddressBaseH = static_cast(reinterpret_cast(dst_addr_base) >> 32); + ctx->destinationAddressOffset = ctx_def.dst_addr_offset(); + + ctx->nonTailDataLength = ctx_def.non_tail_data_len(); + ctx->tailDataLength = ctx_def.tail_data_len(); + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitDataCtx(const domi::FftsPlusDataCtxDef &ctx_def, rtFftsPlusDataCtx_t *&ctx) { + ctx->successorNum = static_cast(ctx_def.successor_num()); + ctx->aten = static_cast(ctx_def.aten() & k1BitMask); + ctx->cntInit = static_cast(ctx_def.cnt_init()); + ctx->cnt = static_cast(ctx_def.cnt()); + + if (ctx_def.successor_list_size() > RT_CTX_SUCCESSOR_NUM) { + REPORT_INNER_ERROR("E19999", "Size of successor_list in FftsPlusDataCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + GELOGE(FAILED, "[Check][Param] Size of successor_list in FftsPlusDataCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.successor_list_size(); i++) { + ctx->successorList[i] = static_cast(ctx_def.successor_list(i)); + } + + ctx->atm = static_cast(ctx_def.atm() & k1BitMask); + + ctx->origConsumerCounter = static_cast(ctx_def.orig_consumer_counter()); + ctx->runConsumerCounter = static_cast(ctx_def.run_consumer_counter()); + ctx->threadId = static_cast(ctx_def.thread_id()); + ctx->threadDim = static_cast(ctx_def.thread_dim()); + + const auto &rts_param = davinci_model_->GetRuntimeParam(); + uint8_t *addr_base = nullptr; + if (ModelUtils::GetRtAddress(rts_param, ctx_def.addr_base(), addr_base) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress] failed, logic addr base is 0x%lx.", ctx_def.addr_base()); + return INTERNAL_ERROR; + } + ctx->addressBaseL = static_cast(reinterpret_cast(addr_base) & k32BitsMask); + ctx->addressBaseH = static_cast(reinterpret_cast(addr_base) >> 32); + ctx->addressOffset = ctx_def.addr_offset(); + + ctx->nonTailNumOutter = static_cast(ctx_def.non_tail_num_outter()); + ctx->nonTailNumInner = static_cast(ctx_def.non_tail_num_inner()); + ctx->nonTailLengthInner = ctx_def.non_tail_len_inner(); + ctx->nonTailStrideOutter = ctx_def.non_tail_stride_outter(); + ctx->nonTailStrideInner = ctx_def.non_tail_stride_inner(); + + ctx->tailNumOutter = static_cast(ctx_def.tail_num_outter()); + ctx->tailNumInner = static_cast(ctx_def.tail_num_inner()); + ctx->tailLengthInner = ctx_def.tail_len_inner(); + ctx->tailStrideOutter = ctx_def.tail_stride_outter(); + ctx->tailStrideInner = ctx_def.tail_stride_inner(); + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitAicpuCtx(const domi::FftsPlusAicpuCtxDef &ctx_def, rtFftsPlusAiCpuCtx_t *&ctx) { + ctx->successorNum = static_cast(ctx_def.successor_num()); + ctx->aten = static_cast(ctx_def.aten() & k1BitMask); + ctx->predCntInit = static_cast(ctx_def.pred_cnt_init()); + ctx->predCnt = static_cast(ctx_def.pred_cnt()); + + if (ctx_def.successor_context_id_size() > RT_CTX_SUCCESSOR_NUM) { + REPORT_INNER_ERROR("E19999", "Size of successor_context_id in FftsPlusAicpuCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_context_id_size()); + GELOGE(FAILED, "[Check][Param] Size of successor_context_id in FftsPlusAicpuCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_context_id_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.successor_context_id_size(); i++) { + ctx->successorContextID[i] = static_cast(ctx_def.successor_context_id(i)); + } + + ctx->atm = static_cast(ctx_def.atm() & k1BitMask); + + ctx->sqeIndex = static_cast(ctx_def.sqe_index()); + ctx->kernelType = static_cast(ctx_def.kernel_type() & k7BitsMask); + ctx->bm = static_cast(ctx_def.bm() & k1BitMask); + ctx->topicType = static_cast(ctx_def.topic_type() & k4BitsMask); + ctx->qos = static_cast(ctx_def.qos() & k3BitsMask); + + ctx->threadId = static_cast(ctx_def.thread_id()); + ctx->threadDim = static_cast(ctx_def.thread_dim()); + + ctx->nonTailBlockdim = static_cast(ctx_def.non_tail_block_dim()); + ctx->tailBlockdim = static_cast(ctx_def.tail_block_dim()); + + if (ctx_def.user_data_size() > kUserDataNum) { + REPORT_INNER_ERROR("E19999", "Size of user_data in FftsPlusAicpuCtxDef should not > %d, but %d exactly", + kUserDataNum, ctx_def.user_data_size()); + GELOGE(FAILED, "[Check][Param] Size of user_data in FftsPlusAicpuCtxDef should not > %d, but %d exactly", + kUserDataNum, ctx_def.user_data_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.user_data_size(); i++) { + ctx->usrData[i] = static_cast(ctx_def.user_data(i)); + } + + ctx->subtopicId = static_cast(ctx_def.sub_topic_id() & k12BitsMask); + ctx->topicId = static_cast(ctx_def.topic_id() & k6BitsMask); + ctx->groupId = static_cast(ctx_def.group_id() & k6BitsMask); + ctx->usrDataLength = static_cast(ctx_def.user_data_len() & k8BitsMask); + + ctx->taskParamOffset = ctx_def.qos(); + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitCondSwitchCtx(const domi::FftsPlusCondSwitchCtxDef &ctx_def, + rtFftsPlusCondSwitchCtx_t *&ctx) { + ctx->trueSuccessorNum = static_cast(ctx_def.true_successor_num()); + ctx->falseSuccessorNum = static_cast(ctx_def.false_successor_num() & k7BitsMask); + ctx->aten = static_cast(ctx_def.aten() & k1BitMask); + + if (ctx_def.condition() == RT_COND_TYPE_MAX) { + REPORT_INNER_ERROR("E19999", "Unsupported cond type %u", ctx_def.condition()); + GELOGE(FAILED, "[Check][CtxType] Unsupported cond type %u", ctx_def.condition()); + return FAILED; + } + ctx->condition = static_cast(ctx_def.condition()); + ctx->predCntInit = static_cast(ctx_def.pred_cnt_init()); + ctx->predCnt = static_cast(ctx_def.pred_cnt()); + + if (ctx_def.true_successor_list_size() > RT_CTX_TRUE_SUCCESSOR_NUM) { + REPORT_INNER_ERROR("E19999", + "Size of true_successor_list in FftsPlusCondSwitchCtxDef should not > %d, but %d exactly", + RT_CTX_TRUE_SUCCESSOR_NUM, ctx_def.true_successor_list_size()); + GELOGE(FAILED, + "[Check][Param] Size of true_successor_list in FftsPlusCondSwitchCtxDef should not > %d, but %d exactly", + RT_CTX_TRUE_SUCCESSOR_NUM, ctx_def.true_successor_list_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.true_successor_list_size(); i++) { + ctx->trueSuccessorList[i] = static_cast(ctx_def.true_successor_list(i)); + } + + if (ctx_def.false_successor_list_size() > RT_CTX_FALSE_SUCCESSOR_NUM) { + REPORT_INNER_ERROR("E19999", + "Size of false_successor_list in FftsPlusCondSwitchCtxDef should not > %d, but %d exactly", + RT_CTX_FALSE_SUCCESSOR_NUM, ctx_def.false_successor_list_size()); + GELOGE(FAILED, + "[Check][Param] Size of false_successor_list in FftsPlusCondSwitchCtxDef should not > %d, but %d exactly", + RT_CTX_FALSE_SUCCESSOR_NUM, ctx_def.false_successor_list_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.false_successor_list_size(); i++) { + ctx->falseSuccessorList[i] = static_cast(ctx_def.false_successor_list(i)); + } + + ctx->atm = static_cast(ctx_def.atm() & k1BitMask); + + ctx->threadId = static_cast(ctx_def.thread_id()); + ctx->threadDim = static_cast(ctx_def.thread_dim()); + + ctx->arSize = static_cast(ctx_def.ar_size() & k3BitsMask); + ctx->snoop = static_cast(ctx_def.snoop() & k1BitMask); + ctx->arCache = static_cast(ctx_def.ar_cache() & k4BitsMask); + ctx->arProt = static_cast(ctx_def.ar_prot() & k3BitsMask); + ctx->va = static_cast(ctx_def.va() & k1BitMask); + + const auto &rts_param = davinci_model_->GetRuntimeParam(); + uint8_t *addr_base_0 = nullptr; + if (ModelUtils::GetRtAddress(rts_param, ctx_def.load_addr0_base(), addr_base_0) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress] failed, logic load addr0 base is 0x%lx.", ctx_def.load_addr0_base()); + return INTERNAL_ERROR; + } + ctx->loadAddress0BaseL = static_cast(reinterpret_cast(addr_base_0) & k32BitsMask); + ctx->loadAddress0BaseH = static_cast((reinterpret_cast(addr_base_0) >> 32) & k17BitsMask); + ctx->ld0En = static_cast(ctx_def.ld0_en() & k1BitMask); + ctx->loadAddress0Offset = ctx_def.load_addr0_offset(); + + uint8_t *addr_base_1 = nullptr; + if (ModelUtils::GetRtAddress(rts_param, ctx_def.load_addr1_base(), addr_base_1) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress] failed, logic load addr1 base is 0x%lx.", ctx_def.load_addr1_base()); + return INTERNAL_ERROR; + } + ctx->loadAddress1BaseL = static_cast(reinterpret_cast(addr_base_1) & k32BitsMask); + ctx->loadAddress1BaseH = static_cast((reinterpret_cast(addr_base_1) >> 32) & k17BitsMask); + ctx->ld1En = static_cast(ctx_def.ld1_en() & k1BitMask); + ctx->loadAddress1Offset = ctx_def.load_addr1_offset(); + + ctx->cmpValue1 = ctx_def.cmp_value_1(); + ctx->cmpValue2 = ctx_def.cmp_value_2(); + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitCaseSwitchCtx(const domi::FftsPlusCaseSwitchCtxDef &ctx_def, + rtFftsPlusCaseSwitchCtx_t *&ctx) { + ctx->successorNum = static_cast(ctx_def.successor_num()); + ctx->aten = static_cast(ctx_def.aten() & k1BitMask); + + ctx->startLabelId = static_cast(ctx_def.successor_num()); + ctx->labelListLen = static_cast(ctx_def.label_list_len()); + ctx->predCntInit = static_cast(ctx_def.pred_cnt_init()); + ctx->predCnt = static_cast(ctx_def.pred_cnt()); + + if (ctx_def.successor_list_size() > RT_CTX_SUCCESSOR_NUM) { + REPORT_INNER_ERROR("E19999", "Size of successor_list in FftsPlusCaseDefaultCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + GELOGE(FAILED, "[Check][Param] Size of successor_list in FftsPlusCaseDefaultCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.successor_list_size(); i++) { + ctx->successorList[i] = static_cast(ctx_def.successor_list(i)); + } + + ctx->atm = static_cast(ctx_def.atm() & k1BitMask); + + ctx->threadId = static_cast(ctx_def.thread_id()); + ctx->threadDim = static_cast(ctx_def.thread_dim()); + + ctx->arSize = static_cast(ctx_def.ar_size() & k3BitsMask); + ctx->snoop = static_cast(ctx_def.snoop() & k1BitMask); + ctx->arCache = static_cast(ctx_def.ar_cache() & k4BitsMask); + ctx->arProt = static_cast(ctx_def.ar_prot() & k3BitsMask); + ctx->va = static_cast(ctx_def.va() & k1BitMask); + + const auto &rts_param = davinci_model_->GetRuntimeParam(); + uint8_t *addr_base_0 = nullptr; + if (ModelUtils::GetRtAddress(rts_param, ctx_def.load_addr0_base(), addr_base_0) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress] failed, logic load addr0 base is 0x%lx.", ctx_def.load_addr0_base()); + return INTERNAL_ERROR; + } + ctx->loadAddress0BaseL = static_cast(reinterpret_cast(addr_base_0) & k32BitsMask); + ctx->loadAddress0BaseH = static_cast((reinterpret_cast(addr_base_0) >> 32) & k17BitsMask); + ctx->ld0En = static_cast(ctx_def.ld0_en() & k1BitMask); + ctx->loadAddress0Offset = ctx_def.load_addr0_offset(); + + uint8_t *addr_base_1 = nullptr; + if (ModelUtils::GetRtAddress(rts_param, ctx_def.load_addr1_base(), addr_base_1) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress] failed, logic load addr1 base is 0x%lx.", ctx_def.load_addr1_base()); + return INTERNAL_ERROR; + } + ctx->loadAddress1BaseL = static_cast(reinterpret_cast(addr_base_1) & k32BitsMask); + ctx->loadAddress1BaseH = static_cast((reinterpret_cast(addr_base_1) >> 32) & k17BitsMask); + ctx->ld1En = static_cast(ctx_def.ld1_en() & k1BitMask); + ctx->loadAddress1Offset = ctx_def.load_addr1_offset(); + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitCaseDefaultCtx(const domi::FftsPlusCaseDefaultCtxDef &ctx_def, + rtFftsPlusCaseDefCtx_t *&ctx) { + ctx->successorNum = static_cast(ctx_def.successor_num()); + ctx->aten = static_cast(ctx_def.aten() & k1BitMask); + + ctx->startLabelId = static_cast(ctx_def.successor_num()); + ctx->labelListLen = static_cast(ctx_def.label_list_len()); + ctx->predCntInit = static_cast(ctx_def.pred_cnt_init()); + ctx->predCnt = static_cast(ctx_def.pred_cnt()); + + if (ctx_def.successor_list_size() > RT_CTX_SUCCESSOR_NUM) { + REPORT_INNER_ERROR("E19999", "Size of successor_list in FftsPlusCaseDefaultCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + GELOGE(FAILED, "[Check][Param] Size of successor_list in FftsPlusCaseDefaultCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.successor_list_size(); i++) { + ctx->successorList[i] = static_cast(ctx_def.successor_list(i)); + } + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitAtStartCtx(const domi::FftsPlusAtStartCtxDef &ctx_def, rtFftsPlusAtStartCtx_t *&ctx) { + ctx->successorNum = static_cast(ctx_def.successor_num()); + ctx->aten = static_cast(ctx_def.aten() & k1BitMask); + ctx->predCntInit = static_cast(ctx_def.pred_cnt_init()); + ctx->predCnt = static_cast(ctx_def.pred_cnt()); + + if (ctx_def.successor_list_size() > RT_CTX_SUCCESSOR_NUM) { + REPORT_INNER_ERROR("E19999", "Size of successor_list in FftsPlusAtStartCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + GELOGE(FAILED, "[Check][Param] Size of successor_list in FftsPlusAtStartCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.successor_list_size(); i++) { + ctx->successorList[i] = static_cast(ctx_def.successor_list(i)); + } + + ctx->threadId = static_cast(ctx_def.thread_id()); + ctx->threadDim = static_cast(ctx_def.thread_dim()); + + ctx->threadIdInit = static_cast(ctx_def.thread_id_init()); + ctx->threadWindowSize = static_cast(ctx_def.thread_window_size()); + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitAtEndCtx(const domi::FftsPlusAtEndCtxDef &ctx_def, rtFftsPlusAtEndCtx_t *&ctx) { + ctx->atStartSlotNumber = static_cast(ctx_def.at_start_slot_num()); + ctx->outLabelSlotNumber = static_cast(ctx_def.out_label_slot_num() & k7BitsMask); + + ctx->aten = static_cast(ctx_def.aten() & k1BitMask); + ctx->predCntInit = static_cast(ctx_def.pred_cnt_init()); + ctx->predCnt = static_cast(ctx_def.pred_cnt()); + + if (ctx_def.succ_at_start_slot_size() > RT_CTX_SUCC_AT_START_SLOT_NUM) { + REPORT_INNER_ERROR("E19999", "Size of succ_at_start_slot in FftsPlusAtEndCtxDef should not > %d, but %d exactly", + RT_CTX_SUCC_AT_START_SLOT_NUM, ctx_def.succ_at_start_slot_size()); + GELOGE(FAILED, "[Check][Param] Size of succ_at_start_slot in FftsPlusAtStartCtxDef should not > %d, but %d exactly", + RT_CTX_SUCC_AT_START_SLOT_NUM, ctx_def.succ_at_start_slot_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.succ_at_start_slot_size(); i++) { + ctx->succAtStartSlot[i] = static_cast(ctx_def.succ_at_start_slot(i)); + } + + if (ctx_def.succ_out_label_slot_size() > RT_CTX_SUCC_OUT_LABEL_SLOT_NUM) { + REPORT_INNER_ERROR("E19999", "Size of succ_out_label_slot in FftsPlusAtEndCtxDef should not > %d, but %d exactly", + RT_CTX_SUCC_OUT_LABEL_SLOT_NUM, ctx_def.succ_out_label_slot_size()); + GELOGE(FAILED, "[Check][Param] Size of succ_out_label_slot in FftsPlusAtStartCtxDef should not > %d, but %d exactly", + RT_CTX_SUCC_OUT_LABEL_SLOT_NUM, ctx_def.succ_out_label_slot_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.succ_out_label_slot_size(); i++) { + ctx->succOutLabelSlot[i] = static_cast(ctx_def.succ_out_label_slot(i)); + } + + ctx->threadId = static_cast(ctx_def.thread_id()); + + return SUCCESS; +} + +Status FftsPlusTaskInfo::InitLabelCtx(const domi::FftsPlusLabelCtxDef &ctx_def, rtFftsPlusLabelCtx_t *&ctx) { + ctx->successorNum = static_cast(ctx_def.successor_num()); + ctx->predCntInit = static_cast(ctx_def.pred_cnt_init()); + ctx->predCnt = static_cast(ctx_def.pred_cnt()); + + if (ctx_def.successor_list_size() > RT_CTX_SUCCESSOR_NUM) { + REPORT_INNER_ERROR("E19999", "Size of successor_list in FftsPlusLabelCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + GELOGE(FAILED, "[Check][Param] Size of successor_list in FftsPlusLabelCtxDef should not > %d, but %d exactly", + RT_CTX_SUCCESSOR_NUM, ctx_def.successor_list_size()); + return FAILED; + } + for (auto i = 0; i < ctx_def.successor_list_size(); i++) { + ctx->successorList[i] = static_cast(ctx_def.successor_list(i)); + } + + return SUCCESS; +} + +Status FftsPlusTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + return SUCCESS; +} + +Status FftsPlusTaskInfo::UpdateArgs() { + GE_CHECK_NOTNULL(davinci_model_); + std::vector io_addrs = io_addrs_; + davinci_model_->UpdateKnownZeroCopyAddr(io_addrs); + auto addr_size = sizeof(void *) * io_addrs.size(); + GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs.data(), addr_size, RT_MEMCPY_HOST_TO_DEVICE)); + return SUCCESS; +} + +Status FftsPlusTaskInfo::Distribute() { + GELOGI("FftsPlusTaskInfo Distribute Start."); + rtError_t rt_ret = rtFftsPlusTaskLaunch(&ffts_plus_task_info_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "[Check][RT_ret] Call rtFftsPlusTaskLaunch failed, ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + + GELOGI("FftsPlusTaskInfo Distribute Success."); + return SUCCESS; +} + +Status FftsPlusTaskInfo::Release() { + if (args_ != nullptr) { + GE_CHK_RT_RET(rtFree(args_)); + args_ = nullptr; + } + return SUCCESS; +} + +template +Status FftsPlusTaskInfo::InitIoAddrs(const RuntimeParam &rts_param, const T &ctx_def, int thread_id, int addr_count) { + for (auto i = 0; i < addr_count; ++i) { + uintptr_t logic_addr = ctx_def.task_addr(i) + thread_id * ctx_def.task_addr_offset(i); + uint8_t *io_addr = nullptr; + if (ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress] failed, logic addr is 0x%lx.", logic_addr); + return INTERNAL_ERROR; + } + GELOGD("task base addr is %ld, offset is %ld, thread id is %d, logic addr is 0x%lx, io addr is 0x%lx", + ctx_def.task_addr(i), ctx_def.task_addr_offset(i), thread_id, logic_addr, + reinterpret_cast(io_addr)); + io_addrs_.emplace_back(io_addr); + } + return SUCCESS; +} + +REGISTER_TASK_INFO(RT_MODEL_TASK_FFTS_PLUS_TASK, FftsPlusTaskInfo); +} // namespace ge diff --git a/ge/graph/load/model_manager/task_info/ffts_plus_task_info.h b/ge/graph/load/model_manager/task_info/ffts_plus_task_info.h new file mode 100644 index 00000000..941e2562 --- /dev/null +++ b/ge/graph/load/model_manager/task_info/ffts_plus_task_info.h @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_PLUS_TASK_INFO_H_ +#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_PLUS_TASK_INFO_H_ + +#include "graph/load/model_manager/task_info/task_info.h" +#include "graph/op_desc.h" + +namespace ge { +class FftsPlusTaskInfo : public TaskInfo { + public: + FftsPlusTaskInfo() = default; + ~FftsPlusTaskInfo() override; + + Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + Status Distribute() override; + Status Release() override; + Status UpdateArgs() override; + Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + + private: + void InitFftsPlusSqe(const domi::FftsPlusSqeDef &sqe_def, rtFftsPlusSqe_t *&sqe); + void InitFftsPlusSqeHeader(const domi::StarsSqeHeaderDef &sqe_header_def, rtStarsSqeHeader_t &sqe_header); + Status InitFftsPlusCtx(const domi::FftsPlusTaskDef &task_def, int ctx_num, void *&ctx); + + Status InitAicAivCtx(const domi::FftsPlusAicAivCtxDef &ctx_def, rtFftsPlusAicAivCtx_t *&ctx); + Status InitManualAicAivCtx(const domi::FftsPlusAicAivCtxDef &ctx_def, rtFftsPlusAicAivCtx_t *&ctx); + Status InitAutoAicAivCtx(const domi::FftsPlusAicAivCtxDef &ctx_def, rtFftsPlusAicAivCtx_t *&ctx); + Status InitNotifyCtx(const domi::FftsPlusNotifyCtxDef &ctx_def, rtFftsPlusNotifyCtx_t *&ctx); + Status InitWriteValueCtx(const domi::FftsPlusWriteValueCtxDef &ctx_def, rtFftsPlusWriteValueCtx_t *&ctx); + Status InitMixAicAivCtx(const domi::FftsPlusMixAicAivCtxDef &ctx_def, rtFftsPlusMixAicAivCtx_t *&ctx); + Status InitManualMixAicAivCtx(const domi::FftsPlusMixAicAivCtxDef &ctx_def, rtFftsPlusMixAicAivCtx_t *&ctx); + Status InitAutoMixAicAivCtx(const domi::FftsPlusMixAicAivCtxDef &ctx_def, rtFftsPlusMixAicAivCtx_t *&ctx); + Status InitSdmaCtx(const domi::FftsPlusSdmaCtxDef &ctx_def, rtFftsPlusSdmaCtx_t *&ctx); + Status InitDataCtx(const domi::FftsPlusDataCtxDef &ctx_def, rtFftsPlusDataCtx_t *&ctx); + Status InitAicpuCtx(const domi::FftsPlusAicpuCtxDef &ctx_def, rtFftsPlusAiCpuCtx_t *&ctx); + + Status InitCondSwitchCtx(const domi::FftsPlusCondSwitchCtxDef &ctx_def, rtFftsPlusCondSwitchCtx_t *&ctx); + Status InitCaseSwitchCtx(const domi::FftsPlusCaseSwitchCtxDef &ctx_def, rtFftsPlusCaseSwitchCtx_t *&ctx); + Status InitCaseDefaultCtx(const domi::FftsPlusCaseDefaultCtxDef &ctx_def, rtFftsPlusCaseDefCtx_t *&ctx); + + Status InitAtStartCtx(const domi::FftsPlusAtStartCtxDef &ctx_def, rtFftsPlusAtStartCtx_t *&ctx); + Status InitAtEndCtx(const domi::FftsPlusAtEndCtxDef &ctx_def, rtFftsPlusAtEndCtx_t *&ctx); + Status InitLabelCtx(const domi::FftsPlusLabelCtxDef &ctx_def, rtFftsPlusLabelCtx_t *&ctx); + + template + Status InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, int thread_id, int addr_count); + + DavinciModel *davinci_model_{nullptr}; + rtFftsPlusTaskInfo_t ffts_plus_task_info_; + std::vector io_addrs_; + void *args_{nullptr}; // runtime args memory + int args_size_{0}; // runtime args memory length +}; +} // namespace ge +#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_PLUS_TASK_INFO_H_ diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index fa140bfe..79a50081 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -808,6 +808,14 @@ Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_gr GELOGE(ret, "[Call][OptimizeSubGraphWithMultiThreads] failed, ret:%d, session_id:%lu", ret, session_id); return ret; } + for (const auto &item : sub_graph_map) { + for (const auto &subgraph_info : item.second) { + const auto &subgraph = subgraph_info->GetSubGraph(); + for (const auto &new_graph : subgraph->GetAllSubgraphs()) { + compute_graph->AddSubGraph(new_graph); + } + } + } return SUCCESS; } @@ -881,8 +889,8 @@ Status GraphManager::PreRunAfterOptimizeSubGraph(const GraphNodePtr &graph_node, CompilerStages &stages = GetCompilerStages(graph_node->GetGraphId()); GM_RUN_AND_DUMP_PERF("OptimizeWholeGraph", stages.optimizer.OptimizeWholeGraph, compute_graph); GM_RUN_AND_DUMP_PERF("Optimize2", OptimizeStage2, compute_graph); - GM_RUN_AND_DUMP_PERF("OptimizeGraphBeforeBuildForRts", - GetCompilerStages(graph_node->GetGraphId()).optimizer.OptimizeGraphBeforeBuildForRts, + GM_RUN_AND_DUMP_PERF("OptimizeGraphBeforeBuild", + GetCompilerStages(graph_node->GetGraphId()).optimizer.OptimizeGraphBeforeBuild, compute_graph); Status ret = compute_graph->TopologicalSorting(); @@ -2837,20 +2845,53 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra GELOGE(ret, "[Call][Partition] for Graph:%s by dynamic shape Failed", compute_graph->GetName().c_str()); return ret; } - bool dynamic_shape_partitioned = false; - if (!AttrUtils::GetBool(*compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, dynamic_shape_partitioned)) { - REPORT_INNER_ERROR("E19999", "Get Attr:%s from graph:%s(id:%u) fail", - ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED.c_str(), compute_graph->GetName().c_str(), - compute_graph->GetGraphID()); - GELOGE(FAILED, "[Get][Attr] %s from graph:%u failed", - ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED.c_str(), compute_graph->GetGraphID()); + if (!compute_graph->HasAttr(ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED)) { + REPORT_INNER_ERROR("E19999", "Attr:%s not exist in graph:%s(id:%u)", ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED.c_str(), + compute_graph->GetName().c_str(), compute_graph->GetGraphID()); + GELOGE(FAILED, "[Get][Attr] Attr %s not exist in graph:%u", ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED.c_str(), + compute_graph->GetGraphID()); return FAILED; } GE_TIMESTAMP_EVENT_END(GraphPartitionDynamicShape, "OptimizeSubgraph::GraphPartitionDynamicShape"); GE_DUMP(compute_graph, "AfterDynamicShapePartition"); + + GE_TIMESTAMP_START(SubgraphPartitionAndOptimization_CompositeEngine); + ret = SubgraphPartitionAndOptimization(graph_node, compute_graph, session_id, + GraphPartitioner::kCompositeEnginePartitioning); + if (ret != SUCCESS) { + GELOGE(ret, "[SubgraphPartitionAndOptimization][CompositeEngine] for graph:%s failed", + compute_graph->GetName().c_str()); + return ret; + } + GE_TIMESTAMP_EVENT_END(SubgraphPartitionAndOptimization_CompositeEngine, + "OptimizeSubgraph::SubgraphPartitionAndOptimization::CompositeEngine"); + GE_DUMP(compute_graph, "MergedComputeGraphAfterCompositeEnginePartition"); + + GE_TIMESTAMP_START(SubgraphPartitionAndOptimization_AtomicEngine); + ret = SubgraphPartitionAndOptimization(graph_node, compute_graph, session_id, + GraphPartitioner::kAtomicEnginePartitioning); + if (ret != SUCCESS) { + GELOGE(ret, "[SubgraphPartitionAndOptimization][AtomicEngine] for graph:%s failed", + compute_graph->GetName().c_str()); + return ret; + } + GE_TIMESTAMP_EVENT_END(SubgraphPartitionAndOptimization_AtomicEngine, + "OptimizeSubgraph::SubgraphPartitionAndOptimization::AtomicEngine"); + GE_DUMP(compute_graph, "MergedComputeGraphAfterAtomicEnginePartition"); + + return SUCCESS; +} + +Status GraphManager::SubgraphPartitionAndOptimization(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, + uint64_t session_id, GraphPartitioner::Mode mode) { + if ((mode == GraphPartitioner::kCompositeEnginePartitioning) && + OpsKernelManager::GetInstance().GetCompositeEngines().empty()) { + GELOGI("No composite engine registers, ignore subgraph partition and optimization for composite engine"); + return SUCCESS; + } GE_TIMESTAMP_START(GraphPartition); GraphPartitioner &partitioner = GetCompilerStages(graph_node->GetGraphId()).partitioner; - ret = partitioner.Partition(compute_graph, GraphPartitioner::kPartitioning); + Status ret = partitioner.Partition(compute_graph, mode); if (ret != SUCCESS) { GELOGE(ret, "[Call][Partition] for Graph:%s Failed", compute_graph->GetName().c_str()); return ret; @@ -2863,24 +2904,24 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra return ret; } GE_TIMESTAMP_EVENT_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph"); - std::set build_steps = {BUILD_STEP_BEFORE_UB_MATCH, BUILD_STEP_AFTER_BUILDER, BUILD_STEP_AFTER_BUILDER_SUB}; - if ((options_.build_mode == BUILD_MODE_TUNING) && (build_steps.count(options_.build_step) > 0)) { - GE_TIMESTAMP_START(ConvertGraphToFile); - std::string tuning_path; - (void) GetContext().GetOption(TUNING_PATH, tuning_path); - Status ret = ConvertGraphToFile(compute_graph, partitioner, tuning_path, - (options_.build_step == BUILD_STEP_AFTER_BUILDER)); - if (ret != SUCCESS) { - GELOGE(ret, "[Convert][Graph] [%s] to file failed", compute_graph->GetName().c_str()); - return ret; + if (mode == GraphPartitioner::kAtomicEnginePartitioning) { + std::set build_steps = {BUILD_STEP_BEFORE_UB_MATCH, BUILD_STEP_AFTER_BUILDER, BUILD_STEP_AFTER_BUILDER_SUB}; + if ((options_.build_mode == BUILD_MODE_TUNING) && (build_steps.count(options_.build_step) > 0)) { + GE_TIMESTAMP_START(ConvertGraphToFile); + std::string tuning_path; + (void) GetContext().GetOption(TUNING_PATH, tuning_path); + Status ret = ConvertGraphToFile(compute_graph, partitioner, tuning_path, + (options_.build_step == BUILD_STEP_AFTER_BUILDER)); + if (ret != SUCCESS) { + GELOGE(ret, "[Convert][Graph] [%s] to file failed", compute_graph->GetName().c_str()); + return ret; + } + GE_TIMESTAMP_EVENT_END(ConvertGraphToFile, "OptimizeSubgraph::ConvertGraphToFile"); + return SUCCESS; } - GE_TIMESTAMP_EVENT_END(ConvertGraphToFile, "OptimizeSubgraph::ConvertGraphToFile"); - return SUCCESS; } ComputeGraphPtr merged_compute_graph = nullptr; - std::vector merged_sub_graph_list; - GE_TIMESTAMP_START(MergeSubgraph); ret = MergeSubGraph(merged_compute_graph, compute_graph, graph_node->GetGraphId()); if (ret != SUCCESS) { @@ -2896,27 +2937,31 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra sub_graph->SetSessionID(session_id); sub_graph->SetGraphID(graph_node->GetGraphId()); } - bool off_superkernel = false; - if (AttrUtils::GetBool(compute_graph, ATTR_NAME_OFF_SUPERKERNEL_ATTR, off_superkernel)) { - GELOGI("Compute graph %s get superkernel flag %d.", compute_graph->GetName().c_str(), off_superkernel); - if (!AttrUtils::SetBool(merged_compute_graph, ATTR_NAME_OFF_SUPERKERNEL_ATTR, off_superkernel)) { - REPORT_INNER_ERROR("E19999", "Set Attr:%s to graph:%u fail", + bool off_super_kernel = false; + if (AttrUtils::GetBool(compute_graph, ATTR_NAME_OFF_SUPERKERNEL_ATTR, off_super_kernel)) { + GELOGI("Compute graph %s get super kernel flag %d.", compute_graph->GetName().c_str(), off_super_kernel); + if (!AttrUtils::SetBool(merged_compute_graph, ATTR_NAME_OFF_SUPERKERNEL_ATTR, off_super_kernel)) { + REPORT_INNER_ERROR("E19999", "Set Attr:%s to graph:%u failed", ATTR_NAME_OFF_SUPERKERNEL_ATTR.c_str(), compute_graph->GetGraphID()); - GELOGE(FAILED, "[Set][Attr] %s to graph:%u fail", + GELOGE(FAILED, "[Set][Attr] %s to graph:%u failed", ATTR_NAME_OFF_SUPERKERNEL_ATTR.c_str(), compute_graph->GetGraphID()); return FAILED; } } + bool dynamic_shape_partitioned = false; + if (AttrUtils::GetBool(compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, dynamic_shape_partitioned)) { + GELOGI("Compute graph %s get super kernel flag %d.", compute_graph->GetName().c_str(), dynamic_shape_partitioned); + if (!AttrUtils::SetBool(merged_compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, dynamic_shape_partitioned)) { + REPORT_INNER_ERROR("E19999", "Set Attr:%s to graph:%u failed", + ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED.c_str(), compute_graph->GetGraphID()); + GELOGE(FAILED, "[Set][Attr] %s to graph:%u failed", + ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED.c_str(), compute_graph->GetGraphID()); + return FAILED; + } + } GE_TIMESTAMP_EVENT_END(MergeSubgraph, "OptimizeSubgraph::MergeSubGraph"); - GE_DUMP(merged_compute_graph, "mergedComputeGraph"); compute_graph = merged_compute_graph; - if (!AttrUtils::SetBool(*compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, dynamic_shape_partitioned)) { - REPORT_INNER_ERROR("E19999", "Set Attr:%s to graph:%u fail", - ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED.c_str(), compute_graph->GetGraphID()); - GELOGE(FAILED, "[Set][Attr] %s to graph:%u fail", - ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED.c_str(), compute_graph->GetGraphID()); - return FAILED; - } + return SUCCESS; } diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index e7cd88a9..ea041871 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -243,6 +243,9 @@ class GraphManager { Status OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, uint64_t session_id); + Status SubgraphPartitionAndOptimization(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, + uint64_t session_id, GraphPartitioner::Mode mode); + Status Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, GeRootModelPtr &ge_root_model, uint64_t session_id); diff --git a/ge/graph/optimize/graph_optimize.cc b/ge/graph/optimize/graph_optimize.cc index a321ed43..f7b869ac 100644 --- a/ge/graph/optimize/graph_optimize.cc +++ b/ge/graph/optimize/graph_optimize.cc @@ -17,10 +17,10 @@ #include "graph/optimize/graph_optimize.h" #include "graph/ge_context.h" -#include "common/local_context.h" #include "graph/passes/dimension_adjust_pass.h" #include "inc/pass_manager.h" #include "init/gelib.h" +#include "graph/partition/engine_place.h" namespace { const char *const kVectorCore = "VectorCore"; @@ -85,20 +85,9 @@ Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; } - Status ret = SUCCESS; vector graph_optimizer; - - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - REPORT_INNER_ERROR("E19999", "Gelib not init before, check invalid, graph:%s", - compute_graph->GetName().c_str()); - GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Get][GELib] Gelib not init before, graph:%s", - compute_graph->GetName().c_str()); - return GE_CLI_GE_NOT_INITIALIZED; - } - - if (instance_ptr->DNNEngineManagerObj().IsEngineRegistered(engine_name)) { - instance_ptr->OpsKernelManagerObj().GetGraphOptimizerByEngine(engine_name, graph_optimizer); + if (DNNEngineManager::GetInstance().IsEngineRegistered(engine_name)) { + OpsKernelManager::GetInstance().GetGraphOptimizerByEngine(engine_name, graph_optimizer); AddNodeInputProperty(compute_graph); if (compute_graph->GetDirectNode().size() == 0) { @@ -123,7 +112,7 @@ Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std } for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) { - ret = (*iter)->OptimizeFusedGraph(*(compute_graph)); + Status ret = (*iter)->OptimizeFusedGraph(*(compute_graph)); if (ret != SUCCESS) { REPORT_INNER_ERROR("E19999", "Call OptimizeFusedGraph failed, ret:%d, engine_name:%s, " "graph_name:%s", ret, engine_name.c_str(), @@ -137,7 +126,7 @@ Status GraphOptimize::OptimizeSubGraph(ComputeGraphPtr &compute_graph, const std GELOGI("Engine: %s is not registered. do nothing in subGraph Optimize by ATC.", engine_name.c_str()); } - return ret; + return SUCCESS; } Status GraphOptimize::OptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { @@ -269,28 +258,32 @@ Status GraphOptimize::OptimizeOriginalGraphForQuantize(ComputeGraphPtr &compute_ return ret; } -Status GraphOptimize::OptimizeGraphBeforeBuildForRts(ComputeGraphPtr &compute_graph) { +Status GraphOptimize::OptimizeGraphBeforeBuild(ComputeGraphPtr &compute_graph) { if (compute_graph == nullptr) { REPORT_INNER_ERROR("E19999", "Param compute_graph is nullptr, check invalid"); GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[Check][Param] compute_graph is nullptr."); return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; } - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - REPORT_INNER_ERROR("E19999", "Gelib not init before, check invalid, graph:%s.", - compute_graph->GetName().c_str()); - GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Get][GELib] Gelib not init before, graph:%s.", - compute_graph->GetName().c_str()); - return GE_CLI_GE_NOT_INITIALIZED; + EnginePlacer engine_place(compute_graph); + Status ret = engine_place.Run(); + if (ret != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Assign atomic engine for graph %s failed", compute_graph->GetName().c_str()); + GELOGE(ret, "[Assign][Engine] Assign atomic engine for graph %s failed", compute_graph->GetName().c_str()); + return ret; + } + ret = engine_place.AssignCompositeEngine(); + if (ret != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Assign composite engine for graph %s failed", compute_graph->GetName().c_str()); + GELOGE(ret, "[Assign][Engine] Assign composite engine for graph %s failed", compute_graph->GetName().c_str()); + return ret; } - auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); + auto graph_optimizer = OpsKernelManager::GetInstance().GetAllGraphOptimizerObjsByPriority(); GELOGD("optimize by opskernel in graph optimize before build phase. num of graph_optimizer is %zu.", graph_optimizer.size()); - Status ret = SUCCESS; string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; - GELOGD("[OptimizeGraphBeforeBuildForRts]: engine type will exclude: %s, core_type_: %s", + GELOGD("[OptimizeGraphBeforeBuild]: engine type will exclude: %s, core_type_: %s", exclude_core_Type.c_str(), core_type_.c_str()); if (graph_optimizer.size() != 0) { for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) { @@ -308,7 +301,7 @@ Status GraphOptimize::OptimizeGraphBeforeBuildForRts(ComputeGraphPtr &compute_gr } } } - return ret; + return SUCCESS; } Status GraphOptimize::OptimizeAfterStage1(ComputeGraphPtr &compute_graph) { diff --git a/ge/graph/optimize/graph_optimize.h b/ge/graph/optimize/graph_optimize.h index a3d359b6..ef7182ee 100755 --- a/ge/graph/optimize/graph_optimize.h +++ b/ge/graph/optimize/graph_optimize.h @@ -55,8 +55,8 @@ class GraphOptimize { // for engine to optimize merged whole graph before ge Optimize2 Status OptimizeWholeGraph(ComputeGraphPtr &compute_graph); - // for rts optimize before build to add attr and insert memcpy op - Status OptimizeGraphBeforeBuildForRts(ComputeGraphPtr &compute_graph); + // for optimize before build + Status OptimizeGraphBeforeBuild(ComputeGraphPtr &compute_graph); // optimize whole graph, using after stage1 Status OptimizeAfterStage1(ComputeGraphPtr &graph); diff --git a/ge/graph/partition/engine_place.cc b/ge/graph/partition/engine_place.cc index 8639f015..0821e505 100755 --- a/ge/graph/partition/engine_place.cc +++ b/ge/graph/partition/engine_place.cc @@ -16,19 +16,12 @@ #include "graph/partition/engine_place.h" -#include -#include -#include -#include #include #include "framework/common/op/ge_op_utils.h" -#include "common/util/error_manager/error_manager.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" #include "init/gelib.h" -#include "opskernel_manager/ops_kernel_manager.h" -#include "analyzer/analyzer.h" namespace ge { namespace { @@ -40,7 +33,7 @@ Status EnginePlacer::Check() const { GELOGE(GE_GRAPH_NULL_INPUT, "[Check][Param] compute_graph_ is nullptr."); return FAILED; } - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); + std::shared_ptr instance_ptr = GELib::GetInstance(); if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { REPORT_INNER_ERROR("E19999", "GELib instance is nullptr or it is not InitFlag, check invalid."); GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Get][GELib] Run enginePlacer failed, because GELib is invalid."); @@ -49,7 +42,7 @@ Status EnginePlacer::Check() const { return SUCCESS; } -Status EnginePlacer::Run() { +Status EnginePlacer::Run(bool direct_node_flag) { std::lock_guard lock(check_support_cost_mutex); GELOGD("Engine placer starts."); @@ -58,8 +51,8 @@ Status EnginePlacer::Run() { } bool is_check_support_success = true; // Assign engine for each node in the graph - ge::GELib::GetInstance()->DNNEngineManagerObj().InitPerformanceStaistic(); - for (const auto &node_ptr : compute_graph_->GetDirectNode()) { + DNNEngineManager::GetInstance().InitPerformanceStatistic(); + for (const auto &node_ptr : compute_graph_->GetNodes(direct_node_flag)) { GE_CHECK_NOTNULL(node_ptr); auto op_desc = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -73,9 +66,7 @@ Status EnginePlacer::Run() { bool use_exist_engine_name = !op_desc->GetOpKernelLibName().empty() || (has_kernel_attr && has_engine_attr); if (use_exist_engine_name) { if (op_desc->GetOpEngineName().empty()) { - GELOGI("Op %s set engine_name %s engine_name %s from attrs", - op_desc->GetName().c_str(), - engine_name.c_str(), + GELOGI("Op %s set engine_name %s engine_name %s from attrs", op_desc->GetName().c_str(), engine_name.c_str(), kernel_name.c_str()); op_desc->SetOpEngineName(engine_name); op_desc->SetOpKernelLibName(kernel_name); @@ -83,7 +74,7 @@ Status EnginePlacer::Run() { engine_name = op_desc->GetOpEngineName(); } else { // Call placer cost model to get the "best" engine for this node - engine_name = ge::GELib::GetInstance()->DNNEngineManagerObj().GetDNNEngineName(node_ptr); + engine_name = DNNEngineManager::GetInstance().GetDNNEngineName(node_ptr); // If can't get op's engine name, keep check support finish and return failed if (engine_name.empty()) { is_check_support_success = false; @@ -94,34 +85,48 @@ Status EnginePlacer::Run() { continue; } } - if (AssignEngineAndLog(node_ptr, engine_name) != SUCCESS) { - GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "[Call][AssignEngineAndLog] FAILED, node:%s", op_desc->GetName().c_str()); - return FAILED; - } + + // Record the node assigned atomic_engine name + GELOGD("Assigning DNNEngine %s to node %s, op type %s", engine_name.c_str(), node_ptr->GetName().c_str(), + node_ptr->GetType().c_str()); + node_atomic_engine_map_.insert(std::make_pair(node_ptr, engine_name)); } - for (auto &it : ge::GELib::GetInstance()->DNNEngineManagerObj().GetCheckSupportCost()) { + for (auto &it : DNNEngineManager::GetInstance().GetCheckSupportCost()) { GEEVENT("The time cost of %s::CheckSupported is [%lu] micro second.", it.first.c_str(), it.second); } GELOGD("Engine placer ends."); return is_check_support_success ? SUCCESS : FAILED; } -Status EnginePlacer::AssignEngineAndLog(ge::ConstNodePtr node_ptr, const std::string &engine_name) { - if ((node_ptr == nullptr) || (node_ptr->GetOpDesc() == nullptr)) { - REPORT_INNER_ERROR("E19999", "Param node_ptr is nullptr or it's opdesc is nullptr, check invalid."); - GELOGE(FAILED, "[Check][Param] node_ptr is nullptr."); +Status EnginePlacer::AssignCompositeEngine() { + if (OpsKernelManager::GetInstance().GetCompositeEngines().empty()) { + GELOGI("No composite engine registers, ignore assign composite engine"); + return SUCCESS; + } + std::vector subgraphs; + if (GraphUtils::GetSubgraphsRecursively(compute_graph_, subgraphs) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Get subgraphs contained in graph %s failed", compute_graph_->GetName().c_str()); + GELOGE(FAILED, "[Get][Subgraphs] Get subgraphs contained in graph %s failed", compute_graph_->GetName().c_str()); return FAILED; } - - // private function, promise node_ptr->GetOpDesc() not null - GELOGD("Assigning DNNEngine %s to node %s, op type %s", engine_name.c_str(), node_ptr->GetName().c_str(), - node_ptr->GetOpDesc()->GetType().c_str()); - - // Record the node assigned engine name - node_engine_map_.insert(std::make_pair(node_ptr, engine_name)); - + for (const auto &subgraph : subgraphs) { + (void)subgraph->DelAttr(ATTR_NAME_COMPOSITE_ENGINE_NAME); + } + std::reverse(subgraphs.begin(), subgraphs.end()); + subgraphs.emplace_back(compute_graph_); + for (const auto &subgraph : subgraphs) { + for (const auto &node : subgraph->GetDirectNode()) { + std::string composite_engine_name = DNNEngineManager::GetInstance().GetCompositeEngineName(node, 1); + GELOGD("Assign composite engine %s to node %s, op type %s", composite_engine_name.c_str(), + node->GetName().c_str(), node->GetType().c_str()); + node_composite_engine_map_.insert(std::make_pair(node, composite_engine_name)); + } + } return SUCCESS; } -} // namespace ge +const NodeEngineMap &EnginePlacer::GetNodeEngineMap(bool is_composite_engine_mode) const { + return is_composite_engine_mode ? node_composite_engine_map_ : node_atomic_engine_map_; +} +} // namespace ge diff --git a/ge/graph/partition/engine_place.h b/ge/graph/partition/engine_place.h index 125babb6..813ebf4a 100755 --- a/ge/graph/partition/engine_place.h +++ b/ge/graph/partition/engine_place.h @@ -17,7 +17,6 @@ #ifndef GE_GRAPH_PARTITION_ENGINE_PLACE_H_ #define GE_GRAPH_PARTITION_ENGINE_PLACE_H_ -#include #include #include "framework/common/ge_inner_error_codes.h" @@ -37,19 +36,20 @@ class EnginePlacer { EnginePlacer() = default; ~EnginePlacer() = default; - Status Run(); + Status Run(bool direct_node_flag = true); + Status AssignCompositeEngine(); // Get the unique node-engine map - const NodeEngineMap *GetNodeEngineMap() const { return &node_engine_map_; } + const NodeEngineMap &GetNodeEngineMap(bool is_composite_engine_mode) const; void SetComputeGraph(const ComputeGraphPtr &compute_graph) { compute_graph_ = compute_graph; } private: - Status AssignEngineAndLog(ConstNodePtr node_ptr, const std::string &engine_name); Status Check() const; ComputeGraphPtr compute_graph_; - NodeEngineMap node_engine_map_; + NodeEngineMap node_atomic_engine_map_; + NodeEngineMap node_composite_engine_map_; }; } // namespace ge diff --git a/ge/graph/partition/graph_partition.cc b/ge/graph/partition/graph_partition.cc index 86c9f1fd..2b245e04 100755 --- a/ge/graph/partition/graph_partition.cc +++ b/ge/graph/partition/graph_partition.cc @@ -23,17 +23,12 @@ #include #include "analyzer/analyzer.h" -#include "common/ge/ge_util.h" #include "framework/common/op/ge_op_utils.h" -#include "framework/common/types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/manager/graph_manager_utils.h" #include "common/ge_call_wrapper.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" #include "init/gelib.h" -#include "opskernel_manager/ops_kernel_manager.h" namespace { const char *const kEngineDefaultData = "ENGINE_DEFAULT_DATA"; @@ -386,7 +381,8 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr dst_node_op_desc->GetOpEngineName()), GELOGW("SetStr rearNodeEngineName failed");) // replace input_desc of end with owner node's desc int output_index = ge::AnchorUtils::GetIdx(out_anchor); - bool is_need_update_desc = (output_index >= 0) && (graph_info_.mode_ == kPartitioning); + bool is_need_update_desc = (output_index >= 0) && ((graph_info_.mode_ == kAtomicEnginePartitioning) || + (graph_info_.mode_ == kCompositeEnginePartitioning)); if (is_need_update_desc) { if (UpdateEndOpDesc(src_node, output_index, end_op_desc) != SUCCESS) { GELOGE(GRAPH_PARAM_INVALID, "[Update][EndOpDesc] failed, input index:%d, end_op_desc:%s", @@ -464,7 +460,8 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr graph_info_.num_of_pld_end_++; // replace output_desc of pld with input node's output desc int input_index = ge::AnchorUtils::GetIdx(peer_in_anchor); - is_need_update_desc = (input_index >= 0) && (graph_info_.mode_ == kPartitioning); + is_need_update_desc = (input_index >= 0) && ((graph_info_.mode_ == kAtomicEnginePartitioning) || + (graph_info_.mode_ == kCompositeEnginePartitioning)); if (is_need_update_desc) { if (UpdatePldOpDesc(dst_node, input_index, pld_op_desc) != SUCCESS) { GELOGE(GRAPH_PARAM_INVALID, "[Update][PldOpDesc] failed, output index:%d, pld_op_desc:%s", @@ -629,18 +626,8 @@ bool ge::GraphPartitioner::HasNoInput(ge::NodePtr node) { Status ge::GraphPartitioner::Initialize(ge::ComputeGraphPtr compute_graph) { GELOGI("Initialize starts."); - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || compute_graph == nullptr) { - REPORT_INNER_ERROR("E19999", "compute_graph or instance_ptr of GELib is nullptr, check invalid."); - GELOGE(GE_GRAPH_NOT_INIT, "[Check][Param] compute_graph or instance_ptr of GELib is nullptr."); - return FAILED; - } - graph_info_.engine_placer_.SetComputeGraph(compute_graph); - if (graph_info_.engine_placer_.Run() != SUCCESS) { - GELOGE(FAILED, "[Call][Run] Engine placer run failed, graph:%s.", compute_graph->GetName().c_str()); - return FAILED; - } - const NodeEngineMap *node_engine_map = graph_info_.engine_placer_.GetNodeEngineMap(); + GE_CHECK_NOTNULL(compute_graph); + const auto &node_engine_map = GetNodeEngineMap(); size_t temp_index = 0; // travese nodes by topo order one by one for (const auto &node : compute_graph->GetDirectNode()) { @@ -654,12 +641,12 @@ Status ge::GraphPartitioner::Initialize(ge::ComputeGraphPtr compute_graph) { ClusterPtr cluster = MakeShared(temp_index, kEngineDefaultData, temp_stream); new_cluster = cluster; } else { - if (node_engine_map->count(node) == 0) { + if (node_engine_map.count(node) == 0) { REPORT_INNER_ERROR("E19999", "node:%s not find in node_engine_map", node->GetName().c_str()); GELOGE(FAILED, "[Check][Param] node[%s] does not owner engine!", node->GetName().c_str()); return FAILED; } - ClusterPtr cluster = MakeShared(temp_index, node_engine_map->at(node), temp_stream); + ClusterPtr cluster = MakeShared(temp_index, node_engine_map.at(node), temp_stream); new_cluster = cluster; } if (new_cluster == nullptr) { @@ -999,6 +986,25 @@ bool ge::GraphPartitioner::HasSecondPath(size_t src, size_t dst, size_t upper_bo } Status ge::GraphPartitioner::Partition(ge::ComputeGraphPtr compute_graph, Mode mode) { + if (compute_graph->TopologicalSorting() != SUCCESS) { + REPORT_CALL_ERROR("E19999", "TopologicalSorting for graph:%s failed", + compute_graph->GetName().c_str()); + GELOGE(GE_GRAPH_TOPO_SORT_FAILED, "[Call][TopologicalSorting] for subGraph:%s failed", + compute_graph->GetName().c_str()); + return FAILED; + } + graph_info_.engine_placer_.SetComputeGraph(compute_graph); + if (graph_info_.engine_placer_.Run(false) != SUCCESS) { + GELOGE(FAILED, "[Call][Run] Engine placer run failed, graph:%s.", compute_graph->GetName().c_str()); + return FAILED; + } + if (mode == GraphPartitioner::kCompositeEnginePartitioning) { + if (graph_info_.engine_placer_.AssignCompositeEngine() != SUCCESS) { + GELOGE(FAILED, "[Partition][SubGraph] Assign composite engine for graph %s failed", + compute_graph->GetName().c_str()); + return FAILED; + } + } ClearAllPartitionData(); auto real_ret = SUCCESS; auto ret = PartitionSubGraph(compute_graph, mode); @@ -1043,14 +1049,6 @@ Status ge::GraphPartitioner::PartitionSubGraph(ge::ComputeGraphPtr compute_graph return FAILED; } GELOGI("Graph Partition starts, graph nodes size is %zu", compute_graph->GetDirectNodesSize()); - Status ret = compute_graph->TopologicalSorting(); - if (ret != SUCCESS) { - REPORT_CALL_ERROR("E19999", "TopologicalSorting for graph:%s failed", - compute_graph->GetName().c_str()); - GELOGE(GE_GRAPH_TOPO_SORT_FAILED, "[Call][TopologicalSorting] for subGraph:%s failed", - compute_graph->GetName().c_str()); - return FAILED; - } GE_TIMESTAMP_START(PartitionSubGraphInitialize); if (Initialize(compute_graph) != SUCCESS) { GELOGE(GE_GRAPH_INIT_FAILED, "[Call][Initialize] for graph:%s failed", compute_graph->GetName().c_str()); @@ -1234,4 +1232,8 @@ void ge::GraphPartitioner::ClearAllPartitionData() { GELOGD("Clear all partition data success."); return; } + +const NodeEngineMap &GraphPartitioner::GetNodeEngineMap() const { + return graph_info_.engine_placer_.GetNodeEngineMap(graph_info_.mode_ == kCompositeEnginePartitioning); +} } // namespace ge diff --git a/ge/graph/partition/graph_partition.h b/ge/graph/partition/graph_partition.h index 6c21fabe..3ec36481 100644 --- a/ge/graph/partition/graph_partition.h +++ b/ge/graph/partition/graph_partition.h @@ -56,7 +56,12 @@ class GraphPartitioner { /// Partition() can only be called in Partition mode. /// MergeAfterSubGraphOptimization() can only be called in Merge mode. /// After Partition(), change to Merge mode. After MergeAfterSubGraphOptimization(), change to Partition mode - enum Mode { kPartitioning, kSecondPartitioning, kMerging }; + enum Mode { + kAtomicEnginePartitioning, + kCompositeEnginePartitioning, + kSecondPartitioning, + kMerging + }; GraphPartitioner() : partition_times_(0){}; ~GraphPartitioner() = default; @@ -136,6 +141,8 @@ class GraphPartitioner { void ClearAllPartitionData(); void SetMergedGraphId(ComputeGraphPtr &output_merged_compute_graph); + const NodeEngineMap &GetNodeEngineMap() const; + struct GraphPartitionInfo { EnginePlacer engine_placer_; PartitionMap partitions_; // sub-graphs after partition @@ -165,12 +172,12 @@ class GraphPartitioner { pld_2_end_.clear(); end_2_pld_.clear(); if (mode_ == kMerging) { - mode_ = kPartitioning; + mode_ = kAtomicEnginePartitioning; } else { mode_ = mode; } } - GraphPartitionInfo() : num_of_pld_end_(0), input_size_(0), output_size_(0), mode_(kPartitioning) {} + GraphPartitionInfo() : num_of_pld_end_(0), input_size_(0), output_size_(0), mode_(kAtomicEnginePartitioning) {} ~GraphPartitionInfo() = default; }; std::unordered_map graph_2_graph_partition_info_; @@ -178,8 +185,10 @@ class GraphPartitioner { Graph2InputNodesSubGraphInfo graph_2_input_subgraph_; GraphPartitionInfo graph_info_; uint32_t partition_times_; // times of call partition - std::map mode_2_str_ = {{kPartitioning, "Partitioning"}, - {kSecondPartitioning, "SecondPartitioning"}, {kMerging, "Merging"}}; + std::map mode_2_str_ = {{ kAtomicEnginePartitioning, "AtomicEnginePartitioning" }, + { kCompositeEnginePartitioning, "CompositeEnginePartitioning" }, + { kSecondPartitioning, "SecondPartitioning" }, + { kMerging, "Merging" }}; friend class GraphManager; }; } // namespace ge diff --git a/ge/graph/partition/stage_partition.cc b/ge/graph/partition/stage_partition.cc index 68b4209f..41adcbcd 100644 --- a/ge/graph/partition/stage_partition.cc +++ b/ge/graph/partition/stage_partition.cc @@ -93,15 +93,15 @@ Status StagePartitioner::SplitStageLevel() { auto node = nodes.top(); nodes.pop(); GE_CHECK_NOTNULL(node->GetOpDesc()); - uint32_t tmp_level = cur_stage_level; - (void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_STAGE_LEVEL, tmp_level); - if (tmp_level != cur_stage_level) { - continue; - } for (const auto &in_node : node->GetInAllNodes()) { if (visited_stage_nodes.count(in_node) != 0) { continue; } + uint32_t tmp_level = cur_stage_level; + (void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_STAGE_LEVEL, tmp_level); + if (tmp_level != cur_stage_level) { + continue; + } if (!AttrUtils::SetInt(in_node->GetOpDesc(), ATTR_STAGE_LEVEL, cur_stage_level)) { REPORT_CALL_ERROR("E19999", "Set Attr %s on node %s failed.", ATTR_STAGE_LEVEL.c_str(), in_node->GetName().c_str()); @@ -128,315 +128,27 @@ Status StagePartitioner::SplitStageLevel() { Status StagePartitioner::StagePartition() { for (const auto &stage : stage_nodes_) { - StageInfo stage_info(stage.first); - FindStageIO(stage.second, stage_info); - - std::string subgraph_name = "Subgraph_Level_" + std::to_string(stage.first); - NodePtr graph_node = BuildSubgraphNode(subgraph_name, stage_info); - if (graph_node == nullptr) { - GELOGE(FAILED, "[Build][SubgraphNode] for stage %u failed, graph name:%s.", stage.first, subgraph_name.c_str()); + const std::string &subgraph_name = "Subgraph_Level_" + std::to_string(stage.first); + const auto &stage_subgraph = GraphUtils::BuildSubgraphWithNodes(root_graph_, stage.second, subgraph_name); + if (stage_subgraph == nullptr) { + REPORT_CALL_ERROR("E19999", "Build subgraph %s failed.", subgraph_name.c_str()); + GELOGE(FAILED, "[Build][Subgraph] %s failed.", subgraph_name.c_str()); return FAILED; } - - ComputeGraphPtr subgraph = BuildStageGraph(graph_node, stage_info); - if (subgraph == nullptr) { - GELOGE(FAILED, "[Build][StageGraph] %s for stage %u failed.", graph_node->GetName().c_str(), stage.first); + if (!AttrUtils::SetInt(stage_subgraph, ATTR_STAGE_LEVEL, stage.first)) { + REPORT_CALL_ERROR("E19999", "Set attr %s on graph %s failed.", ATTR_STAGE_LEVEL.c_str(), + stage_subgraph->GetName().c_str()); + GELOGE(FAILED, "[Set][Attr] %s on graph %s failed.", ATTR_STAGE_LEVEL.c_str(), stage_subgraph->GetName().c_str()); return FAILED; } - if (root_graph_->AddSubgraph(subgraph) != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "add subgraph:%s in root graph:%s of stage %u failed.", - subgraph->GetName().c_str(), root_graph_->GetName().c_str(), stage.first); - GELOGE(FAILED, "[Add][SubGraph] %s in root graph:%s of stage %u failed.", - subgraph->GetName().c_str(), root_graph_->GetName().c_str(), stage.first); + const auto &parent_node = stage_subgraph->GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + if (!AttrUtils::SetInt(parent_node->GetOpDesc(), ATTR_STAGE_LEVEL, stage.first)) { + REPORT_CALL_ERROR("E19999", "Set attr %s on node %s failed", ATTR_STAGE_LEVEL.c_str(), + parent_node->GetName().c_str()); + GELOGE(FAILED, "[Set][Attr] %s on node %s failed", ATTR_STAGE_LEVEL.c_str(), parent_node->GetName().c_str()); return FAILED; } - - if ((RelinkDataEdges(graph_node, stage_info) != SUCCESS) || - (RelinkCtrlEdges(graph_node, stage_info) != SUCCESS)) { - GELOGE(FAILED, "[ReLink][Edges] for stage %u failed, graph_node:%s.", stage.first, graph_node->GetName().c_str()); - return FAILED; - } - - for (const auto &stage_node : stage.second) { - if (GraphUtils::RemoveNodeWithoutRelink(root_graph_, stage_node) != GRAPH_SUCCESS) { - GELOGW("Remove node %s failed.", stage_node->GetName().c_str()); - } - } - } - - return SUCCESS; -} - -void StagePartitioner::FindStageIO(const std::unordered_set &stage_nodes, StageInfo &stage_info) { - for (const auto &node : stage_nodes) { - // stage nodes - stage_info.stage_nodes.emplace(node); - // in data nodes - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - continue; - } - if (stage_nodes.count(peer_out_anchor->GetOwnerNode()) == 0) { - stage_info.data_inputs.emplace_back(std::make_pair(peer_out_anchor, in_data_anchor)); - } else { - stage_info.inner_data_edges.emplace_back(std::make_pair(peer_out_anchor, in_data_anchor)); - } - } - // out data nodes - std::list peer_data_anchors; - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - peer_data_anchors.clear(); - for (const auto &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - if (stage_nodes.count(peer_in_anchor->GetOwnerNode()) == 0) { - peer_data_anchors.emplace_back(peer_in_anchor); - } - } - if (!peer_data_anchors.empty()) { - stage_info.data_outputs.emplace_back(std::make_pair(out_data_anchor, peer_data_anchors)); - } - } - // in ctrl nodes - for (const auto &in_ctrl_node : node->GetInControlNodes()) { - if (stage_nodes.count(in_ctrl_node) == 0) { - stage_info.ctrl_inputs.emplace_back(in_ctrl_node->GetOutControlAnchor(), node->GetInControlAnchor()); - } else { - stage_info.inner_ctrl_edges.emplace_back(std::make_pair(in_ctrl_node->GetOutControlAnchor(), - node->GetInControlAnchor())); - } - } - // out ctrl nodes - for (const auto &out_ctrl_node : node->GetOutControlNodes()) { - if (stage_nodes.count(out_ctrl_node) == 0) { - stage_info.ctrl_outputs.emplace_back(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()); - } - } - } -} - -NodePtr StagePartitioner::BuildSubgraphNode(const std::string &graph_name, const StageInfo &stage_info) { - OpDescBuilder op_desc_builder(graph_name, PARTITIONEDCALL); - size_t input_num = stage_info.data_inputs.size(); - for (size_t i = 0; i < input_num; i++) { - auto input_desc = stage_info.data_inputs[i].second->GetOwnerNode()->GetOpDesc(); - if (input_desc == nullptr) { - GELOGE(PARAM_INVALID, "[Check][Param] op_desc is null, node:%s", - stage_info.data_inputs[i].second->GetOwnerNode()->GetName().c_str()); - return nullptr; - } - op_desc_builder.AddInput("args" + std::to_string(i), - input_desc->GetInputDesc(stage_info.data_inputs[i].second->GetIdx())); - } - size_t output_num = stage_info.data_outputs.size(); - for (size_t i = 0; i < output_num; i++) { - auto output_desc = stage_info.data_outputs[i].first->GetOwnerNode()->GetOpDesc(); - if (output_desc == nullptr) { - GELOGE(PARAM_INVALID, "[Check][Param] op_desc is null, node:%s", - stage_info.data_outputs[i].first->GetOwnerNode()->GetName().c_str()); - return nullptr; - } - op_desc_builder.AddOutput("output" + std::to_string(i), - output_desc->GetOutputDesc(stage_info.data_outputs[i].first->GetIdx())); - } - - OpDescPtr op_desc = op_desc_builder.Build(); - if (op_desc == nullptr) { - GELOGE(FAILED, "[Create][OpDesc] for subgraph node failed, name:%s.", graph_name.c_str()); - return nullptr; - } - - op_desc->AddSubgraphName("f"); - op_desc->SetSubgraphInstanceName(0, graph_name); - - if (!AttrUtils::SetInt(op_desc, ATTR_STAGE_LEVEL, stage_info.stage_level)) { - REPORT_CALL_ERROR("E19999", "set attr %s on node %s failed", ATTR_STAGE_LEVEL.c_str(), op_desc->GetName().c_str()); - GELOGE(INTERNAL_ERROR, "[Set][Attr] %s on node %s failed", ATTR_STAGE_LEVEL.c_str(), op_desc->GetName().c_str()); - return nullptr; - } - - NodePtr subgraph_node = root_graph_->AddNode(op_desc); - if (subgraph_node == nullptr) { - REPORT_CALL_ERROR("E19999", "add node:%s in graph:%s failed.", - op_desc->GetName().c_str(), root_graph_->GetName().c_str()); - GELOGE(FAILED, "[Add][Node] %s in graph:%s failed.", op_desc->GetName().c_str(), root_graph_->GetName().c_str()); - return nullptr; - } - if (subgraph_node->SetOwnerComputeGraph(root_graph_) != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "SetOwnerComputeGraph for node %s failed, grpah:%s.", - subgraph_node->GetName().c_str(), root_graph_->GetName().c_str()); - GELOGE(FAILED, "[Set][OwnerGraph] for node %s failed, grpah:%s.", - subgraph_node->GetName().c_str(), root_graph_->GetName().c_str()); - return nullptr; - } - - return subgraph_node; -} - -ComputeGraphPtr StagePartitioner::BuildStageGraph(const NodePtr &subgraph_node, const StageInfo &stage_info) { - CompleteGraphBuilder graph_builder(subgraph_node->GetName(), false); - // Add parent node - graph_builder.SetParentNode(subgraph_node); - - // Add node - for (const auto &node : stage_info.stage_nodes) { - graph_builder.AddNode(AttrUtils::CopyOpDesc(node->GetOpDesc())); - } - - // Set Input - size_t data_input_num = stage_info.data_inputs.size(); - for (size_t i = 0; i < data_input_num; i++) { - graph_builder.SetInput(i, { stage_info.data_inputs[i].second->GetOwnerNode()->GetName() }, - { static_cast(stage_info.data_inputs[i].second->GetIdx()) }); - } - - // Add Outputs - size_t data_output_num = stage_info.data_outputs.size(); - for (uint32_t i = 0; i < data_output_num; i++) { - graph_builder.AddOutput(stage_info.data_outputs[i].first->GetOwnerNode()->GetName(), - stage_info.data_outputs[i].first->GetIdx()); - } - - // Add Data Edges - for (const auto &data_edge : stage_info.inner_data_edges) { - graph_builder.AddDataLink(data_edge.first->GetOwnerNode()->GetName(), data_edge.first->GetIdx(), - data_edge.second->GetOwnerNode()->GetName(), data_edge.second->GetIdx()); - } - - // Add Ctrl Edges - for (const auto &ctrl_edge : stage_info.inner_ctrl_edges) { - graph_builder.AddControlLink(ctrl_edge.first->GetOwnerNode()->GetName(), - ctrl_edge.second->GetOwnerNode()->GetName()); - } - - // Add Input-Mapping - std::map input_mapping; - for (size_t i = 0; i < data_input_num; i++) { - input_mapping[i] = i; - } - graph_builder.SetInputMapping(input_mapping); - - // Add outputMapping - std::map output_mapping; - for (size_t i = 0; i < data_output_num; i++) { - output_mapping[i] = i; - } - graph_builder.SetOutputMapping(output_mapping); - - graphStatus error_code = GRAPH_SUCCESS; - std::string error_msg; - ComputeGraphPtr subgraph = graph_builder.Build(error_code, error_msg); - if (subgraph == nullptr) { - GELOGE(error_code, "[Build][Subgraph] %s failed:%s.", subgraph_node->GetName().c_str(), error_msg.c_str()); - return nullptr; - } - if (!AttrUtils::SetInt(subgraph, ATTR_STAGE_LEVEL, stage_info.stage_level)) { - REPORT_CALL_ERROR("E19999", "set attr %s on graph %s failed.", - ATTR_STAGE_LEVEL.c_str(), subgraph->GetName().c_str()); - GELOGE(FAILED, "[Set][Attr] %s on graph %s failed.", ATTR_STAGE_LEVEL.c_str(), subgraph->GetName().c_str()); - return nullptr; - } - - return subgraph; -} - -Status StagePartitioner::RelinkDataEdges(const NodePtr &subgraph_node, const StageInfo &stage_info) { - // in data nodes - for (size_t i = 0; i < stage_info.data_inputs.size(); i++) { - if (stage_info.data_inputs[i].first->Unlink(stage_info.data_inputs[i].second) != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "remove data edge from %s:%d to %s:%d failed", - stage_info.data_inputs[i].first->GetOwnerNode()->GetName().c_str(), - stage_info.data_inputs[i].first->GetIdx(), - stage_info.data_inputs[i].second->GetOwnerNode()->GetName().c_str(), - stage_info.data_inputs[i].second->GetIdx()); - GELOGE(INTERNAL_ERROR, "[Remove][DataEdge] %s:%d->%s:%d failed.", - stage_info.data_inputs[i].first->GetOwnerNode()->GetName().c_str(), - stage_info.data_inputs[i].first->GetIdx(), - stage_info.data_inputs[i].second->GetOwnerNode()->GetName().c_str(), - stage_info.data_inputs[i].second->GetIdx()); - return INTERNAL_ERROR; - } - if (stage_info.data_inputs[i].first->LinkTo(subgraph_node->GetInDataAnchor(i)) != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "add data edge from %s:%d to %s:%zu failed.", - stage_info.data_inputs[i].first->GetOwnerNode()->GetName().c_str(), - stage_info.data_inputs[i].first->GetIdx(), - subgraph_node->GetName().c_str(), i); - GELOGE(INTERNAL_ERROR, "[Add][DataEdge] %s:%d->%s:%zu failed.", - stage_info.data_inputs[i].first->GetOwnerNode()->GetName().c_str(), - stage_info.data_inputs[i].first->GetIdx(), - subgraph_node->GetName().c_str(), i); - return INTERNAL_ERROR; - } - } - // out data nodes - for (size_t i = 0; i < stage_info.data_outputs.size(); i++) { - const auto &out_data_anchor = subgraph_node->GetOutDataAnchor(i); - GE_CHECK_NOTNULL(out_data_anchor); - for (const auto &peer_in_anchor : stage_info.data_outputs[i].second) { - if (stage_info.data_outputs[i].first->Unlink(peer_in_anchor) != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Remove data edge from %s:%d to %s:%d failed.", - stage_info.data_outputs[i].first->GetOwnerNode()->GetName().c_str(), - stage_info.data_outputs[i].first->GetIdx(), - peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx()); - GELOGE(INTERNAL_ERROR, "[Remove][DataEdge] %s:%d->%s:%d failed.", - stage_info.data_outputs[i].first->GetOwnerNode()->GetName().c_str(), - stage_info.data_outputs[i].first->GetIdx(), - peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx()); - return INTERNAL_ERROR; - } - if (out_data_anchor->LinkTo(peer_in_anchor) != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Add data edge from %s:%zu to %s:%d failed.", subgraph_node->GetName().c_str(), i, - peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx()); - GELOGE(INTERNAL_ERROR, "[Add][DataEdge] %s:%zu->%s:%d failed.", subgraph_node->GetName().c_str(), i, - peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx()); - return INTERNAL_ERROR; - } - } - } - - return SUCCESS; -} - -Status StagePartitioner::RelinkCtrlEdges(const NodePtr &subgraph_node, const StageInfo &stage_info) { - // in ctrl nodes - for (const auto &ctrl_input : stage_info.ctrl_inputs) { - if (ctrl_input.first->Unlink(ctrl_input.second) != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Remove ctrl edge %s->%s failed.", - ctrl_input.first->GetOwnerNode()->GetName().c_str(), - ctrl_input.second->GetOwnerNode()->GetName().c_str()); - GELOGE(INTERNAL_ERROR, "[Remove][CtrlEdge] %s->%s failed.", - ctrl_input.first->GetOwnerNode()->GetName().c_str(), ctrl_input.second->GetOwnerNode()->GetName().c_str()); - return INTERNAL_ERROR; - } - if (!ctrl_input.first->IsLinkedWith(subgraph_node->GetInControlAnchor())) { - if (ctrl_input.first->LinkTo(subgraph_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Add ctrl edge %s->%s failed.", - ctrl_input.first->GetOwnerNode()->GetName().c_str(), subgraph_node->GetName().c_str()); - GELOGE(INTERNAL_ERROR, "[Add][CtrlEdge] %s->%s failed.", - ctrl_input.first->GetOwnerNode()->GetName().c_str(), subgraph_node->GetName().c_str()); - return INTERNAL_ERROR; - } - } - } - // out ctrl nodes - for (const auto &ctrl_output : stage_info.ctrl_outputs) { - if (ctrl_output.first->Unlink(ctrl_output.second) != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Remove ctrl edge %s->%s failed.", - ctrl_output.first->GetOwnerNode()->GetName().c_str(), - ctrl_output.second->GetOwnerNode()->GetName().c_str()); - GELOGE(INTERNAL_ERROR, "[Remove][CtrlEdge] %s->%s failed.", - ctrl_output.first->GetOwnerNode()->GetName().c_str(), - ctrl_output.second->GetOwnerNode()->GetName().c_str()); - return INTERNAL_ERROR; - } - if (!subgraph_node->GetOutControlAnchor()->IsLinkedWith(ctrl_output.second)) { - if (subgraph_node->GetOutControlAnchor()->LinkTo(ctrl_output.second) != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Add ctrl edge %s->%s failed.", - subgraph_node->GetName().c_str(), ctrl_output.second->GetOwnerNode()->GetName().c_str()); - GELOGE(INTERNAL_ERROR, "[Add][CtrlEdge] %s->%s failed.", - subgraph_node->GetName().c_str(), ctrl_output.second->GetOwnerNode()->GetName().c_str()); - return INTERNAL_ERROR; - } - } } return SUCCESS; diff --git a/ge/graph/partition/stage_partition.h b/ge/graph/partition/stage_partition.h index 99aac2b9..9b71f198 100644 --- a/ge/graph/partition/stage_partition.h +++ b/ge/graph/partition/stage_partition.h @@ -17,26 +17,10 @@ #ifndef GE_GRAPH_PARTITION_STAGE_PARTITION_H_ #define GE_GRAPH_PARTITION_STAGE_PARTITION_H_ -#include -#include -#include -#include #include "framework/common/ge_inner_error_codes.h" #include "graph/compute_graph.h" namespace ge { -struct StageInfo { - explicit StageInfo(uint32_t level) : stage_level(level) {} - uint32_t stage_level; - std::unordered_set stage_nodes; - std::vector> data_inputs; - std::vector>> data_outputs; - std::list> ctrl_inputs; - std::list> ctrl_outputs; - std::list> inner_data_edges; - std::list> inner_ctrl_edges; -}; - class StagePartitioner { public: explicit StagePartitioner(ComputeGraphPtr graph) : root_graph_(std::move(graph)) {} @@ -49,18 +33,8 @@ class StagePartitioner { Status StagePartition(); - static void FindStageIO(const std::unordered_set &stage_nodes, StageInfo &stage_info); - - NodePtr BuildSubgraphNode(const std::string &graph_name, const StageInfo &stage_info); - - static ComputeGraphPtr BuildStageGraph(const NodePtr &subgraph_node, const StageInfo &stage_info); - - static Status RelinkDataEdges(const NodePtr &subgraph_node, const StageInfo &stage_info); - - static Status RelinkCtrlEdges(const NodePtr &subgraph_node, const StageInfo &stage_info); - ComputeGraphPtr root_graph_; - std::map> stage_nodes_; + std::map> stage_nodes_; }; } // namespace ge diff --git a/ge/graph/passes/end_of_sequence_add_control_pass.cc b/ge/graph/passes/end_of_sequence_add_control_pass.cc index 0aee7b03..b1e81968 100755 --- a/ge/graph/passes/end_of_sequence_add_control_pass.cc +++ b/ge/graph/passes/end_of_sequence_add_control_pass.cc @@ -20,41 +20,30 @@ #include #include "init/gelib.h" -#include "graph/node.h" namespace ge { Status EndOfSequenceAddControlPass::Run(ComputeGraphPtr graph) { - if (graph == nullptr) { - REPORT_INNER_ERROR("E19999", "Param graph is nullptr, check invalid"); - GELOGE(PARAM_INVALID, "[Check][Param] param [graph] must not be null."); - return PARAM_INVALID; - } if (graph->GetParentGraph() != nullptr) { return SUCCESS; } - NodePtr end_of_sequence = GetEndOfSequence(graph); + const auto &end_of_sequence = graph->FindFirstNodeMatchType(ENDOFSEQUENCE); if (end_of_sequence == nullptr) { return SUCCESS; } - GELOGI("EndOfSequenceAddControlPass begin."); + GELOGI("EndOfSequenceAddControlPass begin."); std::vector target_nodes; for (NodePtr &node : graph->GetDirectNode()) { - if (node == nullptr) { - GELOGW("node is nullptr."); - continue; - } - string stream_label; - (void)AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); - if (!stream_label.empty() || IsDataLikeNode(node)) { + // op_desc of node should not be null + if (node->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL) || + DNNEngineManager::GetInstance().IsStreamAssignSkip(node)) { continue; } // Save the nodes whose pre-nodes are all data-like node - auto in_data_nodes = node->GetInDataNodes(); bool flag = false; - for (auto in_node : in_data_nodes) { - if (!IsDataLikeNode(in_node)) { + for (const auto &in_node : node->GetInDataNodes()) { + if (!DNNEngineManager::GetInstance().IsStreamAssignSkip(in_node)) { flag = true; break; } @@ -64,83 +53,20 @@ Status EndOfSequenceAddControlPass::Run(ComputeGraphPtr graph) { } target_nodes.push_back(node); } - // Insert control edge - Status status = AddControlEdge(end_of_sequence, target_nodes); - if (status != SUCCESS) { - GELOGE(FAILED, "[Add][ControlEdge] Graph add EndOfSequence op:%s out ctrl edge failed.", - end_of_sequence->GetName().c_str()); - return FAILED; - } - GELOGI("EndOfSequenceAddControlPass end."); - return SUCCESS; -} -Status EndOfSequenceAddControlPass::AddControlEdge(NodePtr &end_of_sequence, std::vector &target_nodes) { - auto out_ctrl_anchor = end_of_sequence->GetOutControlAnchor(); - for (NodePtr &node : target_nodes) { - auto in_ctrl_anchor = node->GetInControlAnchor(); - if (in_ctrl_anchor == nullptr) { - continue; - } - Status status = GraphUtils::AddEdge(out_ctrl_anchor, in_ctrl_anchor); - if (status != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed", - end_of_sequence->GetName().c_str(), end_of_sequence->GetType().c_str(), - node->GetName().c_str(), node->GetType().c_str()); - GELOGE(FAILED, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed", - end_of_sequence->GetName().c_str(), end_of_sequence->GetType().c_str(), - node->GetName().c_str(), node->GetType().c_str()); + // Insert control edge + for (const auto &node : target_nodes) { + GELOGI("Add ctrl edge between %s and %s", end_of_sequence->GetName().c_str(), node->GetName().c_str()); + if (GraphUtils::AddEdge(end_of_sequence->GetOutControlAnchor(), node->GetInControlAnchor()) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Add ctrl edge between %s and %s failed", end_of_sequence->GetName().c_str(), + node->GetName().c_str()); + GELOGE(FAILED, "[Add][CtrlEdge] between %s and %s failed", end_of_sequence->GetName().c_str(), + node->GetName().c_str()); return FAILED; } - GELOGI("Graph add EndOfSequence op out ctrl edge, dst node: %s.", node->GetName().c_str()); } - return SUCCESS; -} -inline NodePtr EndOfSequenceAddControlPass::GetEndOfSequence(const ComputeGraphPtr &graph) const { - // Internal function, guaranteeing graph non-null - for (NodePtr &node : graph->GetDirectNode()) { - if (node->GetType() == ENDOFSEQUENCE) { - return node; - } - } - return nullptr; -} - -bool EndOfSequenceAddControlPass::IsDataLikeNode(const NodePtr &node) { - std::shared_ptr instance_ptr = GELib::GetInstance(); - if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { - GELOGW("GELib not initialized"); - return false; - } - OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - return false; - } - string engine_name = op_desc->GetOpEngineName(); - if (engine_name.empty()) { - engine_name = instance_ptr->DNNEngineManagerObj().GetDNNEngineName(node); - } - const map schedulers = instance_ptr->DNNEngineManagerObj().GetSchedulers(); - // Only one scheduler has been supported by now - for (auto schedulers_iter = schedulers.begin(); schedulers_iter != schedulers.end(); ++schedulers_iter) { - const map cal_engines = schedulers_iter->second.cal_engines; - auto cal_engines_iter = cal_engines.find(engine_name); - if (cal_engines_iter == cal_engines.end()) { - GELOGW("No cal_engines found within engine %s, node name %s", engine_name.c_str(), node->GetName().c_str()); - continue; - } - EngineConfPtr engine_conf_ptr = cal_engines_iter->second; - if (engine_conf_ptr == nullptr) { - GELOGW("engine_conf_ptr within engine %s, node name %s is null", engine_name.c_str(), node->GetName().c_str()); - continue; - } - bool skip_assign_stream = engine_conf_ptr->skip_assign_stream; - if (skip_assign_stream) { - return true; - } - return false; - } - return false; + GELOGI("EndOfSequenceAddControlPass end."); + return SUCCESS; } } // namespace ge diff --git a/ge/graph/passes/end_of_sequence_add_control_pass.h b/ge/graph/passes/end_of_sequence_add_control_pass.h index 32ee0b25..b36ad0a8 100644 --- a/ge/graph/passes/end_of_sequence_add_control_pass.h +++ b/ge/graph/passes/end_of_sequence_add_control_pass.h @@ -30,26 +30,6 @@ class EndOfSequenceAddControlPass : public GraphPass { ~EndOfSequenceAddControlPass() override {} Status Run(ComputeGraphPtr graph) override; - - private: - /** - * Get EndOfSequence node in graph, nullptr if not exist. - * @param graph - * @return EndOfSequence node - */ - inline NodePtr GetEndOfSequence(const ComputeGraphPtr &graph) const; - /** - * Check whether this node is a data-like node. - * @param node - * @return - */ - bool IsDataLikeNode(const NodePtr &node); - /** - * Check whether this node is a data-like node. - * @param node - * @return - */ - Status AddControlEdge(NodePtr &end_of_sequence, std::vector &target_nodes); }; } // namespace ge diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index 446af9bf..ed6d5680 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -23,7 +23,6 @@ #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_transpose.h" #include "common/formats/utils/formats_trans_utils.h" -#include "common/util/error_manager/error_manager.h" #include "framework/common/helper/model_helper.h" #include "common/math/math_util.h" #include "framework/common/op/ge_op_utils.h" @@ -39,7 +38,6 @@ #include "graph/passes/addn_pass.h" #include "graph/passes/aicpu_constant_folding_pass.h" #include "graph/passes/assert_pass.h" -#include "external/ge/ge_api_types.h" #include "graph/passes/common_subexpression_elimination_pass.h" #include "graph/passes/cond_pass.h" #include "graph/passes/cond_remove_pass.h" @@ -774,7 +772,12 @@ Status UpdateSubgraphDataOfCase(NodePtr &mbatch_node, DataType &dt_set, int32_t return SUCCESS; } - auto subgraphs = NodeUtils::GetAllSubgraphs(*mbatch_node); + std::vector subgraphs; + if (NodeUtils::GetDirectSubgraphs(mbatch_node, subgraphs) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Get subgraphs of node %s failed", mbatch_node->GetName().c_str()); + GELOGE(FAILED, "[Check][Param] Get subgraphs of node %s failed", mbatch_node->GetName().c_str()); + return FAILED; + } for (const auto &subgraph : subgraphs) { GE_CHECK_NOTNULL(subgraph); for (auto &sub_node : subgraph->GetDirectNode()) { diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 44115240..c89fbc42 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -60,7 +60,6 @@ const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; const char *const kForceInfershape = "_force_infershape_when_running"; const std::set kExecutionDependentTypes{ IF, STATELESSIF, CASE, STREAMSWITCH }; -const std::set kMergeInputSkipTypes{ STREAMACTIVE, STREAMSWITCH, CONSTANT, CONSTANTOP }; const std::set kStreamActiveTypes{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; Status SetOutputNameAttr(ComputeGraph &graph) { @@ -519,170 +518,6 @@ Status HybridModelBuilder::UpdateAnchorStatus(const NodePtr &node) { return SUCCESS; } -Status HybridModelBuilder::DoUnlinkDataAnchors(const OutDataAnchorPtr &out_data_anchor, - const InDataAnchorPtr &in_data_anchor) { - GE_CHK_GRAPH_STATUS_RET(out_data_anchor->Unlink(in_data_anchor), - "[Invoke][Unlink] failed to unlink %s:%d from %s:%d", - out_data_anchor->GetOwnerNode()->GetName().c_str(), out_data_anchor->GetIdx(), - in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetIdx()); - - GELOGD("Succeeded in unlinking %s:%d from %s:%d", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - out_data_anchor->GetIdx(), - in_data_anchor->GetOwnerNode()->GetName().c_str(), - in_data_anchor->GetIdx()); - return SUCCESS; -} - -Status HybridModelBuilder::DoLinkDataAnchors(OutDataAnchorPtr &out_data_anchor, InDataAnchorPtr &in_data_anchor) { - GE_CHK_GRAPH_STATUS_RET(out_data_anchor->LinkTo(in_data_anchor), "[Invoke][LinkTo]Failed to link %s:%d to %s:%d", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - out_data_anchor->GetIdx(), - in_data_anchor->GetOwnerNode()->GetName().c_str(), - in_data_anchor->GetIdx()); - - GELOGD("Succeeded in linking %s:%d to %s:%d", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - out_data_anchor->GetIdx(), - in_data_anchor->GetOwnerNode()->GetName().c_str(), - in_data_anchor->GetIdx()); - return SUCCESS; -} - -Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { - const auto &wrapped_node = graph.GetParentNode(); - std::set root_nodes; - for (const auto &node : graph.GetDirectNode()) { - GE_CHECK_NOTNULL(node); - if (node->GetType() != DATA_TYPE) { - if (node->GetInDataNodes().empty()) { - root_nodes.emplace(node); - } - - continue; - } - - auto data_op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(data_op_desc); - - uint32_t parent_index = 0; - if (!AttrUtils::GetInt(data_op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGE(FAILED, "[Invoke][GetInt] failed, node:[%s] attr:[%s]", - data_op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); - REPORT_CALL_ERROR("E19999", "GetInt failed, node:[%s] attr:[%s]", - data_op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); - return FAILED; - } - - auto wrapped_node_in_anchor = wrapped_node->GetInDataAnchor(parent_index); - GE_CHECK_NOTNULL(wrapped_node_in_anchor); - auto src_out_anchor = wrapped_node_in_anchor->GetPeerOutAnchor(); - if (src_out_anchor == nullptr || src_out_anchor->GetOwnerNode() == nullptr) { - continue; - } - wrapped_node_in_anchor->UnlinkAll(); - - // link src to outputs of DataNode - for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { - GE_CHECK_NOTNULL(out_data_anchor); - for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - auto dst_node = peer_in_data_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(dst_node); - const auto in_nodes = dst_node->GetInDataNodes(); - if (std::all_of(in_nodes.begin(), in_nodes.end(), [](const NodePtr &n) { return n->GetType() == DATA; })) { - root_nodes.emplace(dst_node); - } - GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(out_data_anchor, peer_in_data_anchor)); - GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, peer_in_data_anchor)); - } - } - } - - // transfer in control edges to all root nodes - for (auto &root_node : root_nodes) { - auto in_nodes = root_node->GetInAllNodes(); - std::set in_node_set(in_nodes.begin(), in_nodes.end()); - for (auto &in_control_node : wrapped_node->GetInControlNodes()) { - if (in_node_set.count(in_control_node) == 0 && kMergeInputSkipTypes.count(root_node->GetType()) == 0) { - GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str()); - GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor()); - (void) in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor()); - } - } - } - - wrapped_node->GetInControlAnchor()->UnlinkAll(); - return SUCCESS; -} - -Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { - const auto &parent_node = graph.GetParentNode(); - const NodePtr &net_output_node = graph.FindFirstNodeMatchType(NETOUTPUT); - if (net_output_node == nullptr) { - GELOGD("Graph has no netoutput no need to merge"); - return SUCCESS; - } - const auto &net_output_desc = net_output_node->GetOpDesc(); - GE_CHECK_NOTNULL(net_output_desc); - - auto all_in_nodes = net_output_node->GetInAllNodes(); - auto all_out_nodes = parent_node->GetOutAllNodes(); - net_output_node->GetInControlAnchor()->UnlinkAll(); - parent_node->GetOutControlAnchor()->UnlinkAll(); - - for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { - auto src_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(src_out_anchor); - GE_CHECK_NOTNULL(src_out_anchor->GetOwnerNode()); - GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(src_out_anchor, in_data_anchor)); - - auto index = in_data_anchor->GetIdx(); - auto input_desc = net_output_desc->MutableInputDesc(index); - if (input_desc == nullptr) { - GELOGE(INTERNAL_ERROR, "[Invoke][MutableInputDesc][%s] Failed to get input desc[%d]", - net_output_desc->GetName().c_str(), index); - REPORT_CALL_ERROR("E19999", "[%s] Failed to get input desc[%d].", net_output_desc->GetName().c_str(), index); - return INTERNAL_ERROR; - } - - uint32_t parent_index = 0; - if (!AttrUtils::GetInt(input_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGW("SubGraph: %s NetOutput input tensor %d, attr %s not found.", - graph.GetName().c_str(), index, ATTR_NAME_PARENT_NODE_INDEX.c_str()); - continue; - } - - const OutDataAnchorPtr &parent_out_anchor = parent_node->GetOutDataAnchor(parent_index); - GE_CHECK_NOTNULL(parent_out_anchor); - for (InDataAnchorPtr &dst_in_anchor : parent_out_anchor->GetPeerInDataAnchors()) { - if (dst_in_anchor == nullptr) { - continue; - } - - GE_CHECK_NOTNULL(dst_in_anchor->GetOwnerNode()); - GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(parent_out_anchor, dst_in_anchor)); - GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, dst_in_anchor)); - } - } - - // transfer out control edges - std::set in_node_set(all_in_nodes.begin(), all_in_nodes.end()); - std::set out_node_set(all_out_nodes.begin(), all_out_nodes.end()); - for (auto &src_node : in_node_set) { - GELOGD("[%s] process in node.", src_node->GetName().c_str()); - auto out_nodes = src_node->GetOutAllNodes(); - std::set node_set(out_nodes.begin(), out_nodes.end()); - for (auto &dst_node : out_node_set) { - if (node_set.count(dst_node) == 0) { - src_node->GetOutControlAnchor()->LinkTo(dst_node->GetInControlAnchor()); - GELOGD("[%s] Restore control edge to [%s]", src_node->GetName().c_str(), dst_node->GetName().c_str()); - } - } - } - - return SUCCESS; -} - Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { merged_graph = MakeShared("MergedGraph"); merged_graph->SetGraphUnknownFlag(root_graph->GetGraphUnknownFlag()); @@ -716,9 +551,21 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeG } } } - GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph), + + const auto &filter = [](const ComputeGraphPtr &graph) { + const auto &parent_node = graph->GetParentNode(); + if (parent_node == nullptr || parent_node->GetOpDesc() == nullptr) { + return false; + } + if ((parent_node->GetType() != PARTITIONEDCALL) || + (parent_node->GetOpDesc()->GetSubgraphInstanceNames().size() != 1)) { + return false; + } + return graph->GetGraphUnknownFlag(); + }; + GE_CHK_GRAPH_STATUS_RET(GraphUtils::UnfoldSubgraph(subgraph, filter), "[Invoke][UnfoldSubgraph][%s] Failed to merge subgraph.", - subgraph->GetName().c_str()); + subgraph->GetName().c_str()) } // invoke before adding subgraphs. in case modify node id in known-shaped subgraphs. @@ -744,56 +591,6 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeG return SUCCESS; } -Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, - ComputeGraphPtr &parent_graph, - ComputeGraph &sub_graph) { - auto parent_node = sub_graph.GetParentNode(); - GE_CHECK_NOTNULL(parent_node); - - GE_CHK_STATUS_RET(MergeInputNodes(sub_graph), - "[Invoke][MergeInputNodes][%s] Failed to merge data nodes for subgraph", - sub_graph.GetName().c_str()); - GE_CHK_STATUS_RET(MergeNetOutputNode(sub_graph), - "[Invoke][MergeNetOutputNode][%s] Failed to merge net output nodes for subgraph", - sub_graph.GetName().c_str()); - GELOGD("[%s] Done merging subgraph inputs and outputs successfully", sub_graph.GetName().c_str()); - - for (auto &sub_node : sub_graph.GetDirectNode()) { - auto sub_op_type = sub_node->GetType(); - if (sub_op_type == DATA_TYPE || sub_op_type == NETOUTPUT) { - continue; - } - if (sub_op_type == PARTITIONEDCALL) { - auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, kSubgraphIndex); - GE_CHECK_NOTNULL(sub_sub_graph); - if (sub_sub_graph->GetGraphUnknownFlag()) { - GE_CHK_STATUS_RET(UnfoldSubgraph(root_graph, parent_graph, *sub_sub_graph), - "[Invoke][UnfoldSubgraph][%s] Failed to merge subgraph", - sub_sub_graph->GetName().c_str()); - continue; - } - } - - if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { - for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { - auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i); - GE_CHECK_NOTNULL(sub_sub_graph); - sub_sub_graph->SetParentGraph(parent_graph); - } - } - parent_graph->AddNode(sub_node); - GELOGD("[%s::%s] added to parent graph: [%s].", - sub_graph.GetName().c_str(), - sub_node->GetName().c_str(), - parent_graph->GetName().c_str()); - sub_node->SetOwnerComputeGraph(parent_graph); - } - - GELOGD("[%s] Done merging subgraph. remove it from root graph", sub_graph.GetName().c_str()); - root_graph->RemoveSubgraph(sub_graph.GetName()); - return SUCCESS; -} - Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, const NodeItem &node_item, bool is_root_graph) { diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 3592d3d2..52d519ef 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -39,16 +39,11 @@ class HybridModelBuilder { private: static Status UpdateAnchorStatus(const NodePtr &node); - static Status DoUnlinkDataAnchors(const OutDataAnchorPtr &out_data_anchor, const InDataAnchorPtr &in_data_anchor); - static Status DoLinkDataAnchors(OutDataAnchorPtr &out_data_anchor, InDataAnchorPtr &in_data_anchor); static NodePtr GetPeerNode(const InDataAnchorPtr &in_data_anchor); static Status GetParentNodeOutputIndex(const OpDesc &op_desc, int index, uint32_t &out_index); static Status GetPeerNodeAcrossSubGraphs(const NodePtr &data_node, NodePtr &peer_node, int &peer_out_index); static Status HandleDtString(const GeTensor &tensor, void *var_addr); - static Status MergeInputNodes(ComputeGraph &compute_graph); - static Status MergeNetOutputNode(ComputeGraph &compute_graph); static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph); - static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph); static Status BuildInputMapping(GraphItem &graph_item, std::vector &data_nodes, bool is_root_graph); diff --git a/ge/init/gelib.cc b/ge/init/gelib.cc index 2491715b..e4de6b64 100644 --- a/ge/init/gelib.cc +++ b/ge/init/gelib.cc @@ -125,7 +125,7 @@ Status GELib::InnerInitialize(const map &options) { ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kEngineInit); GELOGI("engineManager initial."); GE_TIMESTAMP_START(EngineInitialize); - Status initEmStatus = engineManager_.Initialize(options); + Status initEmStatus = DNNEngineManager::GetInstance().Initialize(options); GE_TIMESTAMP_END(EngineInitialize, "InnerInitialize::EngineInitialize"); if (initEmStatus != SUCCESS) { GELOGE(initEmStatus, "[Init][EngineManager]GE engine manager initial failed."); @@ -137,7 +137,7 @@ Status GELib::InnerInitialize(const map &options) { ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOpsKernelInit); GELOGI("opsManager initial."); GE_TIMESTAMP_START(OpsManagerInitialize); - Status initOpsStatus = opsManager_.Initialize(options); + Status initOpsStatus = OpsKernelManager::GetInstance().Initialize(options); GE_TIMESTAMP_END(OpsManagerInitialize, "InnerInitialize::OpsManagerInitialize"); if (initOpsStatus != SUCCESS) { GELOGE(initOpsStatus, "[Init][OpsManager]GE ops manager initial failed."); @@ -247,7 +247,7 @@ Status GELib::SetRTSocVersion(const map &options, map GELib::GetInstance() { return instancePtr_; } void GELib::RollbackInit() { - if (engineManager_.init_flag_) { - (void)engineManager_.Finalize(); + if (DNNEngineManager::GetInstance().init_flag_) { + (void)DNNEngineManager::GetInstance().Finalize(); } - if (opsManager_.init_flag_) { - (void)opsManager_.Finalize(); + if (OpsKernelManager::GetInstance().init_flag_) { + (void)OpsKernelManager::GetInstance().Finalize(); } MemManager::Instance().Finalize(); HostMemManager::Instance().Finalize(); diff --git a/ge/init/gelib.h b/ge/init/gelib.h index 226dd4c8..8322142a 100644 --- a/ge/init/gelib.h +++ b/ge/init/gelib.h @@ -53,10 +53,10 @@ class GE_FUNC_VISIBILITY GELib { Status Finalize(); // get DNNEngineManager object - DNNEngineManager &DNNEngineManagerObj() { return engineManager_; } + DNNEngineManager &DNNEngineManagerObj() { return DNNEngineManager::GetInstance(); } // get OpsKernelManager object - OpsKernelManager &OpsKernelManagerObj() { return opsManager_; } + OpsKernelManager &OpsKernelManagerObj() { return OpsKernelManager::GetInstance(); } // get Initial flag bool InitFlag() const { return init_flag_; } @@ -84,8 +84,6 @@ class GE_FUNC_VISIBILITY GELib { void SetDumpModelOptions(const map &options); void SetOpDebugOptions(const map &options); - DNNEngineManager engineManager_; - OpsKernelManager opsManager_; std::mutex status_mutex_; bool init_flag_ = false; Options options_; diff --git a/ge/opskernel_manager/ops_kernel_builder_manager.cc b/ge/opskernel_manager/ops_kernel_builder_manager.cc index 9f981302..736e620c 100644 --- a/ge/opskernel_manager/ops_kernel_builder_manager.cc +++ b/ge/opskernel_manager/ops_kernel_builder_manager.cc @@ -154,12 +154,16 @@ Status OpsKernelBuilderManager::CalcOpRunningParam(Node &node) const { return SUCCESS; } -Status OpsKernelBuilderManager::GenerateTask(const Node &node, - RunContext &context, - std::vector &tasks) const { +Status OpsKernelBuilderManager::GenerateTask(const Node &node, RunContext &context, std::vector &tasks, + bool atomic_engine_flag) const { auto op_desc = node.GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - const std::string &lib_name = op_desc->GetOpKernelLibName(); + std::string lib_name; + if (atomic_engine_flag) { + lib_name = op_desc->GetOpKernelLibName(); + } else { + (void)AttrUtils::GetStr(op_desc, ATTR_NAME_COMPOSITE_ENGINE_KERNEL_LIB_NAME, lib_name); + } auto it = ops_kernel_builders_.find(lib_name); if (it == ops_kernel_builders_.end()) { GELOGE(INTERNAL_ERROR, "[Find][LibName]fail for libName = %s, node:%s", lib_name.c_str(), diff --git a/ge/opskernel_manager/ops_kernel_builder_manager.h b/ge/opskernel_manager/ops_kernel_builder_manager.h index 8e1dec28..d117a068 100644 --- a/ge/opskernel_manager/ops_kernel_builder_manager.h +++ b/ge/opskernel_manager/ops_kernel_builder_manager.h @@ -43,8 +43,8 @@ class GE_FUNC_VISIBILITY OpsKernelBuilderManager { Status CalcOpRunningParam(Node &node) const; - Status GenerateTask(const Node &node, RunContext &context, - std::vector &tasks) const; + Status GenerateTask(const Node &node, RunContext &context, std::vector &tasks, + bool atomic_engine_flag = true) const; private: OpsKernelBuilderManager() = default; diff --git a/ge/opskernel_manager/ops_kernel_manager.cc b/ge/opskernel_manager/ops_kernel_manager.cc index 60958872..88c9192d 100644 --- a/ge/opskernel_manager/ops_kernel_manager.cc +++ b/ge/opskernel_manager/ops_kernel_manager.cc @@ -24,6 +24,7 @@ const char *const kInitialize = "Initialize"; const char *const kGetOpsKernelInfoStores = "GetOpsKernelInfoStores"; const char *const kGetGraphOptimizerObjs = "GetGraphOptimizerObjs"; const char *const kFinalize = "Finalize"; +const char *const kGetCompositeEngines = "GetCompositeEngines"; std::mutex ops_kernel_info_mutex; } // namespace @@ -35,9 +36,19 @@ OpsKernelManager::OpsKernelManager() OpsKernelManager::~OpsKernelManager() { graph_optimizers_.clear(); ops_kernel_store_.clear(); + atomic_graph_optimizers_.clear(); + composite_graph_optimizers_.clear(); + atomic_graph_optimizers_by_priority_.clear(); + atomic_first_optimizers_by_priority_.clear(); + composite_engines_.clear(); ops_kernel_info_.clear(); } +OpsKernelManager &OpsKernelManager::GetInstance() { + static OpsKernelManager instance; + return instance; +} + Status OpsKernelManager::Initialize(const map &options_const) { if (init_flag_) { GELOGW("OpsKernelManager has been initialized."); @@ -70,53 +81,48 @@ Status OpsKernelManager::Initialize(const map &options_const) { GELOGI("OPTION_EXEC_EXTERN_PLUGIN_PATH=%s.", extern_engine_path.c_str()); op_tiling_manager_.LoadSo(); - ret = plugin_manager_.LoadSo(extern_engine_path, func_check_list); - if (ret == SUCCESS) { - initialize_ = options; - Status rst0 = plugin_manager_.InvokeAll &, Status>(kInitialize, initialize_); - if (rst0 == FAILED) { - GELOGE(GE_OPS_GET_NO_VALID_SO, "[Invoke][OpsKernelInfo]PluginManager InvokeAll failed."); - REPORT_INNER_ERROR("E19999", "PluginManager InvokeAll failed."); - return GE_OPS_GET_NO_VALID_SO; - } - Status rst1 = - plugin_manager_.InvokeAll &>(kGetOpsKernelInfoStores, ops_kernel_store_); - if (rst1 != SUCCESS) { - GELOGW("Initialize OpsKernelInfo failed."); - } - Status rst2 = - plugin_manager_.InvokeAll &>(kGetGraphOptimizerObjs, graph_optimizers_); - if (rst2 != SUCCESS) { - GELOGW("Initialize GraphOptimizerObjs failed."); - } - - ret = CheckPluginPtr(); - if (ret != SUCCESS) { - return ret; - } - ret = InitOpKernelInfoStores(options); - if (ret != SUCCESS) { - return ret; - } - InitOpsKernelInfo(); - ret = InitGraphOptimzers(options); - if (ret != SUCCESS) { - return ret; - } - ret = InitGraphOptimizerPriority(); - if ((ret != SUCCESS)) { - GELOGE(ret, "[Init][GraphOptimizerPriority] failed."); - REPORT_CALL_ERROR("E19999", "InitGraphOptimizerPriority failed."); - return ret; - } - init_flag_ = true; - return SUCCESS; - } else { + if (ret != SUCCESS) { GELOGE(ret, "[Check][SoFile] not find any valid so file."); REPORT_INNER_ERROR("E19999", "OpsKernelManager::Initialize failed for not find any valid so file."); return ret; } + + initialize_ = options; + if (plugin_manager_.InvokeAll &, Status>(kInitialize, initialize_) == FAILED) { + GELOGE(GE_OPS_GET_NO_VALID_SO, "[Invoke][OpsKernelInfo]PluginManager InvokeAll failed."); + REPORT_INNER_ERROR("E19999", "PluginManager InvokeAll failed."); + return GE_OPS_GET_NO_VALID_SO; + } + if (plugin_manager_.InvokeAll &>(kGetOpsKernelInfoStores, + ops_kernel_store_) != SUCCESS) { + GELOGW("Initialize OpsKernelInfo failed."); + } + if (plugin_manager_.InvokeAll &>(kGetGraphOptimizerObjs, + graph_optimizers_) != SUCCESS) { + GELOGW("Initialize GraphOptimizerObjs failed."); + } + plugin_manager_. + OptionalInvokeAll> &, std::map &>( + kGetCompositeEngines, composite_engines_, composite_engine_kernel_lib_names_); + + ret = CheckPluginPtr(); + if (ret != SUCCESS) { + return ret; + } + ret = InitOpKernelInfoStores(options); + if (ret != SUCCESS) { + return ret; + } + InitOpsKernelInfo(); + ret = InitGraphOptimizers(options); + if (ret != SUCCESS) { + return ret; + } + ClassifyGraphOptimizers(); + InitGraphOptimizerPriority(); + init_flag_ = true; + return SUCCESS; } void OpsKernelManager::GetExternalEnginePath(std::string &extern_engine_path, @@ -264,7 +270,7 @@ void OpsKernelManager::InitOpsKernelInfo() { REPORT_INNER_ERROR("E19999", "InitOpsKernelInfo failed for new GELib."); return; } - // sort opinfo of ops_kernel_info_ + // sort op_info of ops_kernel_info_ for (auto &it : ops_kernel_info_) { if (it.second.empty()) { continue; @@ -293,31 +299,24 @@ void OpsKernelManager::InitOpsKernelInfo() { GELOGI("Init opsKernelInfo finished, size is %zu", ops_kernel_info_.size()); } -Status OpsKernelManager::InitGraphOptimzers(const map &options) { +Status OpsKernelManager::InitGraphOptimizers(const map &options) { GELOGI("Init graph optimizers options count %zu", options.size()); for (const auto &option : options) { GELOGI("Init graph optimizers option %s: %s", option.first.c_str(), option.second.c_str()); } - GELOGI("The number of GraphOptimzerObjs are %zu.", graph_optimizers_.size()); + GELOGI("The number of GraphOptimizerObjs are %zu.", graph_optimizers_.size()); for (const auto &it : graph_optimizers_) { - GELOGI("GraphOptimzer name: %s.", (it.first).c_str()); + GELOGI("GraphOptimizer name: %s.", (it.first).c_str()); GraphOptimizerAttribute attrs; GE_CHK_STATUS_RET(it.second->GetAttributes(attrs)) - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Get][GELib]malloc instance_ptr failed."); - REPORT_INNER_ERROR("E19999", "InitGraphOptimzers failed for new GELib."); - return GE_CLI_GE_NOT_INITIALIZED; - } - if (!instance_ptr->DNNEngineManagerObj().IsEngineRegistered(attrs.engineName)) { + if (!DNNEngineManager::GetInstance().IsEngineRegistered(attrs.engineName)) { GELOGW("Engine: %s is not registered.", attrs.engineName.c_str()); continue; } - Status ret = it.second->Initialize(options); - if (ret != SUCCESS) { - GELOGE(GE_OPS_GRAPH_OPTIMIZER_INIT_FAILED, - "[Init][GraphOptimzer]GraphOptimzer: %s initialize failed.", (it.first).c_str()); - REPORT_CALL_ERROR("E19999", "InitGraphOptimzers failed. %s initialize failed.", (it.first).c_str()); + if (it.second->Initialize(options) != SUCCESS) { + GELOGE(GE_OPS_GRAPH_OPTIMIZER_INIT_FAILED, + "[Init][GraphOptimizer] GraphOptimizer: %s initialize failed.", (it.first).c_str()); + REPORT_CALL_ERROR("E19999", "InitGraphOptimizers failed. %s initialize failed.", (it.first).c_str()); return GE_OPS_GRAPH_OPTIMIZER_INIT_FAILED; } } @@ -340,11 +339,11 @@ Status OpsKernelManager::Finalize() { } } for (auto iter = graph_optimizers_.begin(); iter != graph_optimizers_.end(); ++iter) { - GELOGI("GraphOptimzers finalize, name: %s.", (iter->first).c_str()); + GELOGI("GraphOptimizer finalize, name: %s.", (iter->first).c_str()); Status status = iter->second->Finalize(); if (status != SUCCESS) { - GELOGE(status, "[Check][Status]GraphOptimzers finalize failed, name: %s.", (iter->first).c_str()); - REPORT_CALL_ERROR("E19999", "GraphOptimzers finalize failed, name: %s.", (iter->first).c_str()); + GELOGE(status, "[Check][Status] GraphOptimizer finalize failed, name: %s.", (iter->first).c_str()); + REPORT_CALL_ERROR("E19999", "GraphOptimizer finalize failed, name: %s.", (iter->first).c_str()); return status; } } @@ -398,20 +397,16 @@ const map &OpsKernelManager::GetAllOpsKernelInfoS const map &OpsKernelManager::GetAllGraphOptimizerObjs() const { return graph_optimizers_; } -const vector> &OpsKernelManager::GetAllGraphOptimizerObjsByPriority() const { - return graph_optimizers_by_priority_; -} - void OpsKernelManager::GetGraphOptimizerByEngine(const std::string &engine_name, vector &graph_optimizer) { for (const auto &it : graph_optimizers_) { GraphOptimizerAttribute attrs; if (it.second->GetAttributes(attrs) != SUCCESS) { - GELOGW("Get GraphOptimzer name: %s attributes failed.", (it.first).c_str()); + GELOGW("Get GraphOptimizer name: %s attributes failed.", (it.first).c_str()); continue; } if (attrs.engineName == engine_name) { - GELOGD("GetGraphOptimizerByEngine GraphOptimzer name: %s, engineName: %s", (it.first).c_str(), + GELOGD("GetGraphOptimizerByEngine GraphOptimizer name: %s, engineName: %s", (it.first).c_str(), attrs.engineName.c_str()); graph_optimizer.push_back(it.second); } @@ -428,39 +423,64 @@ bool OpsKernelManager::GetEnableAICPUFlag() const { return enable_aicpu_flag_; } bool OpsKernelManager::GetEnablePluginFlag() const { return (enable_fe_flag_ || enable_aicpu_flag_); } -Status OpsKernelManager::InitGraphOptimizerPriority() { +void OpsKernelManager::ClassifyGraphOptimizers() { + if (composite_engines_.empty()) { + GELOGI("No composite engine registers"); + atomic_graph_optimizers_ = graph_optimizers_; + composite_graph_optimizers_.clear(); + return; + } + for (const auto &item : graph_optimizers_) { + GraphOptimizerAttribute attrs; + if (item.second->GetAttributes(attrs) != SUCCESS) { + GELOGW("Get GraphOptimizer attributes failed, name: %s.", (item.first).c_str()); + continue; + } + if (composite_engines_.find(attrs.engineName) != composite_engines_.end()) { + GELOGI("Engine of optimizer %s is %s, which is composited.", item.first.c_str(), attrs.engineName.c_str()); + composite_graph_optimizers_.emplace(item); + } else { + GELOGI("Engine of optimizer %s is %s, which is atomic.", item.first.c_str(), attrs.engineName.c_str()); + atomic_graph_optimizers_.emplace(item); + } + } +} + +void OpsKernelManager::InitGraphOptimizerPriority() { string priority_conf_path = "plugin/opskernel/optimizer_priority.pbtxt"; string path = PluginManager::GetPath(); path.append(priority_conf_path); optimizers::Priority optimizerPriority; - bool ret = ReadProtoFromText(path.c_str(), &optimizerPriority); - if (!ret) { + if (!ReadProtoFromText(path.c_str(), &optimizerPriority)) { GELOGW("Read priority file failed. Follow loading sequence."); - return SUCCESS; + return; } auto priorities = optimizerPriority.optimizer(); if (priorities.empty()) { GELOGI("No priority file config. Follow loading sequence."); - return SUCCESS; + return; } // sort optimizer map by priority std::stringstream priority_seq; for (const auto optimizer_name : priorities) { - auto name_to_optimizer_pair = graph_optimizers_.find(optimizer_name); - if (name_to_optimizer_pair != graph_optimizers_.end()) { - graph_optimizers_by_priority_.emplace_back(*name_to_optimizer_pair); + auto name_to_optimizer_pair = atomic_graph_optimizers_.find(optimizer_name); + if (name_to_optimizer_pair != atomic_graph_optimizers_.end()) { + atomic_graph_optimizers_by_priority_.emplace_back(*name_to_optimizer_pair); priority_seq << optimizer_name.c_str() << ' '; } else { GELOGW("Unknown optimizer %s show up in priority config file. Please check.", optimizer_name.c_str()); } } - GELOGI("Graph Optimizers priority initialized. The sequence will follow : %s.", priority_seq.str().c_str()); - return SUCCESS; + GELOGI("Atomic graph Optimizers priority initialized. The sequence will follow : %s.", priority_seq.str().c_str()); + atomic_first_optimizers_by_priority_ = atomic_graph_optimizers_by_priority_; + for (const auto &item : composite_graph_optimizers_) { + atomic_first_optimizers_by_priority_.emplace_back(std::make_pair(item.first, item.second)); + } } Status OpsKernelManager::FinalizeOpsKernel() { - GELOGI("ge invoke ops kernal finalize."); + GELOGI("ge invoke ops kernel finalize."); Status ret = plugin_manager_.InvokeAll(kFinalize); if (ret != SUCCESS) { GELOGE(ret, "[Finalize][Check][Status] invoke Fe finalize failed."); diff --git a/ge/opskernel_manager/ops_kernel_manager.h b/ge/opskernel_manager/ops_kernel_manager.h index 5a72dc50..a9f041ff 100644 --- a/ge/opskernel_manager/ops_kernel_manager.h +++ b/ge/opskernel_manager/ops_kernel_manager.h @@ -18,6 +18,7 @@ #define GE_OPSKERNEL_MANAGER_OPS_KERNEL_MANAGER_H_ #include +#include #include #include #include @@ -44,7 +45,7 @@ using OpsKernelInfoStorePtr = std::shared_ptr; class GE_FUNC_VISIBILITY OpsKernelManager { public: friend class GELib; - + static OpsKernelManager &GetInstance(); // get opsKernelInfo by type const vector &GetOpsKernelInfo(const string &op_type); @@ -61,7 +62,17 @@ class GE_FUNC_VISIBILITY OpsKernelManager { const map &GetAllGraphOptimizerObjs() const; // get all graph_optimizer by priority - const vector> &GetAllGraphOptimizerObjsByPriority() const; + const vector> &GetAllGraphOptimizerObjsByPriority() const { + return atomic_first_optimizers_by_priority_; + } + + const map> &GetCompositeEngines() const { + return composite_engines_; + } + + const map &GetCompositeEngineKernelLibNames() const { + return composite_engine_kernel_lib_names_; + } // get subgraphOptimizer by engine name void GetGraphOptimizerByEngine(const std::string &engine_name, vector &graph_optimizer); @@ -93,15 +104,15 @@ class GE_FUNC_VISIBILITY OpsKernelManager { void InitOpsKernelInfo(); - Status InitGraphOptimzers(const map &options); + Status InitGraphOptimizers(const map &options); Status InitPluginOptions(const map &options); Status ParsePluginOptions(const map &options, const string &plugin_name, bool &enable_flag); - Status LoadGEGraphOptimizer(map& graphOptimizer); + void ClassifyGraphOptimizers(); - Status InitGraphOptimizerPriority(); + void InitGraphOptimizerPriority(); // Finalize other ops kernel resource Status FinalizeOpsKernel(); @@ -112,8 +123,18 @@ class GE_FUNC_VISIBILITY OpsKernelManager { map ops_kernel_store_{}; // graph_optimizer map graph_optimizers_{}; - // ordered graph_optimzer - vector> graph_optimizers_by_priority_{}; + // composite_graph_optimizer + map composite_graph_optimizers_{}; + // atomic_graph_optimizer + map atomic_graph_optimizers_{}; + // ordered atomic_graph_optimizer + vector> atomic_graph_optimizers_by_priority_{}; + // atomic_first graph_optimizer + vector> atomic_first_optimizers_by_priority_{}; + // {composite_engine, {containing atomic_engine_names}} + map> composite_engines_{}; + // {composite_engine, composite_engine_kernel_lib_name} + map composite_engine_kernel_lib_names_{}; // opsKernelInfo map> ops_kernel_info_{}; diff --git a/ge/plugin/engine/dnnengines.cc b/ge/plugin/engine/dnnengines.cc index 5b06310c..45c7e25c 100755 --- a/ge/plugin/engine/dnnengines.cc +++ b/ge/plugin/engine/dnnengines.cc @@ -16,9 +16,7 @@ #include "plugin/engine/dnnengines.h" -#include #include -#include namespace ge { AICoreDNNEngine::AICoreDNNEngine(const std::string &engine_name) { @@ -29,14 +27,6 @@ AICoreDNNEngine::AICoreDNNEngine(const std::string &engine_name) { engine_attribute_.engine_output_format = FORMAT_RESERVED; } -AICoreDNNEngine::AICoreDNNEngine(const DNNEngineAttribute &attrs) { engine_attribute_ = attrs; } - -Status AICoreDNNEngine::Initialize(const std::map &options) { return SUCCESS; } - -Status AICoreDNNEngine::Finalize() { return SUCCESS; } - -void AICoreDNNEngine::GetAttributes(DNNEngineAttribute &attrs) const { attrs = engine_attribute_; } - VectorCoreDNNEngine::VectorCoreDNNEngine(const std::string &engine_name) { engine_attribute_.engine_name = engine_name; engine_attribute_.compute_cost = COST_1; @@ -45,14 +35,6 @@ VectorCoreDNNEngine::VectorCoreDNNEngine(const std::string &engine_name) { engine_attribute_.engine_output_format = FORMAT_RESERVED; } -VectorCoreDNNEngine::VectorCoreDNNEngine(const DNNEngineAttribute &attrs) { engine_attribute_ = attrs; } - -Status VectorCoreDNNEngine::Initialize(const std::map &options) { return SUCCESS; } - -Status VectorCoreDNNEngine::Finalize() { return SUCCESS; } - -void VectorCoreDNNEngine::GetAttributes(DNNEngineAttribute &attrs) const { attrs = engine_attribute_; } - AICpuDNNEngine::AICpuDNNEngine(const std::string &engine_name) { engine_attribute_.engine_name = engine_name; engine_attribute_.compute_cost = COST_2; @@ -61,14 +43,6 @@ AICpuDNNEngine::AICpuDNNEngine(const std::string &engine_name) { engine_attribute_.engine_output_format = FORMAT_RESERVED; } -AICpuDNNEngine::AICpuDNNEngine(const DNNEngineAttribute &attrs) { engine_attribute_ = attrs; } - -Status AICpuDNNEngine::Initialize(const std::map &options) { return SUCCESS; } - -Status AICpuDNNEngine::Finalize() { return SUCCESS; } - -void AICpuDNNEngine::GetAttributes(DNNEngineAttribute &attrs) const { attrs = engine_attribute_; } - AICpuTFDNNEngine::AICpuTFDNNEngine(const std::string &engine_name) { engine_attribute_.engine_name = engine_name; engine_attribute_.compute_cost = COST_3; @@ -77,28 +51,12 @@ AICpuTFDNNEngine::AICpuTFDNNEngine(const std::string &engine_name) { engine_attribute_.engine_output_format = FORMAT_RESERVED; } -AICpuTFDNNEngine::AICpuTFDNNEngine(const DNNEngineAttribute &attrs) { engine_attribute_ = attrs; } - -Status AICpuTFDNNEngine::Initialize(const std::map &options) { return SUCCESS; } - -Status AICpuTFDNNEngine::Finalize() { return SUCCESS; } - -void AICpuTFDNNEngine::GetAttributes(DNNEngineAttribute &attrs) const { attrs = engine_attribute_; } - GeLocalDNNEngine::GeLocalDNNEngine(const std::string &engine_name) { engine_attribute_.engine_name = engine_name; engine_attribute_.engine_input_format = FORMAT_RESERVED; engine_attribute_.engine_output_format = FORMAT_RESERVED; } -GeLocalDNNEngine::GeLocalDNNEngine(const DNNEngineAttribute &attrs) { engine_attribute_ = attrs; } - -Status GeLocalDNNEngine::Initialize(const std::map &options) { return SUCCESS; } - -Status GeLocalDNNEngine::Finalize() { return SUCCESS; } - -void GeLocalDNNEngine::GetAttributes(DNNEngineAttribute &attrs) const { attrs = engine_attribute_; } - HostCpuDNNEngine::HostCpuDNNEngine(const std::string &engine_name) { engine_attribute_.engine_name = engine_name; engine_attribute_.compute_cost = COST_10; @@ -107,39 +65,21 @@ HostCpuDNNEngine::HostCpuDNNEngine(const std::string &engine_name) { engine_attribute_.engine_output_format = FORMAT_RESERVED; } -HostCpuDNNEngine::HostCpuDNNEngine(const DNNEngineAttribute &attrs) { engine_attribute_ = attrs; } - -Status HostCpuDNNEngine::Initialize(const std::map &options) { return SUCCESS; } - -Status HostCpuDNNEngine::Finalize() { return SUCCESS; } - -void HostCpuDNNEngine::GetAttributes(DNNEngineAttribute &attrs) const { attrs = engine_attribute_; } - RtsDNNEngine::RtsDNNEngine(const std::string &engine_name) { engine_attribute_.engine_name = engine_name; engine_attribute_.engine_input_format = FORMAT_RESERVED; engine_attribute_.engine_output_format = FORMAT_RESERVED; } -RtsDNNEngine::RtsDNNEngine(const DNNEngineAttribute &attrs) { engine_attribute_ = attrs; } - -Status RtsDNNEngine::Initialize(const std::map &options) { return SUCCESS; } - -Status RtsDNNEngine::Finalize() { return SUCCESS; } - -void RtsDNNEngine::GetAttributes(DNNEngineAttribute &attrs) const { attrs = engine_attribute_; } - HcclDNNEngine::HcclDNNEngine(const std::string &engine_name) { engine_attribute_.engine_name = engine_name; engine_attribute_.engine_input_format = FORMAT_RESERVED; engine_attribute_.engine_output_format = FORMAT_RESERVED; } -HcclDNNEngine::HcclDNNEngine(const DNNEngineAttribute &attrs) { engine_attribute_ = attrs; } - -Status HcclDNNEngine::Initialize(const std::map &options) { return SUCCESS; } - -Status HcclDNNEngine::Finalize() { return SUCCESS; } - -void HcclDNNEngine::GetAttributes(DNNEngineAttribute &attrs) const { attrs = engine_attribute_; } +FftsPlusDNNEngine::FftsPlusDNNEngine(const std::string &engine_name) { + engine_attribute_.engine_name = engine_name; + engine_attribute_.engine_input_format = FORMAT_RESERVED; + engine_attribute_.engine_output_format = FORMAT_RESERVED; +} } // namespace ge diff --git a/ge/plugin/engine/dnnengines.h b/ge/plugin/engine/dnnengines.h index 829c83f1..22b27313 100644 --- a/ge/plugin/engine/dnnengines.h +++ b/ge/plugin/engine/dnnengines.h @@ -27,123 +27,66 @@ namespace ge { class GE_FUNC_VISIBILITY AICoreDNNEngine : public DNNEngine { public: - AICoreDNNEngine() = default; explicit AICoreDNNEngine(const std::string &engine_name); - explicit AICoreDNNEngine(const DNNEngineAttribute &attrs); - ~AICoreDNNEngine() = default; - - Status Initialize(const std::map &options); - Status Finalize(); - void GetAttributes(DNNEngineAttribute &attr) const; - - private: - DNNEngineAttribute engine_attribute_; + explicit AICoreDNNEngine(const DNNEngineAttribute &attrs) : DNNEngine(attrs) {} + ~AICoreDNNEngine() override = default; }; class GE_FUNC_VISIBILITY VectorCoreDNNEngine : public DNNEngine { public: - VectorCoreDNNEngine() = default; explicit VectorCoreDNNEngine(const std::string &engine_name); - explicit VectorCoreDNNEngine(const DNNEngineAttribute &attrs); - ~VectorCoreDNNEngine() = default; - - Status Initialize(const std::map &options); - Status Finalize(); - void GetAttributes(DNNEngineAttribute &attr) const; - - private: - DNNEngineAttribute engine_attribute_; + explicit VectorCoreDNNEngine(const DNNEngineAttribute &attrs) : DNNEngine(attrs) {} + ~VectorCoreDNNEngine() override = default; }; class GE_FUNC_VISIBILITY AICpuDNNEngine : public DNNEngine { public: - AICpuDNNEngine() = default; explicit AICpuDNNEngine(const std::string &engine_name); - explicit AICpuDNNEngine(const DNNEngineAttribute &attrs); - ~AICpuDNNEngine() = default; - - Status Initialize(const std::map &options); - Status Finalize(); - void GetAttributes(DNNEngineAttribute &attr) const; - - private: - DNNEngineAttribute engine_attribute_; + explicit AICpuDNNEngine(const DNNEngineAttribute &attrs) : DNNEngine(attrs) {} + ~AICpuDNNEngine() override = default; }; class GE_FUNC_VISIBILITY AICpuTFDNNEngine : public DNNEngine { public: - AICpuTFDNNEngine() = default; explicit AICpuTFDNNEngine(const std::string &engine_name); - explicit AICpuTFDNNEngine(const DNNEngineAttribute &attrs); - ~AICpuTFDNNEngine() = default; - - Status Initialize(const std::map &options); - Status Finalize(); - void GetAttributes(DNNEngineAttribute &attr) const; - - private: - DNNEngineAttribute engine_attribute_; + explicit AICpuTFDNNEngine(const DNNEngineAttribute &attrs) : DNNEngine(attrs) {} + ~AICpuTFDNNEngine() override = default; }; class GE_FUNC_VISIBILITY GeLocalDNNEngine : public DNNEngine { public: - GeLocalDNNEngine() = default; explicit GeLocalDNNEngine(const std::string &engine_name); - explicit GeLocalDNNEngine(const DNNEngineAttribute &attrs); - ~GeLocalDNNEngine() = default; - - Status Initialize(const std::map &options); - Status Finalize(); - void GetAttributes(DNNEngineAttribute &attr) const; - - private: - DNNEngineAttribute engine_attribute_; + explicit GeLocalDNNEngine(const DNNEngineAttribute &attrs) : DNNEngine(attrs) {} + ~GeLocalDNNEngine() override = default; }; class GE_FUNC_VISIBILITY HostCpuDNNEngine : public DNNEngine { public: - HostCpuDNNEngine() = default; explicit HostCpuDNNEngine(const std::string &engine_name); - explicit HostCpuDNNEngine(const DNNEngineAttribute &attrs); - ~HostCpuDNNEngine() = default; - - Status Initialize(const std::map &options); - Status Finalize(); - void GetAttributes(DNNEngineAttribute &attr) const; - -private: - DNNEngineAttribute engine_attribute_; + explicit HostCpuDNNEngine(const DNNEngineAttribute &attrs) : DNNEngine(attrs) {} + ~HostCpuDNNEngine() override = default; }; class GE_FUNC_VISIBILITY RtsDNNEngine : public DNNEngine { public: - RtsDNNEngine() = default; explicit RtsDNNEngine(const std::string &engine_name); - explicit RtsDNNEngine(const DNNEngineAttribute &attrs); - ~RtsDNNEngine() = default; - - Status Initialize(const std::map &options); - Status Finalize(); - void GetAttributes(DNNEngineAttribute &attr) const; - - private: - DNNEngineAttribute engine_attribute_; + explicit RtsDNNEngine(const DNNEngineAttribute &attrs) : DNNEngine(attrs) {} + ~RtsDNNEngine() override = default; }; class GE_FUNC_VISIBILITY HcclDNNEngine : public DNNEngine { public: - HcclDNNEngine() = default; explicit HcclDNNEngine(const std::string &engine_name); - explicit HcclDNNEngine(const DNNEngineAttribute &attrs); - ~HcclDNNEngine() = default; - - Status Initialize(const std::map &options); - Status Finalize(); - void GetAttributes(DNNEngineAttribute &attr) const; + explicit HcclDNNEngine(const DNNEngineAttribute &attrs) : DNNEngine(attrs) {} + ~HcclDNNEngine() override = default; +}; - private: - DNNEngineAttribute engine_attribute_; +class GE_FUNC_VISIBILITY FftsPlusDNNEngine : public DNNEngine { + public: + explicit FftsPlusDNNEngine(const std::string &engine_name); + explicit FftsPlusDNNEngine(const DNNEngineAttribute &attrs) : DNNEngine(attrs) {} + ~FftsPlusDNNEngine() override = default; }; } // namespace ge #endif // GE_PLUGIN_ENGINE_DNNENGINES_H_ diff --git a/ge/plugin/engine/engine_manage.cc b/ge/plugin/engine/engine_manage.cc index 0e129526..c38c63e5 100644 --- a/ge/plugin/engine/engine_manage.cc +++ b/ge/plugin/engine/engine_manage.cc @@ -63,7 +63,13 @@ void RegisterAiCoreEngine() { const std::string ai_core = "AIcoreEngine"; std::vector mem_type_aicore; mem_type_aicore.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); - DNNEngineAttribute attr_aicore = {ai_core, mem_type_aicore, COST_0, DEVICE, FORMAT_RESERVED, FORMAT_RESERVED}; + DNNEngineAttribute attr_aicore = { ai_core, + mem_type_aicore, + COST_0, + DEVICE, + FORMAT_RESERVED, + FORMAT_RESERVED, + true }; DNNEnginePtr aicore_engine_ptr = MakeShared(attr_aicore); if (aicore_engine_ptr == nullptr) { GELOGE(ge::FAILED, "[Register][AiCoreEngine] failed, as malloc shared_ptr failed."); @@ -79,8 +85,13 @@ void RegisterVectorEngine() { const std::string vector_core = "VectorEngine"; std::vector mem_type_aivcore; mem_type_aivcore.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); - DNNEngineAttribute attr_vector_core = {vector_core, mem_type_aivcore, COST_1, - DEVICE, FORMAT_RESERVED, FORMAT_RESERVED}; + DNNEngineAttribute attr_vector_core = { vector_core, + mem_type_aivcore, + COST_1, + DEVICE, + FORMAT_RESERVED, + FORMAT_RESERVED, + true }; DNNEnginePtr vectorcore_engine_ptr = MakeShared(attr_vector_core); if (vectorcore_engine_ptr == nullptr) { GELOGE(ge::FAILED, "[Register][VectorEngine] failed, as malloc shared_ptr failed."); @@ -97,7 +108,13 @@ void RegisterAiCpuEngine() { std::vector mem_type_aicpu; mem_type_aicpu.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); - DNNEngineAttribute attr_aicpu = {vm_aicpu, mem_type_aicpu, COST_2, DEVICE, FORMAT_RESERVED, FORMAT_RESERVED}; + DNNEngineAttribute attr_aicpu = { vm_aicpu, + mem_type_aicpu, + COST_2, + DEVICE, + FORMAT_RESERVED, + FORMAT_RESERVED, + true }; DNNEnginePtr vm_engine_ptr = MakeShared(attr_aicpu); if (vm_engine_ptr == nullptr) { @@ -115,8 +132,13 @@ void RegisterAiCpuTFEngine() { std::vector mem_type_aicpu_tf; mem_type_aicpu_tf.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); - DNNEngineAttribute attr_aicpu_tf = {vm_aicpu_tf, mem_type_aicpu_tf, COST_3, DEVICE, FORMAT_RESERVED, FORMAT_RESERVED}; - + DNNEngineAttribute attr_aicpu_tf = { vm_aicpu_tf, + mem_type_aicpu_tf, + COST_3, + DEVICE, + FORMAT_RESERVED, + FORMAT_RESERVED, + true }; DNNEnginePtr vm_engine_ptr = MakeShared(attr_aicpu_tf); if (vm_engine_ptr == nullptr) { GELOGE(ge::FAILED, "[Register][AiCpuTFEngine]make vm_engine_ptr failed"); @@ -133,7 +155,13 @@ void RegisterGeLocalEngine() { std::vector mem_type_ge_local; mem_type_ge_local.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); // GeLocal use minimum priority, set it as 9 - DNNEngineAttribute attr_ge_local = {vm_ge_local, mem_type_ge_local, COST_9, DEVICE, FORMAT_RESERVED, FORMAT_RESERVED}; + DNNEngineAttribute attr_ge_local = { vm_ge_local, + mem_type_ge_local, + COST_9, + DEVICE, + FORMAT_RESERVED, + FORMAT_RESERVED, + true }; DNNEnginePtr ge_local_engine = MakeShared(attr_ge_local); if (ge_local_engine == nullptr) { GELOGE(ge::FAILED, "[Register][GeLocalEngine] failed, as malloc shared_ptr failed."); @@ -150,8 +178,13 @@ void RegisterHostCpuEngine() { std::vector mem_type_host_cpu; mem_type_host_cpu.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); // HostCpu use minimum priority, set it as 10 - DNNEngineAttribute attr_host_cpu = {vm_host_cpu, mem_type_host_cpu, COST_10, - HOST, FORMAT_RESERVED, FORMAT_RESERVED}; + DNNEngineAttribute attr_host_cpu = { vm_host_cpu, + mem_type_host_cpu, + COST_10, + HOST, + FORMAT_RESERVED, + FORMAT_RESERVED, + true }; DNNEnginePtr host_cpu_engine = MakeShared(attr_host_cpu); if (host_cpu_engine == nullptr) { GELOGE(ge::FAILED, "[Register][HostCpuEngine] failed, as malloc shared_ptr failed."); @@ -167,7 +200,13 @@ void RegisterRtsEngine() { const std::string vm_rts = "DNN_VM_RTS"; std::vector mem_type_rts; mem_type_rts.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); - DNNEngineAttribute attr_rts = {vm_rts, mem_type_rts, COST_1, DEVICE, FORMAT_RESERVED, FORMAT_RESERVED}; + DNNEngineAttribute attr_rts = { vm_rts, + mem_type_rts, + COST_1, + DEVICE, + FORMAT_RESERVED, + FORMAT_RESERVED, + true }; DNNEnginePtr rts_engine = MakeShared(attr_rts); if (rts_engine == nullptr) { GELOGE(ge::FAILED, "[Register][RtsEngine] failed, as malloc shared_ptr failed."); @@ -183,7 +222,13 @@ void RegisterHcclEngine() { const std::string dnn_hccl = "DNN_HCCL"; std::vector mem_type_hccl; mem_type_hccl.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); - DNNEngineAttribute attr_hccl = {dnn_hccl, mem_type_hccl, COST_1, DEVICE, FORMAT_RESERVED, FORMAT_RESERVED}; + DNNEngineAttribute attr_hccl = { dnn_hccl, + mem_type_hccl, + COST_1, + DEVICE, + FORMAT_RESERVED, + FORMAT_RESERVED, + true }; DNNEnginePtr hccl_engine = MakeShared(attr_hccl); if (hccl_engine == nullptr) { GELOGE(ge::FAILED, "[Register][HcclEngine] failed, as malloc shared_ptr failed."); @@ -195,6 +240,28 @@ void RegisterHcclEngine() { } } +void RegisterFftsPlusEngine() { + const std::string dnn_ffts_plus = "ffts_plus"; + std::vector mem_type_ffts_plus; + mem_type_ffts_plus.emplace_back(GE_ENGINE_ATTR_MEM_TYPE_HBM); + DNNEngineAttribute attr_ffts_plus = { dnn_ffts_plus, + mem_type_ffts_plus, + COST_0, + DEVICE, + FORMAT_RESERVED, + FORMAT_RESERVED, + false }; + DNNEnginePtr ffts_plus_engine = MakeShared(attr_ffts_plus); + if (ffts_plus_engine == nullptr) { + GELOGE(ge::FAILED, "[Register][FftsPlusDNNEngine] failed, as malloc shared_ptr failed."); + REPORT_INNER_ERROR("E19999", "RegisterFftsPlusEngine failed for new DNNEnginePtr failed."); + return; + } + if (EngineManager::RegisterEngine(dnn_ffts_plus, ffts_plus_engine) != SUCCESS) { + GELOGW("register ffts_plus_engine failed"); + } +} + void GetDNNEngineObjs(std::map &engines) { RegisterAiCoreEngine(); RegisterVectorEngine(); @@ -204,6 +271,7 @@ void GetDNNEngineObjs(std::map &engines) { RegisterHostCpuEngine(); RegisterRtsEngine(); RegisterHcclEngine(); + RegisterFftsPlusEngine(); for (auto it = EngineManager::engine_map_->begin(); it != EngineManager::engine_map_->end(); ++it) { GELOGI("get engine %s from engine plugin.", it->first.c_str()); diff --git a/inc/framework/engine/dnnengine.h b/inc/framework/engine/dnnengine.h index 8a0f3b65..b5f02ebe 100644 --- a/inc/framework/engine/dnnengine.h +++ b/inc/framework/engine/dnnengine.h @@ -43,14 +43,31 @@ struct DNNEngineAttribute { // If engine input format must be specific, set this attribute, else set FORMAT_RESERVED Format engine_input_format; Format engine_output_format; + bool atomic_engine_flag; }; class GE_FUNC_VISIBILITY DNNEngine { public: + DNNEngine() = default; + explicit DNNEngine(const DNNEngineAttribute &attrs) { + engine_attribute_ = attrs; + } virtual ~DNNEngine() = default; - virtual Status Initialize(const std::map &options) = 0; - virtual Status Finalize() = 0; - virtual void GetAttributes(DNNEngineAttribute &attr) const = 0; + Status Initialize(const std::map &options) { + return SUCCESS; + } + Status Finalize() { + return SUCCESS; + } + void GetAttributes(DNNEngineAttribute &attr) const { + attr = engine_attribute_; + } + bool IsAtomic() const { + return engine_attribute_.atomic_engine_flag; + } + + protected: + DNNEngineAttribute engine_attribute_; }; } // namespace ge diff --git a/metadef b/metadef index a725349b..8f2c4395 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit a725349b65aef2940555af2ddb7b9461fbe0d5fd +Subproject commit 8f2c4395c346af026c470b47a7c52f2ab5b51f90 diff --git a/parser b/parser index 7a2daaa2..72d6fcd7 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit 7a2daaa2625505e1a15e1faa46c90df1a23dd6fa +Subproject commit 72d6fcd776ea2eba8000249fd02c8948042e9856 diff --git a/tests/depends/runtime/src/runtime_stub.cc b/tests/depends/runtime/src/runtime_stub.cc index 32df7552..b50dec98 100644 --- a/tests/depends/runtime/src/runtime_stub.cc +++ b/tests/depends/runtime/src/runtime_stub.cc @@ -538,6 +538,14 @@ rtError_t rtFftsTaskLaunch(rtFftsTaskInfo_t *fftsTaskInfo, rtStream_t stream) { return RT_ERROR_NONE; } +rtError_t rtGetAddrAndPrefCntWithHandle(void *handle, const void *devFunc, void **addr, uint32_t *prefetchCnt) { + return RT_ERROR_NONE; +} + +rtError_t rtFftsPlusTaskLaunch(rtFftsPlusTaskInfo_t *fftsPlusTaskInfo, rtStream_t stream) { + return RT_ERROR_NONE; +} + rtError_t rtKernelLaunchFwk(const char *opName, void *args, uint32_t argSize, uint32_t flags, rtStream_t rtStream) { return RT_ERROR_NONE; } diff --git a/tests/framework/cmake/graphengine.cmake b/tests/framework/cmake/graphengine.cmake index d83203b4..3a32c96f 100644 --- a/tests/framework/cmake/graphengine.cmake +++ b/tests/framework/cmake/graphengine.cmake @@ -45,7 +45,7 @@ file(GLOB_RECURSE METADEF_REGISTER_SRCS CONFIGURE_DEPENDS "${GE_CODE_DIR}/metadef/register/*.cpp" ) -file(GLOB_RECURSE PARSER_SRCS CONFIGURE_DEPENDS +file(GLOB_RECURSE PARSER_SRCS CONFIGURE_DEPENDS "${GE_CODE_DIR}/parser/parser/common/*.cc" ) @@ -114,7 +114,6 @@ list(APPEND INCLUDE_DIRECTORIES list(APPEND STUB_LIBS c_sec slog_stub - cce_ge_stub runtime_stub profiler_stub hccl_stub @@ -226,7 +225,7 @@ add_custom_command( add_library(graphengine STATIC ${PARSER_SRCS} ${GE_SRCS}) target_include_directories(graphengine - PUBLIC + PUBLIC "${INCLUDE_DIRECTORIES}" "${GE_CODE_DIR}/ge/host_cpu_engine" ) diff --git a/tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_repo.cc b/tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_repo.cc index 19dfa4a5..4302183d 100644 --- a/tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_repo.cc +++ b/tests/framework/ge_graph_dsl/src/op_desc/op_desc_cfg_repo.cc @@ -16,7 +16,6 @@ #include "ge_graph_dsl/op_desc/op_desc_cfg_repo.h" #include "framework/common/types.h" -#include "graph/debug/ge_attr_define.h" #include "ge_graph_dsl/op_desc/op_desc_cfg.h" GE_NS_BEGIN @@ -39,6 +38,8 @@ static std::map cfg_repo{OP_CFG(DATA, 1, 1, FORMAT_NCHW, DT_F OP_CFG(EXIT, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), OP_CFG(NEXTITERATION, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), OP_CFG(NETOUTPUT, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(CONSTANTOP, 0, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), + OP_CFG(GETNEXT, 0, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), OP_CFG(VARIABLE, 1, 1)}; } // namespace diff --git a/tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc b/tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc index 533e8198..4873060d 100644 --- a/tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc +++ b/tests/framework/ge_graph_dsl/tests/stub/optype_stub.cc @@ -15,7 +15,6 @@ */ #include "framework/common/types.h" -#include "graph/debug/ge_attr_define.h" #include "ge_graph_dsl/ge.h" GE_NS_BEGIN @@ -32,9 +31,10 @@ REGISTER_OPTYPE_DEFINE(ADD, "Add"); REGISTER_OPTYPE_DEFINE(WHILE, "While"); REGISTER_OPTYPE_DEFINE(ENTER, "Enter"); REGISTER_OPTYPE_DEFINE(MERGE, "Merge"); -REGISTER_OPTYPE_DEFINE(LOOPCOND, "Loopcond"); +REGISTER_OPTYPE_DEFINE(LOOPCOND, "LoopCond"); REGISTER_OPTYPE_DEFINE(SWITCH, "Switch"); REGISTER_OPTYPE_DEFINE(EXIT, "Exit"); -REGISTER_OPTYPE_DEFINE(NEXTITERATION, "Nextiteration"); +REGISTER_OPTYPE_DEFINE(NEXTITERATION, "NextIteration"); +REGISTER_OPTYPE_DEFINE(GETNEXT, "GetNext"); GE_NS_END diff --git a/tests/framework/ge_running_env/include/ge_running_env/env_installer.h b/tests/framework/ge_running_env/include/ge_running_env/env_installer.h index 79b65137..29420471 100644 --- a/tests/framework/ge_running_env/include/ge_running_env/env_installer.h +++ b/tests/framework/ge_running_env/include/ge_running_env/env_installer.h @@ -20,6 +20,7 @@ #include "fake_ns.h" #include "opskernel_manager/ops_kernel_manager.h" #include "register/ops_kernel_builder_registry.h" +#include "plugin/engine/engine_manage.h" FAKE_NS_BEGIN @@ -27,6 +28,9 @@ struct EnvInstaller { virtual void InstallTo(std::map&) const {} virtual void InstallTo(std::map&) const {} virtual void InstallTo(std::map&) const {} + virtual void InstallTo(std::map>&) const {} + virtual void InstallTo(std::map&) const {} + virtual void InstallTo(std::map&) const {} virtual void Install() const {} }; diff --git a/tests/framework/ge_running_env/include/ge_running_env/fake_atomic_optimizer.h b/tests/framework/ge_running_env/include/ge_running_env/fake_atomic_optimizer.h new file mode 100644 index 00000000..b47597a6 --- /dev/null +++ b/tests/framework/ge_running_env/include/ge_running_env/fake_atomic_optimizer.h @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_4DCD71AA72F8492D8594C49094B92528 +#define INC_4DCD71AA72F8492D8594C49094B92528 + +#include "ge_running_env/fake_ns.h" +#include "common/optimizer/graph_optimizer.h" + +FAKE_NS_BEGIN + +struct FakeAtomicOptimizer : GraphOptimizer { + explicit FakeAtomicOptimizer(const std::string &engine_name) : engine_name_(engine_name) {} + private: + Status Initialize(const map &options) override; + Status Finalize() override; + Status OptimizeOriginalGraph(ComputeGraph &graph) override; + Status OptimizeFusedGraph(ComputeGraph &graph) override; + Status OptimizeWholeGraph(ComputeGraph &graph) override; + Status GetAttributes(GraphOptimizerAttribute &attrs) const override; + + protected: + std::string engine_name_; +}; + +FAKE_NS_END +#endif diff --git a/tests/framework/ge_running_env/include/ge_running_env/fake_composite_engine.h b/tests/framework/ge_running_env/include/ge_running_env/fake_composite_engine.h new file mode 100644 index 00000000..189957ee --- /dev/null +++ b/tests/framework/ge_running_env/include/ge_running_env/fake_composite_engine.h @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INC_897A92FE9414452E8912FC7204E018A8 +#define INC_897A92FE9414452E8912FC7204E018A8 + +#include "ge_running_env/fake_ns.h" +#include "ge_running_env/fake_engine.h" +#include "common/optimizer/graph_optimizer.h" + +FAKE_NS_BEGIN + +struct FakeCompositeEngine : FakeEngine { + FakeCompositeEngine(const std::string &name, const std::set &sub_engines) : FakeEngine(name), + sub_engines_(sub_engines) {} + + private: + void InstallTo(std::map&) const override; + void InstallTo(std::map&) const override; + void InstallTo(std::map>&) const override; + void InstallTo(std::map&) const override; + void InstallTo(std::map&) const override; + private: + std::set sub_engines_; +}; + +FAKE_NS_END +#endif diff --git a/tests/framework/ge_running_env/include/ge_running_env/fake_composite_optimizer.h b/tests/framework/ge_running_env/include/ge_running_env/fake_composite_optimizer.h new file mode 100644 index 00000000..c0806257 --- /dev/null +++ b/tests/framework/ge_running_env/include/ge_running_env/fake_composite_optimizer.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_8E85C90AF30E4DBF9EF50467846EDA88 +#define INC_8E85C90AF30E4DBF9EF50467846EDA88 + +#include "ge_running_env/fake_ns.h" +#include "ge_running_env/fake_atomic_optimizer.h" + +FAKE_NS_BEGIN + +struct FakeCompositeOptimizer : FakeAtomicOptimizer { + public: + explicit FakeCompositeOptimizer(const std::string &engine_name) : FakeAtomicOptimizer(engine_name) {} + private: + Status OptimizeFusedGraph(ComputeGraph &graph) override; + static uint32_t thread_scope_id_; +}; + +FAKE_NS_END +#endif diff --git a/tests/framework/ge_running_env/include/ge_running_env/fake_engine.h b/tests/framework/ge_running_env/include/ge_running_env/fake_engine.h index c4207223..75fac87b 100644 --- a/tests/framework/ge_running_env/include/ge_running_env/fake_engine.h +++ b/tests/framework/ge_running_env/include/ge_running_env/fake_engine.h @@ -39,14 +39,16 @@ struct FakeEngine : EnvInstaller { private: void InstallTo(std::map&) const override; void InstallTo(std::map&) const override; + void InstallTo(std::map&) const override; - private: template void InstallFor(std::map& maps, const std::map>&) const; - private: + protected: std::string engine_name_; std::set info_store_names_; + + private: std::map custom_builders_; std::map custom_info_stores_; }; diff --git a/tests/framework/ge_running_env/include/ge_running_env/ge_running_env_faker.h b/tests/framework/ge_running_env/include/ge_running_env/ge_running_env_faker.h index 6d325c6a..93543363 100644 --- a/tests/framework/ge_running_env/include/ge_running_env/ge_running_env_faker.h +++ b/tests/framework/ge_running_env/include/ge_running_env/ge_running_env_faker.h @@ -38,6 +38,9 @@ struct GeRunningEnvFaker { std::map &ops_kernel_info_stores_; std::map &ops_kernel_optimizers_; std::map &ops_kernel_builders_; + std::map> &composite_engines_; + std::map &composite_engine_kernel_lib_names_; + std::map &engine_map_; }; FAKE_NS_END diff --git a/tests/framework/ge_running_env/src/engine/fake_atomic_optimizer.cc b/tests/framework/ge_running_env/src/engine/fake_atomic_optimizer.cc new file mode 100644 index 00000000..8f7db3d5 --- /dev/null +++ b/tests/framework/ge_running_env/src/engine/fake_atomic_optimizer.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_running_env/fake_atomic_optimizer.h" + +FAKE_NS_BEGIN + +Status FakeAtomicOptimizer::Initialize(const map &options) { + return SUCCESS; +}; + +Status FakeAtomicOptimizer::Finalize() { + return SUCCESS; +} + +Status FakeAtomicOptimizer::OptimizeOriginalGraph(ComputeGraph &graph) { + return SUCCESS; +} + +Status FakeAtomicOptimizer::OptimizeFusedGraph(ComputeGraph& graph) { + return SUCCESS; +} + +Status FakeAtomicOptimizer::OptimizeWholeGraph(ComputeGraph &graph) { + return SUCCESS; +} + +Status FakeAtomicOptimizer::GetAttributes(GraphOptimizerAttribute &attrs) const { + attrs.engineName = engine_name_; + return SUCCESS; +} + +FAKE_NS_END diff --git a/tests/framework/ge_running_env/src/engine/fake_composite_engine.cc b/tests/framework/ge_running_env/src/engine/fake_composite_engine.cc new file mode 100644 index 00000000..2932af3d --- /dev/null +++ b/tests/framework/ge_running_env/src/engine/fake_composite_engine.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_running_env/fake_composite_engine.h" +#include "ge_running_env/fake_composite_optimizer.h" + +FAKE_NS_BEGIN + +void FakeCompositeEngine::InstallTo(std::map &graph_optimizers) const { + auto optimizer = std::make_shared(engine_name_); + graph_optimizers[engine_name_] = optimizer; +} + +void FakeCompositeEngine::InstallTo(std::map&) const { +} + +void FakeCompositeEngine::InstallTo(std::map> &composite_engines) const { + composite_engines[engine_name_] = sub_engines_; +} + +void FakeCompositeEngine::InstallTo(std::map &composite_engine_kernel_lib_names) const { + if (info_store_names_.size() != 1) { + return; + } + composite_engine_kernel_lib_names[engine_name_] = *info_store_names_.begin(); +} + +void FakeCompositeEngine::InstallTo(std::map &engines) const { + DNNEngineAttribute attr; + attr.engine_name = engine_name_; + attr.atomic_engine_flag = false; + engines[engine_name_] = MakeShared(attr); +} + +FAKE_NS_END diff --git a/tests/framework/ge_running_env/src/engine/fake_composite_optimizer.cc b/tests/framework/ge_running_env/src/engine/fake_composite_optimizer.cc new file mode 100644 index 00000000..7c43a08d --- /dev/null +++ b/tests/framework/ge_running_env/src/engine/fake_composite_optimizer.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "ge_running_env/fake_composite_optimizer.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/debug/ge_attr_define.h" +#include "framework/common/types.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" + +FAKE_NS_BEGIN + +uint32_t FakeCompositeOptimizer::thread_scope_id_ = 0; + +Status FakeCompositeOptimizer::OptimizeFusedGraph(ComputeGraph& graph) { + std::set nodes; + for (const auto &node : graph.GetDirectNode()) { + const auto &type = NodeUtils::GetNodeType(node); + if ((type != PLACEHOLDER) && (type != END)) { + nodes.emplace(node); + } + } + if (nodes.size() == 1) { + return SUCCESS; + } + + const std::string &subgraph_name = "PartitionedCall_" + std::to_string(thread_scope_id_); + const auto &subgraph = GraphUtils::BuildSubgraphWithNodes(graph, nodes, subgraph_name); + if (subgraph == nullptr) { + GELOGE(FAILED, "Build subgraph %s failed", subgraph_name.c_str()); + return FAILED; + } + const auto &parent_node = subgraph->GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + (void)AttrUtils::SetStr(parent_node->GetOpDesc(), ATTR_NAME_FFTS_PLUS_SUB_GRAPH, subgraph_name); + for (const auto &node : subgraph->GetAllNodes()) { + (void)AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_THREAD_SCOPE_ID, thread_scope_id_); + } + + thread_scope_id_++; + + return SUCCESS; +} + +FAKE_NS_END diff --git a/tests/framework/ge_running_env/src/engine/fake_engine.cc b/tests/framework/ge_running_env/src/engine/fake_engine.cc index 4b8fedbc..be6304c9 100644 --- a/tests/framework/ge_running_env/src/engine/fake_engine.cc +++ b/tests/framework/ge_running_env/src/engine/fake_engine.cc @@ -15,9 +15,6 @@ */ #include "ge_running_env/fake_engine.h" -#include "ge_running_env/fake_ops_kernel_builder.h" -#include "ge_running_env/fake_ops_kernel_info_store.h" -#include "opskernel_manager/ops_kernel_manager.h" FAKE_NS_BEGIN @@ -78,4 +75,11 @@ void FakeEngine::InstallTo(std::map &ops_kernel_bui InstallFor(ops_kernel_builders, custom_builders_); } +void FakeEngine::InstallTo(std::map &engines) const { + DNNEngineAttribute attr; + attr.engine_name = engine_name_; + attr.atomic_engine_flag = true; + engines[engine_name_] = MakeShared(attr); +} + FAKE_NS_END diff --git a/tests/framework/ge_running_env/src/env/ge_default_running_env.cc b/tests/framework/ge_running_env/src/env/ge_default_running_env.cc index ab705f55..46875186 100644 --- a/tests/framework/ge_running_env/src/env/ge_default_running_env.cc +++ b/tests/framework/ge_running_env/src/env/ge_default_running_env.cc @@ -32,6 +32,8 @@ std::vector fake_ops = { FakeOp(SWITCH).InfoStoreAndBuilder("RTSLib"), FakeOp(LOOPCOND).InfoStoreAndBuilder("RTSLib"), FakeOp(STREAMMERGE).InfoStoreAndBuilder("RTSLib"), FakeOp(STREAMSWITCH).InfoStoreAndBuilder("RTSLib"), FakeOp(STREAMACTIVE).InfoStoreAndBuilder("RTSLib"), FakeOp(EXIT).InfoStoreAndBuilder("RTSLib"), + FakeOp(SEND).InfoStoreAndBuilder("RTSLib"), FakeOp(RECV).InfoStoreAndBuilder("RTSLib"), + FakeOp(IDENTITY).InfoStoreAndBuilder("RTSLib"), FakeOp(IDENTITYN).InfoStoreAndBuilder("RTSLib"), FakeOp(LESS).InfoStoreAndBuilder("AiCoreLib"), FakeOp(NEXTITERATION).InfoStoreAndBuilder("AiCoreLib"), FakeOp(CAST).InfoStoreAndBuilder("AiCoreLib"), FakeOp(TRANSDATA).InfoStoreAndBuilder("AiCoreLib"), @@ -53,4 +55,4 @@ void GeDefaultRunningEnv::InstallTo(GeRunningEnvFaker& ge_env) { } } -FAKE_NS_END \ No newline at end of file +FAKE_NS_END diff --git a/tests/framework/ge_running_env/src/env/ge_running_env_faker.cc b/tests/framework/ge_running_env/src/env/ge_running_env_faker.cc index 2977f6b2..f591270b 100644 --- a/tests/framework/ge_running_env/src/env/ge_running_env_faker.cc +++ b/tests/framework/ge_running_env/src/env/ge_running_env_faker.cc @@ -15,34 +15,31 @@ */ #include -#include #include "external/ge/ge_api.h" #include "opskernel_manager/ops_kernel_builder_manager.h" #include "init/gelib.h" -#include "utility" +#include "plugin/engine/engine_manage.h" #include "ge_running_env/ge_running_env_faker.h" #include "ge_default_running_env.h" -#include "ge_running_env/env_installer.h" #include "op/fake_op_repo.h" FAKE_NS_BEGIN namespace { -OpsKernelManager& getKernelManger() { - std::shared_ptr instancePtr = ge::GELib::GetInstance(); - return instancePtr->OpsKernelManagerObj(); -} - struct InitEnv { static InitEnv& GetInstance() { static InitEnv instance; return instance; } - void reset(std::map& ops_kernel_info_stores, - std::map& builders) { + void reset(std::map &ops_kernel_info_stores, + std::map &builders, + std::map &ops_kernel_optimizers, + std::map> &composite_engines, + std::map &composite_engine_kernel_lib_names, + std::map &engines) { std::set remove_info_names; - for (auto iter : ops_kernel_info_stores) { + for (auto iter : builders) { if (kernel_info_names.find(iter.first) == kernel_info_names.end()) { remove_info_names.insert(iter.first); } @@ -50,12 +47,16 @@ struct InitEnv { for (auto info_name : remove_info_names) { ops_kernel_info_stores.erase(info_name); builders.erase(info_name); + ops_kernel_optimizers.erase(info_name); + composite_engines.erase(info_name); + composite_engine_kernel_lib_names.erase(info_name); + engines.erase(info_name); } } private: InitEnv() { - for (auto iter : getKernelManger().GetAllOpsKernelInfoStores()) { + for (auto iter : OpsKernelManager::GetInstance().GetAllOpsKernelInfoStores()) { kernel_info_names.insert(iter.first); } } @@ -66,20 +67,27 @@ struct InitEnv { } // namespace GeRunningEnvFaker::GeRunningEnvFaker() - : op_kernel_info_(const_cast>&>(getKernelManger().GetAllOpsKernelInfo())), - ops_kernel_info_stores_( - const_cast&>(getKernelManger().GetAllOpsKernelInfoStores())), - ops_kernel_optimizers_( - const_cast&>(getKernelManger().GetAllGraphOptimizerObjs())), - ops_kernel_builders_(const_cast&>( - OpsKernelBuilderManager::Instance().GetAllOpsKernelBuilders())) { + : op_kernel_info_(const_cast>&>( + OpsKernelManager::GetInstance().GetAllOpsKernelInfo())), + ops_kernel_info_stores_(const_cast&>( + OpsKernelManager::GetInstance().GetAllOpsKernelInfoStores())), + ops_kernel_optimizers_(const_cast&>( + OpsKernelManager::GetInstance().GetAllGraphOptimizerObjs())), + ops_kernel_builders_(const_cast&>( + OpsKernelBuilderManager::Instance().GetAllOpsKernelBuilders())), + composite_engines_(const_cast>&>( + OpsKernelManager::GetInstance().GetCompositeEngines())), + composite_engine_kernel_lib_names_(const_cast&>( + OpsKernelManager::GetInstance().GetCompositeEngineKernelLibNames())), + engine_map_(const_cast&>(DNNEngineManager::GetInstance().GetAllEngines())) { Reset(); } GeRunningEnvFaker& GeRunningEnvFaker::Reset() { InitEnv& init_env = InitEnv::GetInstance(); FakeOpRepo::Reset(); - init_env.reset(ops_kernel_info_stores_, ops_kernel_builders_); + init_env.reset(ops_kernel_info_stores_, ops_kernel_builders_, ops_kernel_optimizers_, composite_engines_, + composite_engine_kernel_lib_names_, engine_map_); flush(); return *this; } @@ -91,13 +99,17 @@ GeRunningEnvFaker& GeRunningEnvFaker::Install(const EnvInstaller& installer) { installer.InstallTo(ops_kernel_info_stores_); installer.InstallTo(ops_kernel_optimizers_); installer.InstallTo(ops_kernel_builders_); + installer.InstallTo(composite_engines_); + installer.InstallTo(composite_engine_kernel_lib_names_); + installer.InstallTo(engine_map_); + flush(); return *this; } void GeRunningEnvFaker::flush() { op_kernel_info_.clear(); - getKernelManger().GetOpsKernelInfo(""); + OpsKernelManager::GetInstance().GetOpsKernelInfo(""); } GeRunningEnvFaker& GeRunningEnvFaker::InstallDefault() { diff --git a/tests/framework/ge_running_env/tests/test_ge_running_env_faker.cc b/tests/framework/ge_running_env/tests/test_ge_running_env_faker.cc index 4429f4a7..0cd007ba 100644 --- a/tests/framework/ge_running_env/tests/test_ge_running_env_faker.cc +++ b/tests/framework/ge_running_env/tests/test_ge_running_env_faker.cc @@ -20,9 +20,10 @@ #include "external/ge/ge_api.h" #include "opskernel_manager/ops_kernel_builder_manager.h" #include "ge_running_env/fake_ops_kernel_builder.h" -#include "ge_running_env/fake_ns.h" #include "ge_running_env/ge_running_env_faker.h" #include "ge_running_env/fake_op.h" +#include "ge_running_env/fake_composite_engine.h" + FAKE_NS_BEGIN #define ASSERT_OPS_LIST_SIZE(list_size) \ @@ -33,8 +34,9 @@ FAKE_NS_BEGIN class GeRunningEvnFakerTest : public testing::Test { protected: void SetUp() {} - OpsKernelManager &kernel_manager = ge::GELib::GetInstance()->OpsKernelManagerObj(); + OpsKernelManager &kernel_manager = OpsKernelManager::GetInstance(); OpsKernelBuilderManager &builder_manager = OpsKernelBuilderManager::Instance(); + DNNEngineManager &dnnengine_manager = DNNEngineManager::GetInstance(); }; TEST_F(GeRunningEvnFakerTest, test_reset_running_env_is_success) { @@ -142,7 +144,31 @@ TEST_F(GeRunningEvnFakerTest, test_install_default_fake_engine_success) { ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 7); ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 7); - ASSERT_EQ(kernel_manager.GetAllOpsKernelInfo().size(), 66); + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfo().size(), 68); +} + +TEST_F(GeRunningEvnFakerTest, test_install_fake_engine_with_optimizer_success) { + GeRunningEnvFaker ge_env; + ge_env.Install(FakeEngine("DNN_VM_AICPU")); + + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 2); + ASSERT_EQ(kernel_manager.GetAllGraphOptimizerObjs().size(), 0); + ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 2); +} + +TEST_F(GeRunningEvnFakerTest, test_install_fake_engine_with_sub_engines_success) { + GeRunningEnvFaker ge_env; + ge_env.Install(FakeEngine("DNN_VM_AICPU")) + .Install(FakeEngine("AIcoreEngine")) + .Install(FakeCompositeEngine("ffts_plus", {"DNN_VM_AICPU", "AIcoreEngine"}).KernelInfoStore("ffts_plus")); + + ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 3); + ASSERT_EQ(kernel_manager.GetAllGraphOptimizerObjs().size(), 1); + ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 4); + ASSERT_EQ(kernel_manager.GetCompositeEngines().size(), 1); + + ASSERT_EQ(OpsKernelManager::GetInstance().GetCompositeEngines().size(), 1); + ASSERT_EQ(OpsKernelManager::GetInstance().GetCompositeEngineKernelLibNames().size(), 1); } FAKE_NS_END diff --git a/tests/st/testcase/test_ffts_plus.cc b/tests/st/testcase/test_ffts_plus.cc new file mode 100644 index 00000000..e9a94aa4 --- /dev/null +++ b/tests/st/testcase/test_ffts_plus.cc @@ -0,0 +1,152 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "init/gelib.h" +#include "opskernel_manager/ops_kernel_builder_manager.h" +#include "external/ge/ge_api.h" +#include "ge_running_env/ge_running_env_faker.h" +#include "ge_graph_dsl/graph_dsl.h" +#include "ge_running_env/fake_composite_engine.h" +#include "ge_running_env/fake_op.h" + +#include "easy_graph/layout/graph_layout.h" +#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" +#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" + +#include "ge_graph_dsl/assert/graph_assert.h" + +using namespace std; +using namespace ge; + +namespace { +bool IfNodeExist(const ComputeGraphPtr &graph, std::function filter, + bool direct_node_flag = true) { + for (const auto &node : graph->GetNodes(direct_node_flag)) { + if (filter(node)) { + return true; + } + } + return false; +} + +void GetSubgraphsWithFilter(const ComputeGraphPtr &graph, std::function filter, + std::vector &subgraphs) { + for (const auto &subgraph : graph->GetAllSubgraphs()) { + if (filter(subgraph)) { + subgraphs.emplace_back(subgraph); + } + } +} + +bool IsAllNodeMatch(const ComputeGraphPtr &graph, std::function filter) { + for (const auto &node : graph->GetAllNodes()) { + if (!filter(node)) { + return false; + } + } + return true; +} +} + +class TestFftsPlus : public testing::Test { + protected: + GeRunningEnvFaker ge_env; + EG_NS::GraphEasyExecutor executor; + void SetUp() { + EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); + ge_env.InstallDefault() + .Install(FakeCompositeEngine("ffts_plus", {"AIcoreEngine", "DNN_VM_AICPU"}).KernelInfoStore("ffts_plus")) + .Install(FakeOp(GETNEXT).InfoStoreAndBuilder("AicpuLib")) + .Install(FakeOp(HCOMREDUCE).InfoStoreAndBuilder("HcclLib")); + } + void TearDown() {} +}; + +/* + * g1 + * + * ┌──────────┐ (0,1) ┌────────┐ (0,0) ┌────────┐ + * │ const │ ───────> │ less │ ───────> │ reduce │ + * └──────────┘ └────────┘ └────────┘ + * ∧ + * │ (0,0) + * │ + * ┌──────────┐ (0,0) ┌────────┐ (0,1) ┌────────┐ + * │ get_next │ ───────> │ add │ <─────── │ data1 │ + * └──────────┘ └────────┘ └────────┘ + * + */ +TEST_F(TestFftsPlus, test_ffts_plus_no_func_node) { + auto tensor = std::make_shared(); + uint32_t value = 0; + tensor->SetData((uint8_t *)&value, sizeof(uint32_t)); + DEF_GRAPH(g1) { + CHAIN(NODE("get_next", GETNEXT)->NODE("add", ADD)); + CHAIN(NODE("data1", DATA)->NODE("add")->NODE("less", LESS)->NODE("reduce", HCOMREDUCE)); + CHAIN(NODE("const", OP_CFG(CONSTANTOP).Attr("value", tensor))->Node("less")); + }; + + auto graph = ToGeGraph(g1); + // new session & add graph + map options; + Session session(options); + auto ret = session.AddGraph(1, graph, options); + EXPECT_EQ(ret, SUCCESS); + + // build input tensor + std::vector inputs; + // build_graph through session + ret = session.BuildGraph(1, inputs); + EXPECT_EQ(ret, SUCCESS); + + CHECK_GRAPH(PreRunAfterBuild) { + // node exist + ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "get_next"; })); + ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "add"; })); + ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "less"; })); + ASSERT_TRUE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetType() == PARTITIONEDCALL; })); + + // subgraph exit + ASSERT_EQ(graph->GetAllSubgraphs().size(), 1); + std::vector subgraphs; + GetSubgraphsWithFilter(graph, + [](const ComputeGraphPtr &graph) { + const auto &parent_node = graph->GetParentNode(); + if ((parent_node == nullptr) || (parent_node->GetOpDesc() == nullptr)) { + return false; + } + return parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH); }, + subgraphs); + ASSERT_EQ(subgraphs.size(), 1); + + // subgraph node check + const auto &subgraph = subgraphs[0]; + ASSERT_TRUE(subgraph != nullptr); + ASSERT_TRUE(IsAllNodeMatch(subgraph, + [](const NodePtr &node) { + return node->GetOpDesc()->HasAttr(ATTR_NAME_THREAD_SCOPE_ID); + })); + const auto &parent_node = subgraph->GetParentNode(); + ASSERT_TRUE(parent_node != nullptr); + ASSERT_TRUE(parent_node->GetOpDesc() != nullptr); + int64_t stream_id = parent_node->GetOpDesc()->GetStreamId(); + ASSERT_TRUE(IsAllNodeMatch(subgraph, + [stream_id](const NodePtr &node) { + return node->GetOpDesc()->GetStreamId() == stream_id; + })); + }; +} diff --git a/tests/st/testcase/test_framework_dummy.cc b/tests/st/testcase/test_framework_dummy.cc index 8f13bb78..b7494b5e 100644 --- a/tests/st/testcase/test_framework_dummy.cc +++ b/tests/st/testcase/test_framework_dummy.cc @@ -19,6 +19,11 @@ #include "graph/debug/ge_attr_define.h" #include "framework/common/types.h" #include "ge_running_env/ge_running_env_faker.h" + +#include "easy_graph/layout/graph_layout.h" +#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" +#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" + #include "ge_graph_dsl/graph_dsl.h" #include "ge_graph_dsl/assert/graph_assert.h" @@ -94,9 +99,13 @@ Graph BuildV1ControlFlowGraph() { } } // namespace class FrameworkTest : public testing::Test { + EG_NS::GraphEasyExecutor executor; protected: GeRunningEnvFaker ge_env; - void SetUp() { ge_env.InstallDefault(); } + void SetUp() { + ge_env.InstallDefault(); + EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); + } void TearDown() {} }; diff --git a/tests/st/testcase/test_ge_opt_info.cc b/tests/st/testcase/test_ge_opt_info.cc index 2e8e5382..1f2b23ef 100644 --- a/tests/st/testcase/test_ge_opt_info.cc +++ b/tests/st/testcase/test_ge_opt_info.cc @@ -21,11 +21,21 @@ #include "framework/common/types.h" #include "graph/ge_local_context.h" #include "ge_graph_dsl/graph_dsl.h" +#include "ge_running_env/ge_running_env_faker.h" + +#include "easy_graph/layout/graph_layout.h" +#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" +#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" namespace ge { class STEST_opt_info : public testing::Test { protected: - void SetUp() {} + GeRunningEnvFaker ge_env; + EG_NS::GraphEasyExecutor executor; + void SetUp() { + EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); + ge_env.InstallDefault(); + } void TearDown() {} }; diff --git a/tests/ut/common/graph/testcase/ge_graph/graph_builder_utils.cc b/tests/ut/common/graph/testcase/ge_graph/graph_builder_utils.cc index 4044d670..5f74721c 100644 --- a/tests/ut/common/graph/testcase/ge_graph/graph_builder_utils.cc +++ b/tests/ut/common/graph/testcase/ge_graph/graph_builder_utils.cc @@ -39,11 +39,11 @@ NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, return graph_->AddNode(op_desc); } -void GraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { +void GraphBuilder::AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx) { GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); } -void GraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { +void GraphBuilder::AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node) { GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); } } // namespace ut -} // namespace ge \ No newline at end of file +} // namespace ge diff --git a/tests/ut/common/graph/testcase/ge_graph/graph_builder_utils.h b/tests/ut/common/graph/testcase/ge_graph/graph_builder_utils.h index 45c75b28..eb236230 100644 --- a/tests/ut/common/graph/testcase/ge_graph/graph_builder_utils.h +++ b/tests/ut/common/graph/testcase/ge_graph/graph_builder_utils.h @@ -35,8 +35,8 @@ class GraphBuilder { NodePtr AddNDNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt) { return AddNode(name, type, in_cnt, out_cnt, FORMAT_ND, DT_FLOAT, {1, 1, 224, 224}); } - void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); - void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); + void AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx); + void AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node); ComputeGraphPtr GetGraph() { graph_->TopologicalSorting(); return graph_; diff --git a/tests/ut/ge/graph/build/task_generator_unittest.cc b/tests/ut/ge/graph/build/task_generator_unittest.cc index 7be20fa1..5ebfa2a7 100644 --- a/tests/ut/ge/graph/build/task_generator_unittest.cc +++ b/tests/ut/ge/graph/build/task_generator_unittest.cc @@ -209,11 +209,8 @@ TEST_F(UtestTaskGeneratorTest, GenerateTask) { Status ret = ge::GELib::Initialize(options); EXPECT_EQ(ret, SUCCESS); - shared_ptr instance_ptr = ge::GELib::GetInstance(); - EXPECT_NE(instance_ptr, nullptr); - OpsKernelInfoStorePtr ops_kernel_info_store_ptr = MakeShared(); - instance_ptr->opsManager_.ops_kernel_store_.insert(make_pair(kKernelInfoNameHccl, ops_kernel_info_store_ptr)); + OpsKernelManager::GetInstance().ops_kernel_store_.insert(make_pair(kKernelInfoNameHccl, ops_kernel_info_store_ptr)); OpsKernelBuilderManager &builder_manager_instance_ptr = ge::OpsKernelBuilderManager::Instance(); OpsKernelBuilderPtr fake_builder = MakeShared(); @@ -230,4 +227,4 @@ TEST_F(UtestTaskGeneratorTest, GenerateTask) { EXPECT_EQ(task_generator.GenerateTask(run_context, graph, task_def_list, op_name_map), SUCCESS); EXPECT_EQ(task_def_list.size(), 1); EXPECT_EQ(task_def_list[0].ops_kernel_store_ptr(), reinterpret_cast(ops_kernel_info_store_ptr.get())); -} \ No newline at end of file +} diff --git a/tests/ut/ge/graph/load/davinci_model_unittest.cc b/tests/ut/ge/graph/load/davinci_model_unittest.cc index 62204f6c..a624785f 100644 --- a/tests/ut/ge/graph/load/davinci_model_unittest.cc +++ b/tests/ut/ge/graph/load/davinci_model_unittest.cc @@ -1082,13 +1082,13 @@ TEST_F(UtestDavinciModel, init_tbe_handle) { EXPECT_EQ(model.used_tbe_handle_map_.size(), 0); } -// test InitTbeHandleWithFfts +// test InitTbeHandleInAutoMode TEST_F(UtestDavinciModel, init_tbe_handle_with_ffts) { DavinciModel model(0, nullptr); OpDescPtr op_desc = CreateOpDesc("data", DATA); model.ge_model_ = make_shared(); // without tbe_kernel - EXPECT_EQ(model.InitTbeHandleWithFfts(op_desc), INTERNAL_ERROR); + EXPECT_EQ(model.InitTbeHandleInAutoMode(op_desc), INTERNAL_ERROR); std::vector tbe_kernel; vector buffer; @@ -1099,14 +1099,14 @@ TEST_F(UtestDavinciModel, init_tbe_handle_with_ffts) { tbe_kernel.push_back(tbe_kernel_ptr1); op_desc->SetExtAttr(OP_EXTATTR_NAME_THREAD_TBE_KERNEL, tbe_kernel); // without _register_stub_func - EXPECT_EQ(model.InitTbeHandleWithFfts(op_desc), INTERNAL_ERROR); + EXPECT_EQ(model.InitTbeHandleInAutoMode(op_desc), INTERNAL_ERROR); vector bin_file_keys; bin_file_keys.emplace_back(op_desc->GetName() + "_0"); bin_file_keys.emplace_back(op_desc->GetName() + "_1"); AttrUtils::SetListStr(op_desc, "_register_stub_func", bin_file_keys); - EXPECT_EQ(model.InitTbeHandleWithFfts(op_desc), SUCCESS); + EXPECT_EQ(model.InitTbeHandleInAutoMode(op_desc), SUCCESS); // rtQueryFunctionRegistered(bin_file_key) failed EXPECT_EQ(model.used_tbe_handle_map_.size(), 0); } @@ -1116,18 +1116,17 @@ TEST_F(UtestDavinciModel, init_binary_magic) { DavinciModel model(0, nullptr); rtDevBinary_t binary; OpDescPtr op_desc = CreateOpDesc("data", DATA); - bool is_ffts = true; vector json_list; AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_MAGIC, json_list); // without tvm_magic - EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), INTERNAL_ERROR); + EXPECT_EQ(model.InitBinaryMagic(op_desc, 0, binary, "_thread_"), INTERNAL_ERROR); json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF_AICPU"); json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF"); op_desc->DelAttr(TVM_ATTR_NAME_THREAD_MAGIC); AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_MAGIC, json_list); - EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), SUCCESS); + EXPECT_EQ(model.InitBinaryMagic(op_desc, 0, binary, "_thread_"), SUCCESS); EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF_AICPU); - EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 1, binary), SUCCESS); + EXPECT_EQ(model.InitBinaryMagic(op_desc, 1, binary, "_thread_"), SUCCESS); EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF); json_list.clear(); @@ -1135,9 +1134,9 @@ TEST_F(UtestDavinciModel, init_binary_magic) { json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF_AICUBE"); op_desc->DelAttr(TVM_ATTR_NAME_THREAD_MAGIC); AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_MAGIC, json_list); - EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), SUCCESS); + EXPECT_EQ(model.InitBinaryMagic(op_desc, 0, binary, "_thread_"), SUCCESS); EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF_AIVEC); - EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 1, binary), SUCCESS); + EXPECT_EQ(model.InitBinaryMagic(op_desc, 1, binary, "_thread_"), SUCCESS); EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF_AICUBE); // with invalid json type @@ -1146,13 +1145,12 @@ TEST_F(UtestDavinciModel, init_binary_magic) { json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF_INVALID"); op_desc->DelAttr(TVM_ATTR_NAME_THREAD_MAGIC); AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_MAGIC, json_list); - EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), PARAM_INVALID); + EXPECT_EQ(model.InitBinaryMagic(op_desc, 0, binary, "_thread_"), PARAM_INVALID); // test unffts - is_ffts = false; string json_string = "RT_DEV_BINARY_MAGIC_ELF_AIVEC"; AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, json_string); - EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), SUCCESS); + EXPECT_EQ(model.InitBinaryMagic(op_desc, UINT32_MAX, binary), SUCCESS); EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF_AIVEC); } @@ -1161,19 +1159,17 @@ TEST_F(UtestDavinciModel, init_meta_data) { DavinciModel model(0, nullptr); void *bin_handle; OpDescPtr op_desc = CreateOpDesc("data", DATA); - bool is_ffts = true; vector meta_data_list; // with empty meta_data - EXPECT_EQ(model.InitMetaData(op_desc, is_ffts, 0, bin_handle), INTERNAL_ERROR); + EXPECT_EQ(model.InitMetaData(op_desc, 0, bin_handle, "_thread_"), INTERNAL_ERROR); meta_data_list.emplace_back("meta_data_0"); meta_data_list.emplace_back("meta_data_1"); AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_METADATA, meta_data_list); - EXPECT_EQ(model.InitMetaData(op_desc, is_ffts, 0, bin_handle), SUCCESS); + EXPECT_EQ(model.InitMetaData(op_desc, 0, bin_handle, "_thread_"), SUCCESS); - is_ffts = false; string meta_data = "meta_data"; AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_METADATA, meta_data); - EXPECT_EQ(model.InitMetaData(op_desc, is_ffts, 0, bin_handle), SUCCESS); + EXPECT_EQ(model.InitMetaData(op_desc, UINT32_MAX, bin_handle), SUCCESS); } // test InitKernelName @@ -1181,25 +1177,23 @@ TEST_F(UtestDavinciModel, init_kernel_name) { DavinciModel model(0, nullptr); string kernel_name; OpDescPtr op_desc = CreateOpDesc("data", DATA); - bool is_ffts = true; // failed when name is invalid - EXPECT_EQ(model.InitKernelName(op_desc, is_ffts, 0, kernel_name), INTERNAL_ERROR); + EXPECT_EQ(model.InitKernelName(op_desc, 0, kernel_name, "_thread_"), INTERNAL_ERROR); OpDescPtr op_desc1 = CreateOpDesc("sgt_graph_nodes/loss_scale", SCALE); string attr_kernel_name = "loss_scale_thread_kernelname"; vector kernel_name_list; AttrUtils::SetListStr(op_desc, attr_kernel_name, kernel_name_list); // failed without kernel_name - EXPECT_EQ(model.InitKernelName(op_desc, is_ffts, 0, kernel_name), INTERNAL_ERROR); + EXPECT_EQ(model.InitKernelName(op_desc, 0, kernel_name, "_thread_"), INTERNAL_ERROR); kernel_name_list.emplace_back("kernel_name_0"); kernel_name_list.emplace_back("kernel_name_1"); AttrUtils::SetListStr(op_desc1, attr_kernel_name, kernel_name_list); - EXPECT_EQ(model.InitKernelName(op_desc1, is_ffts, 0, kernel_name), SUCCESS); + EXPECT_EQ(model.InitKernelName(op_desc1, 0, kernel_name, "_thread_"), SUCCESS); // without ffts - is_ffts = false; attr_kernel_name = "data_kernelname"; kernel_name = "kernel_name"; AttrUtils::SetStr(op_desc, attr_kernel_name, kernel_name); - EXPECT_EQ(model.InitKernelName(op_desc, is_ffts, 0, kernel_name), SUCCESS); + EXPECT_EQ(model.InitKernelName(op_desc, UINT32_MAX, kernel_name), SUCCESS); } } // namespace ge diff --git a/tests/ut/ge/graph/optimize/graph_optimize_unittest.cc b/tests/ut/ge/graph/optimize/graph_optimize_unittest.cc index 7f26aa8c..7a05c754 100644 --- a/tests/ut/ge/graph/optimize/graph_optimize_unittest.cc +++ b/tests/ut/ge/graph/optimize/graph_optimize_unittest.cc @@ -74,6 +74,8 @@ class UtestGraphOptimizeTest : public testing::Test { void TearDown() { DeleteFile(config_file_); DeleteFile(config_dir_); + DNNEngineManager::GetInstance().schedulers_.clear(); + OpsKernelManager::GetInstance().atomic_first_optimizers_by_priority_.clear(); } private: @@ -128,10 +130,8 @@ TEST_F(UtestGraphOptimizeTest, test_OptimizeAfterStage1_succ) { Status ret = ge::GELib::Initialize(options); EXPECT_EQ(ret, SUCCESS); - shared_ptr instance_ptr = ge::GELib::GetInstance(); - EXPECT_NE(instance_ptr, nullptr); GraphOptimizerPtr graph_opt = MakeShared(); - instance_ptr->opsManager_.graph_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); + OpsKernelManager::GetInstance().atomic_first_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); ComputeGraphPtr compute_graph = MakeShared("test_graph"); GraphOptimize base_optimize; @@ -142,6 +142,8 @@ TEST_F(UtestGraphOptimizeTest, test_OptimizeAfterStage1_succ) { ret = base_optimize.OptimizeAfterStage1(compute_graph); EXPECT_EQ(ret, SUCCESS); + shared_ptr instance_ptr = ge::GELib::GetInstance(); + EXPECT_NE(instance_ptr, nullptr); ret = instance_ptr->Finalize(); EXPECT_EQ(ret, SUCCESS); } @@ -164,13 +166,13 @@ TEST_F(UtestGraphOptimizeTest, test_OptimizeAfterStage1_fail) { ret = ge::GELib::Initialize(options); EXPECT_EQ(ret, SUCCESS); - shared_ptr instance_ptr = ge::GELib::GetInstance(); - EXPECT_NE(instance_ptr, nullptr); GraphOptimizerPtr graph_opt = MakeShared(); - instance_ptr->opsManager_.graph_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); + OpsKernelManager::GetInstance().atomic_first_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); ret = base_optimize.OptimizeAfterStage1(compute_graph); EXPECT_EQ(ret, FAILED); + shared_ptr instance_ptr = ge::GELib::GetInstance(); + EXPECT_NE(instance_ptr, nullptr); ret = instance_ptr->Finalize(); EXPECT_EQ(ret, SUCCESS); } @@ -180,10 +182,8 @@ TEST_F(UtestGraphOptimizeTest, test_optimizers_succ) { Status ret = ge::GELib::Initialize(options); EXPECT_EQ(ret, SUCCESS); - shared_ptr instance_ptr = ge::GELib::GetInstance(); - EXPECT_NE(instance_ptr, nullptr); GraphOptimizerPtr graph_opt = MakeShared(); - instance_ptr->opsManager_.graph_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); + OpsKernelManager::GetInstance().atomic_first_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); ComputeGraphPtr compute_graph = MakeShared("test_graph"); GraphOptimize base_optimize; @@ -197,12 +197,14 @@ TEST_F(UtestGraphOptimizeTest, test_optimizers_succ) { ret = base_optimize.OptimizeOriginalGraphForQuantize(compute_graph); EXPECT_EQ(ret, SUCCESS); - ret = base_optimize.OptimizeGraphBeforeBuildForRts(compute_graph); + ret = base_optimize.OptimizeGraphBeforeBuild(compute_graph); EXPECT_EQ(ret, SUCCESS); ret = base_optimize.OptimizeWholeGraph(compute_graph); EXPECT_EQ(ret, SUCCESS); + shared_ptr instance_ptr = ge::GELib::GetInstance(); + EXPECT_NE(instance_ptr, nullptr); ret = instance_ptr->Finalize(); EXPECT_EQ(ret, SUCCESS); } @@ -212,10 +214,8 @@ TEST_F(UtestGraphOptimizeTest, test_optimizers_fail) { Status ret = ge::GELib::Initialize(options); EXPECT_EQ(ret, SUCCESS); - shared_ptr instance_ptr = ge::GELib::GetInstance(); - EXPECT_NE(instance_ptr, nullptr); GraphOptimizerPtr graph_opt = MakeShared(); - instance_ptr->opsManager_.graph_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); + OpsKernelManager::GetInstance().atomic_first_optimizers_by_priority_.push_back(make_pair("AIcoreEngine", graph_opt)); ComputeGraphPtr compute_graph = MakeShared("test_graph"); GraphOptimize base_optimize; @@ -229,12 +229,14 @@ TEST_F(UtestGraphOptimizeTest, test_optimizers_fail) { ret = base_optimize.OptimizeOriginalGraphForQuantize(compute_graph); EXPECT_EQ(ret, FAILED); - ret = base_optimize.OptimizeGraphBeforeBuildForRts(compute_graph); + ret = base_optimize.OptimizeGraphBeforeBuild(compute_graph); EXPECT_EQ(ret, FAILED); ret = base_optimize.OptimizeWholeGraph(compute_graph); EXPECT_EQ(ret, FAILED); + shared_ptr instance_ptr = ge::GELib::GetInstance(); + EXPECT_NE(instance_ptr, nullptr); ret = instance_ptr->Finalize(); EXPECT_EQ(ret, SUCCESS); } diff --git a/tests/ut/ge/graph/passes/graph_builder_utils.cc b/tests/ut/ge/graph/passes/graph_builder_utils.cc index 9904e731..bd164fae 100644 --- a/tests/ut/ge/graph/passes/graph_builder_utils.cc +++ b/tests/ut/ge/graph/passes/graph_builder_utils.cc @@ -37,10 +37,10 @@ NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, return graph_->AddNode(op_desc); } -void GraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { +void GraphBuilder::AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx) { GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); } -void GraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { +void GraphBuilder::AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node) { GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); } diff --git a/tests/ut/ge/graph/passes/graph_builder_utils.h b/tests/ut/ge/graph/passes/graph_builder_utils.h index d024beb4..eada46bf 100644 --- a/tests/ut/ge/graph/passes/graph_builder_utils.h +++ b/tests/ut/ge/graph/passes/graph_builder_utils.h @@ -32,8 +32,8 @@ class GraphBuilder { NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, std::vector shape = {1, 1, 224, 224}); - void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); - void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); + void AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx); + void AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node); ComputeGraphPtr GetGraph() { graph_->TopologicalSorting(); return graph_; diff --git a/third_party/fwkacllib/inc/runtime/kernel.h b/third_party/fwkacllib/inc/runtime/kernel.h index 9b0221c7..aeddf6e1 100644 --- a/third_party/fwkacllib/inc/runtime/kernel.h +++ b/third_party/fwkacllib/inc/runtime/kernel.h @@ -356,7 +356,7 @@ RTS_API rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void * * @return RT_ERROR_INVALID_VALUE for error input */ RTS_API rtError_t rtKernelLaunchWithHandle(void *handle, const void *devFunc, uint32_t blockDim, void *args, uint32_t argsSize, - rtSmDesc_t *smDesc, rtStream_t stream_, const void *kernelInfo); + rtSmDesc_t *smDesc, rtStream_t stream_, const void *kernelInfo); /** * @ingroup rt_kernel @@ -652,4 +652,3 @@ RTS_API rtError_t rtStopMDCProfiler(void *addr); #endif #endif // __CCE_RUNTIME_KERNEL_H__ - diff --git a/third_party/fwkacllib/inc/runtime/rt.h b/third_party/fwkacllib/inc/runtime/rt.h index aa394eea..1d696be4 100644 --- a/third_party/fwkacllib/inc/runtime/rt.h +++ b/third_party/fwkacllib/inc/runtime/rt.h @@ -28,5 +28,7 @@ #include "rt_model.h" #include "stream.h" #include "rt_ffts.h" +#include "rt_ffts_plus.h" +#include "rt_ffts_plus_define.h" #endif // __CCE_RUNTIME_RT_H__ diff --git a/third_party/fwkacllib/inc/runtime/rt_ffts_plus.h b/third_party/fwkacllib/inc/runtime/rt_ffts_plus.h new file mode 100644 index 00000000..8b26099a --- /dev/null +++ b/third_party/fwkacllib/inc/runtime/rt_ffts_plus.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. + * Description: ffts plus interface + */ + +#ifndef __CCE_RUNTIME_FFTS_PLUS_H +#define __CCE_RUNTIME_FFTS_PLUS_H + +#include "base.h" +#include "rt_stars_define.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +#pragma pack(push) +#pragma pack (1) + +typedef struct tagFftsPlusTaskInfo { + const rtFftsPlusSqe_t *fftsPlusSqe; + const void *descBuf; // include total context + size_t descBufLen; // the length of descBuf +} rtFftsPlusTaskInfo_t; + +#pragma pack(pop) + +RTS_API rtError_t rtGetAddrAndPrefCntWithHandle(void *handle, const void *devFunc, void **addr, uint32_t *prefetchCnt); +RTS_API rtError_t rtFftsPlusTaskLaunch(rtFftsPlusTaskInfo_t *fftsPlusTaskInfo, rtStream_t stream); + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif +#endif // __CCE_RUNTIME_FFTS_H diff --git a/third_party/fwkacllib/inc/runtime/rt_ffts_plus_define.h b/third_party/fwkacllib/inc/runtime/rt_ffts_plus_define.h new file mode 100644 index 00000000..fb871451 --- /dev/null +++ b/third_party/fwkacllib/inc/runtime/rt_ffts_plus_define.h @@ -0,0 +1,689 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. + * Description: the definition of ffts plus + */ + +#ifndef __CCE_RUNTIME_FFTS_PLUS_DEFINE_H +#define __CCE_RUNTIME_FFTS_PLUS_DEFINE_H + +#include "base.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +#pragma pack(push) +#pragma pack (1) + +// hardware context type +typedef enum tagFftsPlusHwType { + RT_HW_CTX_TYPE_AIC = 0, + RT_HW_CTX_TYPE_AIV = 1, + RT_HW_CTX_TYPE_NOTIFY_WAIT = 3, + RT_HW_CTX_TYPE_NOTIFY_RECORD = 4, + RT_HW_CTX_TYPE_WRITE_VALUE = 5, + RT_HW_CTX_TYPE_MIX_AIC = 6, + RT_HW_CTX_TYPE_MIX_AIV = 7, + RT_HW_CTX_TYPE_SDMA = 8, + RT_HW_CTX_TYPE_FLUSH_DATA = 9, + RT_HW_CTX_TYPE_INVALIDATE_DATA = 10, + RT_HW_CTX_TYPE_WRITEBACK_DATA = 11, + RT_HW_CTX_TYPE_AICPU = 12, + RT_HW_CTX_TYPE_LOAD = 13, + RT_HW_CTX_TYPE_MAX, +}rtFftsPlusHwType_t; + +// hardware context type +typedef enum tagFftsPlusSoftType { + RT_SOFT_CTX_TYPE_COND_SWITCH = 1, + RT_SOFT_CTX_TYPE_CASE_SWITCH = 2, + RT_SOFT_CTX_TYPE_AT_START = 3, + RT_SOFT_CTX_TYPE_AT_END = 4, + RT_SOFT_CTX_TYPE_LABEL = 5, + RT_SOFT_CTX_TYPE_MAX, +}rtFftsPlusSoftType_t; + +typedef enum tagFftsPlusContextType { + RT_CTX_TYPE_AICORE = 0x0000, + RT_CTX_TYPE_AIV = 0x0001, + RT_CTX_TYPE_NOTIFY_WAIT = 0x0003, + RT_CTX_TYPE_NOTIFY_RECORD = 0x0004, + RT_CTX_TYPE_WRITE_VALUE = 0x0005, + RT_CTX_TYPE_MIX_AIC = 0x0006, + RT_CTX_TYPE_MIX_AIV = 0x0007, + RT_CTX_TYPE_SDMA = 0x0008, + RT_CTX_TYPE_FLUSH_DATA = 0x0009, + RT_CTX_TYPE_INVALIDATE_DATA = 0x000A, + RT_CTX_TYPE_WRITEBACK_DATA = 0x000B, + RT_CTX_TYPE_AICPU = 0x000C, + RT_CTX_TYPE_COND_SWITCH = 0x010D, + RT_CTX_TYPE_CASE_SWITCH = 0x020D, + RT_CTX_TYPE_AT_START = 0x0300, + RT_CTX_TYPE_AT_END = 0x0400, + RT_CTX_TYPE_LABEL = 0x0500, +}rtFftsPlusContextType_t; + +// condition type +typedef enum tagFftsPlusCondType { + RT_COND_TYPE_EQUAL = 0, + RT_COND_TYPE_NOTEQUAL = 1, + RT_COND_TYPE_GREATER = 2, + RT_COND_TYPE_GREATER_OR_EQUAL = 3, + RT_COND_TYPE_LESS = 4, + RT_COND_TYPE_LESS_OR_EQUAL = 5, + RT_COND_TYPE_MAX, +}rtFftsPlusCondType_t; + +// the definition of ffts plus context + +#define RT_CTX_SUCCESSOR_NUM 26 + +// ffts plus common context +typedef struct tagFftsPlusComCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1 : 7; + uint8_t aten : 1; + // 4-7 + uint8_t res2; + uint8_t res3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-71 + uint32_t res5[2]; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-127 + uint32_t res6[13]; +} rtFftsPlusComCtx_t; + +// aic/aiv context +typedef struct tagFftsPlusAicAivCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1: 7; + uint8_t aten: 1; + // 4-7 + uint8_t prefetchConfig; + uint8_t res3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t stat: 1; + uint16_t schem: 2; + uint16_t icachePrefetchCnt: 5; + uint16_t res5: 7; + uint16_t atm: 1; + uint16_t prefetchEnableBitmap: 4; + uint16_t res6: 4; + uint16_t prefetchOnceBitmap: 4; + uint16_t res7: 4; + // 68-71 + uint32_t res8; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint16_t nonTailBlockdim; + uint16_t tailBlockdim; + // 80-83 + uint32_t taskParamPtrBaseL; + // 84-87 + uint16_t taskParamPtrBaseH; + uint16_t taskParamPtrOffset; + // 88-95 + uint32_t res9; + uint32_t res10; + // 96-103 + uint32_t nonTailTaskStartPcL; + uint16_t nonTailTaskStartPcH; + uint16_t res11; + // 104-111 + uint32_t tailTaskStartPcL; + uint16_t tailTaskStartPcH; + uint16_t res12; + // 112-119 + uint32_t res13; + uint32_t res14; + // 120-127 + uint16_t srcSlot[4]; // src_slot0-3(context ID for source data which is out of subgraph) +} rtFftsPlusAicAivCtx_t; + +// mix aic/aiv context +typedef struct tagFftsPlusMixAicAivCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1: 7; + uint8_t aten: 1; + // 4-7 + uint8_t prefetchConfig; + uint8_t res3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t stat: 1; + uint16_t schem: 2; + uint16_t aicIcachePrefetchCnt: 5; + uint16_t aivIcachePrefetchCnt: 5; + uint16_t res5: 2; + uint16_t atm: 1; + uint16_t prefetchEnableBitmap: 4; + uint16_t res6: 4; + uint16_t prefetchOnceBitmap: 4; + uint16_t res7: 4; + // 68-71 + uint16_t res8; + uint8_t nonTailBlockRatioN; + uint8_t tailBlockRatioN; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint16_t nonTailBlockdim; + uint16_t tailBlockdim; + // 80-87 + uint32_t aicTaskParamPtrL; + uint16_t aicTaskParamPtrH; + uint16_t aicTaskParamPtrOffset; + // 88-95 + uint32_t aivTaskParamPtrL; + uint16_t aivTaskParamPtrH; + uint16_t aivTaskParamPtrOffset; + // 96-103 + uint32_t nonTailAicTaskStartPcL; + uint16_t nonTailAicTaskStartPcH; + uint16_t tailAicTaskStartPcH; + // 104-111 + uint32_t tailAicTaskStartPcL; + uint32_t nonTailAivTaskStartPcL; + // 112-119 + uint16_t nonTailAivTaskStartPcH; + uint16_t tailAivTaskStartPcH; + uint32_t tailAivTaskStartPcL; + // 120-127 + uint16_t srcSlot[4]; // src_slot0-3(context ID for source data which is out of subgraph) +} rtFftsPlusMixAicAivCtx_t; + +// adma context +typedef struct tagFftsPlusSdmaCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1: 7; + uint8_t aten: 1; + // 4-7 + uint8_t res2; + uint8_t res3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint8_t sat: 1; + uint8_t res5: 7; + uint8_t res6: 7; + uint8_t atm: 1; + uint16_t res7; + // 68-71 + uint32_t res8; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint32_t sdmaSqeHeader; // (FORMAT/MPAMNS/PARTID/DRO/SRO/QOS/DNS/SNS/DSSV/SSSV/IE/UPCODE) + // 80-83 + uint16_t sourceStreamId; + uint16_t sourceSubstreamId; + // 84-87 + uint16_t destinationStreamId; + uint16_t destinationSubstreamId; + // 88-127 + uint32_t sourceAddressBaseL; + uint32_t sourceAddressBaseH; + uint32_t sourceAddressOffset; + uint32_t destinationAddressBaseL; + uint32_t destinationAddressBaseH; + uint32_t destinationAddressOffset; + uint32_t nonTailDataLength; + uint32_t tailDataLength; + uint32_t res9[2]; +} rtFftsPlusSdmaCtx_t; + +// ffts plus notify record/wait context +typedef struct tagFftsPlusNotifyCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1: 7; + uint8_t aten: 1; + // 4-7 + uint8_t res2; + uint8_t res3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t res5: 15; + uint16_t atm: 1; + uint16_t res6; + // 68-71 + uint32_t res7; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint16_t notifyIdBase; + uint16_t res8; + // 80-127 + uint32_t res9[12]; +} rtFftsPlusNotifyCtx_t; + +// write Value context +typedef struct tagFftsPlusWriteValueCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1: 7; + uint8_t aten: 1; + // 4-7 + uint8_t res2; + uint8_t res3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t res5: 15; + uint16_t atm: 1; + uint16_t res6; + // 68-71 + uint32_t res7; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint8_t awSize: 3; + uint8_t snoop: 1; + uint8_t res8: 4; + uint8_t awCache: 4; + uint8_t awProt: 3; + uint8_t va: 1; + uint16_t res9; + // 80-83 + uint32_t writeAddressBaseL; + // 84-87 + uint32_t writeAddressBaseH: 17; + uint32_t res10: 15; + // 88-91 + uint32_t writeAddressOffset; + // 92-95 + uint32_t res11; + // 96-111 + uint32_t writeValue[4]; // write_value_00 -> write_value_03 + // 112-127 + uint32_t res12[4]; +} rtFftsPlusWriteValueCtx_t; + +// ai cpu context +typedef struct tagFftsPlusAiCpuCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1: 7; + uint8_t aten: 1; + // 4-7 + uint8_t res2; + uint8_t res3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res4; + // 12-63 + uint16_t successorContextID[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t sat: 1; + uint16_t res5: 14; + uint16_t atm: 1; + uint16_t res6; + // 68-71 + uint16_t sqeIndex; + uint8_t kernelType: 7; + uint8_t bm: 1; + uint8_t topicType: 4; + uint8_t qos: 3; + uint8_t res7: 1; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint16_t nonTailBlockdim; + uint16_t tailBlockdim; + // 80-115 + uint32_t usrData[9]; // usr_data0 -> usr_data8 usr_data2(task_param_base_l) usr_data3(task_param_base_h) + // 116--119 + uint32_t res8; + // 120-123 + uint32_t subtopicId: 12; + uint32_t topicId: 6; + uint32_t groupId: 6; + uint32_t usrDataLength: 8; + // 124-127 + uint32_t taskParamOffset; +} rtFftsPlusAiCpuCtx_t; + +// data context +typedef struct tagFftsPlusDataCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1: 7; + uint8_t aten: 1; + // 4-7 + uint8_t res2; + uint8_t res3; + uint8_t cntInit; // cons_cnt_init / prod_cnt_init + uint8_t cnt; // cons_cnt / prod_cnt + // 8-11 + uint32_t res4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t res5: 15; + uint16_t atm: 1; + uint16_t res6; + // 68-81 + uint16_t origConsumerCounter; + uint16_t runConsumerCounter; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint32_t res7; + // 80-83 + uint32_t addressBaseL; + // 84-87 + uint32_t addressBaseH; + // 88-91 + uint32_t addressOffset; + // 92-95 + uint32_t res8; + // 96-99 + uint16_t nonTailNumOutter; + uint16_t nonTailNumInner; + // 100-103 + uint32_t nonTailLengthInner; + // 104-107 + uint32_t nonTailStrideOutter; + // 108-111 + uint32_t nonTailStrideInner; + // 112-115 + uint16_t tailNumOutter; + uint16_t tailNumInner; + // 116-119 + uint32_t tailLengthInner; + // 120-123 + uint32_t tailStrideOutter; + // 124-127 + uint32_t tailStrideInner; +} rtFftsPlusDataCtx_t; + +// at start context +typedef struct tagFftsPlusAtStartCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1: 7; + uint8_t aten: 1; + // 4-7 + uint8_t res2; + uint8_t res3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t res5; + uint16_t res6; + // 68-71 + uint16_t res7; + uint16_t res8; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint16_t threadIdInit; + uint16_t threadWindowSize; + // 80-127 + uint16_t res9[12]; +} rtFftsPlusAtStartCtx_t; + +// at end context +#define RT_CTX_SUCC_AT_START_SLOT_NUM 12 +#define RT_CTX_SUCC_OUT_LABEL_SLOT_NUM 12 + +typedef struct tagFftsPlusAtEndCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t atStartSlotNumber; + uint8_t outLabelSlotNumber: 7; + uint8_t aten: 1; + // 4-7 + uint8_t res1; + uint8_t res2; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res3; + // 12-59 + uint16_t succAtStartSlot[RT_CTX_SUCC_AT_START_SLOT_NUM]; + uint16_t succOutLabelSlot[RT_CTX_SUCC_OUT_LABEL_SLOT_NUM]; + // 60-63 + uint16_t res4; + uint16_t res5; + // 64-67 + uint16_t res6; + uint16_t res7; + // 68-71 + uint16_t res8; + uint16_t res9; + // 72-75 + uint16_t threadId; + uint16_t res10; + // 76-79 + uint16_t res11; + uint16_t res12; + // 80-127 + uint32_t res13[12]; +} rtFftsPlusAtEndCtx_t; + +// label context +typedef struct tagFftsPlusLabelCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1; + // 4-7 + uint8_t res2; + uint8_t res3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-79 + uint16_t res5[8]; + // 80-127 + uint32_t res6[12]; +} rtFftsPlusLabelCtx_t; + +// case switch context +typedef struct tagFftsPlusCaseSwitchCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1: 7; + uint8_t aten: 1; + // 4-7 + uint8_t startLabelId; + uint8_t labelListLen; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res2; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t res3: 15; + uint16_t atm: 1; + uint16_t res4; + // 68-71 + uint32_t res5; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint8_t arSize: 3; + uint8_t snoop: 1; + uint8_t res6: 4; + uint8_t arCache: 4; + uint8_t arProt: 3; + uint8_t va: 1; + uint16_t res7; + // 80-83 + uint32_t loadAddress0BaseL; + // 84-87 + uint32_t loadAddress0BaseH: 17; + uint32_t res8: 14; + uint32_t ld0En: 1; + // 88-91 + uint32_t loadAddress0Offset; + // 92-95 + uint32_t res9; + // 96-99 + uint32_t loadAddress1BaseL; + // 100-103 + uint32_t loadAddress1BaseH: 17; + uint32_t res10: 14; + uint32_t ld1En: 1; + // 104-107 + uint32_t loadAddress1Offset; + // 108-127 + uint32_t res11[5]; +} rtFftsPlusCaseSwitchCtx_t; + +// case default context +typedef struct tagFftsPlusCaseDefCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1: 7; + uint8_t aten: 1; + // 4-7 + uint8_t startLabelId; + uint8_t labelListLen; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res2; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t res3; + uint16_t res4; + // 68-127 + uint32_t res5[15]; +} rtFftsPlusCaseDefCtx_t; + +// condition switch context +#define RT_CTX_TRUE_SUCCESSOR_NUM 12 +#define RT_CTX_FALSE_SUCCESSOR_NUM 14 + +typedef struct tagFftsPlusCondSwitchCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t trueSuccessorNum; + uint8_t falseSuccessorNum: 7; + uint8_t aten: 1; + // 4-7 + uint8_t condition; + uint8_t res1; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res2; + // 12-63 + uint16_t trueSuccessorList[RT_CTX_TRUE_SUCCESSOR_NUM]; + uint16_t falseSuccessorList[RT_CTX_FALSE_SUCCESSOR_NUM]; + // 64-67 + uint16_t res3: 15; + uint16_t atm: 1; + uint16_t res4; + // 68-71 + uint32_t res5; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint8_t arSize: 3; + uint8_t snoop: 1; + uint8_t res6: 4; + uint8_t arCache: 4; + uint8_t arProt: 3; + uint8_t va: 1; + uint16_t res7; + // 80-83 + uint32_t loadAddress0BaseL; + // 84-87 + uint32_t loadAddress0BaseH: 17; + uint32_t res8: 14; + uint32_t ld0En: 1; + // 88-91 + uint32_t loadAddress0Offset; + // 92-95 + uint32_t res9; + // 96-99 + uint32_t loadAddress1BaseL; + // 100-103 + uint32_t loadAddress1BaseH: 17; + uint32_t res10: 14; + uint32_t ld1En: 1; + // 104-107 + uint32_t loadAddress1Offset; + // 108-127 + uint32_t res11[3]; + uint32_t cmpValue1; + uint32_t cmpValue2; +} rtFftsPlusCondSwitchCtx_t; + +#pragma pack(pop) + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif +#endif // __CCE_RUNTIME_FFTS_PLUS_DEFINE_H diff --git a/third_party/fwkacllib/inc/runtime/rt_model.h b/third_party/fwkacllib/inc/runtime/rt_model.h index a7618b45..d5aa860e 100644 --- a/third_party/fwkacllib/inc/runtime/rt_model.h +++ b/third_party/fwkacllib/inc/runtime/rt_model.h @@ -53,6 +53,7 @@ typedef enum tagModelTaskType { RT_MODEL_TASK_ALL_KERNEL, RT_MODEL_TASK_PROFILER_TRACE_EX, RT_MODEL_TASK_FFTS_TASK, + RT_MODEL_TASK_FFTS_PLUS_TASK, } rtModelTaskType_t; typedef enum tagModelStreamType { diff --git a/third_party/fwkacllib/inc/runtime/rt_stars_define.h b/third_party/fwkacllib/inc/runtime/rt_stars_define.h new file mode 100644 index 00000000..a5de3aac --- /dev/null +++ b/third_party/fwkacllib/inc/runtime/rt_stars_define.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. + * Description: the definition of stars + */ + +#ifndef __CCE_RUNTIME_STARS_DEFINE__H +#define __CCE_RUNTIME_STARS_DEFINE__H + +#include "base.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +#pragma pack(push) +#pragma pack (1) + +typedef struct tagStarsSqeHeader { + uint8_t type: 6; + uint8_t l1Lock: 1; + uint8_t l1Unlock: 1; + + uint8_t ie: 2; + uint8_t preP: 2; + uint8_t postP: 2; + uint8_t wrCqe: 1; + uint8_t reserved: 1; + + uint16_t blockDim; + + uint16_t rtStreamId; + uint16_t taskId; +} rtStarsSqeHeader_t; + +// ffts+ type +typedef enum tagFftsPlusType { + RT_FFTS_PLUS_TYPE_RES1 = 2, // Reserved + RT_FFTS_PLUS_TYPE_RES2 = 3, // Reserved + RT_FFTS_PLUS_TYPE = 4, // FFTS+ mode +} rtFftsPlusType_t; + +// ffts+ sqe +typedef struct tagFftsPlusSqe { + // 0-7 bytes + rtStarsSqeHeader_t sqeHeader; + // 8-11 bytes + uint16_t fftsType: 3; + uint16_t reserved1: 13; + uint16_t reserved2; + // 12-15 bytes + uint16_t pmg: 2; + uint16_t ns: 1; + uint16_t partId: 8; + uint16_t reserved3: 1; + uint16_t qos: 4; + uint8_t kernelCredit; + uint8_t reserved4; + // 16-23 bytes + uint32_t stackPhyBaseL; + uint32_t stackPhyBaseH; + // 24-31 bytes + uint16_t totalContextNum; + uint16_t readyContextNum; + uint16_t preloadContextNum; + uint16_t reserved5; + // 32-35 bytes + uint16_t reserved6: 8; + uint16_t reserved7: 4; + uint16_t dsplitUnit: 3; + uint16_t reserved8: 1; + uint16_t prefetchOstNum: 5; + uint16_t reserved9: 3; + uint16_t cmaintOstNum: 5; + uint16_t reserved10: 3; + // 36-39 bytes + uint16_t aicPrefetchLower: 5; + uint16_t reserved11: 3; + uint16_t aicPrefetchUpper: 5; + uint16_t Reserved12: 3; + uint16_t aivPrefetchLower: 5; + uint16_t Reserved13: 3; + uint16_t aivPrefetchUpper: 5; + uint16_t Reserved14: 3; + // 40-47 bytes + uint32_t contextAddressBaseL; + uint32_t contextAddressBaseH:17; + uint32_t reserved15:15; + // 48-63 bytes + uint32_t reserved16[4]; +} rtFftsPlusSqe_t; + +#pragma pack(pop) + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif +#endif // __CCE_RUNTIME_STARS_DEFINE__H