From 6cdea5a3d91514ac1951c56b1c3a03d6c7a9f7b9 Mon Sep 17 00:00:00 2001 From: dingshihao2 Date: Fri, 16 Apr 2021 19:21:06 +0800 Subject: [PATCH] fix gen_mask control-edges bug --- ge/graph/passes/link_gen_mask_nodes_pass.cc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc index 14f5dfc3..27b12ffc 100755 --- a/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -28,7 +28,10 @@ using std::vector; namespace ge { namespace { +<<<<<<< Updated upstream const size_t kGenMaskInputIndex = 1; +======= +>>>>>>> Stashed changes const size_t kDefaultMaxParallelNum = 1; } // namespace @@ -105,8 +108,18 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vectorGetInDataNodes(); +<<<<<<< Updated upstream if (in_data_nodes.size() > kGenMaskInputIndex) { NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); +======= + for (auto &in_data_node : in_data_nodes) { + // node gen_mask is located at different place in the fused node + if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) { + continue; + } + NodePtr &gen_mask = in_data_node; + +>>>>>>> Stashed changes if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { continue; }