From: @lichun30 Reviewed-by: @xchu42,@wqtshg Signed-off-by: @ji_chentags/v1.1.0
@@ -17,8 +17,6 @@ | |||
#include "aicore_node_executor.h" | |||
#include "cce/taskdown_common.hpp" | |||
#include "hybrid/executor/hybrid_execution_context.h" | |||
#include "init/gelib.h" | |||
#include "hybrid/executor/hybrid_execution_context.h" | |||
namespace ge { | |||
namespace hybrid { | |||
@@ -28,19 +26,10 @@ AiCoreNodeTask::AiCoreNodeTask(std::vector<std::unique_ptr<AiCoreOpTask>> &&task | |||
} | |||
Status AiCoreNodeExecutor::Initialize() { | |||
auto ge_lib = GELib::GetInstance(); | |||
GE_CHECK_NOTNULL(ge_lib); | |||
if (!ge_lib->InitFlag()) { | |||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge_lib is uninitialized, failed."); | |||
return GE_CLI_GE_NOT_INITIALIZED; | |||
compiler_ = TaskCompilerFactory::GetInstance().GetTaskCompiler(); | |||
if (compiler_ != nullptr) { | |||
GE_CHK_STATUS_RET(compiler_->Initialize(), "Failed to init aicore task compiler."); | |||
} | |||
auto &kernel_manager = ge_lib->OpsKernelManagerObj(); | |||
auto aic_ops_store = kernel_manager.GetOpsKernelInfoStore("AIcoreEngine"); | |||
GE_CHECK_NOTNULL(aic_ops_store); | |||
compiler_.reset(new(std::nothrow)AiCoreTaskCompiler(aic_ops_store)); | |||
GE_CHECK_NOTNULL(compiler_); | |||
return SUCCESS; | |||
} | |||
@@ -120,6 +109,12 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, | |||
GE_CHECK_NOTNULL(op_desc); | |||
GELOGI("AiCoreNodeExecutor(%s) CompileTask Start.", node->GetName().c_str()); | |||
auto ori_node_name = node->GetName(); | |||
if (compiler_ == nullptr) { | |||
GELOGE(FAILED, "[%s] Can not find any valid aicore task compiler.", ori_node_name.c_str()); | |||
return FAILED; | |||
} | |||
AiCoreNodeTaskRegistry ®istry = AiCoreNodeTaskRegistry::GetInstance(); | |||
std::string shape_key; | |||
GE_CHK_STATUS_RET(GenNodeKey(node, shape_key), "GenNodeKey failed, op name = %s.", node->GetName().c_str()); | |||
@@ -133,7 +128,6 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, | |||
} | |||
std::vector<domi::TaskDef> task_defs; | |||
auto ori_node_name = node->GetName(); | |||
op_desc->SetName(ori_node_name + "_" + shape_key); | |||
GE_CHK_STATUS_RET(compiler_->CompileOp(node, task_defs), "Compile op(%s) failed.", ori_node_name.c_str()); | |||
op_desc->SetName(ori_node_name); | |||
@@ -239,5 +233,23 @@ bool AiCoreNodeTask::IsNoOp(TaskContext &task_context) { | |||
return true; | |||
} | |||
TaskCompilerFactory &TaskCompilerFactory::GetInstance() { | |||
static TaskCompilerFactory instance; | |||
return instance; | |||
} | |||
void TaskCompilerFactory::Register(CreateFn fn) { | |||
compiler_func_ = fn; | |||
} | |||
std::unique_ptr<TaskCompiler> TaskCompilerFactory::GetTaskCompiler() { | |||
auto compiler_instance = std::unique_ptr<TaskCompiler>(compiler_func_()); | |||
return compiler_instance; | |||
} | |||
CompilerFunctionRegistrar::CompilerFunctionRegistrar(CreateFn fn) { | |||
TaskCompilerFactory::GetInstance().Register(fn); | |||
} | |||
} // namespace hybrid | |||
} // namespace ge |
@@ -18,13 +18,21 @@ | |||
#define GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ | |||
#include "hybrid/node_executor/aicore/aicore_task_builder.h" | |||
#include "hybrid/node_executor/aicore/aicore_task_compiler.h" | |||
#include "hybrid/node_executor/node_executor.h" | |||
#include <map> | |||
#include <mutex> | |||
namespace ge { | |||
namespace hybrid { | |||
class TaskCompiler { | |||
public: | |||
TaskCompiler() = default; | |||
virtual ~TaskCompiler() = default; | |||
virtual Status CompileOp(const NodePtr &node, std::vector<domi::TaskDef> &tasks) = 0; | |||
virtual Status Initialize() = 0; | |||
}; | |||
class AiCoreNodeTaskRegistry { | |||
public: | |||
~AiCoreNodeTaskRegistry() = default; | |||
@@ -65,8 +73,33 @@ class AiCoreNodeExecutor : public NodeExecutor { | |||
private: | |||
static Status GenNodeKey(const NodePtr &node, std::string &node_key); | |||
std::unique_ptr<AiCoreTaskCompiler> compiler_; | |||
std::unique_ptr<TaskCompiler> compiler_; | |||
}; | |||
using CreateFn = TaskCompiler *(*)(); | |||
class TaskCompilerFactory { | |||
public: | |||
static TaskCompilerFactory &GetInstance(); | |||
void Register(CreateFn fn); | |||
std::unique_ptr<TaskCompiler> GetTaskCompiler(); | |||
private: | |||
CreateFn compiler_func_; | |||
}; | |||
class CompilerFunctionRegistrar { | |||
public: | |||
CompilerFunctionRegistrar(CreateFn fn); | |||
~CompilerFunctionRegistrar() = default; | |||
}; | |||
} // namespace hybrid | |||
} // namespace ge | |||
#endif //GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ | |||
#define REGISTER_TASK_COMPILER(compiler) \ | |||
static ::ge::hybrid::CompilerFunctionRegistrar register_compiler_function \ | |||
__attribute__((unused)) = \ | |||
::ge::hybrid::CompilerFunctionRegistrar([]()->::ge::hybrid::TaskCompiler* { \ | |||
return new (std::nothrow) compiler(); \ | |||
}) \ | |||
#endif //GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ |
@@ -18,6 +18,7 @@ | |||
#include "framework/common/debug/log.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "opskernel_manager/ops_kernel_builder_manager.h" | |||
#include "init/gelib.h" | |||
namespace ge { | |||
namespace hybrid { | |||
@@ -25,11 +26,22 @@ namespace { | |||
uintptr_t kWeightBase = 0x10000000; | |||
uintptr_t kMemBase = 0x20000000; | |||
uint64_t kFakeSize = 0x10000000UL; | |||
REGISTER_TASK_COMPILER(AiCoreTaskCompiler); | |||
} | |||
std::mutex AiCoreTaskCompiler::mu_; | |||
AiCoreTaskCompiler::AiCoreTaskCompiler(OpsKernelInfoStorePtr aic_kernel_store) | |||
: aic_kernel_store_(std::move(aic_kernel_store)) {} | |||
Status AiCoreTaskCompiler::Initialize() { | |||
auto ge_lib = GELib::GetInstance(); | |||
GE_CHECK_NOTNULL(ge_lib); | |||
if (!ge_lib->InitFlag()) { | |||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge_lib is uninitialized, failed."); | |||
return GE_CLI_GE_NOT_INITIALIZED; | |||
} | |||
auto &kernel_manager = ge_lib->OpsKernelManagerObj(); | |||
aic_kernel_store_ = kernel_manager.GetOpsKernelInfoStore("AIcoreEngine"); | |||
GE_CHECK_NOTNULL(aic_kernel_store_); | |||
return SUCCESS; | |||
} | |||
Status AiCoreTaskCompiler::DoCompileOp(const NodePtr &node) const { | |||
GE_CHECK_NOTNULL(node); | |||
@@ -19,15 +19,17 @@ | |||
#include <mutex> | |||
#include "opskernel_manager/ops_kernel_manager.h" | |||
#include "aicore_node_executor.h" | |||
namespace ge { | |||
namespace hybrid { | |||
class AiCoreTaskCompiler { | |||
class AiCoreTaskCompiler : public TaskCompiler { | |||
public: | |||
explicit AiCoreTaskCompiler(OpsKernelInfoStorePtr aic_kernel_store); | |||
AiCoreTaskCompiler() = default; | |||
~AiCoreTaskCompiler() = default; | |||
Status CompileOp(const NodePtr &node, std::vector<domi::TaskDef> &tasks); | |||
Status CompileOp(const NodePtr &node, std::vector<domi::TaskDef> &tasks) override; | |||
Status Initialize() override; | |||
private: | |||
Status DoCompileOp(const NodePtr &node) const; | |||
Status DoGenerateTask(const Node &node, std::vector<domi::TaskDef> &tasks); | |||