diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc index 14f5dfc3..27b12ffc 100755 --- a/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -28,7 +28,10 @@ using std::vector; namespace ge { namespace { +<<<<<<< Updated upstream const size_t kGenMaskInputIndex = 1; +======= +>>>>>>> Stashed changes const size_t kDefaultMaxParallelNum = 1; } // namespace @@ -105,8 +108,18 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vectorGetInDataNodes(); +<<<<<<< Updated upstream if (in_data_nodes.size() > kGenMaskInputIndex) { NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); +======= + for (auto &in_data_node : in_data_nodes) { + // node gen_mask is located at different place in the fused node + if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) { + continue; + } + NodePtr &gen_mask = in_data_node; + +>>>>>>> Stashed changes if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { continue; }