@@ -174,6 +174,7 @@ set(TRAIN_SRC_LIST | |||||
"graph/load/model_manager/task_info/model_exit_task_info.cc" | "graph/load/model_manager/task_info/model_exit_task_info.cc" | ||||
"graph/load/model_manager/task_info/event_record_task_info.cc" | "graph/load/model_manager/task_info/event_record_task_info.cc" | ||||
"graph/load/model_manager/task_info/event_wait_task_info.cc" | "graph/load/model_manager/task_info/event_wait_task_info.cc" | ||||
"graph/load/model_manager/task_info/ffts_task_info.cc" | |||||
"graph/load/model_manager/task_info/fusion_start_task_info.cc" | "graph/load/model_manager/task_info/fusion_start_task_info.cc" | ||||
"graph/load/model_manager/task_info/fusion_stop_task_info.cc" | "graph/load/model_manager/task_info/fusion_stop_task_info.cc" | ||||
"graph/load/model_manager/task_info/hccl_task_info.cc" | "graph/load/model_manager/task_info/hccl_task_info.cc" | ||||
@@ -662,6 +663,7 @@ set(INFER_SRC_LIST | |||||
"graph/load/model_manager/task_info/task_info.cc" | "graph/load/model_manager/task_info/task_info.cc" | ||||
"graph/load/model_manager/task_info/event_record_task_info.cc" | "graph/load/model_manager/task_info/event_record_task_info.cc" | ||||
"graph/load/model_manager/task_info/event_wait_task_info.cc" | "graph/load/model_manager/task_info/event_wait_task_info.cc" | ||||
"graph/load/model_manager/task_info/ffts_task_info.cc" | |||||
"graph/load/model_manager/task_info/fusion_start_task_info.cc" | "graph/load/model_manager/task_info/fusion_start_task_info.cc" | ||||
"graph/load/model_manager/task_info/fusion_stop_task_info.cc" | "graph/load/model_manager/task_info/fusion_stop_task_info.cc" | ||||
"graph/load/model_manager/task_info/kernel_ex_task_info.cc" | "graph/load/model_manager/task_info/kernel_ex_task_info.cc" | ||||
@@ -37,6 +37,7 @@ set(SRC_LIST | |||||
"../graph/load/model_manager/task_info/task_info.cc" | "../graph/load/model_manager/task_info/task_info.cc" | ||||
"../graph/load/model_manager/task_info/event_record_task_info.cc" | "../graph/load/model_manager/task_info/event_record_task_info.cc" | ||||
"../graph/load/model_manager/task_info/event_wait_task_info.cc" | "../graph/load/model_manager/task_info/event_wait_task_info.cc" | ||||
"../graph/load/model_manager/task_info/ffts_task_info.cc" | |||||
"../graph/load/model_manager/task_info/fusion_start_task_info.cc" | "../graph/load/model_manager/task_info/fusion_start_task_info.cc" | ||||
"../graph/load/model_manager/task_info/fusion_stop_task_info.cc" | "../graph/load/model_manager/task_info/fusion_stop_task_info.cc" | ||||
"../graph/load/model_manager/task_info/kernel_ex_task_info.cc" | "../graph/load/model_manager/task_info/kernel_ex_task_info.cc" | ||||
@@ -86,6 +86,11 @@ bool LabelAllocator::CollectFunctionalNode(ComputeGraphPtr &graph, std::set<Node | |||||
return false; | return false; | ||||
} | } | ||||
if (func_node->GetOpDesc() != nullptr && func_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { | |||||
GELOGD("Graph[%s] is ffts subgraph, skip label allocator.", graph->GetName().c_str()); | |||||
return true; | |||||
} | |||||
ComputeGraphPtr owner_graph = func_node->GetOwnerComputeGraph(); | ComputeGraphPtr owner_graph = func_node->GetOwnerComputeGraph(); | ||||
if (owner_graph == nullptr) { | if (owner_graph == nullptr) { | ||||
REPORT_INNER_ERROR("E19999", "ComputeGraph owner not set in node:%s(%s), graph:%s", | REPORT_INNER_ERROR("E19999", "ComputeGraph owner not set in node:%s(%s), graph:%s", | ||||
@@ -474,6 +474,11 @@ Status UpdateForSkippedEnginePass::Run(ComputeGraphPtr graph, const vector<Subgr | |||||
for (ge::NodePtr &node : graph->GetDirectNode()) { | for (ge::NodePtr &node : graph->GetDirectNode()) { | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
if (op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID)) { | |||||
op_desc->SetStreamId(kInvalidStream); | |||||
GELOGI("Ffts node %s of type %s reassign to invalid stream.", node->GetName().c_str(), node->GetType().c_str()); | |||||
continue; | |||||
} | |||||
int64_t stream_id = op_desc->GetStreamId(); | int64_t stream_id = op_desc->GetStreamId(); | ||||
if (ops_without_label.find(op_desc) != ops_without_label.end()) { | if (ops_without_label.find(op_desc) != ops_without_label.end()) { | ||||
if (AreAllPredStreamsInvalid(node) && op_desc->GetSubgraphInstanceNames().empty()) { | if (AreAllPredStreamsInvalid(node) && op_desc->GetSubgraphInstanceNames().empty()) { | ||||
@@ -432,7 +432,11 @@ Status StreamAllocator::SetActiveStreamsForSubgraphs() { | |||||
// Insert the send/recv event id to the graph | // Insert the send/recv event id to the graph | ||||
Status StreamAllocator::InsertSyncEvents() { | Status StreamAllocator::InsertSyncEvents() { | ||||
for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { | |||||
auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) { | |||||
return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH); | |||||
}; | |||||
for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag(), nullptr, ffts_filter)) { | |||||
// Take the adjacent points, then judge whether need to insert the event | // Take the adjacent points, then judge whether need to insert the event | ||||
for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) { | for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) { | ||||
for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) { | for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) { | ||||
@@ -531,6 +535,11 @@ Status StreamAllocator::InsertOneEventInTwoNodes(const NodePtr &cur_node, const | |||||
Status StreamAllocator::InsertEventsForSubgraph() { | Status StreamAllocator::InsertEventsForSubgraph() { | ||||
for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) { | for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) { | ||||
GE_CHECK_NOTNULL(subgraph); | GE_CHECK_NOTNULL(subgraph); | ||||
const auto parent_node = subgraph->GetParentNode(); | |||||
if (parent_node != nullptr && parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { | |||||
GELOGD("Skip ffts subgraph, parent node is %s.", parent_node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
for (const auto &node : subgraph->GetDirectNode()) { | for (const auto &node : subgraph->GetDirectNode()) { | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
@@ -354,7 +354,10 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra | |||||
}; | }; | ||||
GE_MAKE_GUARD(release, callback); | GE_MAKE_GUARD(release, callback); | ||||
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { | |||||
auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) { | |||||
return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH); | |||||
}; | |||||
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag(), nullptr, ffts_filter)) { | |||||
OpDescPtr op_desc = node->GetOpDesc(); | OpDescPtr op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
node_index++; | node_index++; | ||||
@@ -380,10 +383,8 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra | |||||
GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str()); | GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str()); | ||||
continue; | continue; | ||||
} | } | ||||
if (op_kernel_lib_name.empty()) { | |||||
GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); | |||||
continue; | |||||
} | |||||
GE_CHK_BOOL_EXEC_INFO(!op_kernel_lib_name.empty(), continue, | |||||
"Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); | |||||
auto kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); | auto kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); | ||||
if (kernel_info_store == nullptr) { | if (kernel_info_store == nullptr) { | ||||
REPORT_INNER_ERROR("E19999", "Get ops kernel info store failed for op:%s(%s), op_kernel_name:%s", | REPORT_INNER_ERROR("E19999", "Get ops kernel info store failed for op:%s(%s), op_kernel_name:%s", | ||||
@@ -394,6 +395,10 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra | |||||
} | } | ||||
GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", name.c_str(), | GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", name.c_str(), | ||||
type.c_str()); | type.c_str()); | ||||
if (node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { | |||||
GE_CHK_STATUS_RET(UpdateAnchorStatusForFfts(node), "[Call][UpdateAnchorStatusForFfts] node:%s(%s) failed", | |||||
name.c_str(), type.c_str()); | |||||
} | |||||
// Profiling task | // Profiling task | ||||
size_t task_list_size_before = task_def_list.size(); | size_t task_list_size_before = task_def_list.size(); | ||||
GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); | GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); | ||||
@@ -571,7 +576,24 @@ Status TaskGenerator::GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info | |||||
return ret; | return ret; | ||||
} | } | ||||
Status TaskGenerator::UpdateAnchorStatusForFfts(const NodePtr &node) { | |||||
GELOGD("Start UpdateAnchorStatusForFfts for %s.", node->GetName().c_str()); | |||||
if (!node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { | |||||
for (size_t i = 0; i < node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { | |||||
auto sub_graph = NodeUtils::GetSubgraph(*node, i); | |||||
GE_CHECK_NOTNULL(sub_graph); | |||||
GELOGD("Start update anchor status for %s.", sub_graph->GetName().c_str()); | |||||
for (auto &ffts_node : sub_graph->GetDirectNode()) { | |||||
GE_CHK_STATUS_RET(UpdateAnchorStatus(ffts_node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", | |||||
ffts_node->GetName().c_str(), ffts_node->GetType().c_str()); | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) { | Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) { | ||||
GELOGD("Start UpdateAnchorStatus for %s.", node->GetName().c_str()); | |||||
if (NodeUtils::SetAllAnchorStatus(node) != GRAPH_SUCCESS) { | if (NodeUtils::SetAllAnchorStatus(node) != GRAPH_SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "SetAllAnchorStatus fail for op:%s(%s)", | REPORT_CALL_ERROR("E19999", "SetAllAnchorStatus fail for op:%s(%s)", | ||||
node->GetName().c_str(), node->GetType().c_str()); | node->GetName().c_str(), node->GetType().c_str()); | ||||
@@ -80,6 +80,7 @@ class TaskGenerator { | |||||
Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | ||||
std::vector<uint32_t> &all_reduce_nodes); | std::vector<uint32_t> &all_reduce_nodes); | ||||
private: | private: | ||||
Status UpdateAnchorStatusForFfts(const NodePtr &node); | |||||
Status UpdateAnchorStatus(const NodePtr &node); | Status UpdateAnchorStatus(const NodePtr &node); | ||||
Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id); | Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id); | ||||
@@ -99,6 +99,9 @@ const uint32_t kEndOfSequenceNew = 507005; | |||||
const int32_t kModelAbortNormal = 0x0704000e; | const int32_t kModelAbortNormal = 0x0704000e; | ||||
const int32_t kModelAbortNormalNew = 507024; | const int32_t kModelAbortNormalNew = 507024; | ||||
const uint32_t kInteval = 2; | const uint32_t kInteval = 2; | ||||
const uint32_t kFftsTbeHandleElementSize = 2; | |||||
const uint32_t kNonTailBlock = 0; | |||||
const uint32_t kTailBlock = 1; | |||||
const char *const kModelName = "model_name"; | const char *const kModelName = "model_name"; | ||||
const char *const kModeleId = "model_id"; | const char *const kModeleId = "model_id"; | ||||
const char *const kLoadStartTime = "load_start_time"; | const char *const kLoadStartTime = "load_start_time"; | ||||
@@ -116,14 +119,15 @@ const char *const kWorkSpaceSize = "workspace_size"; | |||||
const char *const kTotalSize = "total_size"; | const char *const kTotalSize = "total_size"; | ||||
const char *const kTaskCount = "task_count"; | const char *const kTaskCount = "task_count"; | ||||
const char *const kTaskId = "task_id"; | const char *const kTaskId = "task_id"; | ||||
const char* const kRequestId = "request_id"; | |||||
const char* const kThreadId = "thread_id"; | |||||
const char* const kInputBeginTime = "input_begin_time"; | |||||
const char* const kInputEndTime = "input_end_time"; | |||||
const char* const kInferBeginTime = "infer_begin_time"; | |||||
const char* const kInferEndTime = "infer_end_time"; | |||||
const char* const kOutputBeginTime = "output_start_time"; | |||||
const char* const kOutputEndTime = "output_end_time"; | |||||
const char *const kRequestId = "request_id"; | |||||
const char *const kThreadId = "thread_id"; | |||||
const char *const kInputBeginTime = "input_begin_time"; | |||||
const char *const kInputEndTime = "input_end_time"; | |||||
const char *const kInferBeginTime = "infer_begin_time"; | |||||
const char *const kInferEndTime = "infer_end_time"; | |||||
const char *const kOutputBeginTime = "output_start_time"; | |||||
const char *const kOutputEndTime = "output_end_time"; | |||||
const char *const kStubFuncName = "_register_stub_func"; | |||||
const uint32_t kStringHeadElems = 2; | const uint32_t kStringHeadElems = 2; | ||||
const uint32_t kPlacementHostData = 0; | const uint32_t kPlacementHostData = 0; | ||||
const size_t kAlignment = 64; | const size_t kAlignment = 64; | ||||
@@ -902,10 +906,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { | |||||
SetLabelForDynamic(node); | SetLabelForDynamic(node); | ||||
auto it = op_desc_handle.find(op_desc->GetType()); | auto it = op_desc_handle.find(op_desc->GetType()); | ||||
if (it != op_desc_handle.end()) { | if (it != op_desc_handle.end()) { | ||||
if ((this->*it->second)(op_desc) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "[Init][Node] failed, Name:%s", op_desc->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((this->*it->second)(op_desc) != SUCCESS, return PARAM_INVALID, | |||||
"[Init][Node] failed, Name:%s", op_desc->GetName().c_str()); | |||||
continue; | continue; | ||||
} | } | ||||
@@ -935,7 +937,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { | |||||
GE_TIMESTAMP_RESTART(InitTbeHandle); | GE_TIMESTAMP_RESTART(InitTbeHandle); | ||||
if (IsTbeTask(op_desc)) { | if (IsTbeTask(op_desc)) { | ||||
Status status = InitTbeHandle(op_desc); | |||||
Status status = | |||||
op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID) ? InitTbeHandleWithFfts(op_desc) : InitTbeHandle(op_desc); | |||||
if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
GELOGE(status, "[Init][TbeHandle] failed. op:%s", op_desc->GetName().c_str()); | GELOGE(status, "[Init][TbeHandle] failed. op:%s", op_desc->GetName().c_str()); | ||||
return status; | return status; | ||||
@@ -3700,6 +3703,7 @@ Status DavinciModel::InitConstant(const OpDescPtr &op_desc) { | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { | Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { | ||||
string bin_file = op_desc->GetName(); | |||||
auto kernel = ge_model_->GetTBEKernelStore().FindKernel(op_desc->GetName()); | auto kernel = ge_model_->GetTBEKernelStore().FindKernel(op_desc->GetName()); | ||||
auto tbe_kernel = (kernel != nullptr) ? kernel : op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); | auto tbe_kernel = (kernel != nullptr) ? kernel : op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); | ||||
if (tbe_kernel == nullptr) { | if (tbe_kernel == nullptr) { | ||||
@@ -3708,12 +3712,61 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file!", op_desc->GetName().c_str()); | GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file!", op_desc->GetName().c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file, tbe_kernel, false), "Function register of bin file: %s failed", | |||||
bin_file.c_str()); | |||||
return SUCCESS; | |||||
} | |||||
std::string session_graph_model_id; | |||||
GetUniqueId(op_desc, session_graph_model_id); | |||||
const char *bin_file_key = GetRegisterStub(op_desc->GetName(), session_graph_model_id); // from set, always valid. | |||||
TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); | |||||
Status DavinciModel::InitTbeHandleWithFfts(const OpDescPtr &op_desc) { | |||||
std::vector<OpKernelBinPtr> tbe_kernel; | |||||
tbe_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_NAME_THREAD_TBE_KERNEL, tbe_kernel); | |||||
GELOGD("Kernel bin ptr vec size is %zu.", tbe_kernel.size()); | |||||
if (tbe_kernel.size() != kFftsTbeHandleElementSize) { | |||||
REPORT_INNER_ERROR("E19999", "Get tbe_kernel for op:%s(%s) fail, model_id:%u", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file, size is %zu when ffts", | |||||
op_desc->GetName().c_str(), tbe_kernel.size()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
if (tbe_kernel[0] == nullptr || tbe_kernel[1] == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "Tbe kernel for op:%s is nullptr.", op_desc->GetName().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: tvm bin file of %s is nullptr when ffts.", op_desc->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
vector<string> bin_file_keys; | |||||
(void)AttrUtils::GetListStr(op_desc, kStubFuncName, bin_file_keys); | |||||
if (bin_file_keys.size() != kFftsTbeHandleElementSize) { | |||||
REPORT_INNER_ERROR("E19999", "Get bin_file for op:%s(%s) fail.", op_desc->GetName().c_str(), | |||||
op_desc->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find bin file keys, size is %zu when ffts", | |||||
op_desc->GetName().c_str(), bin_file_keys.size()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kNonTailBlock], tbe_kernel[kNonTailBlock], true, | |||||
kNonTailBlock), | |||||
"Function register of first bin file %s failed.", bin_file_keys[kNonTailBlock].c_str()); | |||||
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kTailBlock], tbe_kernel[kTailBlock], true, kTailBlock), | |||||
"Function register of second bin file %s failed.", bin_file_keys[kTailBlock].c_str()); | |||||
return SUCCESS; | |||||
} | |||||
Status DavinciModel::FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel, | |||||
bool is_ffts, size_t thread_index) { | |||||
if (thread_index > 1) { | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Thread index: %zu should less than 1.", thread_index); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
const char *bin_file_key; | |||||
if (is_ffts) { | |||||
bin_file_key = GetRegisterStub(bin_file, ""); | |||||
GELOGI("Node:%s inherit func name:%s directly.", op_desc->GetName().c_str(), bin_file_key); | |||||
} else { | |||||
std::string session_graph_model_id; | |||||
GetUniqueId(op_desc, session_graph_model_id); | |||||
bin_file_key = GetRegisterStub(bin_file, session_graph_model_id); // from set, always valid. | |||||
} | |||||
TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); | |||||
std::lock_guard<std::mutex> lock(tvm_bin_mutex_); | std::lock_guard<std::mutex> lock(tvm_bin_mutex_); | ||||
if (rtQueryFunctionRegistered(bin_file_key) != RT_ERROR_NONE) { | if (rtQueryFunctionRegistered(bin_file_key) != RT_ERROR_NONE) { | ||||
void *bin_handle = nullptr; | void *bin_handle = nullptr; | ||||
@@ -3721,59 +3774,115 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { | |||||
GELOGD("TBE: can't find the kernel_name[%s] in HandleMap", bin_file_key); | GELOGD("TBE: can't find the kernel_name[%s] in HandleMap", bin_file_key); | ||||
rtDevBinary_t binary; | rtDevBinary_t binary; | ||||
std::string json_string; | |||||
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, TVM_ATTR_NAME_MAGIC, json_string), | |||||
GELOGD("Get original type of session_graph_id.")); | |||||
if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICPU") { | |||||
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AICPU; | |||||
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF") { | |||||
binary.magic = RT_DEV_BINARY_MAGIC_ELF; | |||||
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AIVEC") { | |||||
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC; | |||||
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICUBE") { | |||||
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AICUBE; | |||||
} else { | |||||
REPORT_INNER_ERROR("E19999", "Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", | |||||
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(), | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); | |||||
GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", | |||||
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(), | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); | |||||
return PARAM_INVALID; | |||||
} | |||||
GE_CHK_STATUS_RET(InitBinaryMagic(op_desc, is_ffts, thread_index, binary), "Init binary magic of %s failed.", | |||||
op_desc->GetName().c_str()); | |||||
binary.version = 0; | binary.version = 0; | ||||
binary.data = tbe_kernel->GetBinData(); | binary.data = tbe_kernel->GetBinData(); | ||||
binary.length = tbe_kernel->GetBinDataSize(); | binary.length = tbe_kernel->GetBinDataSize(); | ||||
GELOGD("TBE: binary.length: %lu", binary.length); | GELOGD("TBE: binary.length: %lu", binary.length); | ||||
GE_CHK_RT_RET(rtDevBinaryRegister(&binary, &bin_handle)); | GE_CHK_RT_RET(rtDevBinaryRegister(&binary, &bin_handle)); | ||||
std::string meta_data; | |||||
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, TVM_ATTR_NAME_METADATA, meta_data), | |||||
GELOGI("Get original type of json_string")); | |||||
GELOGD("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str()); | |||||
GE_IF_BOOL_EXEC(!meta_data.empty(), GE_CHK_RT_RET(rtMetadataRegister(bin_handle, meta_data.c_str()))); | |||||
GE_CHK_STATUS_RET(InitMetaData(op_desc, is_ffts, thread_index, bin_handle), "Init tvm meta data of %s failed.", | |||||
op_desc->GetName().c_str()); | |||||
kernel_store.StoreTBEHandle(bin_file_key, bin_handle, tbe_kernel); | kernel_store.StoreTBEHandle(bin_file_key, bin_handle, tbe_kernel); | ||||
} else { | } else { | ||||
GELOGI("TBE: find the kernel_name[%s] in HandleMap", bin_file_key); | GELOGI("TBE: find the kernel_name[%s] in HandleMap", bin_file_key); | ||||
kernel_store.ReferTBEHandle(bin_file_key); | kernel_store.ReferTBEHandle(bin_file_key); | ||||
} | } | ||||
std::string kernel_name; | std::string kernel_name; | ||||
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name), | |||||
GELOGD("Get original type of kernel_name")); | |||||
GE_CHK_STATUS_RET(InitKernelName(op_desc, is_ffts, thread_index, kernel_name), "Init kernel name of %s failed.", | |||||
op_desc->GetName().c_str()); | |||||
GE_CHK_RT_RET(rtFunctionRegister(bin_handle, bin_file_key, bin_file_key, kernel_name.c_str(), 0)); | GE_CHK_RT_RET(rtFunctionRegister(bin_handle, bin_file_key, bin_file_key, kernel_name.c_str(), 0)); | ||||
used_tbe_handle_map_[bin_file_key] = 1; // Init used num to 1. | used_tbe_handle_map_[bin_file_key] = 1; // Init used num to 1. | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
// Kernel registed, Increase used num in store. | // Kernel registed, Increase used num in store. | ||||
StoreTbeHandle(bin_file_key); | StoreTbeHandle(bin_file_key); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status DavinciModel::InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, | |||||
rtDevBinary_t &binary) { | |||||
string json_string; | |||||
const string &tvm_magic = is_ffts ? TVM_ATTR_NAME_THREAD_MAGIC : TVM_ATTR_NAME_MAGIC; | |||||
const static std::map<std::string, uint32_t> binary_magics = { | |||||
{"RT_DEV_BINARY_MAGIC_ELF_AICPU", RT_DEV_BINARY_MAGIC_ELF_AICPU}, | |||||
{"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF}, | |||||
{"RT_DEV_BINARY_MAGIC_ELF_AIVEC", RT_DEV_BINARY_MAGIC_ELF_AIVEC}, | |||||
{"RT_DEV_BINARY_MAGIC_ELF_AICUBE", RT_DEV_BINARY_MAGIC_ELF_AICUBE} | |||||
}; | |||||
if (is_ffts) { | |||||
vector<string> json_list; | |||||
(void)AttrUtils::GetListStr(op_desc, tvm_magic, json_list); | |||||
if (json_list.size() != kFftsTbeHandleElementSize) { | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Attr is %s, thread index is %zu, json list size is %zu.", | |||||
tvm_magic.c_str(), thread_index, json_list.size()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
json_string = json_list[thread_index]; | |||||
} else { | |||||
(void)AttrUtils::GetStr(op_desc, tvm_magic, json_string); | |||||
} | |||||
auto iter = binary_magics.find(json_string); | |||||
if (iter == binary_magics.end()) { | |||||
REPORT_INNER_ERROR("E19999", "Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", | |||||
tvm_magic.c_str(), json_string.c_str(), op_desc->GetName().c_str(), | |||||
op_desc->GetType().c_str(), model_id_); | |||||
GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid", | |||||
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(), | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_); | |||||
return PARAM_INVALID; | |||||
} | |||||
binary.magic = iter->second; | |||||
return SUCCESS; | |||||
} | |||||
Status DavinciModel::InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle) { | |||||
string meta_data; | |||||
const string &tvm_metadata = is_ffts ? TVM_ATTR_NAME_THREAD_METADATA : TVM_ATTR_NAME_METADATA; | |||||
if (is_ffts) { | |||||
vector<string> meta_data_list; | |||||
(void)AttrUtils::GetListStr(op_desc, tvm_metadata, meta_data_list); | |||||
if (meta_data_list.size() != kFftsTbeHandleElementSize) { | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, meta data list size is %zu.", | |||||
tvm_metadata.c_str(), thread_index, meta_data_list.size()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
meta_data = meta_data_list[thread_index]; | |||||
} else { | |||||
(void)AttrUtils::GetStr(op_desc, tvm_metadata, meta_data); | |||||
} | |||||
GELOGD("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str()); | |||||
if (!meta_data.empty()) { | |||||
GE_CHK_RT_RET(rtMetadataRegister(bin_handle, meta_data.c_str())); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status DavinciModel::InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name) { | |||||
if (is_ffts) { | |||||
// delete prefix, eg: *sgt_graph_nodes*/loss_scale/gradient/fp32_vals/Mean_grad/Tile | |||||
vector<string> kernel_name_list; | |||||
auto pos = op_desc->GetName().find("/"); | |||||
if (pos == std::string::npos) { | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, subgraph node name: %s.", op_desc->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
string attr_kernel_name = op_desc->GetName().substr(pos + 1) + "_thread_kernelname"; | |||||
(void)AttrUtils::GetListStr(op_desc, attr_kernel_name, kernel_name_list); | |||||
if (kernel_name_list.size() != kFftsTbeHandleElementSize) { | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, kernel name list size is %zu.", | |||||
attr_kernel_name.c_str(), thread_index, kernel_name_list.size()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
kernel_name = kernel_name_list[thread_index]; | |||||
} else { | |||||
string attr_kernel_name = op_desc->GetName() + "_kernelname"; | |||||
(void)AttrUtils::GetStr(op_desc, attr_kernel_name, kernel_name); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
void DavinciModel::StoreTbeHandle(const std::string &handle_key) { | void DavinciModel::StoreTbeHandle(const std::string &handle_key) { | ||||
// Online mode FE may call rtFunctionRegister. | // Online mode FE may call rtFunctionRegister. | ||||
TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); | TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); | ||||
@@ -771,6 +771,12 @@ class DavinciModel { | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status InitTbeHandle(const OpDescPtr &op_desc); | Status InitTbeHandle(const OpDescPtr &op_desc); | ||||
Status InitTbeHandleWithFfts(const OpDescPtr &op_desc); | |||||
Status FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel, bool is_ffts, | |||||
size_t thread_index = 0); | |||||
Status InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, rtDevBinary_t &binary); | |||||
Status InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle); | |||||
Status InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name); | |||||
void StoreTbeHandle(const string &handle_key); | void StoreTbeHandle(const string &handle_key); | ||||
void CleanTbeHandle(); | void CleanTbeHandle(); | ||||
@@ -0,0 +1,393 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/load/model_manager/task_info/ffts_task_info.h" | |||||
#include <vector> | |||||
#include "graph/load/model_manager/davinci_model.h" | |||||
namespace { | |||||
constexpr uint32_t kAddrLen = sizeof(void *); | |||||
} | |||||
namespace ge { | |||||
FftsTaskInfo::~FftsTaskInfo() { | |||||
GE_FREE_RT_LOG(args_); | |||||
} | |||||
Status FftsTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||||
GELOGI("FftsTaskInfo Init Start."); | |||||
GE_CHECK_NOTNULL(davinci_model); | |||||
davinci_model_ = davinci_model; | |||||
GE_CHK_STATUS_RET_NOLOG(SetStream(task_def.stream_id(), davinci_model_->GetStreamList())); | |||||
const domi::FftsTaskDef &ffts_task_def = task_def.ffts_task(); | |||||
OpDescPtr op_desc = davinci_model_->GetOpByIndex(ffts_task_def.op_index()); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if ((ffts_task_def.sub_task_size() > static_cast<int>(RT_FFTS_MAX_SUB_TASK_NUM)) || | |||||
(ffts_task_def.ticket_cache_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_NUM))) { | |||||
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Node: %s, sub task desc size: %d, ticket cache size: %d", | |||||
op_desc->GetName().c_str(), ffts_task_def.sub_task_size(), ffts_task_def.ticket_cache_size()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
args_size_ = kAddrLen * ffts_task_def.addr_size(); | |||||
GE_CHK_RT_RET(rtMalloc(&args_, args_size_, RT_MEMORY_HBM)); | |||||
InitFftsDescInfo(ffts_task_def.ffts_desc(), sub_task_info_.fftsDesc); | |||||
sub_task_info_.fftsType = static_cast<rtFftsType_t>(ffts_task_def.ffts_type()); | |||||
sub_task_info_.subTaskNum = ffts_task_def.sub_task_size(); | |||||
for (int idx = 0; idx < ffts_task_def.sub_task_size(); ++idx) { | |||||
GE_CHK_STATUS_RET_NOLOG(InitSubTaskInfo(ffts_task_def.sub_task(idx), sub_task_info_.subTask[idx])); | |||||
} | |||||
sub_task_info_.tickCacheNum = ffts_task_def.ticket_cache_size(); | |||||
for (int idx = 0; idx < ffts_task_def.ticket_cache_size(); ++idx) { | |||||
GE_CHK_STATUS_RET_NOLOG(InitTicketCache(ffts_task_def.ticket_cache(idx), sub_task_info_.ticketCache[idx])); | |||||
} | |||||
size_t data_size = kAddrLen * io_addrs_.size(); | |||||
GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs_.data(), data_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||||
GELOGI("FftsTaskInfo::Init Success. Node: %s, input/output size: %zu", op_desc->GetName().c_str(), io_addrs_.size()); | |||||
return SUCCESS; | |||||
} | |||||
void FftsTaskInfo::InitFftsDescInfo(const domi::FftsDescInfoDef &ffts_desc_def, rtFftsDescInfo_t &ffts_desc) { | |||||
ffts_desc.tm = static_cast<uint8_t>(ffts_desc_def.tm()); | |||||
ffts_desc.di = static_cast<uint8_t>(ffts_desc_def.di()); | |||||
ffts_desc.dw = static_cast<uint8_t>(ffts_desc_def.dw()); | |||||
ffts_desc.df = static_cast<uint8_t>(ffts_desc_def.df()); | |||||
ffts_desc.dataSplitUnit = static_cast<uint8_t>(ffts_desc_def.data_split_unit()); | |||||
ffts_desc.prefetchOstNum = static_cast<uint8_t>(ffts_desc_def.prefetch_ost_num()); | |||||
ffts_desc.cacheMaintainOstNum = static_cast<uint8_t>(ffts_desc_def.cache_maintain_ost_num()); | |||||
ffts_desc.aicPrefetchUpper = static_cast<uint8_t>(ffts_desc_def.aic_prefetch_upper()); | |||||
ffts_desc.aicPrefetchLower = static_cast<uint8_t>(ffts_desc_def.aic_prefetch_lower()); | |||||
ffts_desc.aivPrefetchUpper = static_cast<uint8_t>(ffts_desc_def.aiv_prefetch_upper()); | |||||
ffts_desc.aivPrefetchLower = static_cast<uint8_t>(ffts_desc_def.aiv_prefetch_lower()); | |||||
} | |||||
Status FftsTaskInfo::InitSubTaskInfo(const domi::FftsSubTaskDef &sub_task_def, rtFftsSubTaskInfo_t &sub_task_desc) { | |||||
if ((sub_task_def.dst_tick_cache_id_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) || | |||||
(sub_task_def.src_tick_cache_id_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK))) { | |||||
GELOGE(FAILED, "[Check][Param] Invalid FftsSubTaskInfo, dst tick cache id size: %d, src tick cache id size: %d", | |||||
sub_task_def.dst_tick_cache_id_size(), sub_task_def.src_tick_cache_id_size()); | |||||
return FAILED; | |||||
} | |||||
if (sub_task_def.has_auto_thread_aic_aiv() == sub_task_def.has_manual_thread_aic_aiv()) { | |||||
GELOGE(FAILED, "[Check][Param] Invalid FftsSubTaskInfo, auto thread aic/aiv: %d, manual thread aic/aiv: %d", | |||||
sub_task_def.has_auto_thread_aic_aiv(), sub_task_def.has_manual_thread_aic_aiv()); | |||||
return FAILED; | |||||
} | |||||
thread_dim_ = sub_task_def.thread_dim(); | |||||
GE_CHK_BOOL_RET_STATUS(thread_dim_ != 0, FAILED, "[Get][thread_dim] failed, Invalid thread dim: %u!", thread_dim_); | |||||
sub_task_desc.subTaskType = static_cast<rtFftsSubTaskType_t>(sub_task_def.sub_task_type()); | |||||
sub_task_desc.threadDim = sub_task_def.thread_dim(); | |||||
sub_task_desc.dstTickCacheVldBitmap = sub_task_def.dst_tick_cache_vld_bitmap(); | |||||
sub_task_desc.srcTickCacheVldBitmap = sub_task_def.src_tick_cache_vld_bitmap(); | |||||
sub_task_desc.srcDataOutOfSubGraphBitmap = sub_task_def.src_data_out_of_subgraph_bitmap(); | |||||
for (int idx = 0; idx < sub_task_def.dst_tick_cache_id_size(); ++idx) { | |||||
sub_task_desc.dstTickCacheID[idx] = sub_task_def.dst_tick_cache_id(idx); | |||||
} | |||||
for (int idx = 0; idx < sub_task_def.src_tick_cache_id_size(); ++idx) { | |||||
sub_task_desc.srcTickCacheID[idx] = sub_task_def.src_tick_cache_id(idx); | |||||
} | |||||
if (sub_task_def.has_auto_thread_aic_aiv()) { | |||||
GE_CHK_STATUS_RET_NOLOG(InitAutoAicAiv(sub_task_def.auto_thread_aic_aiv(), sub_task_desc.custom.autoThreadAicAiv)); | |||||
} | |||||
if (sub_task_def.has_manual_thread_aic_aiv()) { | |||||
GE_CHK_STATUS_RET_NOLOG( | |||||
InitManualAicAiv(sub_task_def.manual_thread_aic_aiv(), sub_task_desc.custom.manualThreadAicAiv)); | |||||
} | |||||
if (sub_task_def.has_manual_thread_nop()) { | |||||
GE_CHK_STATUS_RET_NOLOG(InitManualNop(sub_task_def.manual_thread_nop(), sub_task_desc.custom.manualThreadNop)); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status FftsTaskInfo::InitTicketCache(const domi::TicketCacheDef &ticket_cache_def, rtTicketCache_t &ticket_cache) { | |||||
if (ticket_cache_def.has_auto_thread_cache() == ticket_cache_def.has_manual_thread_cache()) { | |||||
GELOGE(FAILED, "[Check][Param] Invalid TicketCacheDef, has auto thread cache: %d, has manual thread cache: %d", | |||||
ticket_cache_def.has_auto_thread_cache(), ticket_cache_def.has_manual_thread_cache()); | |||||
return FAILED; | |||||
} | |||||
ticket_cache.cacheOption = static_cast<rtCacheOp_t>(ticket_cache_def.cache_option()); | |||||
ticket_cache.ticketCacheWindow = ticket_cache_def.ticket_cache_window(); | |||||
if (ticket_cache_def.has_auto_thread_cache()) { | |||||
InitAutoCacheInfo(ticket_cache_def.auto_thread_cache(), ticket_cache.custom.autoThreadCache); | |||||
} | |||||
if (ticket_cache_def.has_manual_thread_cache()) { | |||||
GE_CHK_STATUS_RET_NOLOG( | |||||
InitManualCacheInfo(ticket_cache_def.manual_thread_cache(), ticket_cache.custom.manualThreadCache)); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
// task_addr = {0,200,700,1000,2000, 3500} | |||||
// task_addr_offset = {20,40,2,100,200} | |||||
template <typename T> | |||||
Status FftsTaskInfo::InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, uint32_t thread_dim, | |||||
uint32_t addr_count) { | |||||
for (uint32_t i = 0; i < addr_count; ++i) { | |||||
uintptr_t logic_addr = aic_aiv_def.task_addr(i) + thread_dim * aic_aiv_def.task_addr_offset(i); | |||||
uint8_t *io_addr = nullptr; | |||||
if (ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress]GetRtAddress failed."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
GELOGD("aic_aiv_def task base addr is %ld, offset is %ld, thread is %d, logic addrs is 0x%lx, io addr is %p", | |||||
aic_aiv_def.task_addr(i), aic_aiv_def.task_addr_offset(i), thread_dim, logic_addr, io_addr); | |||||
io_addrs_.emplace_back(io_addr); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status FftsTaskInfo::InitAutoAicAiv(const domi::AutoThreadAicAivDef &aic_aiv_def, rtAutoThreadAicAivInfo_t &aic_aiv) { | |||||
if (aic_aiv_def.src_prefetch_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) { | |||||
GELOGE(FAILED, "[Check][Param] Invalid AutoThreadAicAivInfo, prefetch size: %d", aic_aiv_def.src_prefetch_size()); | |||||
return FAILED; | |||||
} | |||||
aic_aiv.taskParamAddr = reinterpret_cast<uintptr_t>(args_) + kAddrLen * io_addrs_.size(); | |||||
GELOGD("AutoThreadAicAivDef: task param addr is %lu.", aic_aiv.taskParamAddr); | |||||
const auto &rts_param = davinci_model_->GetRuntimeParam(); | |||||
for (uint32_t i = 0; i < thread_dim_ - 1; ++i) { | |||||
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, i, | |||||
static_cast<uint32_t>(aic_aiv_def.task_addr_offset_size()))); | |||||
} | |||||
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, thread_dim_ - 1, aic_aiv_def.input_output_count())); | |||||
int last_thread_workspace_size = aic_aiv_def.task_addr_size() - aic_aiv_def.task_addr_offset_size(); | |||||
for (int k = 0; k < last_thread_workspace_size; ++k) { | |||||
uintptr_t logic_addr = aic_aiv_def.task_addr(aic_aiv_def.task_addr_offset_size() + k); | |||||
uint8_t *io_addr = nullptr; | |||||
GE_CHK_STATUS_RET_NOLOG(ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr)); | |||||
GELOGD("logic addr is 0x%lx, io addr is %p.", logic_addr, io_addr); | |||||
io_addrs_.emplace_back(io_addr); | |||||
} | |||||
aic_aiv.taskParamOffset = aic_aiv_def.task_param_offset(); | |||||
GELOGD("args_: %p, io_addrs size: %zu, task param offset: %u.", args_, io_addrs_.size(), aic_aiv.taskParamOffset); | |||||
aic_aiv.satMode = aic_aiv_def.sat_mode(); | |||||
aic_aiv.scheduleMode = aic_aiv_def.schedule_mode(); | |||||
aic_aiv.iCachePrefetchCnt = aic_aiv_def.cache_prefetch_cnt(); | |||||
aic_aiv.prefetchEnableBitmap = aic_aiv_def.prefetch_enable_bitmap(); | |||||
aic_aiv.prefetchOnceBitmap = aic_aiv_def.prefetch_once_bitmap(); | |||||
aic_aiv.tailBlkDim = aic_aiv_def.tail_blk_dim(); | |||||
aic_aiv.nonTailBlkDim = aic_aiv_def.non_tail_blk_dim(); | |||||
aic_aiv.nonTailTaskFuncStub = davinci_model_->GetRegisterStub(aic_aiv_def.non_tail_task_func_stub(), ""); | |||||
aic_aiv.tailTaskFuncStub = davinci_model_->GetRegisterStub(aic_aiv_def.tail_task_func_stub(), ""); | |||||
GELOGI("Set func name[%s][%s] succ.", aic_aiv.nonTailTaskFuncStub, aic_aiv.tailTaskFuncStub); | |||||
for (int idx = 0; idx < aic_aiv_def.src_prefetch_size(); ++idx) { | |||||
InitAutoPrefetch(aic_aiv_def.src_prefetch(idx), aic_aiv.srcPrefetch[idx]); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
void FftsTaskInfo::InitAutoCacheInfo(const domi::AutoThreadCacheDef &cache_def, rtAutoThreadCacheInfo_t &cache) { | |||||
cache.dataAddr = cache_def.data_addr(); | |||||
cache.dataAddrOffset = cache_def.data_addr_offset(); | |||||
cache.nonTailDataLen = cache_def.non_tail_data_len(); | |||||
cache.tailDataLen = cache_def.tail_data_len(); | |||||
cache.ticketCacheRefCnt = cache_def.ticket_cache_ref_cnt(); | |||||
} | |||||
void FftsTaskInfo::InitAutoPrefetch(const domi::AutoThreadPrefetchDef &prefetch_def, rtAutoThreadPrefetch_t &prefetch) { | |||||
prefetch.dataAddr = prefetch_def.data_addr(); | |||||
prefetch.dataAddrOffset = prefetch_def.data_addr_offset(); | |||||
prefetch.nonTailDataLen = prefetch_def.non_tail_data_len(); | |||||
prefetch.tailDataLen = prefetch_def.tail_data_len(); | |||||
} | |||||
Status FftsTaskInfo::InitManualAicAiv(const domi::ManualThreadAicAivDef &aic_aiv_def, | |||||
rtManualThreadAicAivInfo_t &aic_aiv) { | |||||
if ((aic_aiv_def.thread_prefetch_dmu_idx_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || | |||||
(aic_aiv_def.thread_blk_dim_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || | |||||
(aic_aiv_def.thread_task_func_stub_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || | |||||
(aic_aiv_def.src_dep_tbl_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK))) { | |||||
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadAicAivInfo, thread prefetch dmu desc size: %d, " | |||||
"thread blk dim size: %d, thread task func stub size: %d, src dep tbl size: %d", | |||||
aic_aiv_def.thread_prefetch_dmu_idx_size(), aic_aiv_def.thread_blk_dim_size(), | |||||
aic_aiv_def.thread_task_func_stub_size(), aic_aiv_def.src_dep_tbl_size()); | |||||
return FAILED; | |||||
} | |||||
aic_aiv.taskParamAddr = reinterpret_cast<uintptr_t>(args_) + kAddrLen * io_addrs_.size(); | |||||
GELOGD("ManualThreadAicAivDef: task param addr is %lu.", aic_aiv.taskParamAddr); | |||||
const auto &rts_param = davinci_model_->GetRuntimeParam(); | |||||
for (uint32_t i = 0; i < thread_dim_ - 1; ++i) { | |||||
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, i, | |||||
static_cast<uint32_t>(aic_aiv_def.task_addr_offset_size()))); | |||||
} | |||||
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, thread_dim_ - 1, aic_aiv_def.input_output_count())); | |||||
int last_thread_workspace_size = aic_aiv_def.task_addr_size() - aic_aiv_def.task_addr_offset_size(); | |||||
for (int k = 0; k < last_thread_workspace_size; ++k) { | |||||
uintptr_t logic_addr = aic_aiv_def.task_addr(aic_aiv_def.task_addr_offset_size() + k); | |||||
uint8_t *io_addr = nullptr; | |||||
GE_CHK_STATUS_RET_NOLOG(ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr)); | |||||
io_addrs_.emplace_back(io_addr); | |||||
} | |||||
aic_aiv.taskParamOffset = aic_aiv_def.task_param_offset(); | |||||
aic_aiv.satMode = aic_aiv_def.sat_mode(); | |||||
aic_aiv.scheduleMode = aic_aiv_def.schedule_mode(); | |||||
aic_aiv.iCachePrefetchCnt = aic_aiv_def.cache_prefetch_cnt(); | |||||
aic_aiv.prefetchEnableBitmap = aic_aiv_def.prefetch_enable_bitmap(); // 8 bit bitmap 1 0 1 0 | |||||
aic_aiv.prefetchOnceBitmap = aic_aiv_def.prefetch_once_bitmap(); // 8 bit bitmap 1 0 1 0 | |||||
aic_aiv.prefetchOnceDmuNum = aic_aiv_def.prefetch_once_dmu_num(); | |||||
for (int idx = 0; idx < aic_aiv_def.thread_prefetch_dmu_idx_size(); ++idx) { | |||||
aic_aiv.threadPrefetchDmuIdx[idx] = aic_aiv_def.thread_prefetch_dmu_idx(idx); | |||||
} | |||||
for (int idx = 0; idx < aic_aiv_def.thread_blk_dim_size(); ++idx) { | |||||
aic_aiv.threadBlkDim[idx] = aic_aiv_def.thread_blk_dim(idx); | |||||
} | |||||
for (int idx = 0; idx < aic_aiv_def.thread_task_func_stub_size(); ++idx) { | |||||
aic_aiv.threadTaskFuncStub[idx] = aic_aiv_def.thread_task_func_stub(idx).c_str(); | |||||
} | |||||
InitManualDmuInfo(aic_aiv_def, aic_aiv.prefetchList); | |||||
for (int idx = 0; idx < aic_aiv_def.src_dep_tbl_size(); ++idx) { | |||||
GE_CHK_STATUS_RET_NOLOG(InitManualDependency(aic_aiv_def.src_dep_tbl(idx), aic_aiv.srcDepTbl[idx])); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status FftsTaskInfo::InitManualCacheInfo(const domi::ManualThreadCacheDef &cache_def, | |||||
rtManualThreadCacheInfo_t &cache_info) { | |||||
if ((cache_def.slice_dmu_idx_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) || | |||||
(cache_def.ticket_cache_ref_cnt_tbl_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM))) { | |||||
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadCacheInfo slice dum desc index %d, ticket cache ref cnt %d", | |||||
cache_def.slice_dmu_idx_size(), cache_def.ticket_cache_ref_cnt_tbl_size()); | |||||
return FAILED; | |||||
} | |||||
InitManualDmuInfo(cache_def, cache_info.dmuList); | |||||
for (int idx = 0; idx < cache_def.slice_dmu_idx_size(); ++idx) { | |||||
cache_info.sliceDmuIdx[idx] = cache_def.slice_dmu_idx(idx); | |||||
} | |||||
for (int idx = 0; idx < cache_def.ticket_cache_ref_cnt_tbl_size(); ++idx) { | |||||
cache_info.ticketCacheRefCntTbl[idx] = cache_def.ticket_cache_ref_cnt_tbl(idx); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status FftsTaskInfo::InitManualDependency(const domi::ManualThreadDependencyDef &dependency_def, | |||||
rtManualThreadDependency_t &dependency) { | |||||
if (dependency_def.dependency_size() > static_cast<int>(RT_FFTS_MANUAL_SRC_DEPEND_TBL_LEN)) { | |||||
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadDependency size: %d", dependency_def.dependency_size()); | |||||
return FAILED; | |||||
} | |||||
for (int idx = 0; idx < dependency_def.dependency_size(); ++idx) { | |||||
dependency.dependency[idx] = dependency_def.dependency(idx); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status FftsTaskInfo::InitManualNop(const domi::ManualThreadNopDef &nop_def, rtManualThreadNopInfo_t &nop_info) { | |||||
if (nop_def.src_dep_tbl_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) { | |||||
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadNopInfo, src dep tbl size: %d", nop_def.src_dep_tbl_size()); | |||||
return FAILED; | |||||
} | |||||
for (int idx = 0; idx < nop_def.src_dep_tbl_size(); ++idx) { | |||||
GE_CHK_STATUS_RET_NOLOG(InitManualDependency(nop_def.src_dep_tbl(idx), nop_info.srcDepTbl[idx])); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadDmuInfo_t *&dmu) { | |||||
if (aic_aiv_def.prefetch_list().empty()) { | |||||
return; | |||||
} | |||||
std::vector<uint8_t> buffer(sizeof(rtManualThreadDmuInfo_t) * aic_aiv_def.prefetch_list_size()); | |||||
dmu = reinterpret_cast<rtManualThreadDmuInfo_t *>(buffer.data()); | |||||
for (int idx = 0; idx < aic_aiv_def.prefetch_list_size(); ++idx) { | |||||
InitManualDmuInfo(aic_aiv_def.prefetch_list(idx), dmu[idx]); | |||||
} | |||||
} | |||||
void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadDmuInfo_t *&dmu) { | |||||
if (cache_def.dmu_list().empty()) { | |||||
return; | |||||
} | |||||
std::vector<uint8_t> buffer(sizeof(rtManualThreadDmuInfo_t) * cache_def.dmu_list_size()); | |||||
dmu = reinterpret_cast<rtManualThreadDmuInfo_t *>(buffer.data()); | |||||
for (int idx = 0; idx < cache_def.dmu_list_size(); ++idx) { | |||||
InitManualDmuInfo(cache_def.dmu_list(idx), dmu[idx]); | |||||
} | |||||
} | |||||
void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadDmuDef &dmu_def, rtManualThreadDmuInfo_t &dmu) { | |||||
dmu.dataAddr = dmu_def.data_addr(); | |||||
dmu.numOuter = dmu_def.num_outer(); | |||||
dmu.numInner = dmu_def.num_inner(); | |||||
dmu.strideOuter = dmu_def.stride_outer(); | |||||
dmu.lenInner = dmu_def.len_inner(); | |||||
dmu.strideInner = dmu_def.stride_inner(); | |||||
} | |||||
Status FftsTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||||
return SUCCESS; | |||||
} | |||||
Status FftsTaskInfo::UpdateArgs() { | |||||
GE_CHECK_NOTNULL(davinci_model_); | |||||
std::vector<void *> io_addrs = io_addrs_; | |||||
davinci_model_->UpdateKnownZeroCopyAddr(io_addrs); | |||||
auto addr_size = kAddrLen * io_addrs.size(); | |||||
GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs.data(), addr_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||||
return SUCCESS; | |||||
} | |||||
Status FftsTaskInfo::Distribute() { | |||||
GELOGI("FftsTaskInfo Distribute Start."); | |||||
rtError_t rt_ret = rtFftsTaskLaunch(&sub_task_info_, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "[Check][RT_ret] Call rtFftsTaskLaunch failed, ret: 0x%X", rt_ret); | |||||
return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
} | |||||
GELOGI("FftsTaskInfo Distribute Success."); | |||||
return SUCCESS; | |||||
} | |||||
REGISTER_TASK_INFO(RT_MODEL_TASK_FFTS_TASK, FftsTaskInfo); | |||||
} // namespace ge |
@@ -0,0 +1,66 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_ | |||||
#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_ | |||||
#include "graph/load/model_manager/task_info/task_info.h" | |||||
#include "graph/op_desc.h" | |||||
namespace ge { | |||||
class FftsTaskInfo : public TaskInfo { | |||||
public: | |||||
FftsTaskInfo() = default; | |||||
~FftsTaskInfo() override; | |||||
Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||||
Status Distribute() override; | |||||
Status UpdateArgs() override; | |||||
Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||||
private: | |||||
void InitFftsDescInfo(const domi::FftsDescInfoDef &ffts_desc_def, rtFftsDescInfo_t &ffts_desc); | |||||
Status InitSubTaskInfo(const domi::FftsSubTaskDef &task_def, rtFftsSubTaskInfo_t &task); | |||||
Status InitTicketCache(const domi::TicketCacheDef &cache_def, rtTicketCache_t &cache); | |||||
Status InitAutoAicAiv(const domi::AutoThreadAicAivDef &aic_aiv_def, rtAutoThreadAicAivInfo_t &aic_aiv); | |||||
void InitAutoCacheInfo(const domi::AutoThreadCacheDef &cache_def, rtAutoThreadCacheInfo_t &cache); | |||||
void InitAutoPrefetch(const domi::AutoThreadPrefetchDef &prefetch_def, rtAutoThreadPrefetch_t &prefetch); | |||||
Status InitManualAicAiv(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadAicAivInfo_t &aic_aiv); | |||||
Status InitManualCacheInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadCacheInfo_t &cache); | |||||
Status InitManualDependency(const domi::ManualThreadDependencyDef &depend_def, rtManualThreadDependency_t &depend); | |||||
Status InitManualNop(const domi::ManualThreadNopDef &nop_def, rtManualThreadNopInfo_t &nop); | |||||
void InitManualDmuInfo(const domi::ManualThreadDmuDef &dmu_def, rtManualThreadDmuInfo_t &dmu); | |||||
void InitManualDmuInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadDmuInfo_t *&dmu); | |||||
void InitManualDmuInfo(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadDmuInfo_t *&dmu); | |||||
template<typename T> | |||||
Status InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, uint32_t thread_dim, uint32_t addr_count); | |||||
DavinciModel *davinci_model_{nullptr}; | |||||
rtFftsTaskInfo_t sub_task_info_; | |||||
std::vector<void *> io_addrs_; | |||||
uint32_t thread_dim_{0}; | |||||
void *args_{nullptr}; // runtime args memory | |||||
uint32_t args_size_{0}; // runtime args memory length | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_ |
@@ -179,6 +179,7 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr | |||||
GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); | GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); | ||||
} | } | ||||
GE_CHECK_NOTNULL(original_compute_graph); | GE_CHECK_NOTNULL(original_compute_graph); | ||||
output_merged_compute_graph->SetName(original_compute_graph->GetName()); | |||||
// partition sub graph | // partition sub graph | ||||
for (const auto &sub_graph : original_compute_graph->GetAllSubgraphs()) { | for (const auto &sub_graph : original_compute_graph->GetAllSubgraphs()) { | ||||
ComputeGraphPtr merged_sub_graph = nullptr; | ComputeGraphPtr merged_sub_graph = nullptr; | ||||
@@ -188,8 +189,16 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr | |||||
GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); | GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); | ||||
continue; | continue; | ||||
} | } | ||||
// this means subgraph added in optimize subgraph and without partitions, so just add to root graph | |||||
if (merged_sub_graph == sub_graph) { | |||||
GELOGI("Just add subgraph %s (parent node is %s) to root graph %s.", sub_graph->GetName().c_str(), | |||||
sub_graph->GetParentNode()->GetName().c_str(), output_merged_compute_graph->GetName().c_str()); | |||||
sub_graph->SetParentGraph(sub_graph->GetParentNode()->GetOwnerComputeGraph()); | |||||
GE_IF_BOOL_EXEC(output_merged_compute_graph->AddSubgraph(sub_graph->GetName(), merged_sub_graph) != SUCCESS, | |||||
return FAILED;) | |||||
continue; | |||||
} | |||||
// add sub graph | // add sub graph | ||||
output_merged_compute_graph->SetName(original_compute_graph->GetName()); | |||||
merged_sub_graph->SetName(sub_graph->GetName()); | merged_sub_graph->SetName(sub_graph->GetName()); | ||||
merged_sub_graph->SetInputSize(sub_graph->GetInputSize()); | merged_sub_graph->SetInputSize(sub_graph->GetInputSize()); | ||||
merged_sub_graph->SetOutputSize(sub_graph->GetOutputSize()); | merged_sub_graph->SetOutputSize(sub_graph->GetOutputSize()); | ||||
@@ -245,12 +254,9 @@ Status ge::GraphPartitioner::MergeSubGraph(ge::ComputeGraphPtr &output_merged_co | |||||
} | } | ||||
if ((graph_2_graph_partition_info_.find(original_compute_graph) == graph_2_graph_partition_info_.end()) || | if ((graph_2_graph_partition_info_.find(original_compute_graph) == graph_2_graph_partition_info_.end()) || | ||||
(graph_2_subgraph_list_.find(original_compute_graph) == graph_2_subgraph_list_.end())) { | (graph_2_subgraph_list_.find(original_compute_graph) == graph_2_subgraph_list_.end())) { | ||||
REPORT_INNER_ERROR("E19999", "original_compute_graph:%s is not find in graph_2_graph_partition_info_.", | |||||
original_compute_graph->GetName().c_str()); | |||||
GELOGE(GE_GRAPH_NULL_INPUT, | |||||
"[Check][Param] original_compute_graph:%s is not find in graph_2_graph_partition_info_.", | |||||
original_compute_graph->GetName().c_str()); | |||||
return FAILED; | |||||
GELOGW("[GraphPartition]: compute_graph has not found, just return original."); | |||||
output_merged_compute_graph = original_compute_graph; | |||||
return SUCCESS; | |||||
} | } | ||||
GraphPartitionInfo &subgraph_info = graph_2_graph_partition_info_[original_compute_graph]; | GraphPartitionInfo &subgraph_info = graph_2_graph_partition_info_[original_compute_graph]; | ||||
const auto &sub_graph_list = graph_2_subgraph_list_[original_compute_graph]; | const auto &sub_graph_list = graph_2_subgraph_list_[original_compute_graph]; | ||||
@@ -708,6 +714,7 @@ Status ge::GraphPartitioner::AddPartitionsToGraphNode(vector<ge::SubGraphInfoPtr | |||||
} | } | ||||
auto &engine_name = graph_info_.partitions_.at(sub_graph); | auto &engine_name = graph_info_.partitions_.at(sub_graph); | ||||
(void)AttrUtils::SetStr(sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName()); | (void)AttrUtils::SetStr(sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName()); | ||||
(void)sub_graph->SetExtAttr("part_src_graph", compute_graph); | |||||
GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(), | GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(), | ||||
compute_graph->GetName().c_str()); | compute_graph->GetName().c_str()); | ||||
GE_DUMP(sub_graph, sub_graph->GetName() + "_" + mode_2_str_[graph_info_.mode_]); | GE_DUMP(sub_graph, sub_graph->GetName() + "_" + mode_2_str_[graph_info_.mode_]); | ||||
@@ -1 +1 @@ | |||||
Subproject commit c6030152c6dc05515115765babb5d64fde649df4 | |||||
Subproject commit 00c0c12eede6c7bce93a1eda5f0bb437ae80a7ec |
@@ -1 +1 @@ | |||||
Subproject commit 155d3262ba17f800094abb58b6a809b041cf0a74 | |||||
Subproject commit 3073129b68c0fae12a8b7531d60782e39128a28c |
@@ -456,6 +456,10 @@ rtError_t rtDebugRegisterForStream(rtStream_t stream, uint32_t flag, const void | |||||
rtError_t rtDebugUnRegisterForStream(rtStream_t stream) { | rtError_t rtDebugUnRegisterForStream(rtStream_t stream) { | ||||
return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
} | } | ||||
rtError_t rtFftsTaskLaunch(rtFftsTaskInfo_t *fftsTaskInfo, rtStream_t stream) { | |||||
return RT_ERROR_NONE; | |||||
} | |||||
#ifdef __cplusplus | #ifdef __cplusplus | ||||
} | } | ||||
#endif | #endif |
@@ -437,6 +437,7 @@ set(DISTINCT_GRAPH_LOAD_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/stream_active_task_info.cc" | "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/stream_active_task_info.cc" | ||||
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/end_graph_task_info.cc" | "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/end_graph_task_info.cc" | ||||
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/model_exit_task_info.cc" | "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/model_exit_task_info.cc" | ||||
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/ffts_task_info.cc" | |||||
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/super_kernel/super_kernel.cc" | "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/super_kernel/super_kernel.cc" | ||||
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/super_kernel/super_kernel_factory.cc" | "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/super_kernel/super_kernel_factory.cc" | ||||
"${GE_CODE_DIR}/ge/model/ge_model.cc" | "${GE_CODE_DIR}/ge/model/ge_model.cc" | ||||
@@ -649,6 +650,7 @@ set(DISTINCT_GRAPH_LOAD_TEST_FILES | |||||
"graph/load/hccl_task_info_unittest.cc" | "graph/load/hccl_task_info_unittest.cc" | ||||
"graph/load/kernel_ex_task_info_unittest.cc" | "graph/load/kernel_ex_task_info_unittest.cc" | ||||
"graph/load/kernel_task_info_unittest.cc" | "graph/load/kernel_task_info_unittest.cc" | ||||
"graph/load/ffts_task_info_unittest.cc" | |||||
"graph/load/memcpy_addr_async_task_info_unittest.cc" | "graph/load/memcpy_addr_async_task_info_unittest.cc" | ||||
"graph/load/memcpy_async_task_info_unittest.cc" | "graph/load/memcpy_async_task_info_unittest.cc" | ||||
"graph/load/cpu_queue_schedule_unittest.cc" | "graph/load/cpu_queue_schedule_unittest.cc" | ||||
@@ -1059,4 +1059,144 @@ TEST_F(UtestDavinciModel, get_total_memsize_exclude_zero_copy) { | |||||
EXPECT_EQ(model.GetTotalMemSizeExcludeZeroCopy(total_useful_size), SUCCESS); | EXPECT_EQ(model.GetTotalMemSizeExcludeZeroCopy(total_useful_size), SUCCESS); | ||||
EXPECT_EQ(total_useful_size, 512); | EXPECT_EQ(total_useful_size, 512); | ||||
} | } | ||||
// test InitTbeHandle | |||||
TEST_F(UtestDavinciModel, init_tbe_handle) { | |||||
DavinciModel model(0, nullptr); | |||||
OpDescPtr op_desc = CreateOpDesc("data", DATA); | |||||
model.ge_model_ = make_shared<GeModel>(); | |||||
// without kernel | |||||
EXPECT_EQ(model.InitTbeHandle(op_desc), INTERNAL_ERROR); | |||||
vector<char> buffer; | |||||
string key = op_desc->GetName(); | |||||
TBEKernelPtr tbe_kernel_ptr = std::make_shared<ge::OpKernelBin>(key, std::move(buffer)); | |||||
op_desc->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); | |||||
string attr_kernel_name = op_desc->GetName() + "_kernelname"; | |||||
string kernel_name = "kernel_name"; | |||||
AttrUtils::SetStr(op_desc, attr_kernel_name, kernel_name); | |||||
EXPECT_EQ(model.InitTbeHandle(op_desc), SUCCESS); | |||||
// rtQueryFunctionRegistered(bin_file_key) failed | |||||
EXPECT_EQ(model.used_tbe_handle_map_.size(), 0); | |||||
} | |||||
// test InitTbeHandleWithFfts | |||||
TEST_F(UtestDavinciModel, init_tbe_handle_with_ffts) { | |||||
DavinciModel model(0, nullptr); | |||||
OpDescPtr op_desc = CreateOpDesc("data", DATA); | |||||
model.ge_model_ = make_shared<GeModel>(); | |||||
// without tbe_kernel | |||||
EXPECT_EQ(model.InitTbeHandleWithFfts(op_desc), INTERNAL_ERROR); | |||||
std::vector<OpKernelBinPtr> tbe_kernel; | |||||
vector<char> buffer; | |||||
string key = op_desc->GetName(); | |||||
OpKernelBinPtr tbe_kernel_ptr0 = std::make_shared<ge::OpKernelBin>(key, std::move(buffer)); | |||||
OpKernelBinPtr tbe_kernel_ptr1 = std::make_shared<ge::OpKernelBin>(key, std::move(buffer)); | |||||
tbe_kernel.push_back(tbe_kernel_ptr0); | |||||
tbe_kernel.push_back(tbe_kernel_ptr1); | |||||
op_desc->SetExtAttr(OP_EXTATTR_NAME_THREAD_TBE_KERNEL, tbe_kernel); | |||||
// without _register_stub_func | |||||
EXPECT_EQ(model.InitTbeHandleWithFfts(op_desc), INTERNAL_ERROR); | |||||
vector<string> bin_file_keys; | |||||
bin_file_keys.emplace_back(op_desc->GetName() + "_0"); | |||||
bin_file_keys.emplace_back(op_desc->GetName() + "_1"); | |||||
AttrUtils::SetListStr(op_desc, "_register_stub_func", bin_file_keys); | |||||
EXPECT_EQ(model.InitTbeHandleWithFfts(op_desc), SUCCESS); | |||||
// rtQueryFunctionRegistered(bin_file_key) failed | |||||
EXPECT_EQ(model.used_tbe_handle_map_.size(), 0); | |||||
} | |||||
// test InitBinaryMagic | |||||
TEST_F(UtestDavinciModel, init_binary_magic) { | |||||
DavinciModel model(0, nullptr); | |||||
rtDevBinary_t binary; | |||||
OpDescPtr op_desc = CreateOpDesc("data", DATA); | |||||
bool is_ffts = true; | |||||
vector<string> json_list; | |||||
AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_MAGIC, json_list); | |||||
// without tvm_magic | |||||
EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), INTERNAL_ERROR); | |||||
json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF_AICPU"); | |||||
json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF"); | |||||
op_desc->DelAttr(TVM_ATTR_NAME_THREAD_MAGIC); | |||||
AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_MAGIC, json_list); | |||||
EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), SUCCESS); | |||||
EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF_AICPU); | |||||
EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 1, binary), SUCCESS); | |||||
EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF); | |||||
json_list.clear(); | |||||
json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF_AIVEC"); | |||||
json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF_AICUBE"); | |||||
op_desc->DelAttr(TVM_ATTR_NAME_THREAD_MAGIC); | |||||
AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_MAGIC, json_list); | |||||
EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), SUCCESS); | |||||
EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF_AIVEC); | |||||
EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 1, binary), SUCCESS); | |||||
EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF_AICUBE); | |||||
// with invalid json type | |||||
json_list.clear(); | |||||
json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF_INVALID"); | |||||
json_list.emplace_back("RT_DEV_BINARY_MAGIC_ELF_INVALID"); | |||||
op_desc->DelAttr(TVM_ATTR_NAME_THREAD_MAGIC); | |||||
AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_MAGIC, json_list); | |||||
EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), PARAM_INVALID); | |||||
// test unffts | |||||
is_ffts = false; | |||||
string json_string = "RT_DEV_BINARY_MAGIC_ELF_AIVEC"; | |||||
AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_MAGIC, json_string); | |||||
EXPECT_EQ(model.InitBinaryMagic(op_desc, is_ffts, 0, binary), SUCCESS); | |||||
EXPECT_EQ(binary.magic, RT_DEV_BINARY_MAGIC_ELF_AIVEC); | |||||
} | |||||
// test InitMetaData | |||||
TEST_F(UtestDavinciModel, init_meta_data) { | |||||
DavinciModel model(0, nullptr); | |||||
void *bin_handle; | |||||
OpDescPtr op_desc = CreateOpDesc("data", DATA); | |||||
bool is_ffts = true; | |||||
vector<string> meta_data_list; | |||||
// with empty meta_data | |||||
EXPECT_EQ(model.InitMetaData(op_desc, is_ffts, 0, bin_handle), INTERNAL_ERROR); | |||||
meta_data_list.emplace_back("meta_data_0"); | |||||
meta_data_list.emplace_back("meta_data_1"); | |||||
AttrUtils::SetListStr(op_desc, TVM_ATTR_NAME_THREAD_METADATA, meta_data_list); | |||||
EXPECT_EQ(model.InitMetaData(op_desc, is_ffts, 0, bin_handle), SUCCESS); | |||||
is_ffts = false; | |||||
string meta_data = "meta_data"; | |||||
AttrUtils::SetStr(op_desc, TVM_ATTR_NAME_METADATA, meta_data); | |||||
EXPECT_EQ(model.InitMetaData(op_desc, is_ffts, 0, bin_handle), SUCCESS); | |||||
} | |||||
// test InitKernelName | |||||
TEST_F(UtestDavinciModel, init_kernel_name) { | |||||
DavinciModel model(0, nullptr); | |||||
string kernel_name; | |||||
OpDescPtr op_desc = CreateOpDesc("data", DATA); | |||||
bool is_ffts = true; | |||||
// failed when name is invalid | |||||
EXPECT_EQ(model.InitKernelName(op_desc, is_ffts, 0, kernel_name), INTERNAL_ERROR); | |||||
OpDescPtr op_desc1 = CreateOpDesc("sgt_graph_nodes/loss_scale", SCALE); | |||||
string attr_kernel_name = "loss_scale_thread_kernelname"; | |||||
vector<string> kernel_name_list; | |||||
AttrUtils::SetListStr(op_desc, attr_kernel_name, kernel_name_list); | |||||
// failed without kernel_name | |||||
EXPECT_EQ(model.InitKernelName(op_desc, is_ffts, 0, kernel_name), INTERNAL_ERROR); | |||||
kernel_name_list.emplace_back("kernel_name_0"); | |||||
kernel_name_list.emplace_back("kernel_name_1"); | |||||
AttrUtils::SetListStr(op_desc1, attr_kernel_name, kernel_name_list); | |||||
EXPECT_EQ(model.InitKernelName(op_desc1, is_ffts, 0, kernel_name), SUCCESS); | |||||
// without ffts | |||||
is_ffts = false; | |||||
attr_kernel_name = "data_kernelname"; | |||||
kernel_name = "kernel_name"; | |||||
AttrUtils::SetStr(op_desc, attr_kernel_name, kernel_name); | |||||
EXPECT_EQ(model.InitKernelName(op_desc, is_ffts, 0, kernel_name), SUCCESS); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -0,0 +1,212 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#define private public | |||||
#define protected public | |||||
#include "graph/load/model_manager/task_info/ffts_task_info.h" | |||||
#include "cce/aicpu_engine_struct.h" | |||||
#include "common/ge/ge_util.h" | |||||
#include "common/properties_manager.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "framework/common/fmk_error_codes.h" | |||||
#include "graph/attr_value.h" | |||||
#include "graph/load/model_manager/davinci_model.h" | |||||
#include "graph/load/model_manager/model_manager.h" | |||||
#include "runtime/rt_ffts.h" | |||||
namespace ge { | |||||
extern OpDescPtr CreateOpDesc(string name, string type); | |||||
class UtestFftsTaskInfo : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
public: | |||||
void CreateFftsTaskInfo(DavinciModel &davinci_model, domi::TaskDef &task_def, FftsTaskInfo &ffts_task_info) { | |||||
rtStream_t stream = nullptr; | |||||
rtStreamCreate(&stream, 0); | |||||
davinci_model.stream_list_ = { stream }; | |||||
task_def.set_stream_id(0); | |||||
domi::FftsTaskDef *ffts_task_def = task_def.mutable_ffts_task(); | |||||
davinci_model.op_list_[0] = CreateOpDesc("test", PARTITIONEDCALL); | |||||
ffts_task_def->set_op_index(0); | |||||
ffts_task_def->set_addr_size(2); | |||||
domi::FftsDescInfoDef *ffts_desc = ffts_task_def->mutable_ffts_desc(); | |||||
ffts_desc->set_tm(0); | |||||
rtFftsTaskInfo_t sub_task_info; | |||||
ffts_task_info.sub_task_info_ = sub_task_info; | |||||
ffts_task_def->set_ffts_type(RT_FFTS_TYPE_AUTO_THREAD); | |||||
} | |||||
}; | |||||
// test FftsTaskInfo Init with no subtask and no ticket cache | |||||
TEST_F(UtestFftsTaskInfo, success_ffts_task_info_without_subtask) { | |||||
DavinciModel davinci_model(0, nullptr); | |||||
rtStream_t stream = nullptr; | |||||
rtStreamCreate(&stream, 0); | |||||
davinci_model.stream_list_ = { stream }; | |||||
domi::TaskDef task_def; | |||||
task_def.set_stream_id(0); | |||||
domi::FftsTaskDef *ffts_task_def = task_def.mutable_ffts_task(); | |||||
FftsTaskInfo ffts_task_info; | |||||
// init failed when model without op_desc | |||||
EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), PARAM_INVALID); | |||||
davinci_model.op_list_[0] = CreateOpDesc("test", PARTITIONEDCALL); | |||||
ffts_task_def->set_op_index(0); | |||||
ffts_task_def->set_addr_size(2); | |||||
domi::FftsDescInfoDef *ffts_desc = ffts_task_def->mutable_ffts_desc(); | |||||
ffts_desc->set_tm(0); | |||||
rtFftsTaskInfo_t sub_task_info; | |||||
ffts_task_info.sub_task_info_ = sub_task_info; | |||||
ffts_task_def->set_ffts_type(RT_FFTS_TYPE_AUTO_THREAD); | |||||
ffts_task_info.io_addrs_ = { (void*)0x12345678, (void*)0x22345678 }; | |||||
EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), SUCCESS); | |||||
} | |||||
// test FftsTaskInfo Init with subtask and no ticket cache: AutoThreadAicAivDef | |||||
TEST_F(UtestFftsTaskInfo, success_ffts_task_info_with_auto_thread_subgraph) { | |||||
DavinciModel davinci_model(0, nullptr); | |||||
domi::TaskDef task_def; | |||||
FftsTaskInfo ffts_task_info; | |||||
CreateFftsTaskInfo(davinci_model, task_def, ffts_task_info); | |||||
domi::FftsSubTaskDef *ffts_sub_task_def = task_def.mutable_ffts_task()->add_sub_task(); | |||||
ffts_sub_task_def->set_thread_dim(static_cast<uint32_t>(1)); | |||||
//sub_task_def.has_auto_thread_aic_aiv() == sub_task_def.has_manual_thread_aic_aiv() | |||||
EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), FAILED); | |||||
domi::AutoThreadAicAivDef *auto_thread_aic_aiv_def = ffts_sub_task_def->mutable_auto_thread_aic_aiv(); | |||||
domi::AutoThreadPrefetchDef *src_prefetch = auto_thread_aic_aiv_def->add_src_prefetch(); | |||||
// without InitIoAddrs | |||||
ffts_task_info.thread_dim_ = 0; | |||||
RuntimeParam runtime_param; | |||||
ffts_task_info.io_addrs_ = { (void*)0x12345678, (void*)0x22345678 }; | |||||
EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), SUCCESS); | |||||
} | |||||
// test FftsTaskInfo Init with subtask and no ticket cache: ManualThreadAicAivDef | |||||
TEST_F(UtestFftsTaskInfo, success_ffts_task_info_with_manual_thread_subgraph) { | |||||
DavinciModel davinci_model(0, nullptr); | |||||
domi::TaskDef task_def; | |||||
FftsTaskInfo ffts_task_info; | |||||
CreateFftsTaskInfo(davinci_model, task_def, ffts_task_info); | |||||
domi::FftsSubTaskDef *ffts_sub_task_def = task_def.mutable_ffts_task()->add_sub_task(); | |||||
ffts_sub_task_def->set_thread_dim(static_cast<uint32_t>(1)); | |||||
//sub_task_def.has_auto_thread_aic_aiv() == sub_task_def.has_manual_thread_aic_aiv() | |||||
domi::ManualThreadAicAivDef *manual_thread_aic_aiv_def = ffts_sub_task_def->mutable_manual_thread_aic_aiv(); | |||||
manual_thread_aic_aiv_def->add_thread_prefetch_dmu_idx(static_cast<uint32_t>(0)); | |||||
manual_thread_aic_aiv_def->add_thread_blk_dim(static_cast<uint32_t>(0)); | |||||
manual_thread_aic_aiv_def->add_thread_task_func_stub("ffts"); | |||||
domi::ManualThreadDmuDef *prefetch_list = manual_thread_aic_aiv_def->add_prefetch_list(); | |||||
prefetch_list->set_data_addr(static_cast<uint64_t>(0)); | |||||
// without InitIoAddrs | |||||
ffts_task_info.thread_dim_ = 0; | |||||
RuntimeParam runtime_param; | |||||
ffts_task_info.io_addrs_ = { (void*)0x12345678, (void*)0x22345678 }; | |||||
EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), SUCCESS); | |||||
} | |||||
// test FftsTaskInfo Init with subtask and no ticket cache: ManualThreadNopDef | |||||
TEST_F(UtestFftsTaskInfo, success_ffts_task_info_with_manual_thread_nop_subgraph) { | |||||
DavinciModel davinci_model(0, nullptr); | |||||
domi::TaskDef task_def; | |||||
FftsTaskInfo ffts_task_info; | |||||
CreateFftsTaskInfo(davinci_model, task_def, ffts_task_info); | |||||
domi::FftsSubTaskDef *ffts_sub_task_def = task_def.mutable_ffts_task()->add_sub_task(); | |||||
ffts_sub_task_def->set_thread_dim(static_cast<uint32_t>(1)); | |||||
domi::AutoThreadAicAivDef *auto_thread_aic_aiv_def = ffts_sub_task_def->mutable_auto_thread_aic_aiv(); | |||||
domi::ManualThreadNopDef *manual_thread_nop = ffts_sub_task_def->mutable_manual_thread_nop(); | |||||
domi::ManualThreadDependencyDef *src_dep_tbl = manual_thread_nop->add_src_dep_tbl(); | |||||
src_dep_tbl->add_dependency(static_cast<uint32_t>(0)); | |||||
// without InitIoAddrs | |||||
ffts_task_info.thread_dim_ = 0; | |||||
RuntimeParam runtime_param; | |||||
ffts_task_info.io_addrs_ = { (void*)0x12345678, (void*)0x22345678 }; | |||||
EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), SUCCESS); | |||||
} | |||||
// test FftsTaskInfo Init with no subtask and ticket cache:AutoThreadCacheDef | |||||
TEST_F(UtestFftsTaskInfo, success_ffts_task_info_with_auto_thread_ticket_cache) { | |||||
DavinciModel davinci_model(0, nullptr); | |||||
domi::TaskDef task_def; | |||||
FftsTaskInfo ffts_task_info; | |||||
CreateFftsTaskInfo(davinci_model, task_def, ffts_task_info); | |||||
domi::TicketCacheDef *ticket_cache_def = task_def.mutable_ffts_task()->add_ticket_cache(); | |||||
//ticket_cache_def.has_auto_thread_cache() == ticket_cache_def.has_manual_thread_cache() | |||||
EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), FAILED); | |||||
domi::AutoThreadCacheDef *auto_thread_cache = ticket_cache_def->mutable_auto_thread_cache(); | |||||
ffts_task_info.io_addrs_ = { (void*)0x12345678, (void*)0x22345678 }; | |||||
EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), SUCCESS); | |||||
} | |||||
// test FftsTaskInfo Init with no subtask and ticket cache:ManualThreadCacheDef | |||||
TEST_F(UtestFftsTaskInfo, success_ffts_task_info_with_manual_thread_ticket_cache) { | |||||
DavinciModel davinci_model(0, nullptr); | |||||
domi::TaskDef task_def; | |||||
FftsTaskInfo ffts_task_info; | |||||
CreateFftsTaskInfo(davinci_model, task_def, ffts_task_info); | |||||
domi::TicketCacheDef *ticket_cache_def = task_def.mutable_ffts_task()->add_ticket_cache(); | |||||
domi::ManualThreadCacheDef *manual_thread_cache = ticket_cache_def->mutable_manual_thread_cache(); | |||||
manual_thread_cache->add_slice_dmu_idx(static_cast<uint32_t>(0)); | |||||
manual_thread_cache->add_ticket_cache_ref_cnt_tbl(static_cast<uint32_t>(0)); | |||||
domi::ManualThreadDmuDef *dmu_list = manual_thread_cache->add_dmu_list(); | |||||
ffts_task_info.io_addrs_ = { (void*)0x12345678, (void*)0x22345678 }; | |||||
EXPECT_EQ(ffts_task_info.Init(task_def, &davinci_model), SUCCESS); | |||||
} | |||||
// test FftsTaskInfo UpdateArgs | |||||
TEST_F(UtestFftsTaskInfo, success_ffts_task_info_update_args) { | |||||
DavinciModel davinci_model(0, nullptr); | |||||
FftsTaskInfo ffts_task_info; | |||||
ffts_task_info.davinci_model_ = &davinci_model; | |||||
ffts_task_info.io_addrs_ = { (void*)0x12345678, (void*)0x22345678 }; | |||||
EXPECT_EQ(ffts_task_info.UpdateArgs(), SUCCESS); | |||||
} | |||||
// test FftsTaskInfo CalculateArgs | |||||
TEST_F(UtestFftsTaskInfo, success_ffts_task_info_calculate_args) { | |||||
DavinciModel davinci_model(0, nullptr); | |||||
domi::TaskDef task_def; | |||||
FftsTaskInfo ffts_task_info; | |||||
EXPECT_EQ(ffts_task_info.CalculateArgs(task_def, &davinci_model), SUCCESS); | |||||
} | |||||
// test FftsTaskInfo Distribute | |||||
TEST_F(UtestFftsTaskInfo, success_ffts_task_info_distribute) { | |||||
DavinciModel davinci_model(0, nullptr); | |||||
FftsTaskInfo ffts_task_info; | |||||
rtFftsTaskInfo_t sub_task_info; | |||||
ffts_task_info.sub_task_info_ = sub_task_info; | |||||
rtStream_t stream = nullptr; | |||||
rtStreamCreate(&stream, 0); | |||||
ffts_task_info.stream_ = stream; | |||||
EXPECT_EQ(ffts_task_info.Distribute(), SUCCESS); | |||||
} | |||||
} // namespace ge |
@@ -27,5 +27,6 @@ | |||||
#include "mem.h" | #include "mem.h" | ||||
#include "rt_model.h" | #include "rt_model.h" | ||||
#include "stream.h" | #include "stream.h" | ||||
#include "rt_ffts.h" | |||||
#endif // __CCE_RUNTIME_RT_H__ | #endif // __CCE_RUNTIME_RT_H__ |
@@ -0,0 +1,185 @@ | |||||
/* | |||||
* Copyright (c) Huawei Technologies Co. , Ltd. 2021. All rights reserved. | |||||
* Description: ffts interface | |||||
*/ | |||||
#ifndef __CCE_RUNTIME_FFTS_H | |||||
#define __CCE_RUNTIME_FFTS_H | |||||
#include "base.h" | |||||
#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) | |||||
extern "C" { | |||||
#endif | |||||
#define RT_FFTS_MAX_SUB_TASK_NUM 32U | |||||
#define RT_FFTS_MAX_TICKET_CACHE_NUM 64U | |||||
#define RT_FFTS_MAX_MANUAL_THREAD_NUM 16U | |||||
#define RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK 8U | |||||
#define RT_FFTS_MANUAL_SRC_DEPEND_TBL_LEN 32U | |||||
typedef enum tagFftsType { | |||||
RT_FFTS_TYPE_AUTO_THREAD = 2, // ffts auto thread mode, same as ffts define | |||||
RT_FFTS_TYPE_MANUAL_THREAD = 3, // ffts manual thread mode, same as ffts define | |||||
} rtFftsType_t; | |||||
typedef enum tagFftsSubTaskType { | |||||
RT_FFTS_SUB_TASK_TYPE_AIC = 0, | |||||
RT_FFTS_SUB_TASK_TYPE_AIV = 1, | |||||
RT_FFTS_SUB_TASK_TYPE_NOP = 2, | |||||
RT_FFTS_SUB_TASK_TYPE_NOTIFY_WAIT = 3, | |||||
RT_FFTS_SUB_TASK_TYPE_NOTIFY_RECORD = 4, | |||||
RT_FFTS_SUB_TASK_TYPE_WRITE_VALUE = 5, | |||||
RT_FFTS_SUB_TASK_TYPE_MIX_AIC = 6, | |||||
RT_FFTS_SUB_TASK_TYPE_MIX_AIV = 7, | |||||
RT_FFTS_SUB_TASK_TYPE_SDMA = 8, | |||||
RT_FFTS_SUB_TASK_TYPE_RESERVED, | |||||
} rtFftsSubTaskType_t; | |||||
typedef struct tagManualThreadDmuInfo { | |||||
uint64_t dataAddr; // device mem | |||||
uint16_t numOuter; | |||||
uint16_t numInner; | |||||
uint32_t strideOuter; | |||||
uint32_t lenInner; | |||||
uint32_t strideInner; | |||||
} rtManualThreadDmuInfo_t; | |||||
typedef struct tagManualThreadDependency { | |||||
uint8_t dependency[RT_FFTS_MANUAL_SRC_DEPEND_TBL_LEN]; | |||||
} rtManualThreadDependency_t; | |||||
typedef struct tagManualThreadAicAivInfo { | |||||
uint64_t taskParamAddr; // device mem | |||||
uint16_t taskParamOffset; | |||||
// when satMode=1 and FP16 computation with none INF inputs overflows/underflows, results will be +/-INF of FP16 | |||||
// when satMode=0 and FP16 computation with none INF inputs overflows/underflows | |||||
// results will be saturated to +/- MAX of FP16 | |||||
uint8_t satMode; | |||||
uint8_t scheduleMode; // 0:normal mode, 1:batch mode, 2:sync mode, 3: reserved | |||||
uint8_t iCachePrefetchCnt; // units is 2K | |||||
uint8_t prefetchEnableBitmap; // 8 bit bitmap 1 0 1 0 | |||||
uint8_t prefetchOnceBitmap; // 8 bit bitmap 1 0 1 0 | |||||
uint16_t prefetchOnceDmuNum; // prefetch_once_dmu_descriptor_index in ffts | |||||
// num: thread0_prefetch_dmu_descriptor_index - prefetch_once_dmu_descriptor_index | |||||
uint16_t threadPrefetchDmuIdx[RT_FFTS_MAX_MANUAL_THREAD_NUM]; // max valid is threadDim | |||||
uint16_t threadBlkDim[RT_FFTS_MAX_MANUAL_THREAD_NUM]; | |||||
const char *threadTaskFuncStub[RT_FFTS_MAX_MANUAL_THREAD_NUM]; | |||||
rtManualThreadDmuInfo_t *prefetchList; // dmu desc 0-64k, length is the last threadPrefetchDmuIdx[threadDim - 1] | |||||
rtManualThreadDependency_t srcDepTbl[RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK]; | |||||
} rtManualThreadAicAivInfo_t; | |||||
typedef struct tagAutoThreadPrefetch { | |||||
uint64_t dataAddr; // device mem | |||||
uint32_t dataAddrOffset; | |||||
uint32_t nonTailDataLen; | |||||
uint32_t tailDataLen; | |||||
} rtAutoThreadPrefetch_t; | |||||
typedef struct tagAutoThreadAicAivInfo { | |||||
uint64_t taskParamAddr; // device mem | |||||
uint16_t taskParamOffset; | |||||
// when satMode=1 and FP16 computation with none INF inputs overflows/underflows, results will be +/-INF of FP16 | |||||
// when satMode=0 and FP16 computation with none INF inputs overflows/underflows | |||||
// results will be saturated to +/- MAX of FP16 | |||||
uint8_t satMode; | |||||
uint8_t scheduleMode; // 0:normal mode, 1:batch mode, 2:sync mode, 3: reserved | |||||
uint8_t iCachePrefetchCnt; // units is 2K | |||||
uint8_t prefetchEnableBitmap; // 8 bit bitmap | |||||
uint8_t prefetchOnceBitmap; // 8 bit bitmap | |||||
uint16_t tailBlkDim; | |||||
uint16_t nonTailBlkDim; | |||||
const char *nonTailTaskFuncStub; | |||||
const char *tailTaskFuncStub; | |||||
// for prefetch, valid num is prefetchEnableBitmap bit count | |||||
// if prefetchEnableBitmap = '00010011', need prefetch number is 3, srcPrefetch is only 0, 1, 2 is valid | |||||
rtAutoThreadPrefetch_t srcPrefetch[RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK]; | |||||
} rtAutoThreadAicAivInfo_t; | |||||
typedef struct tagAutoThreadCacheInfo { | |||||
uint64_t dataAddr; // device mem | |||||
uint32_t dataAddrOffset; | |||||
uint32_t nonTailDataLen; | |||||
uint32_t tailDataLen; | |||||
uint16_t ticketCacheRefCnt; | |||||
} rtAutoThreadCacheInfo_t; | |||||
typedef struct tagManualThreadCacheInfo { | |||||
rtManualThreadDmuInfo_t *dmuList; // 0-64k | |||||
uint16_t dmuNum; | |||||
uint16_t sliceDmuIdx[RT_FFTS_MAX_MANUAL_THREAD_NUM]; | |||||
uint16_t ticketCacheRefCntTbl[RT_FFTS_MAX_MANUAL_THREAD_NUM]; | |||||
} rtManualThreadCacheInfo_t; | |||||
typedef enum tagCacheOp { | |||||
RT_CACHE_OP_NONE = 0, | |||||
RT_CACHE_OP_FLUSH = 1, | |||||
RT_CACHE_OP_INVALIDATE = 2, | |||||
RT_CACHE_OP_WRITE_BACK = 3, | |||||
} rtCacheOp_t; | |||||
typedef struct tagTicketCache { | |||||
rtCacheOp_t cacheOption; | |||||
uint8_t ticketCacheWindow; | |||||
union { | |||||
rtAutoThreadCacheInfo_t autoThreadCache; | |||||
rtManualThreadCacheInfo_t manualThreadCache; | |||||
} custom; | |||||
} rtTicketCache_t; | |||||
typedef struct tagManualThreadNopInfo { | |||||
// depend srcTickCacheVldBitmap in rtFftsSubTaskInfo_t | |||||
rtManualThreadDependency_t srcDepTbl[RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK]; | |||||
} rtManualThreadNopInfo_t; | |||||
typedef struct tagFftsSubTaskInfo { | |||||
rtFftsSubTaskType_t subTaskType; | |||||
uint16_t threadDim; | |||||
uint8_t dstTickCacheVldBitmap; | |||||
uint8_t srcTickCacheVldBitmap; | |||||
uint8_t srcDataOutOfSubGraphBitmap; | |||||
uint8_t dstTickCacheID[RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK]; | |||||
uint8_t srcTickCacheID[RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK]; | |||||
union { | |||||
rtAutoThreadAicAivInfo_t autoThreadAicAiv; | |||||
rtManualThreadAicAivInfo_t manualThreadAicAiv; | |||||
rtManualThreadNopInfo_t manualThreadNop; | |||||
} custom; | |||||
} rtFftsSubTaskInfo_t; | |||||
typedef struct tagFftsDescInfo { | |||||
uint8_t tm; // thread subtask kickstart mode, 0:order, 1:disorder | |||||
uint8_t di; // discard invalidate | |||||
uint8_t dw; // discard write back | |||||
uint8_t df; // discard flush | |||||
uint8_t dataSplitUnit; // split source or ticket cache by 2~dataSplitUnit MB | |||||
uint8_t prefetchOstNum; | |||||
uint8_t cacheMaintainOstNum; | |||||
uint8_t aicPrefetchUpper; | |||||
uint8_t aicPrefetchLower; | |||||
uint8_t aivPrefetchUpper; | |||||
uint8_t aivPrefetchLower; | |||||
} rtFftsDescInfo_t; | |||||
typedef struct tagFftsTaskInfo { | |||||
rtFftsType_t fftsType; | |||||
uint16_t subTaskNum; | |||||
uint16_t tickCacheNum; | |||||
rtFftsDescInfo_t fftsDesc; | |||||
// sub task desc, real num is subTaskNum | |||||
rtFftsSubTaskInfo_t subTask[RT_FFTS_MAX_SUB_TASK_NUM]; | |||||
// ticket cache, real number is ticketCacheNum | |||||
rtTicketCache_t ticketCache[RT_FFTS_MAX_TICKET_CACHE_NUM]; | |||||
} rtFftsTaskInfo_t; | |||||
RTS_API rtError_t rtFftsTaskLaunch(rtFftsTaskInfo_t *fftsTaskInfo, rtStream_t stream); | |||||
#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) | |||||
} | |||||
#endif | |||||
#endif //__CCE_RUNTIME_FFTS_H |
@@ -50,6 +50,7 @@ typedef enum tagModelTaskType { | |||||
RT_MODEL_TASK_STREAM_LABEL_SWITCH_BY_INDEX, | RT_MODEL_TASK_STREAM_LABEL_SWITCH_BY_INDEX, | ||||
RT_MODEL_TASK_STREAM_LABEL_GOTO, | RT_MODEL_TASK_STREAM_LABEL_GOTO, | ||||
RT_MODEL_TASK_MODEL_EXIT, | RT_MODEL_TASK_MODEL_EXIT, | ||||
RT_MODEL_TASK_FFTS_TASK, | |||||
RT_MODEL_TASK_ALL_KERNEL, | RT_MODEL_TASK_ALL_KERNEL, | ||||
} rtModelTaskType_t; | } rtModelTaskType_t; | ||||