diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 951101e..c38763e 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -206,6 +206,14 @@ void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr paren } GELOGD("Add dump origin name %s for node %s.", original_names[0].c_str(), node->GetName().c_str()); } +void AddDumpOriginNameForRootGraph(const ge::ComputeGraphPtr& graph) { + for (auto &node : graph->GetDirectNode()) { + if (!ge::AttrUtils::SetListStr(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, {node->GetName()})) { + GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), node->GetOpDesc()->GetName().c_str()); + } + GELOGD("Add dump origin name %s for node %s.", node->GetName().c_str(), node->GetName().c_str()); + } +} } // namespace ge namespace ge { @@ -273,6 +281,7 @@ Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque Status PostOpProcessForSubgraph(const ParseArg &arg) { if (arg.parent_node == nullptr) { + AddDumpOriginNameForRootGraph(arg.graph); return SUCCESS; } std::string op_type = arg.parent_node->GetType();