From 8a6ff82ccf9f3e418a9f57916a916d077ebe29c1 Mon Sep 17 00:00:00 2001 From: dingshihao2 Date: Wed, 14 Apr 2021 10:18:26 +0800 Subject: [PATCH] fix gen_mask control-edges bug --- ge/graph/passes/link_gen_mask_nodes_pass.cc | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc index 14f5dfc3..a58dbef2 100755 --- a/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -29,6 +29,8 @@ using std::vector; namespace ge { namespace { const size_t kGenMaskInputIndex = 1; +const size_t K_GEN_MASK_FUSED_INPUT_INDEX1 = 2; +const size_t K_GEN_MASK_FUSED_INPUT_INDEX2 = 3; const size_t kDefaultMaxParallelNum = 1; } // namespace @@ -93,10 +95,29 @@ bool LinkGenMaskNodesPass::AreAllInputsConst(const NodePtr &node) const { return true; } +void GetMatMulFusionNodes(const NodePtr &node, NodePtr &gen_mask) { + // "batch_matmul + dropout_do_mask" is transformed to batch_matmul in a ub fusion pass + // node gen_mask is located at different place in the fused node + auto in_data_nodes = node->GetInDataNodes(); + if (in_data_nodes.size() > K_GEN_MASK_FUSED_INPUT_INDEX1 && node->GetType() == "BatchMatMul") { + NodePtr &gen_mask_candidate = in_data_nodes.at(K_GEN_MASK_FUSED_INPUT_INDEX1); + if (gen_mask_candidate->GetName().find("DropOutGenMaskV3") != gen_mask_candidate->GetName().npos) { + gen_mask = gen_mask_candidate; + } + } else if (in_data_nodes.size() > K_GEN_MASK_FUSED_INPUT_INDEX2 && node->GetType() == "MatMulV2") { + NodePtr &gen_mask_candidate = in_data_nodes.at(K_GEN_MASK_FUSED_INPUT_INDEX2); + if (gen_mask_candidate->GetName().find("DropOutGenMaskV3") != gen_mask_candidate->GetName().npos) { + gen_mask = gen_mask_candidate; + } + } +} + void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector &gen_mask_nodes) const { set nodes_set; for (const NodePtr &node : graph->GetDirectNode()) { - if (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && node->GetType() != DROPOUTDOMASKV3D) { + bool not_dropout_do_mask_flag = (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && + node->GetType() != DROPOUTDOMASKV3D && node->GetType() != "BatchMatMul" && node->GetType() != "MatMulV2"); + if (not_dropout_do_mask_flag) { continue; } @@ -107,6 +128,9 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vectorGetInDataNodes(); if (in_data_nodes.size() > kGenMaskInputIndex) { NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); + + GetMatMulFusionNodes(node, gen_mask); + if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { continue; }