Browse Source

Add ut.

tags/v1.3.0
zhaozhixuan 4 years ago
parent
commit
1ab9ae32dc
1 changed files with 13 additions and 10 deletions
  1. +13
    -10
      ge/single_op/single_op_model.cc

+ 13
- 10
ge/single_op/single_op_model.cc View File

@@ -48,7 +48,7 @@ const uint32_t kInputIndexOfData = 0;
const uint32_t kOutputIndexOfData = 0;
constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape";

Status CheckHostMem(const std::vector<string> &dependencies, const NodePtr &node, bool &flag) {
Status CheckHostMem(const std::vector<string> &dependencies, const NodePtr &node, bool &is_host_mem) {
for (const auto &input_name : dependencies) {
auto op_desc = node->GetOpDesc();
int input_index = op_desc->GetInputIndexByName(input_name);
@@ -75,14 +75,14 @@ Status CheckHostMem(const std::vector<string> &dependencies, const NodePtr &node
continue;
}
}
flag = false;
is_host_mem = false;
return SUCCESS;
}
flag = true;
is_host_mem = true;
return SUCCESS;
}

Status IfInferDepend(GeModelPtr &ge_model, bool &flag) {
Status CheckInferDepend(GeModelPtr &ge_model, bool &is_infer_depend, bool &is_host_mem) {
auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph());
GE_CHECK_NOTNULL(comp_graph);
for (const auto &node : comp_graph->GetAllNodes()) {
@@ -93,16 +93,18 @@ Status IfInferDepend(GeModelPtr &ge_model, bool &flag) {
bool support_dynamic_shape = false;
(void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic_shape);
if (!depends.empty() && support_dynamic_shape) {
CheckHostMem(depends, node, flag);
return SUCCESS;
is_infer_depend = true;
return CheckHostMem(depends, node, is_host_mem);
}
}
return SUCCESS;
}

Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) {
bool infer_depend_flag = false;
GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed.");
bool is_infer_depend = false;
bool is_host_mem = false;
GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed.");
bool need_d2h_cpy = is_infer_depend && !is_host_mem;
auto tasks = ge_model->GetModelTaskDefPtr()->task();
int32_t kernel_task_num = 0;
for (int i = 0; i < tasks.size(); ++i) {
@@ -112,7 +114,7 @@ Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) {
tasks[i].kernel_with_handle().context();
auto kernel_type = static_cast<ccKernelType>(context.kernel_type());
if (kernel_type == ccKernelType::TE) {
if (infer_depend_flag) {
if (need_d2h_cpy) {
flag = true;
return SUCCESS;
}
@@ -553,7 +555,8 @@ Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) {
auto ge_model = model_helper_.GetGeModel();
GE_CHECK_NOTNULL(ge_model);
bool infer_depend_flag = false;
GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed.");
bool is_host_mem = false;
GE_CHK_STATUS_RET(CheckInferDepend(ge_model, infer_depend_flag, is_host_mem)), "[Check][InferDepend] failed.");
if (infer_depend_flag) {
// construct single_op, do single op with HybridModelExecutor
GELOGD("Init hybrid model params of single op, and will do execute with hybrid model executor.");


Loading…
Cancel
Save