Browse Source

Merge branch 'master' of gitee.com:mindspore/graphengine into baby3

tags/v1.3.0
gengchao4@huawei.com 4 years ago
parent
commit
b50ec7d24f
15 changed files with 245 additions and 75 deletions
  1. +7
    -6
      ge/generator/ge_generator.cc
  2. +28
    -51
      ge/graph/manager/graph_manager.cc
  3. +2
    -0
      ge/graph/manager/graph_manager.h
  4. +18
    -10
      ge/hybrid/executor/worker/execution_engine.cc
  5. +2
    -0
      ge/hybrid/executor/worker/execution_engine.h
  6. +2
    -1
      ge/hybrid/node_executor/task_context.cc
  7. +11
    -0
      ge/ir_build/atc_ir_common.cc
  8. +9
    -5
      ge/offline/main.cc
  9. +4
    -0
      inc/external/ge/ge_api_types.h
  10. +1
    -1
      metadef
  11. +1
    -1
      parser
  12. +1
    -0
      tests/ut/ge/CMakeLists.txt
  13. +30
    -0
      tests/ut/ge/graph/manager/graph_manager_unittest.cc
  14. +10
    -0
      tests/ut/ge/graph_ir/ge_ir_build_unittest.cc
  15. +119
    -0
      tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc

+ 7
- 6
ge/generator/ge_generator.cc View File

@@ -783,9 +783,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in
GELOGD("Inputs size is %zu, outputs size is %zu.", inputs.size(), outputs.size());
GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
impl_->is_offline_ = is_offline;
if (!is_offline) {
(void)AttrUtils::SetBool(op_desc, ATTR_SINGLE_OP_SCENE, true);
}
(void)AttrUtils::SetBool(op_desc, ATTR_SINGLE_OP_SCENE, true);

if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) {
GELOGE(PARAM_INVALID, "input param is invalid when build single op!");
@@ -824,7 +822,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in
auto node = comp_graph->FindNode(op_desc->GetName());
Status ret = CheckEngineTypeSupport(node, engine_type);
if (ret != SUCCESS) {
GELOGE(ret, "[Check][EngineType]value:%d for node:%s not support", engine_type, node->GetName().c_str());
GELOGE(ret, "[Check][EngineType]not support node:%s with engine of %d.", node->GetName().c_str(), engine_type);
return ret;
}
}
@@ -850,6 +848,11 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in

bool all_shape = false;
(void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape);
GELOGD("Node: %s, all_shape is %d, compile_flag is %d.", op_desc->GetName().c_str(), all_shape, compile_flag);
(void)AttrUtils::SetInt(ge_model, ATTR_NAME_BUILD_MODE, fuzz_compile_flag);
if (all_shape) {
(void)AttrUtils::SetBool(ge_model, kAicpuAllshape, all_shape);
}
if (all_shape && CheckNoAicore(root_graph)) {
GELOGD("Get aicpu all_shape kernel!");
vector<GeTensor> inputs_dynamic;
@@ -859,8 +862,6 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in
GE_CHK_STATUS_RET_NOLOG(
impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs_dynamic, outputs_dynamic));
} else if (fuzz_compile_flag) {
GELOGD("Get fuzz build result of %s.", op_desc->GetName().c_str());
(void)AttrUtils::SetInt(ge_model, ATTR_NAME_BUILD_MODE, fuzz_compile_flag);
GeAttrValue::LIST_NAMED_ATTRS fuzz_build_attrs;
if (GetFuzzBuildAttrs(op_desc, ge_root_model, fuzz_build_attrs) != SUCCESS) {
GELOGE(FAILED, "[Get][FuzzRet]Failed to get fuzz build result of %s.", op_desc->GetName().c_str());


+ 28
- 51
ge/graph/manager/graph_manager.cc View File

@@ -495,7 +495,7 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph,
auto compute_graph = GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);
compute_graph->SetGraphID(graph_id);
(void)AttrUtils::SetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, true);
SetSessionGraphId(compute_graph, graph_id);

if (CreateGraphNode(graph_id, graph, options) != SUCCESS) {
@@ -527,14 +527,7 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph,
return SUCCESS;
}

