Browse Source

add force infershape for some op

tags/v1.2.0
wxl 4 years ago
parent
commit
5ae267433b
2 changed files with 17 additions and 1 deletions
  1. +3
    -1
      ge/hybrid/model/hybrid_model_builder.cc
  2. +14
    -0
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 3
- 1
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -272,7 +272,9 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt
GE_CHECK_NOTNULL(op_desc);
// not care result, if no this attr, stand for the op does not need force infershape
(void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape);
GELOGD("node [%s] is need do infershape , flag is %d", node_item.is_need_force_infershape);
GELOGD("node [%s] is need do infershape , flag is %d",
op_desc->GetName().c_str(),
node_item.is_need_force_infershape);
return SUCCESS;
}



+ 14
- 0
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -152,6 +152,20 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) {
ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR);
}

TEST_F(UtestGeHybrid, parse_force_infershape_nodes) {
const char *const kForceInfershape = "_force_infershape_when_running";
auto graph = make_shared<ComputeGraph>("graph");
OpDescPtr op_desc = CreateOpDesc("Conv2D", "Conv2D");
ge::AttrUtils::SetBool(op_desc, kForceInfershape, true);
auto node = graph->AddNode(op_desc);
std::unique_ptr<NodeItem> new_node;
NodeItem::Create(node, new_node);
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS);
}

TEST_F(UtestGeHybrid, index_taskdefs_success) {
// build aicore task
domi::ModelTaskDef model_task_def;


Loading…
Cancel
Save