@@ -24,11 +24,7 @@ Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { | |||||
FindNodes(graph); | FindNodes(graph); | ||||
for (const auto &node : need_label_nodes_) { | for (const auto &node : need_label_nodes_) { | ||||
OpDescPtr op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (!op_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { | |||||
GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); | |||||
} | |||||
GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); | |||||
} | } | ||||
GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed."); | GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed."); | ||||
@@ -55,13 +51,15 @@ Status AttachStreamLabelPass::ClearStatus() { | |||||
/// | /// | ||||
void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { | void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { | ||||
for (const NodePtr &node : graph->GetDirectNode()) { | for (const NodePtr &node : graph->GetDirectNode()) { | ||||
const std::string &type = node->GetType(); | |||||
if (type == STREAMSWITCH) { | |||||
const auto &op_desc = node->GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
continue; | |||||
} | |||||
const std::string &type = op_desc->GetType(); | |||||
if ((type == STREAMSWITCH) && op_desc->HasAttr(ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG)) { | |||||
stream_switch_nodes_.emplace_back(node); | stream_switch_nodes_.emplace_back(node); | ||||
} else if (type == STREAMMERGE) { | |||||
if ((node->GetOpDesc() != nullptr) && !node->GetOpDesc()->HasAttr(ATTR_NAME_NEXT_ITERATION)) { | |||||
need_label_nodes_.emplace_back(node); | |||||
} | |||||
} else if ((type == STREAMMERGE) && !op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { | |||||
need_label_nodes_.emplace_back(node); | |||||
} else if ((type == ENTER) || (type == REFENTER)) { | } else if ((type == ENTER) || (type == REFENTER)) { | ||||
enter_nodes_.emplace_back(node); | enter_nodes_.emplace_back(node); | ||||
} | } | ||||
@@ -83,11 +81,15 @@ void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { | |||||
/// | /// | ||||
Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { | Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { | ||||
std::string stream_label; | std::string stream_label; | ||||
if (AttachFlag(node, stream_label) != SUCCESS) { | |||||
GELOGE(FAILED, "Attach flag for node %s failed.", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
std::unordered_set<NodePtr> branch_nodes; | std::unordered_set<NodePtr> branch_nodes; | ||||
std::unordered_set<NodePtr> visited; | std::unordered_set<NodePtr> visited; | ||||
std::stack<NodePtr> nodes; | std::stack<NodePtr> nodes; | ||||
nodes.push(node); | nodes.push(node); | ||||
static const std::set<std::string> end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; | static const std::set<std::string> end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; | ||||
while (!nodes.empty()) { | while (!nodes.empty()) { | ||||
NodePtr cur_node = nodes.top(); | NodePtr cur_node = nodes.top(); | ||||
@@ -95,10 +97,7 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { | |||||
if (visited.count(cur_node) > 0) { | if (visited.count(cur_node) > 0) { | ||||
continue; | continue; | ||||
} | } | ||||
if (AttachFlag(cur_node, stream_label) != SUCCESS) { | |||||
GELOGE(FAILED, "Attach flag for node %s failed.", cur_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
const std::string &type = cur_node->GetType(); | const std::string &type = cur_node->GetType(); | ||||
for (const auto &out_node : cur_node->GetOutAllNodes()) { | for (const auto &out_node : cur_node->GetOutAllNodes()) { | ||||
@@ -115,10 +114,6 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { | |||||
visited.insert(cur_node); | visited.insert(cur_node); | ||||
} | } | ||||
if (node->GetType() == STREAMSWITCH) { | |||||
GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); | |||||
} | |||||
for (const NodePtr &tmp_node : branch_nodes) { | for (const NodePtr &tmp_node : branch_nodes) { | ||||
GELOGD("Attach label %s to node: %s.", stream_label.c_str(), tmp_node->GetName().c_str()); | GELOGD("Attach label %s to node: %s.", stream_label.c_str(), tmp_node->GetName().c_str()); | ||||
GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "Set stream label failed."); | GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "Set stream label failed."); | ||||
@@ -148,11 +143,10 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea | |||||
GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, | GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, | ||||
"StreamSwitch get attr TRUE_BRANCH_STREAM failed."); | "StreamSwitch get attr TRUE_BRANCH_STREAM failed."); | ||||
stream_label += (value ? "_t" : "_f"); | stream_label += (value ? "_t" : "_f"); | ||||
GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); | |||||
} else if (type == STREAMMERGE) { | } else if (type == STREAMMERGE) { | ||||
stream_label = node->GetName(); | stream_label = node->GetName(); | ||||
GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); | GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); | ||||
} else if ((type == EXIT) || (type == REFEXIT)) { | |||||
GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -166,12 +160,13 @@ Status AttachStreamLabelPass::UpdateEnterNode() { | |||||
std::unordered_map<NodePtr, std::vector<NodePtr>> enter_active_map; | std::unordered_map<NodePtr, std::vector<NodePtr>> enter_active_map; | ||||
for (const auto &enter_node : enter_nodes_) { | for (const auto &enter_node : enter_nodes_) { | ||||
for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { | for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { | ||||
if (out_ctrl_node->GetType() == STREAMACTIVE) { | |||||
if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) { | |||||
enter_active_map[out_ctrl_node] = {enter_node}; | |||||
} else { | |||||
enter_active_map[out_ctrl_node].emplace_back(enter_node); | |||||
} | |||||
if (out_ctrl_node->GetType() != STREAMACTIVE) { | |||||
continue; | |||||
} | |||||
if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) { | |||||
enter_active_map[out_ctrl_node] = {enter_node}; | |||||
} else { | |||||
enter_active_map[out_ctrl_node].emplace_back(enter_node); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -226,9 +221,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no | |||||
std::string stream_label; | std::string stream_label; | ||||
GE_CHECK_NOTNULL(active_node); | GE_CHECK_NOTNULL(active_node); | ||||
(void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); | (void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); | ||||
if (stream_label.empty()) { | if (stream_label.empty()) { | ||||
GELOGW("stream_label of enter_active & enter_nodes is empty."); | |||||
GELOGD("stream_label of enter_active %s is empty.", active_node->GetName().c_str()); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -238,7 +232,6 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no | |||||
GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); | GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); | ||||
} | } | ||||
} | } | ||||
GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "Set stream label failed."); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -37,6 +37,12 @@ Status CondRemovePass::Run(NodePtr &node) { | |||||
OutDataAnchorPtr cond_out_anchor = nullptr; | OutDataAnchorPtr cond_out_anchor = nullptr; | ||||
InDataAnchorPtr cond_in_anchor = nullptr; | InDataAnchorPtr cond_in_anchor = nullptr; | ||||
Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor); | Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor); | ||||
if (ret == NOT_CHANGED) { | |||||
return SUCCESS; | |||||
} else if (ret != SUCCESS) { | |||||
GELOGE(FAILED, "Get cond_info for node %s failed.", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
int32_t cond_index = 0; | int32_t cond_index = 0; | ||||
GELOGD("Handle cond remove for node %s.", node->GetOpDesc()->GetName().c_str()); | GELOGD("Handle cond remove for node %s.", node->GetOpDesc()->GetName().c_str()); | ||||
bool if_cond_const = CheckIfCondConstInput(cond_out_anchor, cond_in_anchor, cond_index); | bool if_cond_const = CheckIfCondConstInput(cond_out_anchor, cond_in_anchor, cond_index); | ||||
@@ -322,11 +328,11 @@ Status CondRemovePass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, | |||||
std::string type = node->GetType(); | std::string type = node->GetType(); | ||||
if ((kIfOpTypes.count(type) != 0) || (kCaseOpTypes.count(type) != 0)) { | if ((kIfOpTypes.count(type) != 0) || (kCaseOpTypes.count(type) != 0)) { | ||||
if (GetCondInfoForIfCase(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { | if (GetCondInfoForIfCase(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { | ||||
GELOGE(FAILED, "Get cond_info for if node failed."); | |||||
GELOGE(FAILED, "Get cond_info for if/case node failed."); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
} else { | } else { | ||||
GELOGD("no need cond_pass for node %s.", node->GetName().c_str()); | |||||
GELOGD("no need cond_remove_pass for node %s.", node->GetName().c_str()); | |||||
return NOT_CHANGED; | return NOT_CHANGED; | ||||
} | } | ||||
@@ -16,6 +16,7 @@ | |||||
#include "graph/passes/enter_pass.h" | #include "graph/passes/enter_pass.h" | ||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
@@ -72,33 +73,25 @@ Status EnterPass::Run(NodePtr &node) { | |||||
} | } | ||||
Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { | Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { | ||||
auto out_nodes_of_in_node = in_node->GetOutAllNodes(); | |||||
if (out_nodes_of_in_node.size() != kOutNodesNum) { | |||||
if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
if (!node->GetOutControlNodes().empty()) { | |||||
bool is_constant_flag = true; | |||||
(void)AttrUtils::GetBool(node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant_flag); | |||||
if (!is_constant_flag) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
for (const auto &out_node : node->GetOutDataNodes()) { | |||||
GE_CHECK_NOTNULL(out_node); | |||||
if (out_node->GetType() == MERGE) { | |||||
return SUCCESS; | |||||
} | |||||
} | |||||
GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); | GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); | ||||
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))); | GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))); | ||||
auto out_data_anchor = node->GetOutDataAnchor(0); | |||||
const auto &out_data_anchor = node->GetOutDataAnchor(0); | |||||
GE_CHECK_NOTNULL(out_data_anchor); | GE_CHECK_NOTNULL(out_data_anchor); | ||||
for (auto peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)); | GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)); | ||||
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)); | GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)); | ||||
} | } | ||||
auto graph = node->GetOwnerComputeGraph(); | |||||
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, node)) | |||||
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)); | |||||
AddNodeDeleted(node); | |||||
AddRePassNodesWithInOut(in_node); | AddRePassNodesWithInOut(in_node); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -137,7 +137,7 @@ Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &n | |||||
for_info.ctrl_inputs = std::move(ctrl_inputs); | for_info.ctrl_inputs = std::move(ctrl_inputs); | ||||
for_info.ctrl_outputs = std::move(ctrl_outputs); | for_info.ctrl_outputs = std::move(ctrl_outputs); | ||||
GELOGI("Build for_info for node %s succ.", node->GetName().c_str()); | |||||
GELOGI("Build for_info for node %s success.", node->GetName().c_str()); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -159,13 +159,7 @@ OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
if (peer_out_anchor == nullptr) { | |||||
GELOGE(FAILED, "FindInputWithIndex %s:%u failed: peer_out_anchor is NULL.", node->GetName().c_str(), index); | |||||
return nullptr; | |||||
} | |||||
return peer_out_anchor; | |||||
return in_data_anchor->GetPeerOutAnchor(); | |||||
} | } | ||||
/// | /// | ||||
@@ -186,20 +180,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnc | |||||
uint32_t input_data_num = node->GetAllInDataAnchorsSize(); | uint32_t input_data_num = node->GetAllInDataAnchorsSize(); | ||||
for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) { | for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) { | ||||
InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index); | InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index); | ||||
if (in_data_anchor == nullptr) { | |||||
GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index); | |||||
return FAILED; | |||||
} | |||||
GE_IF_BOOL_EXEC(in_data_anchor->GetPeerOutAnchor() == nullptr, | |||||
GELOGW("Get null input by index %d from node %s ", | |||||
in_data_anchor->GetIdx(), node->GetName().c_str()); | |||||
continue); | |||||
GE_CHECK_NOTNULL(in_data_anchor); | |||||
data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor()); | data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor()); | ||||
} | } | ||||
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||||
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||||
std::vector<ge::InDataAnchorPtr> peer_in_data_anchors; | std::vector<ge::InDataAnchorPtr> peer_in_data_anchors; | ||||
for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
peer_in_data_anchors.emplace_back(peer_in_data_anchor); | peer_in_data_anchors.emplace_back(peer_in_data_anchor); | ||||
} | } | ||||
data_outputs.emplace_back(peer_in_data_anchors); | data_outputs.emplace_back(peer_in_data_anchors); | ||||
@@ -207,13 +194,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnc | |||||
InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor(); | InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor(); | ||||
GE_CHECK_NOTNULL(in_ctrl_anchor); | GE_CHECK_NOTNULL(in_ctrl_anchor); | ||||
for (auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { | |||||
for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { | |||||
ctrl_inputs.emplace_back(peer_out_ctrl_anchor); | ctrl_inputs.emplace_back(peer_out_ctrl_anchor); | ||||
} | } | ||||
OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor(); | OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor(); | ||||
GE_CHECK_NOTNULL(out_ctrl_anchor); | GE_CHECK_NOTNULL(out_ctrl_anchor); | ||||
for (auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
for (const auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
ctrl_outputs.emplace_back(peer_in_ctrl_anchor); | ctrl_outputs.emplace_back(peer_in_ctrl_anchor); | ||||
} | } | ||||
@@ -21,16 +21,12 @@ | |||||
#include <vector> | #include <vector> | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "common/ge_inner_error_codes.h" | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/passes/pass_utils.h" | #include "graph/passes/pass_utils.h" | ||||
using domi::PARAM_INVALID; | |||||
using domi::SUCCESS; | |||||
namespace ge { | namespace ge { | ||||
const int kValueIndexOutputIndex = 1; | const int kValueIndexOutputIndex = 1; | ||||
@@ -52,8 +48,7 @@ Status MergePass::Run(NodePtr &node) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
auto out_data_anchors = node->GetAllOutDataAnchors(); | |||||
if (out_data_anchors.empty()) { | |||||
if (node->GetAllOutDataAnchors().empty()) { | |||||
GELOGE(PARAM_INVALID, "[%s] Merge node output anchor is empty", node->GetName().c_str()); | GELOGE(PARAM_INVALID, "[%s] Merge node output anchor is empty", node->GetName().c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -63,7 +58,7 @@ Status MergePass::Run(NodePtr &node) { | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
auto in_data_nodes = node->GetInDataNodes(); | |||||
const auto &in_data_nodes = node->GetInDataNodes(); | |||||
switch (in_data_nodes.size()) { | switch (in_data_nodes.size()) { | ||||
case 0: { | case 0: { | ||||
/// Case A: input_count = 0, the output of merge node is inactive as well | /// Case A: input_count = 0, the output of merge node is inactive as well | ||||
@@ -22,9 +22,6 @@ | |||||
#include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
using std::string; | |||||
using std::vector; | |||||
namespace ge { | namespace ge { | ||||
Status MultiBatchPass::Run(ComputeGraphPtr graph) { | Status MultiBatchPass::Run(ComputeGraphPtr graph) { | ||||
GELOGD("MultiBatchPass Enter"); | GELOGD("MultiBatchPass Enter"); | ||||
@@ -53,7 +50,7 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
std::vector<std::vector<int64_t>> batch_shape; | std::vector<std::vector<int64_t>> batch_shape; | ||||
vector<vector<int64_t>> combined_batch; | |||||
std::vector<std::vector<int64_t>> combined_batch; | |||||
if (!CheckSwitchN(batch_shape, combined_batch)) { | if (!CheckSwitchN(batch_shape, combined_batch)) { | ||||
GELOGE(FAILED, "CheckSwitchN failed."); | GELOGE(FAILED, "CheckSwitchN failed."); | ||||
return FAILED; | return FAILED; | ||||
@@ -104,6 +101,7 @@ Status MultiBatchPass::ClearStatus() { | |||||
/// | /// | ||||
Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) { | Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) { | ||||
const auto &func_desc = case_node->GetOpDesc(); | const auto &func_desc = case_node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(func_desc); | |||||
if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) { | if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) { | ||||
GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str()); | GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str()); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -114,7 +112,7 @@ Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr | |||||
const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]); | const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]); | ||||
GE_CHECK_NOTNULL(subgraph); | GE_CHECK_NOTNULL(subgraph); | ||||
const string batch_label = "Batch_" + std::to_string(i); | |||||
const std::string batch_label = "Batch_" + std::to_string(i); | |||||
for (const auto &node : subgraph->GetDirectNode()) { | for (const auto &node : subgraph->GetDirectNode()) { | ||||
(void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); | (void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); | ||||
} | } | ||||
@@ -139,12 +137,12 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor | |||||
continue; | continue; | ||||
} | } | ||||
InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||||
const auto &in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||||
if (in_data_anchor == nullptr) { | if (in_data_anchor == nullptr) { | ||||
GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str()); | GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
OutDataAnchorPtr pred_input = in_data_anchor->GetPeerOutAnchor(); | |||||
const auto &pred_input = in_data_anchor->GetPeerOutAnchor(); | |||||
if (pred_input == nullptr) { | if (pred_input == nullptr) { | ||||
GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str()); | GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str()); | ||||
return FAILED; | return FAILED; | ||||
@@ -178,12 +176,10 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status MultiBatchPass::GetDynamicType() { | Status MultiBatchPass::GetDynamicType() { | ||||
for (const auto &switchn : switch_n_nodes_) { | |||||
auto switchn_desc = switchn->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(switchn_desc); | |||||
for (const auto &switch_n : switch_n_nodes_) { | |||||
int32_t dynamic_type = static_cast<int32_t>(FIXED); | int32_t dynamic_type = static_cast<int32_t>(FIXED); | ||||
if (!AttrUtils::GetInt(switchn_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) { | |||||
GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switchn->GetName().c_str()); | |||||
if (!AttrUtils::GetInt(switch_n->GetOpDesc(), ATTR_DYNAMIC_TYPE, dynamic_type)) { | |||||
GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switch_n->GetName().c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
if (dynamic_type == static_cast<int32_t>(FIXED)) { | if (dynamic_type == static_cast<int32_t>(FIXED)) { | ||||
@@ -191,7 +187,7 @@ Status MultiBatchPass::GetDynamicType() { | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) { | if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) { | ||||
GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switchn node should be same, while one is %d and another is %d.", | |||||
GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switch_n node should be same, while one is %d and another is %d.", | |||||
dynamic_type, dynamic_type_); | dynamic_type, dynamic_type_); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -212,21 +208,19 @@ Status MultiBatchPass::GetDynamicType() { | |||||
Status MultiBatchPass::GetUserDesignateShape() { | Status MultiBatchPass::GetUserDesignateShape() { | ||||
data_name_order_.clear(); | data_name_order_.clear(); | ||||
bool first_check = true; | bool first_check = true; | ||||
for (const auto &switchn : switch_n_nodes_) { | |||||
auto switchn_desc = switchn->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(switchn_desc); | |||||
vector<string> cur_switchn_data_name_order; | |||||
if (!AttrUtils::GetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_switchn_data_name_order)) { | |||||
GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switchn->GetName().c_str()); | |||||
for (const auto &switch_n : switch_n_nodes_) { | |||||
std::vector<std::string> cur_data_name_order; | |||||
if (!AttrUtils::GetListStr(switch_n->GetOpDesc(), ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_data_name_order)) { | |||||
GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switch_n->GetName().c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
if (first_check) { | if (first_check) { | ||||
data_name_order_ = cur_switchn_data_name_order; | |||||
data_name_order_ = cur_data_name_order; | |||||
first_check = false; | first_check = false; | ||||
} else { | } else { | ||||
if (data_name_order_ != cur_switchn_data_name_order) { | |||||
if (data_name_order_ != cur_data_name_order) { | |||||
GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.", | GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.", | ||||
switchn->GetName().c_str()); | |||||
switch_n->GetName().c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
} | } | ||||
@@ -245,7 +239,8 @@ Status MultiBatchPass::GetUserDesignateShape() { | |||||
/// @param [out] combined_batch | /// @param [out] combined_batch | ||||
/// @return bool | /// @return bool | ||||
/// | /// | ||||
bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<vector<int64_t>> &combined_batch) { | |||||
bool MultiBatchPass::CheckSwitchN(std::vector<std::vector<int64_t>> &batch_shape, | |||||
std::vector<std::vector<int64_t>> &combined_batch) { | |||||
// Check if output_num of different SwitchN is same | // Check if output_num of different SwitchN is same | ||||
uint32_t batch_num = 0; | uint32_t batch_num = 0; | ||||
for (const NodePtr &node : switch_n_nodes_) { | for (const NodePtr &node : switch_n_nodes_) { | ||||
@@ -281,7 +276,8 @@ bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<v | |||||
} | } | ||||
size_t tmp_combined_dim_num = combined_batch[i].size(); | size_t tmp_combined_dim_num = combined_batch[i].size(); | ||||
if (combined_dim_num != tmp_combined_dim_num) { | if (combined_dim_num != tmp_combined_dim_num) { | ||||
GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num); | |||||
GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", | |||||
combined_dim_num, i, tmp_combined_dim_num); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -296,11 +292,11 @@ bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<v | |||||
/// @param [out] combined_batch | /// @param [out] combined_batch | ||||
/// @return bool | /// @return bool | ||||
/// | /// | ||||
bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &batch_shape, | |||||
vector<vector<int64_t>> &combined_batch) { | |||||
bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, std::vector<std::vector<int64_t>> &batch_shape, | |||||
std::vector<std::vector<int64_t>> &combined_batch) { | |||||
// Check if output_shape of different SwitchN is same | // Check if output_shape of different SwitchN is same | ||||
vector<vector<int64_t>> idx_batch_shape; | |||||
vector<vector<int64_t>> idx_combined_batch; | |||||
std::vector<std::vector<int64_t>> idx_batch_shape; | |||||
std::vector<std::vector<int64_t>> idx_combined_batch; | |||||
for (uint32_t i = 0; i < batch_num; i++) { | for (uint32_t i = 0; i < batch_num; i++) { | ||||
idx_batch_shape.clear(); | idx_batch_shape.clear(); | ||||
idx_combined_batch.clear(); | idx_combined_batch.clear(); | ||||
@@ -310,7 +306,7 @@ bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &b | |||||
GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str()); | GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str()); | ||||
return false; | return false; | ||||
} | } | ||||
vector<int64_t> output_dims; | |||||
std::vector<int64_t> output_dims; | |||||
if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) { | if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) { | ||||
GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i); | GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i); | ||||
return false; | return false; | ||||
@@ -385,8 +381,8 @@ Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, | Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, | ||||
const vector<vector<int64_t>> &batch_shape, | |||||
const vector<vector<int64_t>> &combined_batch) { | |||||
const std::vector<std::vector<int64_t>> &batch_shape, | |||||
const std::vector<std::vector<int64_t>> &combined_batch) { | |||||
NodePtr pred_value_node = pred_value->GetOwnerNode(); | NodePtr pred_value_node = pred_value->GetOwnerNode(); | ||||
// Create SwitchCase node | // Create SwitchCase node | ||||
const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; | const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; | ||||
@@ -429,31 +425,11 @@ bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_s | |||||
return false; | return false; | ||||
} | } | ||||
size_t num = output_shape.size(); | |||||
size_t dim_num = output_shape[0].size(); | |||||
for (size_t i = 1; i < num; i++) { | |||||
size_t tmp_dim_num = output_shape[i].size(); | |||||
if (dim_num != tmp_dim_num) { | |||||
GELOGE(FAILED, "CheckDims failed: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num); | |||||
for (auto iter = output_shape.begin() + 1; iter != output_shape.end(); ++iter) { | |||||
if (output_shape[0] != *iter) { | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
if (dim_num == 0) { | |||||
return true; | |||||
} | |||||
for (size_t i = 0; i < dim_num; i++) { | |||||
int64_t dim_value = output_shape[0][i]; | |||||
for (size_t j = 1; j < num; j++) { | |||||
int64_t tmp_dim_value = output_shape[j][i]; | |||||
if (dim_value != tmp_dim_value) { | |||||
GELOGE(FAILED, "CheckDims failed: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i, | |||||
dim_value, j, tmp_dim_value); | |||||
return false; | |||||
} | |||||
} | |||||
} | |||||
return true; | return true; | ||||
} | } | ||||
@@ -468,8 +444,8 @@ bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_s | |||||
/// | /// | ||||
NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, | NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, | ||||
const OutDataAnchorPtr &pred_value, | const OutDataAnchorPtr &pred_value, | ||||
const vector<vector<int64_t>> &batch_shape, | |||||
const vector<vector<int64_t>> &combined_batch) { | |||||
const std::vector<std::vector<int64_t>> &batch_shape, | |||||
const std::vector<std::vector<int64_t>> &combined_batch) { | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN); | OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN); | ||||
if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str()); | GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str()); | ||||
@@ -512,7 +488,7 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const | |||||
GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str()); | GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str()); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
const string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); | |||||
const std::string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); | |||||
if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) { | if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) { | ||||
GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str()); | GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str()); | ||||
return nullptr; | return nullptr; | ||||
@@ -72,25 +72,26 @@ Status SwitchToStreamSwitchPass::CheckCycleDependence(const ComputeGraphPtr &gra | |||||
std::unordered_map<NodePtr, std::vector<NodePtr>> cond_switch_map; | std::unordered_map<NodePtr, std::vector<NodePtr>> cond_switch_map; | ||||
for (const NodePtr &node : graph->GetDirectNode()) { | for (const NodePtr &node : graph->GetDirectNode()) { | ||||
GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed."); | GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed."); | ||||
if ((type == SWITCH) || (type == REFSWITCH)) { | |||||
InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||||
GE_CHECK_NOTNULL(in_cond_anchor); | |||||
OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); | |||||
GE_CHECK_NOTNULL(peer_out_anchor); | |||||
if (FindSwitchCondInput(true, peer_out_anchor) != SUCCESS) { | |||||
GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
if ((type != SWITCH) && (type != REFSWITCH)) { | |||||
continue; | |||||
} | |||||
InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); | |||||
GE_CHECK_NOTNULL(in_cond_anchor); | |||||
OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); | |||||
GE_CHECK_NOTNULL(peer_out_anchor); | |||||
if (FindSwitchCondInput(peer_out_anchor) != SUCCESS) { | |||||
GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
NodePtr cond_node = peer_out_anchor->GetOwnerNode(); | |||||
auto iter = cond_switch_map.find(cond_node); | |||||
if (iter == cond_switch_map.end()) { | |||||
cond_switch_map[cond_node] = { node }; | |||||
} else { | |||||
iter->second.emplace_back(node); | |||||
} | |||||
switch_nodes_.emplace_back(node); | |||||
NodePtr cond_node = peer_out_anchor->GetOwnerNode(); | |||||
auto iter = cond_switch_map.find(cond_node); | |||||
if (iter == cond_switch_map.end()) { | |||||
cond_switch_map[cond_node] = { node }; | |||||
} else { | |||||
iter->second.emplace_back(node); | |||||
} | } | ||||
switch_nodes_.emplace_back(node); | |||||
} | } | ||||
MarkCycleDependence(cond_switch_map); | MarkCycleDependence(cond_switch_map); | ||||
@@ -241,10 +242,6 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou | |||||
if (idx == SWITCH_DATA_INPUT) { | if (idx == SWITCH_DATA_INPUT) { | ||||
peer_data_anchor = peer_out_anchor; | peer_data_anchor = peer_out_anchor; | ||||
} else { | } else { | ||||
if (FindSwitchCondInput(false, peer_out_anchor) != SUCCESS) { | |||||
GELOGE(FAILED, "Find pred_input for switch_node %s failed.", switch_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
peer_cond_anchor = peer_out_anchor; | peer_cond_anchor = peer_out_anchor; | ||||
} | } | ||||
} | } | ||||
@@ -254,15 +251,14 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou | |||||
/// | /// | ||||
/// @brief Find Switch cond input | /// @brief Find Switch cond input | ||||
/// @param [in] pass_switch_flag | |||||
/// @param [out] peer_cond_anchor | /// @param [out] peer_cond_anchor | ||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor) { | |||||
Status SwitchToStreamSwitchPass::FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor) { | |||||
NodePtr tmp_node = nullptr; | NodePtr tmp_node = nullptr; | ||||
string type; | |||||
bool need_pass_type = true; | |||||
while (need_pass_type) { | |||||
std::string type; | |||||
bool pass_flag = true; | |||||
while (pass_flag) { | |||||
if (tmp_node == nullptr) { | if (tmp_node == nullptr) { | ||||
tmp_node = peer_cond_anchor->GetOwnerNode(); | tmp_node = peer_cond_anchor->GetOwnerNode(); | ||||
} else { | } else { | ||||
@@ -274,7 +270,7 @@ Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutD | |||||
} | } | ||||
GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed."); | GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed."); | ||||
need_pass_type = (pass_switch_flag && ((type == SWITCH) || (type == REFSWITCH))); | |||||
pass_flag = ((type == SWITCH) || (type == REFSWITCH)); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -369,7 +365,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_ | |||||
} | } | ||||
} else { | } else { | ||||
int64_t switch_group_id = GetGroupId(stream_switch); | int64_t switch_group_id = GetGroupId(stream_switch); | ||||
map<int64_t, std::vector<std::list<NodePtr>>> switch_group_map; | |||||
std::map<int64_t, std::vector<std::list<NodePtr>>> switch_group_map; | |||||
std::list<NodePtr> false_node_list; | std::list<NodePtr> false_node_list; | ||||
std::list<NodePtr> true_node_list; | std::list<NodePtr> true_node_list; | ||||
std::list<NodePtr> &node_list = true_branch_flag ? true_node_list : false_node_list; | std::list<NodePtr> &node_list = true_branch_flag ? true_node_list : false_node_list; | ||||
@@ -389,7 +385,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_ | |||||
/// @return group_id | /// @return group_id | ||||
/// | /// | ||||
int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | ||||
string tailing_optimization_option; | |||||
std::string tailing_optimization_option; | |||||
bool is_tailing_optimization = false; | bool is_tailing_optimization = false; | ||||
if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) { | if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) { | ||||
// "1" means it's True from frontend option | // "1" means it's True from frontend option | ||||
@@ -400,7 +396,7 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | |||||
return 0; | return 0; | ||||
} | } | ||||
string hccl_group_id; | |||||
std::string hccl_group_id; | |||||
if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { | if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { | ||||
GELOGI("Node %s can not find hccl group id.", node->GetName().c_str()); | GELOGI("Node %s can not find hccl group id.", node->GetName().c_str()); | ||||
return 0; | return 0; | ||||
@@ -432,6 +428,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||||
same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); | same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); | ||||
OutDataAnchorPtr peer_cond_anchor = iter->first; | OutDataAnchorPtr peer_cond_anchor = iter->first; | ||||
GE_CHECK_NOTNULL(peer_cond_anchor); | |||||
NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); | NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); | ||||
GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str()); | GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str()); | ||||
@@ -549,6 +546,7 @@ NodePtr SwitchToStreamSwitchPass::CreateCastOp(const ComputeGraphPtr &graph, con | |||||
NodePtr cast_node = graph->AddNode(cast_desc); | NodePtr cast_node = graph->AddNode(cast_desc); | ||||
GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node failed."); | GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node failed."); | ||||
// Cast node has and only has one input | |||||
GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed."); | GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed."); | ||||
return cast_node; | return cast_node; | ||||
@@ -614,24 +612,24 @@ Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_no | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
for (const NodePtr &in_ctl_node : switch_node->GetInControlNodes()) { | |||||
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), | |||||
for (const NodePtr &in_ctrl_node : switch_node->GetInControlNodes()) { | |||||
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), | |||||
"Remove ctl edge failed."); | "Remove ctl edge failed."); | ||||
GE_IF_BOOL_EXEC(!in_ctl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { | |||||
GE_CHK_STATUS(GraphUtils::AddEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||||
GE_IF_BOOL_EXEC(!in_ctrl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { | |||||
GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||||
"Add ctl edge failed."); | "Add ctl edge failed."); | ||||
}); | }); | ||||
GE_IF_BOOL_EXEC(in_ctl_node->GetType() != STREAMSWITCH, continue); | |||||
if (same_cond_switch.count(in_ctl_node) > 0) { | |||||
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||||
GE_IF_BOOL_EXEC(in_ctrl_node->GetType() != STREAMSWITCH, continue); | |||||
if (same_cond_switch.count(in_ctrl_node) > 0) { | |||||
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), | |||||
"Remove ctl edge failed."); | "Remove ctl edge failed."); | ||||
continue; | continue; | ||||
} | } | ||||
auto find_res1 = switch_node_map_.find(in_ctl_node); | |||||
auto find_res1 = switch_node_map_.find(in_ctrl_node); | |||||
GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { | GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { | ||||
GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctl_node->GetName().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctrl_node->GetName().c_str()); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
}); | }); | ||||
auto find_res2 = find_res1->second.find(orig_switch_name); | auto find_res2 = find_res1->second.find(orig_switch_name); | ||||
@@ -131,11 +131,10 @@ class SwitchToStreamSwitchPass : public GraphPass { | |||||
/// | /// | ||||
/// @brief Find Switch cond input | /// @brief Find Switch cond input | ||||
/// @param [in] pass_switch_flag | |||||
/// @param [out] peer_cond_anchor | /// @param [out] peer_cond_anchor | ||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor); | |||||
Status FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor); | |||||
/// | /// | ||||
/// @brief Create StreamSwitch Node | /// @brief Create StreamSwitch Node | ||||