diff --git a/ge/common/types.cc b/ge/common/types.cc index 33b7f437..4aa7ce01 100644 --- a/ge/common/types.cc +++ b/ge/common/types.cc @@ -92,6 +92,7 @@ REGISTER_OPTYPE_DEFINE(DROPOUTGENMASK, "DropOutGenMask"); REGISTER_OPTYPE_DEFINE(DROPOUTDOMASK, "DropOutDoMask"); REGISTER_OPTYPE_DEFINE(DROPOUTDOMASKV3, "DropOutDoMaskV3"); REGISTER_OPTYPE_DEFINE(DROPOUTDOMASKV3D, "DropOutDoMaskV3D"); +REGISTER_OPTYPE_DEFINE(SOFTMAXV2WITHDROPOUTDOMASKV3D, "SoftmaxV2WithDropOutDoMaskV3D"); REGISTER_OPTYPE_DEFINE(CONCAT, "Concat"); REGISTER_OPTYPE_DEFINE(ROIPOOLING, "ROIPooling"); REGISTER_OPTYPE_DEFINE(PROPOSAL, "Proposal"); diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc index e00ede45..76df7af1 100755 --- a/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -96,7 +96,9 @@ bool LinkGenMaskNodesPass::AreAllInputsConst(const NodePtr &node) const { void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector &gen_mask_nodes) const { set nodes_set; for (const NodePtr &node : graph->GetDirectNode()) { - if (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && node->GetType() != DROPOUTDOMASKV3D) { + bool not_domask = node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && + node->GetType() != DROPOUTDOMASKV3D && node->GetType() != SOFTMAXV2WITHDROPOUTDOMASKV3D; + if (not_domask) { continue; } diff --git a/inc/framework/common/types.h b/inc/framework/common/types.h index 91759b8f..4242118d 100644 --- a/inc/framework/common/types.h +++ b/inc/framework/common/types.h @@ -132,6 +132,7 @@ REGISTER_OPTYPE_DECLARE(DROPOUT, "Dropout"); REGISTER_OPTYPE_DECLARE(DROPOUTDOMASK, "DropOutDoMask"); REGISTER_OPTYPE_DECLARE(DROPOUTDOMASKV3, "DropOutDoMaskV3"); REGISTER_OPTYPE_DECLARE(DROPOUTDOMASKV3D, "DropOutDoMaskV3D"); +REGISTER_OPTYPE_DECLARE(SOFTMAXV2WITHDROPOUTDOMASKV3D, "SoftmaxV2WithDropOutDoMaskV3D"); REGISTER_OPTYPE_DECLARE(DROPOUTGENMASK, "DropOutGenMask"); REGISTER_OPTYPE_DECLARE(CONCAT, "Concat"); REGISTER_OPTYPE_DECLARE(ROIPOOLING, "ROIPooling");