diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index d7f33b4b..8d59d9f9 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -415,16 +415,16 @@ Status UpdateVarFormats(const NodePtr &var, const GeTensorDesc &tensor_desc) { Status RecoverTransRoadForVar(const NodePtr &var, const VarTransRoad &road) { GE_CHECK_NOTNULL(var); - int index = 0; + static std::atomic_int index(0); NodePtr last_node = var; for (auto iter = road.rbegin(); iter != road.rend(); ++iter) { auto trans_name = var->GetName() + "_trans_" + std::to_string(index++); auto ret = RecoverOneTransNodeForVar(trans_name, *iter, last_node, last_node); if (ret != SUCCESS) { - REPORT_CALL_ERROR("E19999", "Failed to recover trans node for variable %s, index %d, type %s", - var->GetName().c_str(), index, iter->node_type.c_str()); - GELOGE(INTERNAL_ERROR, "[Recover][TransNode] for variable %s, index %d, type %s", var->GetName().c_str(), - index, iter->node_type.c_str()); + REPORT_CALL_ERROR("E19999", "Failed to recover trans node for variable %s, index %s, type %s", + var->GetName().c_str(), std::to_string(index).c_str(), iter->node_type.c_str()); + GELOGE(INTERNAL_ERROR, "[Recover][TransNode] for variable %s, index %s, type %s", var->GetName().c_str(), + std::to_string(index).c_str(), iter->node_type.c_str()); return INTERNAL_ERROR; } // set stream_label @@ -460,17 +460,17 @@ Status RecoverTransRoadForVar(const NodePtr &var, const VarTransRoad &road) { Status RecoverTransRoadForVarRef(const std::set &nodes, const VarTransRoad &road) { for (auto &var : nodes) { GE_CHECK_NOTNULL(var); - int index = 0; + static std::atomic_int index(0); NodePtr last_node = var; GELOGI("Recover trans nodes for variable ref %s", var->GetName().c_str()); for (auto iter = road.rbegin(); iter != road.rend(); ++iter) { auto trans_name = var->GetName() + "_trans_" + std::to_string(index++); auto ret = RecoverOneTransNodeForVarRef(trans_name, *iter, last_node, last_node); if (ret != SUCCESS) { - REPORT_CALL_ERROR("E19999", "Failed to recover trans node for variable %s, index %d, type %s", - var->GetName().c_str(), index, iter->node_type.c_str()); - GELOGE(INTERNAL_ERROR, "[Recover][TransNode] for variable %s failed, index %d, type %s", - var->GetName().c_str(), index, iter->node_type.c_str()); + REPORT_CALL_ERROR("E19999", "Failed to recover trans node for variable %s, index %s, type %s", + var->GetName().c_str(), std::to_string(index).c_str(), iter->node_type.c_str()); + GELOGE(INTERNAL_ERROR, "[Recover][TransNode] for variable %s failed, index %s, type %s", + var->GetName().c_str(), std::to_string(index).c_str(), iter->node_type.c_str()); return INTERNAL_ERROR; } // set stream_label diff --git a/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc b/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc index e53a9f96..b1c07d81 100644 --- a/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc +++ b/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc @@ -23,6 +23,7 @@ #include "graph/passes/graph_builder_utils.h" #include "graph/utils/attr_utils.h" #include "graph/debug/ge_attr_define.h" +#include "graph/manager/graph_var_manager.h" #define private public #define protected public @@ -285,4 +286,26 @@ TEST_F(UtestGraphPreproces, test_prepare_dyn_shape) { GraphPrepare graph_prepare; EXPECT_EQ(graph_prepare.PrepareDynShape(graph_node, user_input, compute_graph, 0), SUCCESS); } + +TEST_F(UtestGraphPreproces, test_updar_variable_formats) { + auto builder = ut::GraphBuilder("g1"); + auto var = builder.AddNode("var", VARIABLE, 1, 1); + auto g1 = builder.GetGraph(); + g1->SetSessionID(0); + TransNodeInfo trans_node_info; + VarTransRoad fusion_road; + fusion_road.emplace_back(trans_node_info); + VarManager::Instance(g1->GetSessionID())->SetTransRoad(var->GetName(), fusion_road); + GraphPrepare graph_prepare; + EXPECT_EQ(graph_prepare.UpdateVariableFormats(g1), INTERNAL_ERROR); + + auto builder1 = ut::GraphBuilder("g2"); + auto var1 = builder1.AddNode("var1", VARIABLE, 1, 1); + auto g2 = builder1.GetGraph(); + g2->SetSessionID(0); + VarTransRoad fusion_road1; + VarManager::Instance(g2->GetSessionID())->SetTransRoad(var1->GetName(), fusion_road1); + AttrUtils::SetStr(var1->GetOpDesc(), REF_VAR_SRC_VAR_NAME, "var1"); + EXPECT_EQ(graph_prepare.UpdateVariableFormats(g2), SUCCESS); +} } \ No newline at end of file