Browse Source

!1705 add NoOp node task

From: @selfws
Reviewed-by: @xchu42,@wqtshg
Signed-off-by: @wqtshg
tags/v1.3.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
87235df6b8
4 changed files with 144 additions and 2 deletions
  1. +21
    -2
      ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc
  2. +8
    -0
      ge/hybrid/node_executor/ge_local/ge_local_node_executor.h
  3. +1
    -0
      tests/ut/ge/CMakeLists.txt
  4. +114
    -0
      tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc

+ 21
- 2
ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc View File

@@ -37,7 +37,7 @@ const std::map<std::string, std::vector<uint32_t>>
{BROADCASTGRADIENTARGS, {}}
};

const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE, NOOP};
const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE};

Status RefInputTask::UpdateArgs(TaskContext &) {
// no need update args
@@ -252,9 +252,16 @@ Status GeLocalNodeExecutor::LoadTask(const HybridModel &model,
GELOGE(INTERNAL_ERROR, "[Get][Tensor] failed for name: %s", node->GetName().c_str());
return INTERNAL_ERROR;
}

task = MakeShared<ConstantNodeTask>(tensor);
GE_CHECK_NOTNULL(task);
} else if (node_type == NOOP) {
GELOGI("node %s type %s , use NoOpNodeTask.", node->GetName().c_str(), node_type.c_str());
task = MakeShared<NoOpNodeTask>();
if (task == nullptr) {
REPORT_CALL_ERROR("E19999", "Create NoOpNodeTask failed for NoOp node %s.", node->GetName().c_str());
GELOGE(MEMALLOC_FAILED, "[Create][NoOpNodeTask]failed for NoOp node %s.", node->GetName().c_str());
return MEMALLOC_FAILED;
}
} else {
GELOGE(UNSUPPORTED, "node %s type %s is not support in GeLocalNodeExecutor now.",
node->GetName().c_str(), node_type.c_str());
@@ -280,5 +287,17 @@ Status ConstantNodeTask::ExecuteAsync(TaskContext &context, std::function<void()
GELOGD("[%s] Done execute successfully.", context.GetNodeName());
return SUCCESS;
}

Status NoOpNodeTask::UpdateArgs(TaskContext &context) {
// no need to update args
return SUCCESS;
}

Status NoOpNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) {
GELOGD("[%s] Start execute.", context.GetNodeName());
GE_CHK_STATUS_RET(context.TryExecuteCallback(done_callback));
GELOGD("[%s] Done execute successfully.", context.GetNodeName());
return SUCCESS;
}
} // namespace hybrid
} // namespace ge

+ 8
- 0
ge/hybrid/node_executor/ge_local/ge_local_node_executor.h View File

@@ -80,6 +80,14 @@ class ConstantNodeTask : public NodeTask {
const TensorValue *tensor_;
};

class NoOpNodeTask : public NodeTask {
public:
explicit NoOpNodeTask() = default;
~NoOpNodeTask() = default;
Status UpdateArgs(TaskContext &context) override;
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
};

class GeLocalNodeExecutor : public NodeExecutor {
public:



+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -838,6 +838,7 @@ set(HYBRID_TEST_FILES
"hybrid/executor/worker/execution_engine_unittest.cc"
"hybrid/model/hybrid_model_builder_unittest.cc"
"hybrid/node_executor/rts/rts_node_task_unittest.cc"
"hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc"
"hybrid/executor/hybrid_model_async_executor_unittest.cc"
"hybrid/executor/hybrid_model_pipeline_executor_unittest.cc"
"hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc"


+ 114
- 0
tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc View File

@@ -0,0 +1,114 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <vector>

#define private public
#define protected public
#include "hybrid/executor/subgraph_context.h"
#include "hybrid/node_executor/ge_local/ge_local_node_executor.h"
#include "model/ge_root_model.h"
#undef protected
#undef private

using namespace std;
using namespace testing;

namespace ge {
using namespace hybrid;

class UtestGeLocalNodeExecutor : public testing::Test {
protected:
void SetUp() {}
void TearDown() { }
};

static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
op_desc->SetStreamId(0);
static int32_t index = 0;
op_desc->SetId(index++);

GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64);
TensorUtils::SetSize(tensor, 64);
vector<int64_t> input_offset;
for (int i = 0; i < in_num; i++) {
op_desc->AddInputDesc(tensor);
input_offset.emplace_back(i * 64);
}
op_desc->SetInputOffset(input_offset);

vector<int64_t> output_offset;
for (int i = 0; i < out_num; i++) {
op_desc->AddOutputDesc(tensor);
output_offset.emplace_back(in_num * 64 + i * 64);
}
op_desc->SetOutputOffset(output_offset);

op_desc->SetWorkspace({});
op_desc->SetWorkspaceBytes({});
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE");

return graph.AddNode(op_desc);
}

TEST_F(UtestGeLocalNodeExecutor, test_no_op_task) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeModelPtr ge_sub_model = std::make_shared<GeModel>();
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);
HybridModel hybrid_model(ge_root_model);

NodePtr node = CreateNode(*graph, "noop", NOOP, 0, 0);

std::unique_ptr<NodeItem> new_node;
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS);
NodeItem *node_item = new_node.get();
hybrid_model.node_items_[node] = std::move(new_node);
node_item->input_start = 0;
node_item->output_start = 0;

GraphItem graph_item;
graph_item.node_items_.emplace_back(node_item);
graph_item.total_inputs_ = 0;
graph_item.total_outputs_ = 0;

GraphExecutionContext graph_context;
SubgraphContext subgraph_context(&graph_item, &graph_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());

auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

NodeTaskPtr task = nullptr;
GeLocalNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
ASSERT_NE(task, nullptr);

ASSERT_EQ(task->UpdateArgs(*node_state->GetTaskContext()), SUCCESS);
std::function<void()> done = []() {};
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS);
}
} // namespace ge

Loading…
Cancel
Save