From 1ab9ae32dc4520be393242297ce900beeb9d2564 Mon Sep 17 00:00:00 2001 From: zhaozhixuan Date: Tue, 15 Jun 2021 10:00:19 +0800 Subject: [PATCH] Add ut. --- ge/single_op/single_op_model.cc | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc index 3c0f7972..4a7638b1 100755 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -48,7 +48,7 @@ const uint32_t kInputIndexOfData = 0; const uint32_t kOutputIndexOfData = 0; constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; -Status CheckHostMem(const std::vector &dependencies, const NodePtr &node, bool &flag) { +Status CheckHostMem(const std::vector &dependencies, const NodePtr &node, bool &is_host_mem) { for (const auto &input_name : dependencies) { auto op_desc = node->GetOpDesc(); int input_index = op_desc->GetInputIndexByName(input_name); @@ -75,14 +75,14 @@ Status CheckHostMem(const std::vector &dependencies, const NodePtr &node continue; } } - flag = false; + is_host_mem = false; return SUCCESS; } - flag = true; + is_host_mem = true; return SUCCESS; } -Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { +Status CheckInferDepend(GeModelPtr &ge_model, bool &is_infer_depend, bool &is_host_mem) { auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); GE_CHECK_NOTNULL(comp_graph); for (const auto &node : comp_graph->GetAllNodes()) { @@ -93,16 +93,18 @@ Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { bool support_dynamic_shape = false; (void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic_shape); if (!depends.empty() && support_dynamic_shape) { - CheckHostMem(depends, node, flag); - return SUCCESS; + is_infer_depend = true; + return CheckHostMem(depends, node, is_host_mem); } } return SUCCESS; } Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { - bool infer_depend_flag = false; - GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); + bool is_infer_depend = false; + bool is_host_mem = false; + GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed."); + bool need_d2h_cpy = is_infer_depend && !is_host_mem; auto tasks = ge_model->GetModelTaskDefPtr()->task(); int32_t kernel_task_num = 0; for (int i = 0; i < tasks.size(); ++i) { @@ -112,7 +114,7 @@ Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { tasks[i].kernel_with_handle().context(); auto kernel_type = static_cast(context.kernel_type()); if (kernel_type == ccKernelType::TE) { - if (infer_depend_flag) { + if (need_d2h_cpy) { flag = true; return SUCCESS; } @@ -553,7 +555,8 @@ Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { auto ge_model = model_helper_.GetGeModel(); GE_CHECK_NOTNULL(ge_model); bool infer_depend_flag = false; - GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); + bool is_host_mem = false; + GE_CHK_STATUS_RET(CheckInferDepend(ge_model, infer_depend_flag, is_host_mem)), "[Check][InferDepend] failed."); if (infer_depend_flag) { // construct single_op, do single op with HybridModelExecutor GELOGD("Init hybrid model params of single op, and will do execute with hybrid model executor.");