diff --git a/ge/hybrid/node_executor/node_executor.cc b/ge/hybrid/node_executor/node_executor.cc index 5f3d6e45..04225557 100755 --- a/ge/hybrid/node_executor/node_executor.cc +++ b/ge/hybrid/node_executor/node_executor.cc @@ -58,8 +58,8 @@ Status NodeExecutor::CompileTask(const HybridModel &model, const NodePtr &node, } Status NodeExecutorManager::EnsureInitialized() { - GE_CHK_STATUS_RET(InitializeExecutors()); std::lock_guard lk(mu_); + ++ref_count_; if (initialized_) { return SUCCESS; } @@ -115,17 +115,14 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node return it->second; } -Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executor) const { +Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executor) { auto executor_type = ResolveExecutorType(node); + GELOGD("[%s] Set node executor by type: %d.", node.GetName().c_str(), static_cast(executor_type)); const auto it = executors_.find(executor_type); if (it == executors_.end()) { - REPORT_INNER_ERROR("E19999", "Failed to get executor by type: %d.", static_cast(executor_type)); - GELOGE(INTERNAL_ERROR, "[Check][ExecutorType]Failed to get executor by type: %d.", - static_cast(executor_type)); - return INTERNAL_ERROR; + return GetOrCreateExecutor(executor_type, executor); } - GELOGD("[%s] Set node executor by type: %d.", node.GetName().c_str(), static_cast(executor_type)); *executor = it->second.get(); return SUCCESS; } @@ -178,51 +175,50 @@ Status NodeExecutorManager::CalcOpRunningParam(Node &node) const { return OpsKernelBuilderManager::Instance().CalcOpRunningParam(node); } -Status NodeExecutorManager::InitializeExecutors() { +Status NodeExecutorManager::GetOrCreateExecutor(ExecutorType executor_type, const NodeExecutor **out_executor) { std::lock_guard lk(mu_); - if (executor_initialized_) { - ++ref_count_; - GELOGI("Executor is already initialized. add ref count to [%d]", ref_count_); + const auto executor_it = executors_.find(executor_type); + if (executor_it != executors_.end()) { + *out_executor = executor_it->second.get(); return SUCCESS; } - GELOGI("Start to Initialize NodeExecutors"); - for (auto &it : builders_) { - auto engine_type = it.first; - auto build_fn = it.second; - GE_CHECK_NOTNULL(build_fn); - auto executor = std::unique_ptr(build_fn()); - if (executor == nullptr) { - REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for engine type = %d", - static_cast(engine_type)); - GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for engine type = %d", static_cast(engine_type)); - return INTERNAL_ERROR; - } + GELOGI("Start to Initialize NodeExecutor, type = %d", static_cast(executor_type)); + auto it = builders_.find(executor_type); + if (it == builders_.end()) { + REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for executor type = %d", + static_cast(executor_type)); + GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for executor type = %d", static_cast(executor_type)); + return INTERNAL_ERROR; + } - GELOGD("Executor of engine type = %d was created successfully", static_cast(engine_type)); - auto ret = executor->Initialize(); - if (ret != SUCCESS) { - REPORT_CALL_ERROR("E19999", "Initialize NodeExecutor failed for type = %d", static_cast(engine_type)); - GELOGE(ret, "[Initialize][NodeExecutor] failed for type = %d", static_cast(engine_type)); - for (auto &executor_it : executors_) { - executor_it.second->Finalize(); - } - executors_.clear(); - return ret; - } + auto build_fn = it->second; + GE_CHECK_NOTNULL(build_fn); + auto executor = std::unique_ptr(build_fn()); + if (executor == nullptr) { + REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for executor type = %d", + static_cast(executor_type)); + GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for engine type = %d", static_cast(executor_type)); + return INTERNAL_ERROR; + } - executors_.emplace(engine_type, std::move(executor)); + GELOGD("Executor of engine type = %d was created successfully", static_cast(executor_type)); + auto ret = executor->Initialize(); + if (ret != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Initialize NodeExecutor failed for type = %d", static_cast(executor_type)); + GELOGE(ret, "[Initialize][NodeExecutor] failed for type = %d", static_cast(executor_type)); + return ret; } - ++ref_count_; - executor_initialized_ = true; - GELOGI("Initializing NodeExecutors successfully."); + *out_executor = executor.get(); + executors_.emplace(executor_type, std::move(executor)); + GELOGI("Initializing NodeExecutor successfully, type = %d", static_cast(executor_type)); return SUCCESS; } void NodeExecutorManager::FinalizeExecutors() { std::lock_guard lk(mu_); - if (!executor_initialized_) { + if (ref_count_ <= 0) { GELOGD("No need for finalizing for not initialized."); return; } @@ -237,7 +233,6 @@ void NodeExecutorManager::FinalizeExecutors() { it.second->Finalize(); } executors_.clear(); - executor_initialized_ = false; GELOGD("Done invoking Finalize successfully."); } diff --git a/ge/hybrid/node_executor/node_executor.h b/ge/hybrid/node_executor/node_executor.h index fffd4e7d..97c9cee9 100644 --- a/ge/hybrid/node_executor/node_executor.h +++ b/ge/hybrid/node_executor/node_executor.h @@ -179,8 +179,6 @@ class NodeExecutorManager { */ Status EnsureInitialized(); - Status InitializeExecutors(); - void FinalizeExecutors(); /** @@ -196,7 +194,7 @@ class NodeExecutorManager { * @param executor executor * @return SUCCESS on success, error code otherwise */ - Status GetExecutor(Node &node, const NodeExecutor **executor) const; + Status GetExecutor(Node &node, const NodeExecutor **executor); /** * Resolve executor type by node @@ -206,12 +204,13 @@ class NodeExecutorManager { ExecutorType ResolveExecutorType(Node &node) const; private: + Status GetOrCreateExecutor(ExecutorType executor_type, const NodeExecutor **executor); + std::map> executors_; std::map> builders_; std::map engine_mapping_; std::mutex mu_; bool initialized_ = false; - bool executor_initialized_ = false; int ref_count_ = 0; }; diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 8b024820..631e18f8 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -839,6 +839,7 @@ set(HYBRID_TEST_FILES "hybrid/executor/subgraph_executor_unittest.cc" "hybrid/executor/worker/execution_engine_unittest.cc" "hybrid/model/hybrid_model_builder_unittest.cc" + "hybrid/node_executor/node_executor_unittest.cc" "hybrid/node_executor/rts/rts_node_task_unittest.cc" "hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc" "hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" @@ -846,6 +847,7 @@ set(HYBRID_TEST_FILES "hybrid/executor/hybrid_model_async_executor_unittest.cc" "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" + ) set(OTHERS_TEST_FILES diff --git a/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc new file mode 100644 index 00000000..8a1240d3 --- /dev/null +++ b/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#define private public +#define protected public +#include "hybrid/node_executor/node_executor.h" +#undef protected +#undef private + +using namespace std; +using namespace testing; + +namespace ge { +using namespace hybrid; + +namespace { + bool finalized = false; +} + +class NodeExecutorTest : public testing::Test { + protected: + void SetUp() {} + void TearDown() { } +}; + +class FailureNodeExecutor : public NodeExecutor { + public: + Status Initialize() override { + return INTERNAL_ERROR; + } +}; + +class SuccessNodeExecutor : public NodeExecutor { + public: + Status Initialize() override { + initialized = true; + finalized = false; + return SUCCESS; + } + + Status Finalize() override { + finalized = true; + } + + bool initialized = false; +}; + +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICORE, FailureNodeExecutor); +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICPU_TF, SuccessNodeExecutor); + +TEST_F(NodeExecutorTest, TestGetOrCreateExecutor) { + auto &manager = NodeExecutorManager::GetInstance(); + const NodeExecutor *executor = nullptr; + Status ret = SUCCESS; + // no builder + ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::RESERVED, &executor); + ASSERT_EQ(ret, INTERNAL_ERROR); + // initialize failure + ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICORE, &executor); + ASSERT_EQ(ret, INTERNAL_ERROR); + ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICPU_TF, &executor); + ASSERT_EQ(ret, SUCCESS); + ASSERT_TRUE(executor != nullptr); + ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICPU_TF, &executor); + ASSERT_EQ(ret, SUCCESS); + ASSERT_TRUE(executor != nullptr); + ASSERT_TRUE(((SuccessNodeExecutor*)executor)->initialized); +} + +TEST_F(NodeExecutorTest, TestInitAndFinalize) { + auto &manager = NodeExecutorManager::GetInstance(); + manager.FinalizeExecutors(); + manager.EnsureInitialized(); + manager.EnsureInitialized(); + const NodeExecutor *executor = nullptr; + auto ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICPU_TF, &executor); + ASSERT_EQ(ret, SUCCESS); + ASSERT_TRUE(executor != nullptr); + ASSERT_TRUE(((SuccessNodeExecutor*)executor)->initialized); + manager.FinalizeExecutors(); + ASSERT_FALSE(manager.executors_.empty()); + manager.FinalizeExecutors(); + ASSERT_TRUE(manager.executors_.empty()); + ASSERT_TRUE(finalized); +} +} // namespace ge