Status GraphManager::AddGraphWithCopy(const GraphId &graph_id, const Graph &graph,
const std::map<std::string, std::string> &options,
const OmgContext &omg_context) {
if (HasGraphNode(graph_id)) {
REPORT_INNER_ERROR("E19999", "graph_id:%u is exist, check invalid", graph_id);
GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, "[GraphManager] graph exists, graph_id = %u.", graph_id);
return GE_GRAPH_GRAPH_ALREADY_EXIST;
}
Status GraphManager::CheckGraphAdded(const GraphId &graph_id, const Graph &graph) {
auto compute_graph = GraphUtils::GetComputeGraph(graph);
if (compute_graph != nullptr) {
compute_graph->SetGraphID(graph_id);
@@ -553,58 +546,44 @@ Status GraphManager::AddGraphWithCopy(const GraphId &graph_id, const Graph &grap
GELOGE(FAILED, "compute graph is null");
return FAILED;
}
std::vector<NodePtr> input_nodes;
std::vector<NodePtr> output_nodes;
auto new_compute_graph = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes);
std::string session_graph_id;
if (!AttrUtils::GetStr(*new_compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id) ||
session_graph_id.empty()) {
session_graph_id = "-1_" + to_string(graph_id);
if (!AttrUtils::SetStr(*new_compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) {
GELOGW("Set attribute of compute graph failed.");
}
for (auto &subgraph : new_compute_graph->GetAllSubgraphs()) {
(void)AttrUtils::SetStr(*subgraph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id);
}
GELOGD("Get graph session_graph_id attr failed, set session id to default value: [0]");
}
return SUCCESS;
}

GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id);
if (graph_node == nullptr) {
REPORT_CALL_ERROR("E19999", "New GraphNode fail, graph_id:%u",
graph_id);
GELOGE(FAILED, "GraphNode make shared failed");
Status GraphManager::AddGraphWithCopy(const GraphId &graph_id, const Graph &graph,
const std::map<std::string, std::string> &options,
const OmgContext &omg_context) {
if (CheckGraphAdded(graph_id, graph) != SUCCESS) {
GELOGE(FAILED, "AddGraphWithCopy failed.");
return FAILED;
}
std::shared_ptr<Graph> graph_ptr = GraphUtils::CreateGraphPtrFromComputeGraph(new_compute_graph);
if (graph_ptr == nullptr) {
REPORT_CALL_ERROR("E19999", "New Graph fail, graph_id:%u",
graph_id);
GELOGE(FAILED, "GraphPtr make shared failed");
IncreaseGraphCount(graph_id);
// Do add graph
auto compute_graph = GraphUtils::GetComputeGraph(graph);
std::vector<NodePtr> input_nodes;
std::vector<NodePtr> output_nodes;
auto new_compute_graph = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes);
GE_CHECK_NOTNULL(new_compute_graph);
new_compute_graph->SetGraphID(graph_id);
SetSessionGraphId(new_compute_graph, graph_id);
std::shared_ptr<Graph> new_graph_ptr = GraphUtils::CreateGraphPtrFromComputeGraph(new_compute_graph);
if (CreateGraphNode(graph_id, *new_graph_ptr, options) != SUCCESS) {
GELOGE(FAILED, "Failed to create graph_node.");
return FAILED;
}
// update option about tuning graph
ParseOption(options, BUILD_MODE, options_.build_mode);
ParseOption(options, BUILD_STEP, options_.build_step);
ParseOption(options, TUNING_PATH, options_.tuning_path);

graph_node->SetGraph(graph_ptr);
graph_node->SetOptions(options);
AddGraphNode(graph_id, graph_node);

AddLocalOmgContext(graph_id, omg_context);
if (!options_.output_datatype.empty()) {
GetLocalOmgContext().output_type = options_.output_datatype;
}
if (InitDynamicParams(new_compute_graph) != SUCCESS) {
GELOGE(GRAPH_PARAM_INVALID, "Failed to init params when online infer is dynamic.");
return GRAPH_PARAM_INVALID;
}

CompilerStages &stages = GetCompilerStages(graph_id);
stages.preparer.SetOptions(options_);
Status status = stages.optimizer.SetOptions(options_);
if (status != SUCCESS) {
GELOGE(status, "Graph optimizer set options failed.");
return status;
if (SetStagesOptions(graph_id, options_) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Set stage options failed.");
return INTERNAL_ERROR;
}
stages.builder.SetOptions(options_);

var_acc_ctrl_.AddGraph(graph_id, new_compute_graph);
return SUCCESS;
@@ -1080,7 +1059,6 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std:
if (!graph_node->IsAsync()) {
ret = LoadGraph(ge_root_model, graph_node);
} else {
GE_CHECK_NOTNULL(ge_root_model);
ret = LoadGraphAsync(ge_root_model, graph_node);
}
if (ret != SUCCESS) {
@@ -1095,7 +1073,6 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std:
if (!graph_node->IsAsync()) {
ret = LoadGraph(ge_root_model_ptr, graph_node);
} else {
GE_CHECK_NOTNULL(ge_root_model);
ret = LoadGraphAsync(ge_root_model_ptr, graph_node);
}
if (ret != SUCCESS) {


+ 2
- 0
ge/graph/manager/graph_manager.h View File

@@ -413,6 +413,8 @@ class GraphManager {

void SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id);

static Status CheckGraphAdded(const GraphId &graph_id, const Graph &graph);

std::atomic_bool thread_run_flag_;
BlockingQueue<PreRunArgs> prerun_args_q_{};
BlockingQueue<RunArgs> run_args_q_{};


+ 18
- 10
ge/hybrid/executor/worker/execution_engine.cc View File

@@ -364,20 +364,28 @@ Status ExecutionEngine::ExecuteAsync(NodeState &node_state,
GraphExecutionContext &execution_context) {
GELOGI("[%s] Node is ready for execution", task_context->GetNodeName());
RECORD_EXECUTION_EVENT(&execution_context, task_context->GetNodeName(), "Start");
auto cb = std::shared_ptr<NodeDoneCallback>(new(std::nothrow) NodeDoneCallback(&execution_context, task_context));
GE_CHECK_NOTNULL(cb);
auto callback = [task_context, cb]() {
auto ret = cb->OnNodeDone();
if (ret != SUCCESS) {
task_context->OnError(ret);
}
};

std::function<void()> callback = nullptr;
GE_CHK_STATUS_RET_NOLOG(InitCallback(task_context, execution_context, callback));
GE_CHK_STATUS_RET_NOLOG(DoExecuteAsync(node_state, *task_context, execution_context, callback));
GE_CHK_STATUS_RET_NOLOG(PropagateOutputs(*node_state.GetNodeItem(), *task_context, execution_context));
return SUCCESS;
}

Status ExecutionEngine::InitCallback(const std::shared_ptr<TaskContext> &task_context,
GraphExecutionContext &execution_context, std::function<void()> &callback) {
if (task_context->NeedCallback()) {
auto cb = std::shared_ptr<NodeDoneCallback>(new(std::nothrow) NodeDoneCallback(&execution_context, task_context));
GE_CHECK_NOTNULL(cb);
callback = [task_context, cb]() {
auto ret = cb->OnNodeDone();
if (ret != SUCCESS) {
task_context->OnError(ret);
}
};
}
return SUCCESS;
}

Status ExecutionEngine::DoExecuteAsync(NodeState &node_state,
TaskContext &task_context,
GraphExecutionContext &context,
@@ -385,7 +393,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state,
const auto &task = node_state.GetKernelTask();
if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[Get][KernelTask] of [%s] is null.", node_state.GetName().c_str());
REPORT_INNER_ERROR("E19999", "GetKernelTask of %s is null.", node_state.GetName().c_str());
REPORT_INNER_ERROR("E19999", "GetKernelTask of %s failed.", node_state.GetName().c_str());
return INTERNAL_ERROR;
}



+ 2
- 0
ge/hybrid/executor/worker/execution_engine.h View File

@@ -35,6 +35,8 @@ class ExecutionEngine {
TaskContext &task_context,
GraphExecutionContext &context,
const std::function<void()> &callback);
static Status InitCallback(const std::shared_ptr<TaskContext> &task_context,
GraphExecutionContext &execution_context, std::function<void()> &callback);
};
} // namespace hybrid
} // namespace ge


+ 2
- 1
ge/hybrid/node_executor/task_context.cc View File

@@ -561,7 +561,8 @@ const DumpProperties &TaskContext::GetDumpProperties() const {
}

bool TaskContext::NeedCallback() {
return node_item_->has_observer || IsDumpEnabled() || execution_context_->profiling_level > 0;
return node_item_->has_observer || IsDumpEnabled() || execution_context_->profiling_level > 0 ||
!execution_context_->model->IsSingleOp();
}

Status TaskContext::Synchronize() {


+ 11
- 0
ge/ir_build/atc_ir_common.cc View File

@@ -55,6 +55,7 @@ const char *const kDigitError = "is not digit";
const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]";
const char *const kSelectImplmodeError = "only support high_performance, high_precision";
const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \"";
const char *const kDynamicImageSizeError = "It can only contains digit, \",\", \" \" and \";\"";
const char *const kKeepDtypeError = "file not found";
const char *const kInputShapeRangeInvalid = "format of shape range is invalid";
const char *const kShapeRangeValueConvertError = "transfer from string to int64 error";
@@ -170,6 +171,16 @@ bool CheckDynamicImagesizeInputShapeValid(map<string, vector<int64_t>> shape_map
}

EraseEndSemicolon(dynamic_image_size);
for (char c : dynamic_image_size) {
bool is_char_valid = isdigit(c) || (c == ',') || (c == ' ') || (c == ';');
if (!is_char_valid) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10033", {"value", "reason"}, {dynamic_image_size, kDynamicImageSizeError});
GELOGE(ge::PARAM_INVALID, "[Check][DynamicImageSizeInputShape] --dynamic_image_size:%s is invalid. reason: %s",
dynamic_image_size.c_str(), kDynamicImageSizeError);
return false;
}
}
// Different parameter sets are split string by ';'
std::vector<std::string> split_set = StringUtils::Split(dynamic_image_size, ';');
// Different dimensions are split by ','


+ 9
- 5
ge/offline/main.cc View File

@@ -220,6 +220,8 @@ DEFINE_string(performance_mode, "", "Optional; express high compile performance
"normal: no need to compile, used saved .o files directly;"
"high: need to recompile, high execute performance mode.");

DEFINE_string(device_id, "0", "Optional; user device id");

class GFlagUtils {
public:
/**
@@ -579,7 +581,7 @@ class GFlagUtils {
if (fileName.size() > static_cast<int>(PATH_MAX)) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E10021", {"parameter", "size"}, {"output", std::to_string(PATH_MAX)});
GELOGE(ge::FAILED,
GELOGE(ge::FAILED,
"[Check][Path]Input parameter[--output]'s path is too long, it must be less than %d", PATH_MAX);
return false;
}
@@ -638,7 +640,7 @@ static bool CheckInputFormat() {
// only support NCHW ND
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kCaffeFormatSupport});
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
FLAGS_input_format.c_str(), kCaffeFormatSupport);
return false;
} else if ((FLAGS_framework == static_cast<int32_t>(domi::TENSORFLOW))) { // tf
@@ -648,7 +650,7 @@ static bool CheckInputFormat() {
// only support NCHW NHWC ND NCDHW NDHWC
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kTFFormatSupport});
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
FLAGS_input_format.c_str(), kTFFormatSupport);
return false;
} else if (FLAGS_framework == static_cast<int32_t>(domi::ONNX)) {
@@ -658,7 +660,7 @@ static bool CheckInputFormat() {
// only support NCHW ND
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, kONNXFormatSupport});
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
GELOGE(ge::FAILED, "[Check][InputFormat]Invalid value for --input_format[%s], %s.",
FLAGS_input_format.c_str(), kONNXFormatSupport);
return false;
}
@@ -903,7 +905,7 @@ static Status ConvertModelToJson(int fwk_type, const string &model_file, const s
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{"--framework", std::to_string(fwk_type), kModelToJsonSupport});
GELOGE(ge::FAILED, "[Convert][ModelToJson]Invalid value for --framework[%d], %s.",
GELOGE(ge::FAILED, "[Convert][ModelToJson]Invalid value for --framework[%d], %s.",
fwk_type, kModelToJsonSupport);
ret = ge::FAILED;
}
@@ -1084,6 +1086,7 @@ static void SetEnvForSingleOp(std::map<string, string> &options) {
options.emplace(ge::MDL_BANK_PATH_FLAG, FLAGS_mdl_bank_path);
options.emplace(ge::OP_BANK_PATH_FLAG, FLAGS_op_bank_path);
options.emplace(ge::PERFORMANCE_MODE, FLAGS_performance_mode);
options.emplace(ge::TUNE_DEVICE_IDS, FLAGS_device_id);
}

domi::Status GenerateSingleOp(const std::string& json_file_path) {
@@ -1176,6 +1179,7 @@ domi::Status GenerateOmModel() {
options.insert(std::pair<string, string>(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes));
options.insert(std::pair<string, string>(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf));
options.insert(std::pair<string, string>(string(ge::PRECISION_MODE), FLAGS_precision_mode));
options.insert(std::pair<string, string>(string(ge::TUNE_DEVICE_IDS), FLAGS_device_id));

options.insert(std::pair<string, string>(string(ge::RUN_FLAG), to_string(0)));
options.insert(std::pair<string, string>(string(ge::TRAIN_FLAG), to_string(0)));


+ 4
- 0
inc/external/ge/ge_api_types.h View File

@@ -166,6 +166,8 @@ const std::string COMPRESS_FLAG = "ge.compressFlag";

const std::string PRECISION_MODE = "ge.exec.precision_mode";

const std::string TUNE_DEVICE_IDS = "ge.exec.tuneDeviceIds";

// Configure single op flag for FE
// its value should be "0" or "1", default value is "0"
const std::string SINGLE_OP_FLAG = "ge.exec.single_op";
@@ -407,6 +409,7 @@ const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT,
DYNAMIC_DIMS,
INSERT_OP_FILE,
PRECISION_MODE,
TUNE_DEVICE_IDS,
EXEC_DISABLE_REUSED_MEMORY,
AUTO_TUNE_MODE,
OUTPUT_TYPE,
@@ -434,6 +437,7 @@ const std::set<std::string> global_options = {CORE_TYPE,
ENABLE_COMPRESS_WEIGHT,
COMPRESS_WEIGHT_CONF,
PRECISION_MODE,
TUNE_DEVICE_IDS,
EXEC_DISABLE_REUSED_MEMORY,
AUTO_TUNE_MODE,
ENABLE_SINGLE_STREAM,


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 4cf2633a8f2290dee165ab11f8d6b8a07cba1412
Subproject commit c1aea328cc04340188e796e639cd55a907488365

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit a41249dc9b50e4c4988eb62a662b7df29ac24ee7
Subproject commit 06e784fad01d7e9089cc7e8e0d00fce5b1901886

+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -815,6 +815,7 @@ set(PROFILING_MNG_TEST_FILES
set(HYBRID_TEST_FILES
"hybrid/ge_hybrid_unittest.cc"
"hybrid/known_node_executor_unittest.cc"
"hybrid/executor/worker/execution_engine_unittest.cc"
)

set(OTHERS_TEST_FILES


+ 30
- 0
tests/ut/ge/graph/manager/graph_manager_unittest.cc View File

@@ -373,3 +373,33 @@ TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_3) {
Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model);
EXPECT_NE(status, ge::SUCCESS);
}

TEST_F(UtestGraphManagerTest, test_add_graph_with_copy_success) {
GraphId graph_id = 1;
GraphManager graph_manager;
// create graph
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph");
Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);

std::map<std::string, std::string> options;
OmgContext context;
Status status = graph_manager.AddGraphWithCopy(graph_id, graph, options, context);
EXPECT_EQ(status, ge::SUCCESS);
}

TEST_F(UtestGraphManagerTest, test_add_graph_with_copy_fail) {
GraphId graph_id = 1;
GraphManager graph_manager;
// create graph
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test_graph");
Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);

std::map<std::string, std::string> options;
OmgContext context;
Status status = graph_manager.AddGraph(graph_id, graph, options, context);
EXPECT_EQ(status, ge::SUCCESS);
status = graph_manager.RemoveGraph(graph_id);
EXPECT_EQ(status, ge::SUCCESS);
status = graph_manager.AddGraphWithCopy(graph_id, graph, options, context);
EXPECT_NE(status, ge::SUCCESS);
}

+ 10
- 0
tests/ut/ge/graph_ir/ge_ir_build_unittest.cc View File

@@ -108,3 +108,13 @@ TEST(UtestIrCommon, update_dynamic_shape_range_failed) {
ret = UpdateDynamicInputShapeRange(graph, input_shape_range);
EXPECT_EQ(ret, ge::PARAM_INVALID);
}

TEST(UtestIrCommon, check_dynamic_image_size_fail) {
map<string, vector<int64_t>> shape_map;
shape_map["input1"] = {8, 3, -1, -1};
string input_format = "NCHW";
string dynamic_image_size = "@64,64;128,128;";

bool ret = CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size);
EXPECT_EQ(ret, false);
}

+ 119
- 0
tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc View File

@@ -0,0 +1,119 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <vector>
#include "runtime/rt.h"

#define protected public
#define private public
#include "hybrid/model/hybrid_model.h"
#include "hybrid/node_executor/node_executor.h"
#include "hybrid/executor/hybrid_execution_context.h"
#include "hybrid/executor/hybrid_model_executor.h"
#include "hybrid/executor/worker/execution_engine.h"
#undef private
#undef protected

using namespace std;
using namespace testing;
using namespace ge;
using namespace hybrid;


class UtestExecutionEngine : public testing::Test {
protected:
void SetUp() {}

void TearDown() {
}
};
namespace {
const int kIntBase = 10;
}
static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") {
auto op_desc = std::make_shared<ge::OpDesc>(name, type);
op_desc->SetStreamId(0);
op_desc->SetId(0);
op_desc->SetWorkspace({});
op_desc->SetWorkspaceBytes({});
op_desc->SetInputOffset({});
op_desc->SetOutputOffset({});

ge::AttrUtils::SetStr(op_desc, ge::TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF_AIVEC");
bool support_dynamic = true;
ge::AttrUtils::GetBool(op_desc, "support_dynamicshape", support_dynamic);
return op_desc;
}

TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) {
auto graph = make_shared<ComputeGraph>("graph");
OpDescPtr op_desc = CreateOpDesc("Add", "Add");
GeShape shape({2, 16});
GeTensorDesc tensor_desc(shape);
op_desc->AddInputDesc(tensor_desc);
op_desc->AddOutputDesc(tensor_desc);
auto node = graph->AddNode(op_desc);
std::unique_ptr<NodeItem> node_item;
NodeItem::Create(node, node_item);
ASSERT_TRUE(node_item != nullptr);
node_item->input_start = 0;
node_item->output_start = 0;

GraphExecutionContext execution_context;
execution_context.profiling_level = 1;
SubgraphContext subgraph_context(nullptr, &execution_context);

NodeState node_state(*node_item, &subgraph_context);
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context);
auto shared_task_context = std::shared_ptr<TaskContext>(task_context.release());
node_state.SetTaskContext(shared_task_context);

ExecutionEngine execution_engine;
ASSERT_TRUE(node_state.GetTaskContext() != nullptr);
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context), INTERNAL_ERROR);
}

TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) {
auto graph = make_shared<ComputeGraph>("graph");
OpDescPtr op_desc = CreateOpDesc("Add", "Add");
GeShape shape({2, 16});
GeTensorDesc tensor_desc(shape);
op_desc->AddInputDesc(tensor_desc);
op_desc->AddOutputDesc(tensor_desc);
auto node = graph->AddNode(op_desc);
std::unique_ptr<NodeItem> node_item;
NodeItem::Create(node, node_item);
ASSERT_TRUE(node_item != nullptr);
node_item->input_start = 0;
node_item->output_start = 0;

GraphExecutionContext execution_context;
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
HybridModel hybrid_model(ge_root_model);
execution_context.model = &hybrid_model;
SubgraphContext subgraph_context(nullptr, &execution_context);

NodeState node_state(*node_item, &subgraph_context);
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context);
auto shared_task_context = std::shared_ptr<TaskContext>(task_context.release());
node_state.SetTaskContext(shared_task_context);

ExecutionEngine execution_engine;
ASSERT_TRUE(node_state.GetTaskContext() != nullptr);
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context), INTERNAL_ERROR);
}

Loading…
Cancel
Save