Browse Source

GE supports aicore DEPEND_SHAPE_RANGE op update outputs shape

pull/2008/head
zhujingjing 3 years ago
parent
commit
84b52ef058
1 changed files with 79 additions and 2 deletions
  1. +79
    -2
      tests/ut/ge/hybrid/node_executor/aicore/aicore_op_task_unittest.cc

+ 79
- 2
tests/ut/ge/hybrid/node_executor/aicore/aicore_op_task_unittest.cc View File

@@ -22,6 +22,7 @@
#define private public
#define protected public
#include "framework/common/taskdown_common.h"
#include "hybrid/executor/subgraph_context.h"
#include "hybrid/node_executor/aicore/aicore_op_task.h"
#include "init/gelib.h"
#undef private
@@ -80,12 +81,88 @@ TEST_F(UtestAiCoreOpTask, Init_failed) {
DEPEND_SHAPE_RANGE);
domi::TaskDef task_def;
task_def.set_type(RT_MODEL_TASK_KERNEL);
std::vector<uint8_t> args(10, 0);
std::vector<uint8_t> args(100, 0);
task_def.mutable_kernel()->set_args(args.data(), args.size());
task_def.mutable_kernel()->set_args_size(10);
task_def.mutable_kernel()->set_args_size(100);
task_def.mutable_kernel()->mutable_context()->set_kernel_type(
ccKernelType::TE);
uint16_t args_offset = 20;
char *a = reinterpret_cast<char *>(&args_offset);
task_def.mutable_kernel()->mutable_context()->set_args_offset(
a, 2 * sizeof(uint16_t));
EXPECT_EQ(task1->Init(*op_desc, task_def), ge::PARAM_INVALID);
dlog_setlevel(0, 3, 0);
}

TEST_F(UtestAiCoreOpTask, Init_success) {
dlog_setlevel(0, 0, 0);
std::unique_ptr<AiCoreOpTask> task1(new AiCoreOpTask());
OpDescPtr op_desc = CreateOpDesc("Add", "Add", 2, 1);
ge::AttrUtils::SetInt(*op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE,
DEPEND_SHAPE_RANGE);
domi::TaskDef task_def;
task_def.set_type(RT_MODEL_TASK_KERNEL);
std::vector<uint8_t> args(100, 0);
task_def.mutable_kernel()->set_args(args.data(), args.size());
task_def.mutable_kernel()->set_args_size(100);
task_def.mutable_kernel()->mutable_context()->set_kernel_type(
ccKernelType::TE);
uint16_t args_offset = 20;
char *a = reinterpret_cast<char *>(&args_offset);
task_def.mutable_kernel()->mutable_context()->set_args_offset(
a, 2 * sizeof(uint16_t));
EXPECT_EQ(task1->Init(*op_desc, task_def), ge::SUCCESS);
dlog_setlevel(0, 3, 0);
}

TEST_F(UtestAiCoreOpTask, UpdateOutputsShape_success) {
dlog_setlevel(0, 0, 0);
std::unique_ptr<AiCoreOpTask> task1(new AiCoreOpTask());
OpDescPtr op_desc = CreateOpDesc("Add", "Add", 2, 1);
ge::AttrUtils::SetInt(*op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE,
DEPEND_SHAPE_RANGE);
domi::TaskDef task_def;
task_def.set_type(RT_MODEL_TASK_KERNEL);
std::vector<uint8_t> args(100, 0);
task_def.mutable_kernel()->set_args(args.data(), args.size());
task_def.mutable_kernel()->set_args_size(100);
task_def.mutable_kernel()->mutable_context()->set_kernel_type(
ccKernelType::TE);
uint16_t args_offset = 20;
char *a = reinterpret_cast<char *>(&args_offset);
task_def.mutable_kernel()->mutable_context()->set_args_offset(
a, 2 * sizeof(uint16_t));
EXPECT_EQ(task1->Init(*op_desc, task_def), ge::SUCCESS);

ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
NodePtr node = graph->AddNode(op_desc);
std::unique_ptr<NodeItem> new_node;
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS);
NodeItem *node_item = new_node.get();
node_item->input_start = 0;
node_item->output_start = 0;
node_item->is_dynamic = true;
node_item->shape_inference_type = DEPEND_SHAPE_RANGE;

GraphItem graph_item;
graph_item.node_items_.emplace_back(node_item);
graph_item.total_inputs_ = 2;
graph_item.total_outputs_ = 1;

GraphExecutionContext graph_context;
SubgraphContext subgraph_context(&graph_item, &graph_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
graph_context.callback_manager =
std::unique_ptr<CallbackManager>(new CallbackManager());

auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);
auto outputs_shape =
reinterpret_cast<uint32_t(*)[1]>(task1->shape_buffer_->GetData());
outputs_shape[0][0] = 2;
outputs_shape[0][1] = 1;
outputs_shape[0][2] = 2;
ASSERT_EQ(task1->UpdateOutputsShape(*node_state->GetTaskContext()), SUCCESS);
dlog_setlevel(0, 3, 0);
}
} // namespace ge

Loading…
Cancel
Save