Browse Source

!557 sync ge_dev to master 20220531

Merge pull request !557 from 唐豪杰/ge_dev
pull/560/MERGE
计晨 Gitee 3 years ago
parent
commit
e90a71d665
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 90 additions and 0 deletions
  1. +18
    -0
      parser/tensorflow/tensorflow_parser.cc
  2. +36
    -0
      tests/st/testcase/test_tensorflow_parser.cc
  3. +36
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 18
- 0
parser/tensorflow/tensorflow_parser.cc View File

@@ -189,6 +189,23 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map<Ascend
GELOGI("AclgrphParse graph %s success.", ParserUtils::GetGraphName(graph).c_str());
return ge::SUCCESS;
}
void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr parent_node, ge::NodePtr node)
{
std::vector<std::string> original_names;
auto parend_desc = parent_node->GetOpDesc();
(void)ge::AttrUtils::GetListStr(parend_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
if (original_names.empty()) {
original_names.emplace_back(string(subgraph_name).append("/").append(node->GetName()));
} else {
// for fusion node also used original_names[0]
(void)original_names[0].append("/").append(subgraph_name).append("/").append(node->GetName());
}

if (!ge::AttrUtils::SetListStr(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names)) {
GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), node->GetOpDesc()->GetName().c_str());
}
GELOGD("Add dump origin name %s for node %s.", original_names[0].c_str(), node->GetName().c_str());
}
} // namespace ge

namespace ge {
@@ -279,6 +296,7 @@ Status PostOpProcessForSubgraph(const ParseArg &arg) {
if ((node->GetOpDesc() == nullptr) || (node->GetType() == "Variable") || (node->GetType() == "VariableV2")) {
continue;
}
AddDumpOriginName(arg.subgraph_name, arg.parent_node, node);
node->GetOpDesc()->SetName(node->GetOwnerComputeGraph()->GetName() + "/" + node->GetName());
}



+ 36
- 0
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -158,6 +158,8 @@ void STestTensorflowParser::RegisterCustomOp() {
domi::OpRegistry::Instance()->registrationDatas.clear();
}

extern void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr parent_node, ge::NodePtr node);

namespace {
NodeDef* AddNode(GraphDef& graph, string type, string name) {
NodeDef* nodeDef = graph.add_node();
@@ -4212,4 +4214,38 @@ TEST_F(STestTensorflowParser, tensorflow_optimizer_fmk_fusion_op) {
EXPECT_EQ(root_graph->GetDirectNode().size(), 3);
}

TEST_F(STestTensorflowParser, AddDumpOriginName_test)
{
GeTensorDesc scalar_tensor(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT);
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(parser::WHILE);
data_op->SetName("WHILE0");
data_op->AddInputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr while0 = graph->AddNode(data_op);

data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(parser::LOOPCOND);
data_op->SetName("COND0");
data_op->AddInputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr cond0 = graph->AddNode(data_op);
AddDumpOriginName(std::string("while"), while0, cond0);

data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(parser::DATA);
data_op->SetName("Data1");
data_op->AddInputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr data1 = graph->AddNode(data_op);
AddDumpOriginName(std::string("cond"), cond0, data1);

auto desc = data1->GetOpDesc();
std::vector<std::string> original_names;
(void)ge::AttrUtils::GetListStr(desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
EXPECT_EQ(original_names.empty(), false);
EXPECT_EQ(original_names[0], "while/COND0/cond/Data1");
}

} // namespace ge

+ 36
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -169,6 +169,8 @@ static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_de
return SUCCESS;
}

extern void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr parent_node, ge::NodePtr node);

void UtestTensorflowParser::RegisterCustomOp() {
REGISTER_CUSTOM_OP("Add")
.FrameworkType(domi::TENSORFLOW)
@@ -4679,4 +4681,38 @@ TEST_F(UtestTensorflowParser, tensorflow_ComputeArgRange)
EXPECT_EQ(ret, domi::INTERNAL_ERROR);
}

TEST_F(UtestTensorflowParser, AddDumpOriginName_test)
{
GeTensorDesc scalar_tensor(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT);
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(parser::WHILE);
data_op->SetName("WHILE0");
data_op->AddInputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr while0 = graph->AddNode(data_op);

data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(parser::LOOPCOND);
data_op->SetName("COND0");
data_op->AddInputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr cond0 = graph->AddNode(data_op);
AddDumpOriginName(std::string("while"), while0, cond0);

data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(parser::DATA);
data_op->SetName("Data1");
data_op->AddInputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr data1 = graph->AddNode(data_op);
AddDumpOriginName(std::string("cond"), cond0, data1);

auto desc = data1->GetOpDesc();
std::vector<std::string> original_names;
(void)ge::AttrUtils::GetListStr(desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
EXPECT_EQ(original_names.empty(), false);
EXPECT_EQ(original_names[0], "while/COND0/cond/Data1");
}

} // namespace ge

Loading…
Cancel
Save