@@ -59,7 +59,7 @@ struct GraphExecutionContext { | |||
uint64_t session_id = 0; | |||
uint64_t context_id = 0; | |||
const HybridModel *model = nullptr; | |||
HybridModel *model = nullptr; | |||
const GEThreadLocalContext *ge_context = nullptr; | |||
rtStream_t stream = nullptr; | |||
rtContext_t rt_context = nullptr; | |||
@@ -26,6 +26,7 @@ | |||
#include "hybrid/model/hybrid_model_builder.h" | |||
#include "hybrid/node_executor/node_executor.h" | |||
#include "common/op/ge_op_utils.h" | |||
#include "graph/load/model_manager/tbe_handle_store.h" | |||
namespace ge { | |||
namespace hybrid { | |||
@@ -37,6 +38,7 @@ HybridModel::HybridModel(GeRootModelPtr ge_model) : ge_root_model_(std::move(ge_ | |||
} | |||
HybridModel::~HybridModel() { | |||
CleanTbeHandle(); | |||
GELOGD("[%s] HybridModel destroyed.", model_name_.c_str()); | |||
} | |||
@@ -454,5 +456,35 @@ Status HybridModel::GetOpAttr(const std::string &op_name, const std::string &att | |||
GELOGD("Get attr:%s of op:%s success, attr value:%s", attr_name.c_str(), op_name.c_str(), attr_value.c_str()); | |||
return SUCCESS; | |||
} | |||
void HybridModel::SetUsedTbeHandleMap(const std::string &handle_key) { | |||
used_tbe_handle_map_[handle_key] = 1; | |||
} | |||
void HybridModel::StoreTbeHandle(const std::string &handle_key) { | |||
// Online mode FE may call rtFunctionRegister. | |||
TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); | |||
auto it = used_tbe_handle_map_.find(handle_key); | |||
if (it != used_tbe_handle_map_.end()) { | |||
// GE registered, increase reference. | |||
kernel_store.ReferTBEHandle(handle_key); | |||
it->second++; | |||
return; | |||
} | |||
void *bin_handle = nullptr; | |||
if (kernel_store.FindTBEHandle(handle_key, bin_handle)) { | |||
// GE registered, increase reference. | |||
used_tbe_handle_map_[handle_key] = 1; // Init used num to 1. | |||
kernel_store.ReferTBEHandle(handle_key); | |||
} | |||
} | |||
void HybridModel::CleanTbeHandle() { | |||
TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); | |||
kernel_store.EraseTBEHandle(used_tbe_handle_map_); | |||
used_tbe_handle_map_.clear(); | |||
} | |||
} // namespace hybrid | |||
} // namespace ge |
@@ -137,6 +137,12 @@ class HybridModel { | |||
Status GetOpAttr(const std::string &op_name, const std::string &attr_name, std::string &attr_value) const; | |||
void SetUsedTbeHandleMap(const std::string &handle_key); | |||
void StoreTbeHandle(const std::string &handle_key); | |||
void CleanTbeHandle(); | |||
private: | |||
friend class HybridModelBuilder; | |||
friend class HybridModelAsyncExecutor; | |||
@@ -172,6 +178,7 @@ class HybridModel { | |||
std::unique_ptr<TensorBuffer> global_step_; | |||
// op name to attrs mapping | |||
std::map<std::string, std::map<std::string, std::vector<std::string>>> op_name_to_attrs_; | |||
map<string, uint32_t> used_tbe_handle_map_; | |||
}; | |||
} // namespace hybrid | |||
} // namespace ge | |||
@@ -47,7 +47,7 @@ Status AiCoreNodeExecutor::Initialize() { | |||
return SUCCESS; | |||
} | |||
Status AiCoreNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | |||
Status AiCoreNodeExecutor::LoadTask(HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | |||
GE_CHECK_NOTNULL(node); | |||
GELOGI("AiCoreNodeExecutor(%s) LoadTask Start.", node->GetName().c_str()); | |||
bool is_single_op = model.IsSingleOp(); | |||
@@ -72,7 +72,7 @@ Status AiCoreNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &nod | |||
AiCoreTaskBuilder builder(node->GetOpDesc(), *task_defs); | |||
std::unique_ptr<AiCoreNodeTask> node_task; | |||
GE_CHK_STATUS_RET(builder.BuildTask(node_task, true, is_single_op), | |||
GE_CHK_STATUS_RET(builder.BuildTask(model, node_task, true, is_single_op), | |||
"[Invoke][BuildTask][%s] Failed to build op tasks.", node->GetName().c_str()); | |||
task = std::move(node_task); | |||
GELOGI("AiCoreNodeExecutor(%s) LoadTask End.", node->GetName().c_str()); | |||
@@ -123,7 +123,7 @@ std::shared_ptr<AiCoreNodeTask> AiCoreNodeTaskRegistry::GetTask(const std::strin | |||
return (iter != reg_node_tasks_.end()) ? iter->second : nullptr; | |||
} | |||
Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, | |||
Status AiCoreNodeExecutor::CompileTask(HybridModel &model, | |||
const NodePtr &node, shared_ptr<NodeTask> &task) const { | |||
auto node_item = model.GetNodeItem(node); | |||
GE_CHECK_NOTNULL(node_item); | |||
@@ -164,7 +164,7 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, | |||
AiCoreTaskBuilder builder(node->GetOpDesc(), task_defs); | |||
std::unique_ptr<AiCoreNodeTask> node_task; | |||
GE_CHK_STATUS_RET(builder.BuildTask(node_task, false), | |||
GE_CHK_STATUS_RET(builder.BuildTask(model, node_task, false), | |||
"[Invoke][BuildTask][%s] Failed to build op tasks.", node->GetName().c_str()); | |||
node_task->SetWorkspaceSizes(op_desc->GetWorkspaceBytes()); | |||
aicore_task = std::move(node_task); | |||
@@ -70,8 +70,8 @@ class AiCoreNodeTask : public NodeTask { | |||
class AiCoreNodeExecutor : public NodeExecutor { | |||
public: | |||
Status Initialize() override; | |||
Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const override; | |||
Status CompileTask(const HybridModel &model, const NodePtr &node, | |||
Status LoadTask(HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const override; | |||
Status CompileTask(HybridModel &model, const NodePtr &node, | |||
std::shared_ptr<NodeTask> &task) const override; | |||
private: | |||
@@ -49,14 +49,14 @@ bool TbeHandleRegistry::AddHandle(std::unique_ptr<TbeHandleHolder> &&holder) { | |||
return ret.second; | |||
} | |||
Status AiCoreOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) { | |||
Status AiCoreOpTask::Init(HybridModel &hybrid_model, const OpDesc &op_desc, const domi::TaskDef &task_def) { | |||
log_name_ = op_desc.GetName() + "_tvmbin"; | |||
log_id_ = log_id++; | |||
auto op_desc_ptr = MakeShared<OpDesc>(op_desc); | |||
GE_CHECK_NOTNULL(op_desc_ptr); | |||
auto task_info = BuildTaskUtils::GetTaskInfo(op_desc_ptr); | |||
GELOGI("[TASK_INFO] %lu/%s %s.", log_id_, log_name_.c_str(), task_info.c_str()); | |||
GE_CHK_STATUS_RET_NOLOG(InitWithTaskDef(op_desc, task_def)); | |||
GE_CHK_STATUS_RET_NOLOG(InitWithTaskDef(hybrid_model, op_desc, task_def)); | |||
GE_CHK_STATUS_RET_NOLOG(InitTilingInfo(op_desc)); | |||
GE_CHECK_LE(op_desc.GetOutputsSize(), static_cast<size_t>(INT_MAX)); | |||
@@ -78,7 +78,7 @@ Status AiCoreOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) | |||
return SUCCESS; | |||
} | |||
Status AiCoreOpTask::RegisterTbeHandle(const OpDesc &op_desc) { | |||
Status AiCoreOpTask::RegisterTbeHandle(HybridModel &hybrid_model, const OpDesc &op_desc) { | |||
rtError_t rt_ret = rtQueryFunctionRegistered(stub_name_.c_str()); | |||
if (rt_ret != RT_ERROR_NONE || is_single_op_) { | |||
auto op_desc_ptr = MakeShared<OpDesc>(op_desc); | |||
@@ -133,7 +133,11 @@ Status AiCoreOpTask::RegisterTbeHandle(const OpDesc &op_desc) { | |||
GELOGI("TBE: binfile_key=%s, kernel_name=%s", stub_name_.c_str(), kernel_name.c_str()); | |||
GE_CHK_RT_RET(rtFunctionRegister(bin_handle, stub_name_.c_str(), | |||
stub_name_.c_str(), kernel_name.c_str(), 0)); | |||
hybrid_model.SetUsedTbeHandleMap(stub_name_.c_str()); | |||
return SUCCESS; | |||
} | |||
// Kernel registed, Increase used num in store. | |||
hybrid_model.StoreTbeHandle(stub_name_.c_str()); | |||
return SUCCESS; | |||
} | |||
@@ -190,11 +194,11 @@ Status AiCoreOpTask::RegisterKernelHandle(const OpDesc &op_desc) { | |||
return SUCCESS; | |||
} | |||
Status AiCoreOpTask::InitWithKernelDef(const OpDesc &op_desc, const domi::TaskDef &task_def) { | |||
Status AiCoreOpTask::InitWithKernelDef(HybridModel &hybrid_model, const OpDesc &op_desc, const domi::TaskDef &task_def) { | |||
const domi::KernelDef &kernel_def = task_def.kernel(); | |||
const domi::KernelContext &context = kernel_def.context(); | |||
stub_name_ = kernel_def.stub_func(); | |||
GE_CHK_STATUS_RET(RegisterTbeHandle(op_desc)); | |||
GE_CHK_STATUS_RET(RegisterTbeHandle(hybrid_model, op_desc)); | |||
GE_CHK_RT_RET(rtGetFunctionByName(stub_name_.c_str(), &stub_func_)); | |||
args_size_ = kernel_def.args_size(); | |||
block_dim_ = kernel_def.block_dim(); | |||
@@ -304,7 +308,7 @@ Status AiCoreOpTask::InitWithKernelDefWithHandle(const OpDesc &op_desc, const do | |||
return SUCCESS; | |||
} | |||
Status AiCoreOpTask::InitWithTaskDef(const OpDesc &op_desc, const domi::TaskDef &task_def) { | |||
Status AiCoreOpTask::InitWithTaskDef(HybridModel &hybrid_model, const OpDesc &op_desc, const domi::TaskDef &task_def) { | |||
auto rt_ret = ValidateTaskDef(task_def); | |||
if (rt_ret != SUCCESS) { | |||
@@ -316,7 +320,7 @@ Status AiCoreOpTask::InitWithTaskDef(const OpDesc &op_desc, const domi::TaskDef | |||
} | |||
if (task_def.type() != RT_MODEL_TASK_ALL_KERNEL) { | |||
GE_CHK_STATUS_RET(InitWithKernelDef(op_desc, task_def)); | |||
GE_CHK_STATUS_RET(InitWithKernelDef(hybrid_model, op_desc, task_def)); | |||
} else { | |||
GE_CHK_STATUS_RET(InitWithKernelDefWithHandle(op_desc, task_def)); | |||
} | |||
@@ -558,8 +562,8 @@ std::string AiCoreOpTask::GetKeyForKernelName(const OpDesc &op_desc) const { | |||
return op_desc.GetName() + "_kernelname"; | |||
} | |||
Status AtomicAddrCleanOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) { | |||
GE_CHK_STATUS_RET_NOLOG(AiCoreOpTask::Init(op_desc, task_def)); | |||
Status AtomicAddrCleanOpTask::Init(HybridModel &hybrid_model, const OpDesc &op_desc, const domi::TaskDef &task_def) { | |||
GE_CHK_STATUS_RET_NOLOG(AiCoreOpTask::Init(hybrid_model, op_desc, task_def)); | |||
return InitAtomicAddrCleanIndices(op_desc); | |||
} | |||
@@ -19,6 +19,7 @@ | |||
#include <memory> | |||
#include <vector> | |||
#include "hybrid/model/hybrid_model.h" | |||
#include "common/ge_inner_error_codes.h" | |||
#include "runtime/stream.h" | |||
#include "hybrid/common/tensor_value.h" | |||
@@ -59,7 +60,7 @@ class AiCoreOpTask { | |||
AiCoreOpTask() = default; | |||
virtual ~AiCoreOpTask() = default; | |||
virtual Status Init(const OpDesc &op_desc, const domi::TaskDef &task_def); | |||
virtual Status Init(HybridModel &hybrid_model, const OpDesc &op_desc, const domi::TaskDef &task_def); | |||
bool IsDynamicShapeSupported(); | |||
@@ -94,11 +95,11 @@ class AiCoreOpTask { | |||
private: | |||
static Status ValidateTaskDef(const domi::TaskDef &task_def); | |||
Status InitWithTaskDef(const OpDesc &node, const domi::TaskDef &task_def); | |||
Status InitWithTaskDef(HybridModel &hybrid_model, const OpDesc &node, const domi::TaskDef &task_def); | |||
Status InitTilingInfo(const OpDesc &op_desc); | |||
Status RegisterTbeHandle(const OpDesc &op_desc); | |||
Status RegisterTbeHandle(HybridModel &hybrid_model, const OpDesc &op_desc); | |||
Status RegisterKernelHandle(const OpDesc &op_desc); | |||
Status InitWithKernelDef(const OpDesc &op_desc, const domi::TaskDef &task_def); | |||
Status InitWithKernelDef(HybridModel &hybrid_model, const OpDesc &op_desc, const domi::TaskDef &task_def); | |||
Status InitWithKernelDefWithHandle(const OpDesc &node, const domi::TaskDef &task_def); | |||
std::string stub_name_; | |||
@@ -121,7 +122,7 @@ class AiCoreOpTask { | |||
class AtomicAddrCleanOpTask : public AiCoreOpTask { | |||
public: | |||
Status Init(const OpDesc &op_desc, const domi::TaskDef &task_def) override; | |||
Status Init(HybridModel &hybrid_model, const OpDesc &op_desc, const domi::TaskDef &task_def) override; | |||
Status UpdateArgs(TaskContext &task_context) override; | |||
protected: | |||
@@ -37,7 +37,8 @@ AiCoreTaskBuilder::AiCoreTaskBuilder(const OpDescPtr &op_desc, const std::vector | |||
: op_desc_(op_desc), task_defs_(task_defs) { | |||
} | |||
Status AiCoreTaskBuilder::BuildTask(std::unique_ptr<AiCoreNodeTask> &node_task, | |||
Status AiCoreTaskBuilder::BuildTask(HybridModel &hybrid_model, | |||
std::unique_ptr<AiCoreNodeTask> &node_task, | |||
bool ignore_failure_on_atomic, | |||
bool is_single_op) { | |||
GE_CHECK_NOTNULL(op_desc_); | |||
@@ -71,7 +72,7 @@ Status AiCoreTaskBuilder::BuildTask(std::unique_ptr<AiCoreNodeTask> &node_task, | |||
std::unique_ptr<AtomicAddrCleanOpTask>(new(std::nothrow)AtomicAddrCleanOpTask()); | |||
GE_CHECK_NOTNULL(atomic_task); | |||
atomic_task->SetSingleOp(is_single_op); | |||
GE_CHK_STATUS_RET(atomic_task->Init(*op_desc_, task_defs_.front()), | |||
GE_CHK_STATUS_RET(atomic_task->Init(hybrid_model, *op_desc_, task_defs_.front()), | |||
"[Invoke][AtomicAddrCleanOpTask::Init] failed for [%s].", | |||
op_desc_->GetName().c_str()); | |||
op_tasks.emplace_back(std::move(atomic_task)); | |||
@@ -81,7 +82,7 @@ Status AiCoreTaskBuilder::BuildTask(std::unique_ptr<AiCoreNodeTask> &node_task, | |||
auto aicore_task = std::unique_ptr<AiCoreOpTask>(new(std::nothrow)AiCoreOpTask()); | |||
GE_CHECK_NOTNULL(aicore_task); | |||
aicore_task->SetSingleOp(is_single_op); | |||
GE_CHK_STATUS_RET(aicore_task->Init(*op_desc_, task_defs_.back()), | |||
GE_CHK_STATUS_RET(aicore_task->Init(hybrid_model, *op_desc_, task_defs_.back()), | |||
"[Invoke][AiCoreOpTask::Init] failed for [%s].", | |||
op_desc_->GetName().c_str()); | |||
op_tasks.emplace_back(std::move(aicore_task)); | |||
@@ -48,7 +48,8 @@ class AiCoreTaskBuilder { | |||
AiCoreTaskBuilder(const OpDescPtr &op_desc, const std::vector<domi::TaskDef> &task_defs); | |||
~AiCoreTaskBuilder() = default; | |||
Status BuildTask(std::unique_ptr<AiCoreNodeTask> &node_task, | |||
Status BuildTask(HybridModel &hybrid_model, | |||
std::unique_ptr<AiCoreNodeTask> &node_task, | |||
bool ignore_failure_on_atomic, | |||
bool is_single_op = false); | |||
@@ -858,7 +858,7 @@ Status AiCpuNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) cons | |||
return status; | |||
} | |||
Status AiCpuNodeExecutor::LoadTask(const HybridModel &model, | |||
Status AiCpuNodeExecutor::LoadTask(HybridModel &model, | |||
const NodePtr &node, | |||
std::shared_ptr<NodeTask> &task) const { | |||
GE_CHECK_NOTNULL(node); | |||
@@ -176,7 +176,7 @@ class AicpuNodeTask : public AicpuNodeTaskBase { | |||
class AiCpuNodeExecutor : public NodeExecutor { | |||
public: | |||
Status LoadTask(const HybridModel &model, | |||
Status LoadTask(HybridModel &hybrid_model, | |||
const NodePtr &node, | |||
std::shared_ptr<NodeTask> &task) const override; | |||
@@ -180,7 +180,7 @@ Status KnownNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) cons | |||
return SUCCESS; | |||
} | |||
Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, | |||
Status KnownNodeExecutor::LoadTask(HybridModel &model, const NodePtr &node, | |||
shared_ptr<NodeTask> &task) const { | |||
GELOGI("[%s] KnownNodeExecutor::LoadTask in.", node->GetName().c_str()); | |||
GE_CHECK_NOTNULL(node); | |||
@@ -47,7 +47,7 @@ class KnownNodeTask : public NodeTask { | |||
class KnownNodeExecutor : public NodeExecutor { | |||
public: | |||
Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const; | |||
Status LoadTask(HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const; | |||
Status PrepareTask(NodeTask &task, TaskContext &context) const; | |||
Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const; | |||
~KnownNodeExecutor() {} | |||
@@ -372,7 +372,7 @@ Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_conti | |||
return SUCCESS; | |||
} | |||
Status ControlOpNodeExecutor::LoadTask(const HybridModel &model, | |||
Status ControlOpNodeExecutor::LoadTask(HybridModel &model, | |||
const NodePtr &node, | |||
shared_ptr<NodeTask> &task) const { | |||
auto node_item = model.GetNodeItem(node); | |||
@@ -93,7 +93,7 @@ class WhileOpNodeTask : public ControlOpNodeTask { | |||
class ControlOpNodeExecutor : public NodeExecutor { | |||
public: | |||
Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const override; | |||
Status LoadTask(HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const override; | |||
Status PrepareTask(NodeTask &task, TaskContext &context) const override; | |||
}; | |||
} // namespace hybrid | |||
@@ -219,7 +219,7 @@ Status GeLocalNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) co | |||
return status; | |||
} | |||
Status GeLocalNodeExecutor::LoadTask(const HybridModel &model, | |||
Status GeLocalNodeExecutor::LoadTask(HybridModel &model, | |||
const NodePtr &node, | |||
std::shared_ptr<NodeTask> &task) const { | |||
GE_CHECK_NOTNULL(node); | |||
@@ -85,7 +85,7 @@ class GeLocalNodeExecutor : public NodeExecutor { | |||
Status PrepareTask(NodeTask &task, TaskContext &context) const override; | |||
virtual Status LoadTask(const HybridModel &model, | |||
virtual Status LoadTask(HybridModel &model, | |||
const NodePtr &node, | |||
std::shared_ptr<NodeTask> &task) const override; | |||
}; | |||
@@ -365,7 +365,7 @@ Status HcclNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const | |||
return SUCCESS; | |||
} | |||
Status HcclNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | |||
Status HcclNodeExecutor::LoadTask(HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | |||
GELOGI("[%s] HcclNodeExecutor::LoadTask in.", node->GetName().c_str()); | |||
GE_CHECK_NOTNULL(node); | |||
if ((kRdmaReadTypes.count(node->GetType()) > 0) || (kRdmaWriteTypes.count(node->GetType()) > 0)) { | |||
@@ -64,7 +64,7 @@ class RdmaNodeTask : public NodeTask { | |||
class HcclNodeExecutor : public NodeExecutor { | |||
public: | |||
Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const; | |||
Status LoadTask(HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const; | |||
Status PrepareTask(NodeTask &task, TaskContext &context) const; | |||
Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const; | |||
Status Initialize() override; | |||
@@ -115,7 +115,7 @@ Status HostCpuNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) co | |||
return task.UpdateArgs(context); | |||
} | |||
Status HostCpuNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, | |||
Status HostCpuNodeExecutor::LoadTask(HybridModel &model, const NodePtr &node, | |||
std::shared_ptr<NodeTask> &task) const { | |||
GE_CHECK_NOTNULL(node); | |||
auto op_desc = node->GetOpDesc(); | |||
@@ -58,7 +58,7 @@ class HostCpuNodeExecutor : public NodeExecutor { | |||
public: | |||
Status PrepareTask(NodeTask &task, TaskContext &context) const override; | |||
Status LoadTask(const HybridModel &model, | |||
Status LoadTask(HybridModel &model, | |||
const NodePtr &node, | |||
std::shared_ptr<NodeTask> &task) const override; | |||
}; | |||
@@ -49,11 +49,11 @@ Status NodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context, const std | |||
return SUCCESS; | |||
} | |||
Status NodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | |||
Status NodeExecutor::LoadTask(HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | |||
return UNSUPPORTED; | |||
} | |||
Status NodeExecutor::CompileTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | |||
Status NodeExecutor::CompileTask(HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | |||
return UNSUPPORTED; | |||
} | |||
@@ -112,7 +112,7 @@ class NodeExecutor { | |||
* @param task generated node task | |||
* @return SUCCESS on success, error code otherwise | |||
*/ | |||
virtual Status LoadTask(const HybridModel &model, | |||
virtual Status LoadTask(HybridModel &model, | |||
const NodePtr &node, | |||
std::shared_ptr<NodeTask> &task) const; | |||
@@ -123,7 +123,7 @@ class NodeExecutor { | |||
* @param task generated node task | |||
* @return SUCCESS on success, error code otherwise | |||
*/ | |||
virtual Status CompileTask(const HybridModel &model, | |||
virtual Status CompileTask(HybridModel &model, | |||
const NodePtr &node, | |||
std::shared_ptr<NodeTask> &task) const; | |||
@@ -66,7 +66,7 @@ Status PartitionedCallNodeTask::UpdateArgs(TaskContext &context) { | |||
return SUCCESS; | |||
} | |||
Status PartitionedCallNodeExecutor::LoadTask(const ge::hybrid::HybridModel &model, | |||
Status PartitionedCallNodeExecutor::LoadTask(ge::hybrid::HybridModel &model, | |||
const ge::NodePtr &node, | |||
std::shared_ptr<NodeTask> &task) const { | |||
GELOGD("Load dynamic partitioned call: [%s]", node->GetName().c_str()); | |||
@@ -45,7 +45,7 @@ class PartitionedCallNodeTask : public NodeTask { | |||
class PartitionedCallNodeExecutor : public NodeExecutor { | |||
public: | |||
Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const override; | |||
Status LoadTask(HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const override; | |||
Status PrepareTask(NodeTask &task, TaskContext &context) const override; | |||
}; | |||
} // namespace hybrid | |||
@@ -130,7 +130,7 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function< | |||
return SUCCESS; | |||
} | |||
Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | |||
Status RtsNodeExecutor::LoadTask(HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | |||
GE_CHECK_NOTNULL(node); | |||
GELOGD("[%s] Load for local task.", node->GetName().c_str()); | |||
std::string node_type; | |||
@@ -52,7 +52,7 @@ class ProfilingTraceNodeTask : public RtsNodeTask { | |||
class RtsNodeExecutor : public NodeExecutor { | |||
public: | |||
Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const override; | |||
Status LoadTask(HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const override; | |||
}; | |||
} // namespace hybrid | |||
} // namespace ge | |||
@@ -76,6 +76,10 @@ static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { | |||
} | |||
TEST_F(UtestGeHybrid, aicore_op_task_init_success) { | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||
HybridModel hybrid_model(ge_root_model); | |||
// build aicore task | |||
auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask()); | |||
domi::TaskDef task_def; | |||
@@ -99,7 +103,44 @@ TEST_F(UtestGeHybrid, aicore_op_task_init_success) { | |||
op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); | |||
std::string kernel_name("kernel/Add"); | |||
AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); | |||
ASSERT_EQ(aicore_task->InitWithTaskDef(*op_desc.get(), task_def), SUCCESS); | |||
ASSERT_EQ(aicore_task->InitWithTaskDef(hybrid_model, *op_desc.get(), task_def), SUCCESS); | |||
rtStream_t stream = nullptr; | |||
rtStreamCreate(&stream, 0); | |||
ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS); | |||
char *handle = ""; | |||
aicore_task->handle_ = handle; | |||
aicore_task->tiling_key_ = 1; | |||
ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS); | |||
} | |||
TEST_F(UtestGeHybrid, aicore_op_task_init_success2) { | |||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||
HybridModel hybrid_model(ge_root_model); | |||
// build aicore task | |||
auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask()); | |||
domi::TaskDef task_def; | |||
task_def.set_type(RT_MODEL_TASK_KERNEL); | |||
domi::KernelDef *kernel = task_def.mutable_kernel(); | |||
kernel->set_original_kernel_key(""); | |||
kernel->set_node_info(""); | |||
kernel->set_block_dim(32); | |||
kernel->set_args_size(64); | |||
string args(64, '1'); | |||
kernel->set_args(args.data(), 64); | |||
domi::KernelContext *context = kernel->mutable_context(); | |||
context->set_op_index(1); | |||
context->set_kernel_type(2); // ccKernelType::TE | |||
uint16_t args_offset[9] = {0}; | |||
context->set_args_offset(args_offset, 9 * sizeof(uint16_t)); | |||
OpDescPtr op_desc = CreateOpDesc("Add", "Add"); | |||
std::vector<char> kernelBin; | |||
TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin)); | |||
op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); | |||
std::string kernel_name("kernel/Add"); | |||
AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); | |||
ASSERT_EQ(aicore_task->InitWithTaskDef(hybrid_model, *op_desc.get(), task_def), SUCCESS); | |||
rtStream_t stream = nullptr; | |||
rtStreamCreate(&stream, 0); | |||
ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS); | |||