Browse Source

bugfix for RecoverTransRoadForVar

tags/v1.5.1
y00500818 wangzhengjun 3 years ago
parent
commit
ecd48b072d
2 changed files with 33 additions and 10 deletions
  1. +10
    -10
      ge/graph/preprocess/graph_preprocess.cc
  2. +23
    -0
      tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc

+ 10
- 10
ge/graph/preprocess/graph_preprocess.cc View File

@@ -415,16 +415,16 @@ Status UpdateVarFormats(const NodePtr &var, const GeTensorDesc &tensor_desc) {

Status RecoverTransRoadForVar(const NodePtr &var, const VarTransRoad &road) {
GE_CHECK_NOTNULL(var);
int index = 0;
static std::atomic_int index(0);
NodePtr last_node = var;
for (auto iter = road.rbegin(); iter != road.rend(); ++iter) {
auto trans_name = var->GetName() + "_trans_" + std::to_string(index++);
auto ret = RecoverOneTransNodeForVar(trans_name, *iter, last_node, last_node);
if (ret != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Failed to recover trans node for variable %s, index %d, type %s",
var->GetName().c_str(), index, iter->node_type.c_str());
GELOGE(INTERNAL_ERROR, "[Recover][TransNode] for variable %s, index %d, type %s", var->GetName().c_str(),
index, iter->node_type.c_str());
REPORT_CALL_ERROR("E19999", "Failed to recover trans node for variable %s, index %s, type %s",
var->GetName().c_str(), std::to_string(index).c_str(), iter->node_type.c_str());
GELOGE(INTERNAL_ERROR, "[Recover][TransNode] for variable %s, index %s, type %s", var->GetName().c_str(),
std::to_string(index).c_str(), iter->node_type.c_str());
return INTERNAL_ERROR;
}
// set stream_label
@@ -460,17 +460,17 @@ Status RecoverTransRoadForVar(const NodePtr &var, const VarTransRoad &road) {
Status RecoverTransRoadForVarRef(const std::set<NodePtr> &nodes, const VarTransRoad &road) {
for (auto &var : nodes) {
GE_CHECK_NOTNULL(var);
int index = 0;
static std::atomic_int index(0);
NodePtr last_node = var;
GELOGI("Recover trans nodes for variable ref %s", var->GetName().c_str());
for (auto iter = road.rbegin(); iter != road.rend(); ++iter) {
auto trans_name = var->GetName() + "_trans_" + std::to_string(index++);
auto ret = RecoverOneTransNodeForVarRef(trans_name, *iter, last_node, last_node);
if (ret != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Failed to recover trans node for variable %s, index %d, type %s",
var->GetName().c_str(), index, iter->node_type.c_str());
GELOGE(INTERNAL_ERROR, "[Recover][TransNode] for variable %s failed, index %d, type %s",
var->GetName().c_str(), index, iter->node_type.c_str());
REPORT_CALL_ERROR("E19999", "Failed to recover trans node for variable %s, index %s, type %s",
var->GetName().c_str(), std::to_string(index).c_str(), iter->node_type.c_str());
GELOGE(INTERNAL_ERROR, "[Recover][TransNode] for variable %s failed, index %s, type %s",
var->GetName().c_str(), std::to_string(index).c_str(), iter->node_type.c_str());
return INTERNAL_ERROR;
}
// set stream_label


+ 23
- 0
tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc View File

@@ -23,6 +23,7 @@
#include "graph/passes/graph_builder_utils.h"
#include "graph/utils/attr_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/manager/graph_var_manager.h"

#define private public
#define protected public
@@ -285,4 +286,26 @@ TEST_F(UtestGraphPreproces, test_prepare_dyn_shape) {
GraphPrepare graph_prepare;
EXPECT_EQ(graph_prepare.PrepareDynShape(graph_node, user_input, compute_graph, 0), SUCCESS);
}

TEST_F(UtestGraphPreproces, test_updar_variable_formats) {
auto builder = ut::GraphBuilder("g1");
auto var = builder.AddNode("var", VARIABLE, 1, 1);
auto g1 = builder.GetGraph();
g1->SetSessionID(0);
TransNodeInfo trans_node_info;
VarTransRoad fusion_road;
fusion_road.emplace_back(trans_node_info);
VarManager::Instance(g1->GetSessionID())->SetTransRoad(var->GetName(), fusion_road);
GraphPrepare graph_prepare;
EXPECT_EQ(graph_prepare.UpdateVariableFormats(g1), INTERNAL_ERROR);

auto builder1 = ut::GraphBuilder("g2");
auto var1 = builder1.AddNode("var1", VARIABLE, 1, 1);
auto g2 = builder1.GetGraph();
g2->SetSessionID(0);
VarTransRoad fusion_road1;
VarManager::Instance(g2->GetSessionID())->SetTransRoad(var1->GetName(), fusion_road1);
AttrUtils::SetStr(var1->GetOpDesc(), REF_VAR_SRC_VAR_NAME, "var1");
EXPECT_EQ(graph_prepare.UpdateVariableFormats(g2), SUCCESS);
}
}

Loading…
Cancel
Save