Merge pull request !557 from 唐豪杰/ge_devpull/560/MERGE
@@ -189,6 +189,23 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map<Ascend | |||
GELOGI("AclgrphParse graph %s success.", ParserUtils::GetGraphName(graph).c_str()); | |||
return ge::SUCCESS; | |||
} | |||
void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr parent_node, ge::NodePtr node) | |||
{ | |||
std::vector<std::string> original_names; | |||
auto parend_desc = parent_node->GetOpDesc(); | |||
(void)ge::AttrUtils::GetListStr(parend_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||
if (original_names.empty()) { | |||
original_names.emplace_back(string(subgraph_name).append("/").append(node->GetName())); | |||
} else { | |||
// for fusion node also used original_names[0] | |||
(void)original_names[0].append("/").append(subgraph_name).append("/").append(node->GetName()); | |||
} | |||
if (!ge::AttrUtils::SetListStr(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names)) { | |||
GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), node->GetOpDesc()->GetName().c_str()); | |||
} | |||
GELOGD("Add dump origin name %s for node %s.", original_names[0].c_str(), node->GetName().c_str()); | |||
} | |||
} // namespace ge | |||
namespace ge { | |||
@@ -279,6 +296,7 @@ Status PostOpProcessForSubgraph(const ParseArg &arg) { | |||
if ((node->GetOpDesc() == nullptr) || (node->GetType() == "Variable") || (node->GetType() == "VariableV2")) { | |||
continue; | |||
} | |||
AddDumpOriginName(arg.subgraph_name, arg.parent_node, node); | |||
node->GetOpDesc()->SetName(node->GetOwnerComputeGraph()->GetName() + "/" + node->GetName()); | |||
} | |||
@@ -158,6 +158,8 @@ void STestTensorflowParser::RegisterCustomOp() { | |||
domi::OpRegistry::Instance()->registrationDatas.clear(); | |||
} | |||
extern void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr parent_node, ge::NodePtr node); | |||
namespace { | |||
NodeDef* AddNode(GraphDef& graph, string type, string name) { | |||
NodeDef* nodeDef = graph.add_node(); | |||
@@ -4212,4 +4214,38 @@ TEST_F(STestTensorflowParser, tensorflow_optimizer_fmk_fusion_op) { | |||
EXPECT_EQ(root_graph->GetDirectNode().size(), 3); | |||
} | |||
TEST_F(STestTensorflowParser, AddDumpOriginName_test) | |||
{ | |||
GeTensorDesc scalar_tensor(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT); | |||
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default"); | |||
ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>(); | |||
data_op->SetType(parser::WHILE); | |||
data_op->SetName("WHILE0"); | |||
data_op->AddInputDesc(ge::GeTensorDesc()); | |||
data_op->AddOutputDesc(ge::GeTensorDesc()); | |||
ge::NodePtr while0 = graph->AddNode(data_op); | |||
data_op = std::make_shared<ge::OpDesc>(); | |||
data_op->SetType(parser::LOOPCOND); | |||
data_op->SetName("COND0"); | |||
data_op->AddInputDesc(ge::GeTensorDesc()); | |||
data_op->AddOutputDesc(ge::GeTensorDesc()); | |||
ge::NodePtr cond0 = graph->AddNode(data_op); | |||
AddDumpOriginName(std::string("while"), while0, cond0); | |||
data_op = std::make_shared<ge::OpDesc>(); | |||
data_op->SetType(parser::DATA); | |||
data_op->SetName("Data1"); | |||
data_op->AddInputDesc(ge::GeTensorDesc()); | |||
data_op->AddOutputDesc(ge::GeTensorDesc()); | |||
ge::NodePtr data1 = graph->AddNode(data_op); | |||
AddDumpOriginName(std::string("cond"), cond0, data1); | |||
auto desc = data1->GetOpDesc(); | |||
std::vector<std::string> original_names; | |||
(void)ge::AttrUtils::GetListStr(desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||
EXPECT_EQ(original_names.empty(), false); | |||
EXPECT_EQ(original_names[0], "while/COND0/cond/Data1"); | |||
} | |||
} // namespace ge |
@@ -169,6 +169,8 @@ static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_de | |||
return SUCCESS; | |||
} | |||
extern void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr parent_node, ge::NodePtr node); | |||
void UtestTensorflowParser::RegisterCustomOp() { | |||
REGISTER_CUSTOM_OP("Add") | |||
.FrameworkType(domi::TENSORFLOW) | |||
@@ -4679,4 +4681,38 @@ TEST_F(UtestTensorflowParser, tensorflow_ComputeArgRange) | |||
EXPECT_EQ(ret, domi::INTERNAL_ERROR); | |||
} | |||
TEST_F(UtestTensorflowParser, AddDumpOriginName_test) | |||
{ | |||
GeTensorDesc scalar_tensor(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT); | |||
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default"); | |||
ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>(); | |||
data_op->SetType(parser::WHILE); | |||
data_op->SetName("WHILE0"); | |||
data_op->AddInputDesc(ge::GeTensorDesc()); | |||
data_op->AddOutputDesc(ge::GeTensorDesc()); | |||
ge::NodePtr while0 = graph->AddNode(data_op); | |||
data_op = std::make_shared<ge::OpDesc>(); | |||
data_op->SetType(parser::LOOPCOND); | |||
data_op->SetName("COND0"); | |||
data_op->AddInputDesc(ge::GeTensorDesc()); | |||
data_op->AddOutputDesc(ge::GeTensorDesc()); | |||
ge::NodePtr cond0 = graph->AddNode(data_op); | |||
AddDumpOriginName(std::string("while"), while0, cond0); | |||
data_op = std::make_shared<ge::OpDesc>(); | |||
data_op->SetType(parser::DATA); | |||
data_op->SetName("Data1"); | |||
data_op->AddInputDesc(ge::GeTensorDesc()); | |||
data_op->AddOutputDesc(ge::GeTensorDesc()); | |||
ge::NodePtr data1 = graph->AddNode(data_op); | |||
AddDumpOriginName(std::string("cond"), cond0, data1); | |||
auto desc = data1->GetOpDesc(); | |||
std::vector<std::string> original_names; | |||
(void)ge::AttrUtils::GetListStr(desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||
EXPECT_EQ(original_names.empty(), false); | |||
EXPECT_EQ(original_names[0], "while/COND0/cond/Data1"); | |||
} | |||
} // namespace ge |