From 57cb65bbd0618b9038d9b52bf88430b8c7395541 Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 17 May 2021 11:21:20 +0800 Subject: [PATCH] add new node for link_gen_mask --- ge/common/types.cc | 1 + ge/graph/passes/link_gen_mask_nodes_pass.cc | 3 ++- inc/framework/common/types.h | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/ge/common/types.cc b/ge/common/types.cc index 33b7f437..4aa7ce01 100644 --- a/ge/common/types.cc +++ b/ge/common/types.cc @@ -92,6 +92,7 @@ REGISTER_OPTYPE_DEFINE(DROPOUTGENMASK, "DropOutGenMask"); REGISTER_OPTYPE_DEFINE(DROPOUTDOMASK, "DropOutDoMask"); REGISTER_OPTYPE_DEFINE(DROPOUTDOMASKV3, "DropOutDoMaskV3"); REGISTER_OPTYPE_DEFINE(DROPOUTDOMASKV3D, "DropOutDoMaskV3D"); +REGISTER_OPTYPE_DEFINE(SOFTMAXV2WITHDROPOUTDOMASKV3D, "SoftmaxV2WithDropOutDoMaskV3D"); REGISTER_OPTYPE_DEFINE(CONCAT, "Concat"); REGISTER_OPTYPE_DEFINE(ROIPOOLING, "ROIPooling"); REGISTER_OPTYPE_DEFINE(PROPOSAL, "Proposal"); diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc index e00ede45..a50f22df 100755 --- a/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -96,7 +96,8 @@ bool LinkGenMaskNodesPass::AreAllInputsConst(const NodePtr &node) const { void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector &gen_mask_nodes) const { set nodes_set; for (const NodePtr &node : graph->GetDirectNode()) { - if (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && node->GetType() != DROPOUTDOMASKV3D) { + if (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && + node->GetType() != DROPOUTDOMASKV3D && node->GetType() != SOFTMAXV2WITHDROPOUTDOMASKV3D) { continue; } diff --git a/inc/framework/common/types.h b/inc/framework/common/types.h index 91759b8f..4242118d 100644 --- a/inc/framework/common/types.h +++ b/inc/framework/common/types.h @@ -132,6 +132,7 @@ REGISTER_OPTYPE_DECLARE(DROPOUT, "Dropout"); REGISTER_OPTYPE_DECLARE(DROPOUTDOMASK, "DropOutDoMask"); REGISTER_OPTYPE_DECLARE(DROPOUTDOMASKV3, "DropOutDoMaskV3"); REGISTER_OPTYPE_DECLARE(DROPOUTDOMASKV3D, "DropOutDoMaskV3D"); +REGISTER_OPTYPE_DECLARE(SOFTMAXV2WITHDROPOUTDOMASKV3D, "SoftmaxV2WithDropOutDoMaskV3D"); REGISTER_OPTYPE_DECLARE(DROPOUTGENMASK, "DropOutGenMask"); REGISTER_OPTYPE_DECLARE(CONCAT, "Concat"); REGISTER_OPTYPE_DECLARE(ROIPOOLING, "ROIPooling");