diff --git a/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc b/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc index b2f3d095..90d95217 100755 --- a/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc +++ b/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc @@ -38,6 +38,7 @@ REGISTER_OP_CREATOR(ExpandDims, GeDeletedOp); REGISTER_OP_CREATOR(Reshape, GeDeletedOp); REGISTER_OP_CREATOR(ReFormat, GeDeletedOp); REGISTER_OP_CREATOR(Squeeze, GeDeletedOp); +REGISTER_OP_CREATOR(Unsqueeze, GeDeletedOp); REGISTER_OP_CREATOR(Size, GeDeletedOp); REGISTER_OP_CREATOR(Shape, GeDeletedOp); REGISTER_OP_CREATOR(ShapeN, GeDeletedOp); diff --git a/ge/hybrid/executor/worker/shape_inference_engine.cc b/ge/hybrid/executor/worker/shape_inference_engine.cc index bb6281e1..27919589 100755 --- a/ge/hybrid/executor/worker/shape_inference_engine.cc +++ b/ge/hybrid/executor/worker/shape_inference_engine.cc @@ -41,7 +41,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { // Wait for "const input nodes" if node's shape inference function requires any. // Even if output shape is static, there are cases that the const-input will be used in OpTiling and Execution GE_CHK_STATUS_RET_NOLOG(AwaitDependentNodes(node_state)); - if (node_item.is_output_shape_static) { + if (node_item.is_output_shape_static && !node_item.is_need_force_infershape) { return SUCCESS; } diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index ac57b2ea..a349210d 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -50,6 +50,7 @@ const char *const kProfilingBpNode = "ProfilingBpNode"; const char *const kProfilingEndNode = "ProfilingEndNode"; const char *const kProfilingArNode = "ProfilingAllReduceNode"; const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; +const char *const kForceInfershape = "_force_infershape_when_running"; Status SetOutputNameAttr(ComputeGraph &graph) { vector output_names; @@ -171,6 +172,9 @@ Status HybridModelBuilder::ValidateParams() { Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { auto op_desc = node->GetOpDesc(); + GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item), + "[%s] Failed to parse force_infershape node.", + node_item.NodeName().c_str()); vector dependencies = node->GetOpDesc()->GetOpInferDepends(); GE_CHK_STATUS_RET(ParseDependentInputNodes(node_item, dependencies), "[%s] Failed to parse node dependencies.", @@ -263,6 +267,17 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n return SUCCESS; } +Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + // not care result, if no this attr, stand for the op does not need force infershape + (void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); + GELOGD("node [%s] is need do infershape , flag is %d", + op_desc->GetName().c_str(), + node_item.is_need_force_infershape); + return SUCCESS; +} + Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies) { std::set dependent_input_nodes; auto &ge_node = node_item.node; diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 71663a6e..313d5ca6 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -62,6 +62,7 @@ class HybridModelBuilder { Status IdentifySameInputs(NodeItem &node_item); Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); + Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); Status ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies); Status ParseDependentForFusedSubgraph(NodeItem &node_item); Status IndexTaskDefs(); diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index 300744d1..631dbd9e 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -83,6 +83,7 @@ struct NodeItem { bool has_observer = false; bool has_optional_inputs = false; bool is_output_shape_static = true; + bool is_need_force_infershape = false; UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE; std::string node_name; std::string node_type; diff --git a/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc b/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc index 3d2e3084..9d92420e 100755 --- a/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc +++ b/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc @@ -33,6 +33,7 @@ const std::map> {RESHAPE, {}}, {EXPANDDIMS, {}}, {SQUEEZE, {}}, + {UNSQUEEZE, {}}, {BROADCASTGRADIENTARGS, {}} }; diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 0b6ca271..286186de 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -152,6 +152,20 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) { ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); } +TEST_F(UtestGeHybrid, parse_force_infershape_nodes) { + const char *const kForceInfershape = "_force_infershape_when_running"; + auto graph = make_shared("graph"); + OpDescPtr op_desc = CreateOpDesc("Conv2D", "Conv2D"); + ge::AttrUtils::SetBool(op_desc, kForceInfershape, true); + auto node = graph->AddNode(op_desc); + std::unique_ptr new_node; + NodeItem::Create(node, new_node); + GeRootModelPtr ge_root_model = make_shared(graph); + HybridModel hybrid_model(ge_root_model); + HybridModelBuilder hybrid_model_builder(hybrid_model); + ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS); +} + TEST_F(UtestGeHybrid, index_taskdefs_success) { // build aicore task domi::ModelTaskDef model_task_def;