|
|
@@ -0,0 +1,114 @@ |
|
|
|
/** |
|
|
|
* Copyright 2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
* You may obtain a copy of the License at |
|
|
|
* |
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
* |
|
|
|
* Unless required by applicable law or agreed to in writing, software |
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <gtest/gtest.h> |
|
|
|
#include <gmock/gmock.h> |
|
|
|
#include <vector> |
|
|
|
|
|
|
|
#define private public |
|
|
|
#define protected public |
|
|
|
#include "hybrid/executor/subgraph_context.h" |
|
|
|
#include "hybrid/node_executor/ge_local/ge_local_node_executor.h" |
|
|
|
#include "model/ge_root_model.h" |
|
|
|
#undef protected |
|
|
|
#undef private |
|
|
|
|
|
|
|
using namespace std; |
|
|
|
using namespace testing; |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
using namespace hybrid; |
|
|
|
|
|
|
|
class UtestGeLocalNodeExecutor : public testing::Test { |
|
|
|
protected: |
|
|
|
void SetUp() {} |
|
|
|
void TearDown() { } |
|
|
|
}; |
|
|
|
|
|
|
|
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { |
|
|
|
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); |
|
|
|
op_desc->SetStreamId(0); |
|
|
|
static int32_t index = 0; |
|
|
|
op_desc->SetId(index++); |
|
|
|
|
|
|
|
GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); |
|
|
|
TensorUtils::SetSize(tensor, 64); |
|
|
|
vector<int64_t> input_offset; |
|
|
|
for (int i = 0; i < in_num; i++) { |
|
|
|
op_desc->AddInputDesc(tensor); |
|
|
|
input_offset.emplace_back(i * 64); |
|
|
|
} |
|
|
|
op_desc->SetInputOffset(input_offset); |
|
|
|
|
|
|
|
vector<int64_t> output_offset; |
|
|
|
for (int i = 0; i < out_num; i++) { |
|
|
|
op_desc->AddOutputDesc(tensor); |
|
|
|
output_offset.emplace_back(in_num * 64 + i * 64); |
|
|
|
} |
|
|
|
op_desc->SetOutputOffset(output_offset); |
|
|
|
|
|
|
|
op_desc->SetWorkspace({}); |
|
|
|
op_desc->SetWorkspaceBytes({}); |
|
|
|
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); |
|
|
|
|
|
|
|
return graph.AddNode(op_desc); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(UtestGeLocalNodeExecutor, test_no_op_task) { |
|
|
|
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); |
|
|
|
GeModelPtr ge_sub_model = std::make_shared<GeModel>(); |
|
|
|
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); |
|
|
|
ge_root_model->SetModelName("test_name"); |
|
|
|
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); |
|
|
|
HybridModel hybrid_model(ge_root_model); |
|
|
|
|
|
|
|
NodePtr node = CreateNode(*graph, "noop", NOOP, 0, 0); |
|
|
|
|
|
|
|
std::unique_ptr<NodeItem> new_node; |
|
|
|
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); |
|
|
|
NodeItem *node_item = new_node.get(); |
|
|
|
hybrid_model.node_items_[node] = std::move(new_node); |
|
|
|
node_item->input_start = 0; |
|
|
|
node_item->output_start = 0; |
|
|
|
|
|
|
|
GraphItem graph_item; |
|
|
|
graph_item.node_items_.emplace_back(node_item); |
|
|
|
graph_item.total_inputs_ = 0; |
|
|
|
graph_item.total_outputs_ = 0; |
|
|
|
|
|
|
|
GraphExecutionContext graph_context; |
|
|
|
SubgraphContext subgraph_context(&graph_item, &graph_context); |
|
|
|
ASSERT_EQ(subgraph_context.Init(), SUCCESS); |
|
|
|
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); |
|
|
|
|
|
|
|
auto node_state = subgraph_context.GetOrCreateNodeState(node_item); |
|
|
|
ASSERT_NE(node_state, nullptr); |
|
|
|
|
|
|
|
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); |
|
|
|
ASSERT_NE(unique_task_context, nullptr); |
|
|
|
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); |
|
|
|
node_state->SetTaskContext(shared_task_context); |
|
|
|
|
|
|
|
NodeTaskPtr task = nullptr; |
|
|
|
GeLocalNodeExecutor node_executor; |
|
|
|
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); |
|
|
|
ASSERT_NE(task, nullptr); |
|
|
|
|
|
|
|
ASSERT_EQ(task->UpdateArgs(*node_state->GetTaskContext()), SUCCESS); |
|
|
|
std::function<void()> done = []() {}; |
|
|
|
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); |
|
|
|
} |
|
|
|
} // namespace ge |