diff --git a/ge/hybrid/executor/hybrid_model_executor.cc b/ge/hybrid/executor/hybrid_model_executor.cc index 2abd9cd6..9bf70d26 100755 --- a/ge/hybrid/executor/hybrid_model_executor.cc +++ b/ge/hybrid/executor/hybrid_model_executor.cc @@ -70,7 +70,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { context_.profiler->Dump(std::cout); context_.profiler->Reset(); } - root_graph_executor_->ResetContext(); + root_graph_executor_->ReleaseContext(); context_.iteration += 1; if (ret == END_OF_SEQUENCE) { diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index 85f9e4c3..e8ccd416 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -177,6 +177,10 @@ struct NodeState { void SetTaskContext(std::shared_ptr &task_context); std::shared_ptr GetTaskContext(); + void SetSkipInferShape(bool skip_infershape) { skip_infershape_ = skip_infershape; } + + bool GetSkipInferShape() const { return skip_infershape_; } + private: bool IsScheduleReady() const; void SetDataSchedule(const NodeState &node_state, const std::function &ready); @@ -204,6 +208,7 @@ struct NodeState { int merge_index_ = -1; // Use for Execute (Reset after Executed). int switch_index_ = -1; // Use for Schedule (Reset after Prepared). int group_ = -1; + bool skip_infershape_ = false; }; } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/executor/subgraph_executor.cc b/ge/hybrid/executor/subgraph_executor.cc index c26eac9b..6979d05f 100644 --- a/ge/hybrid/executor/subgraph_executor.cc +++ b/ge/hybrid/executor/subgraph_executor.cc @@ -110,6 +110,7 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vectorSetShape(tensor_desc->GetShape()); output_desc->SetOriginShape(tensor_desc->GetOriginShape()); output_desc->SetDataType(tensor_desc->GetDataType()); + node_state->SetSkipInferShape(true); } } diff --git a/ge/hybrid/executor/subgraph_executor.h b/ge/hybrid/executor/subgraph_executor.h index 35f6e67e..76732c37 100644 --- a/ge/hybrid/executor/subgraph_executor.h +++ b/ge/hybrid/executor/subgraph_executor.h @@ -41,7 +41,7 @@ class SubgraphExecutor { Status PartialExecuteAsync(int task_group); - void ResetContext() { subgraph_context_.reset(nullptr); } + void ReleaseContext() { subgraph_context_.reset(nullptr); } /** * Execute subgraph async, output tensor address(not data) and output tensor descriptions are diff --git a/ge/hybrid/executor/worker/shape_inference_engine.cc b/ge/hybrid/executor/worker/shape_inference_engine.cc index 18fed710..96959b80 100755 --- a/ge/hybrid/executor/worker/shape_inference_engine.cc +++ b/ge/hybrid/executor/worker/shape_inference_engine.cc @@ -70,7 +70,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { // Do shape inference // Skipping infer shape of input node. GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); - if (node_state.GetType() != DATA_TYPE && node_state.GetType() != AIPP_DATA_TYPE) { + if (node_state.GetSkipInferShape()) { RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), "[Invoke][InferShapeAndType] for %s failed.", node_item.NodeName().c_str()); diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc index 182d1466..90a6362c 100755 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -49,8 +49,8 @@ const uint32_t kOutputIndexOfData = 0; constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; Status CheckHostMem(const std::vector &dependencies, const NodePtr &node, bool &is_host_mem) { + auto op_desc = node->GetOpDesc(); for (const auto &input_name : dependencies) { - auto op_desc = node->GetOpDesc(); int input_index = op_desc->GetInputIndexByName(input_name); if (input_index < 0) { GELOGE(INTERNAL_ERROR, "[Get][InputIndex]failed, node:[%s] inputname: %s.", @@ -60,11 +60,7 @@ Status CheckHostMem(const std::vector &dependencies, const NodePtr &node return INTERNAL_ERROR; } - const auto &in_anchor = node->GetInDataAnchor(input_index); - GE_CHECK_NOTNULL(in_anchor); - const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_anchor); - const auto &src_node = peer_out_anchor->GetOwnerNode(); + const auto &src_node = NodeUtils::GetInDataNodeByIndex(*node, input_index); GE_CHECK_NOTNULL(src_node); auto src_op_desc = src_node->GetOpDesc(); GE_CHECK_NOTNULL(src_op_desc); diff --git a/tests/ut/ge/single_op/single_op_model_unittest.cc b/tests/ut/ge/single_op/single_op_model_unittest.cc index 63a3eafe..1975f9f4 100644 --- a/tests/ut/ge/single_op/single_op_model_unittest.cc +++ b/tests/ut/ge/single_op/single_op_model_unittest.cc @@ -213,7 +213,7 @@ TEST_F(UtestSingleOpModel, test_build_dynamic_op) { // make graph ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); + auto data = builder.AddNode("Data", "Data", 1, 1); auto transdata = builder.AddNode("Transdata", "Transdata", 1, 1); auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); builder.AddDataEdge(data, 0, transdata, 0); @@ -228,11 +228,6 @@ TEST_F(UtestSingleOpModel, test_build_dynamic_op) { op_desc->SetOpInferDepends(depend_names); (void)AttrUtils::SetBool(op_desc, kAttrSupportDynamicShape, true); - auto tensor = std::make_shared(); - auto data_desc = data->GetOpDesc(); - auto tensor_desc = data_desc->MutableInputDesc(0); - AttrUtils::SetTensor(tensor_desc, "_value", tensor); - // set task_def auto model_task_def = make_shared(); domi::TaskDef *task_def = model_task_def->add_task(); @@ -249,6 +244,12 @@ TEST_F(UtestSingleOpModel, test_build_dynamic_op) { op_desc->impl_->input_name_idx_["Data"] = 0; model.BuildDynamicOp(res, dynamic_single_op); + + auto tensor = std::make_shared(); + auto data_desc = data->GetOpDesc(); + auto tensor_desc = data_desc->MutableInputDesc(0); + AttrUtils::SetTensor(tensor_desc, "_value", tensor); + model.BuildDynamicOp(res, dynamic_single_op); } TEST_F(UtestSingleOpModel, test_host_mem) {