From: @wan_xuelei Reviewed-by: @xchu42,@wqtshg Signed-off-by: @wqtshgtags/v1.2.0
@@ -38,6 +38,7 @@ REGISTER_OP_CREATOR(ExpandDims, GeDeletedOp); | |||||
REGISTER_OP_CREATOR(Reshape, GeDeletedOp); | REGISTER_OP_CREATOR(Reshape, GeDeletedOp); | ||||
REGISTER_OP_CREATOR(ReFormat, GeDeletedOp); | REGISTER_OP_CREATOR(ReFormat, GeDeletedOp); | ||||
REGISTER_OP_CREATOR(Squeeze, GeDeletedOp); | REGISTER_OP_CREATOR(Squeeze, GeDeletedOp); | ||||
REGISTER_OP_CREATOR(Unsqueeze, GeDeletedOp); | |||||
REGISTER_OP_CREATOR(Size, GeDeletedOp); | REGISTER_OP_CREATOR(Size, GeDeletedOp); | ||||
REGISTER_OP_CREATOR(Shape, GeDeletedOp); | REGISTER_OP_CREATOR(Shape, GeDeletedOp); | ||||
REGISTER_OP_CREATOR(ShapeN, GeDeletedOp); | REGISTER_OP_CREATOR(ShapeN, GeDeletedOp); | ||||
@@ -41,7 +41,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { | |||||
// Wait for "const input nodes" if node's shape inference function requires any. | // Wait for "const input nodes" if node's shape inference function requires any. | ||||
// Even if output shape is static, there are cases that the const-input will be used in OpTiling and Execution | // Even if output shape is static, there are cases that the const-input will be used in OpTiling and Execution | ||||
GE_CHK_STATUS_RET_NOLOG(AwaitDependentNodes(node_state)); | GE_CHK_STATUS_RET_NOLOG(AwaitDependentNodes(node_state)); | ||||
if (node_item.is_output_shape_static) { | |||||
if (node_item.is_output_shape_static && !node_item.is_need_force_infershape) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -50,6 +50,7 @@ const char *const kProfilingBpNode = "ProfilingBpNode"; | |||||
const char *const kProfilingEndNode = "ProfilingEndNode"; | const char *const kProfilingEndNode = "ProfilingEndNode"; | ||||
const char *const kProfilingArNode = "ProfilingAllReduceNode"; | const char *const kProfilingArNode = "ProfilingAllReduceNode"; | ||||
const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; | const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; | ||||
const char *const kForceInfershape = "_force_infershape_when_running"; | |||||
Status SetOutputNameAttr(ComputeGraph &graph) { | Status SetOutputNameAttr(ComputeGraph &graph) { | ||||
vector<string> output_names; | vector<string> output_names; | ||||
@@ -171,6 +172,9 @@ Status HybridModelBuilder::ValidateParams() { | |||||
Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { | Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item), | |||||
"[%s] Failed to parse force_infershape node.", | |||||
node_item.NodeName().c_str()); | |||||
vector<string> dependencies = node->GetOpDesc()->GetOpInferDepends(); | vector<string> dependencies = node->GetOpDesc()->GetOpInferDepends(); | ||||
GE_CHK_STATUS_RET(ParseDependentInputNodes(node_item, dependencies), | GE_CHK_STATUS_RET(ParseDependentInputNodes(node_item, dependencies), | ||||
"[%s] Failed to parse node dependencies.", | "[%s] Failed to parse node dependencies.", | ||||
@@ -263,6 +267,17 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
// not care result, if no this attr, stand for the op does not need force infershape | |||||
(void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); | |||||
GELOGD("node [%s] is need do infershape , flag is %d", | |||||
op_desc->GetName().c_str(), | |||||
node_item.is_need_force_infershape); | |||||
return SUCCESS; | |||||
} | |||||
Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) { | Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) { | ||||
std::set<NodePtr> dependent_input_nodes; | std::set<NodePtr> dependent_input_nodes; | ||||
auto &ge_node = node_item.node; | auto &ge_node = node_item.node; | ||||
@@ -62,6 +62,7 @@ class HybridModelBuilder { | |||||
Status IdentifySameInputs(NodeItem &node_item); | Status IdentifySameInputs(NodeItem &node_item); | ||||
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | ||||
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | ||||
Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); | |||||
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | ||||
Status ParseDependentForFusedSubgraph(NodeItem &node_item); | Status ParseDependentForFusedSubgraph(NodeItem &node_item); | ||||
Status IndexTaskDefs(); | Status IndexTaskDefs(); | ||||
@@ -83,6 +83,7 @@ struct NodeItem { | |||||
bool has_observer = false; | bool has_observer = false; | ||||
bool has_optional_inputs = false; | bool has_optional_inputs = false; | ||||
bool is_output_shape_static = true; | bool is_output_shape_static = true; | ||||
bool is_need_force_infershape = false; | |||||
UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE; | UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE; | ||||
std::string node_name; | std::string node_name; | ||||
std::string node_type; | std::string node_type; | ||||
@@ -33,6 +33,7 @@ const std::map<std::string, std::vector<uint32_t>> | |||||
{RESHAPE, {}}, | {RESHAPE, {}}, | ||||
{EXPANDDIMS, {}}, | {EXPANDDIMS, {}}, | ||||
{SQUEEZE, {}}, | {SQUEEZE, {}}, | ||||
{UNSQUEEZE, {}}, | |||||
{BROADCASTGRADIENTARGS, {}} | {BROADCASTGRADIENTARGS, {}} | ||||
}; | }; | ||||
@@ -152,6 +152,20 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) { | |||||
ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); | ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); | ||||
} | } | ||||
TEST_F(UtestGeHybrid, parse_force_infershape_nodes) { | |||||
const char *const kForceInfershape = "_force_infershape_when_running"; | |||||
auto graph = make_shared<ComputeGraph>("graph"); | |||||
OpDescPtr op_desc = CreateOpDesc("Conv2D", "Conv2D"); | |||||
ge::AttrUtils::SetBool(op_desc, kForceInfershape, true); | |||||
auto node = graph->AddNode(op_desc); | |||||
std::unique_ptr<NodeItem> new_node; | |||||
NodeItem::Create(node, new_node); | |||||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||||
HybridModel hybrid_model(ge_root_model); | |||||
HybridModelBuilder hybrid_model_builder(hybrid_model); | |||||
ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS); | |||||
} | |||||
TEST_F(UtestGeHybrid, index_taskdefs_success) { | TEST_F(UtestGeHybrid, index_taskdefs_success) { | ||||
// build aicore task | // build aicore task | ||||
domi::ModelTaskDef model_task_def; | domi::ModelTaskDef model_task_def; | ||||