diff --git a/ge/common/types.cc b/ge/common/types.cc index 33b7f437..4aa7ce01 100644 --- a/ge/common/types.cc +++ b/ge/common/types.cc @@ -92,6 +92,7 @@ REGISTER_OPTYPE_DEFINE(DROPOUTGENMASK, "DropOutGenMask"); REGISTER_OPTYPE_DEFINE(DROPOUTDOMASK, "DropOutDoMask"); REGISTER_OPTYPE_DEFINE(DROPOUTDOMASKV3, "DropOutDoMaskV3"); REGISTER_OPTYPE_DEFINE(DROPOUTDOMASKV3D, "DropOutDoMaskV3D"); +REGISTER_OPTYPE_DEFINE(SOFTMAXV2WITHDROPOUTDOMASKV3D, "SoftmaxV2WithDropOutDoMaskV3D"); REGISTER_OPTYPE_DEFINE(CONCAT, "Concat"); REGISTER_OPTYPE_DEFINE(ROIPOOLING, "ROIPooling"); REGISTER_OPTYPE_DEFINE(PROPOSAL, "Proposal"); diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc index e00ede45..a50f22df 100755 --- a/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -96,7 +96,8 @@ bool LinkGenMaskNodesPass::AreAllInputsConst(const NodePtr &node) const { void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector &gen_mask_nodes) const { set nodes_set; for (const NodePtr &node : graph->GetDirectNode()) { - if (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && node->GetType() != DROPOUTDOMASKV3D) { + if (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && + node->GetType() != DROPOUTDOMASKV3D && node->GetType() != SOFTMAXV2WITHDROPOUTDOMASKV3D) { continue; }