|
@@ -29,6 +29,8 @@ using std::vector; |
|
|
namespace ge { |
|
|
namespace ge { |
|
|
namespace { |
|
|
namespace { |
|
|
const size_t kGenMaskInputIndex = 1; |
|
|
const size_t kGenMaskInputIndex = 1; |
|
|
|
|
|
const size_t K_GEN_MASK_FUSED_INPUT_INDEX1 = 2; |
|
|
|
|
|
const size_t K_GEN_MASK_FUSED_INPUT_INDEX2 = 3; |
|
|
const size_t kDefaultMaxParallelNum = 1; |
|
|
const size_t kDefaultMaxParallelNum = 1; |
|
|
} // namespace |
|
|
} // namespace |
|
|
|
|
|
|
|
@@ -93,10 +95,29 @@ bool LinkGenMaskNodesPass::AreAllInputsConst(const NodePtr &node) const { |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void GetMatMulFusionNodes(const NodePtr &node, NodePtr &gen_mask) { |
|
|
|
|
|
// "batch_matmul + dropout_do_mask" is transformed to batch_matmul in a ub fusion pass |
|
|
|
|
|
// node gen_mask is located at different place in the fused node |
|
|
|
|
|
auto in_data_nodes = node->GetInDataNodes(); |
|
|
|
|
|
if (in_data_nodes.size() > K_GEN_MASK_FUSED_INPUT_INDEX1 && node->GetType() == "BatchMatMul") { |
|
|
|
|
|
NodePtr &gen_mask_candidate = in_data_nodes.at(K_GEN_MASK_FUSED_INPUT_INDEX1); |
|
|
|
|
|
if (gen_mask_candidate->GetName().find("DropOutGenMaskV3") != gen_mask_candidate->GetName().npos) { |
|
|
|
|
|
gen_mask = gen_mask_candidate; |
|
|
|
|
|
} |
|
|
|
|
|
} else if (in_data_nodes.size() > K_GEN_MASK_FUSED_INPUT_INDEX2 && node->GetType() == "MatMulV2") { |
|
|
|
|
|
NodePtr &gen_mask_candidate = in_data_nodes.at(K_GEN_MASK_FUSED_INPUT_INDEX2); |
|
|
|
|
|
if (gen_mask_candidate->GetName().find("DropOutGenMaskV3") != gen_mask_candidate->GetName().npos) { |
|
|
|
|
|
gen_mask = gen_mask_candidate; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<NodePtr> &gen_mask_nodes) const { |
|
|
void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<NodePtr> &gen_mask_nodes) const { |
|
|
set<NodePtr> nodes_set; |
|
|
set<NodePtr> nodes_set; |
|
|
for (const NodePtr &node : graph->GetDirectNode()) { |
|
|
for (const NodePtr &node : graph->GetDirectNode()) { |
|
|
if (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && node->GetType() != DROPOUTDOMASKV3D) { |
|
|
|
|
|
|
|
|
bool not_dropout_do_mask_flag = (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && |
|
|
|
|
|
node->GetType() != DROPOUTDOMASKV3D && node->GetType() != "BatchMatMul" && node->GetType() != "MatMulV2"); |
|
|
|
|
|
if (not_dropout_do_mask_flag) { |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@@ -107,6 +128,9 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node |
|
|
auto in_data_nodes = node->GetInDataNodes(); |
|
|
auto in_data_nodes = node->GetInDataNodes(); |
|
|
if (in_data_nodes.size() > kGenMaskInputIndex) { |
|
|
if (in_data_nodes.size() > kGenMaskInputIndex) { |
|
|
NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); |
|
|
NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); |
|
|
|
|
|
|
|
|
|
|
|
GetMatMulFusionNodes(node, gen_mask); |
|
|
|
|
|
|
|
|
if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { |
|
|
if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|