Browse Source

Pre Merge pull request !1521 from 丁世浩/master

pull/1521/MERGE
丁世浩 Gitee 4 years ago
parent
commit
d98f50cc37
1 changed files with 25 additions and 1 deletions
  1. +25
    -1
      ge/graph/passes/link_gen_mask_nodes_pass.cc

+ 25
- 1
ge/graph/passes/link_gen_mask_nodes_pass.cc View File

@@ -29,6 +29,8 @@ using std::vector;
namespace ge { namespace ge {
namespace { namespace {
const size_t kGenMaskInputIndex = 1; const size_t kGenMaskInputIndex = 1;
const size_t K_GEN_MASK_FUSED_INPUT_INDEX1 = 2;
const size_t K_GEN_MASK_FUSED_INPUT_INDEX2 = 3;
const size_t kDefaultMaxParallelNum = 1; const size_t kDefaultMaxParallelNum = 1;
} // namespace } // namespace


@@ -93,10 +95,29 @@ bool LinkGenMaskNodesPass::AreAllInputsConst(const NodePtr &node) const {
return true; return true;
} }


void GetMatMulFusionNodes(const NodePtr &node, NodePtr &gen_mask) {
// "batch_matmul + dropout_do_mask" is transformed to batch_matmul in a ub fusion pass
// node gen_mask is located at different place in the fused node
auto in_data_nodes = node->GetInDataNodes();
if (in_data_nodes.size() > K_GEN_MASK_FUSED_INPUT_INDEX1 && node->GetType() == "BatchMatMul") {
NodePtr &gen_mask_candidate = in_data_nodes.at(K_GEN_MASK_FUSED_INPUT_INDEX1);
if (gen_mask_candidate->GetName().find("DropOutGenMaskV3") != gen_mask_candidate->GetName().npos) {
gen_mask = gen_mask_candidate;
}
} else if (in_data_nodes.size() > K_GEN_MASK_FUSED_INPUT_INDEX2 && node->GetType() == "MatMulV2") {
NodePtr &gen_mask_candidate = in_data_nodes.at(K_GEN_MASK_FUSED_INPUT_INDEX2);
if (gen_mask_candidate->GetName().find("DropOutGenMaskV3") != gen_mask_candidate->GetName().npos) {
gen_mask = gen_mask_candidate;
}
}
}

void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<NodePtr> &gen_mask_nodes) const { void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<NodePtr> &gen_mask_nodes) const {
set<NodePtr> nodes_set; set<NodePtr> nodes_set;
for (const NodePtr &node : graph->GetDirectNode()) { for (const NodePtr &node : graph->GetDirectNode()) {
if (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && node->GetType() != DROPOUTDOMASKV3D) {
bool not_dropout_do_mask_flag = (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 &&
node->GetType() != DROPOUTDOMASKV3D && node->GetType() != "BatchMatMul" && node->GetType() != "MatMulV2");
if (not_dropout_do_mask_flag) {
continue; continue;
} }


@@ -107,6 +128,9 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node
auto in_data_nodes = node->GetInDataNodes(); auto in_data_nodes = node->GetInDataNodes();
if (in_data_nodes.size() > kGenMaskInputIndex) { if (in_data_nodes.size() > kGenMaskInputIndex) {
NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex);

GetMatMulFusionNodes(node, gen_mask);

if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) {
continue; continue;
} }


Loading…
Cancel
Save