@@ -92,6 +92,7 @@ REGISTER_OPTYPE_DEFINE(DROPOUTGENMASK, "DropOutGenMask"); | |||||
REGISTER_OPTYPE_DEFINE(DROPOUTDOMASK, "DropOutDoMask"); | REGISTER_OPTYPE_DEFINE(DROPOUTDOMASK, "DropOutDoMask"); | ||||
REGISTER_OPTYPE_DEFINE(DROPOUTDOMASKV3, "DropOutDoMaskV3"); | REGISTER_OPTYPE_DEFINE(DROPOUTDOMASKV3, "DropOutDoMaskV3"); | ||||
REGISTER_OPTYPE_DEFINE(DROPOUTDOMASKV3D, "DropOutDoMaskV3D"); | REGISTER_OPTYPE_DEFINE(DROPOUTDOMASKV3D, "DropOutDoMaskV3D"); | ||||
REGISTER_OPTYPE_DEFINE(SOFTMAXV2WITHDROPOUTDOMASKV3D, "SoftmaxV2WithDropOutDoMaskV3D"); | |||||
REGISTER_OPTYPE_DEFINE(CONCAT, "Concat"); | REGISTER_OPTYPE_DEFINE(CONCAT, "Concat"); | ||||
REGISTER_OPTYPE_DEFINE(ROIPOOLING, "ROIPooling"); | REGISTER_OPTYPE_DEFINE(ROIPOOLING, "ROIPooling"); | ||||
REGISTER_OPTYPE_DEFINE(PROPOSAL, "Proposal"); | REGISTER_OPTYPE_DEFINE(PROPOSAL, "Proposal"); | ||||
@@ -96,7 +96,9 @@ bool LinkGenMaskNodesPass::AreAllInputsConst(const NodePtr &node) const { | |||||
void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<NodePtr> &gen_mask_nodes) const { | void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<NodePtr> &gen_mask_nodes) const { | ||||
set<NodePtr> nodes_set; | set<NodePtr> nodes_set; | ||||
for (const NodePtr &node : graph->GetDirectNode()) { | for (const NodePtr &node : graph->GetDirectNode()) { | ||||
if (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && node->GetType() != DROPOUTDOMASKV3D) { | |||||
bool not_domask = node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && | |||||
node->GetType() != DROPOUTDOMASKV3D && node->GetType() != SOFTMAXV2WITHDROPOUTDOMASKV3D; | |||||
if (not_domask) { | |||||
continue; | continue; | ||||
} | } | ||||
@@ -132,6 +132,7 @@ REGISTER_OPTYPE_DECLARE(DROPOUT, "Dropout"); | |||||
REGISTER_OPTYPE_DECLARE(DROPOUTDOMASK, "DropOutDoMask"); | REGISTER_OPTYPE_DECLARE(DROPOUTDOMASK, "DropOutDoMask"); | ||||
REGISTER_OPTYPE_DECLARE(DROPOUTDOMASKV3, "DropOutDoMaskV3"); | REGISTER_OPTYPE_DECLARE(DROPOUTDOMASKV3, "DropOutDoMaskV3"); | ||||
REGISTER_OPTYPE_DECLARE(DROPOUTDOMASKV3D, "DropOutDoMaskV3D"); | REGISTER_OPTYPE_DECLARE(DROPOUTDOMASKV3D, "DropOutDoMaskV3D"); | ||||
REGISTER_OPTYPE_DECLARE(SOFTMAXV2WITHDROPOUTDOMASKV3D, "SoftmaxV2WithDropOutDoMaskV3D"); | |||||
REGISTER_OPTYPE_DECLARE(DROPOUTGENMASK, "DropOutGenMask"); | REGISTER_OPTYPE_DECLARE(DROPOUTGENMASK, "DropOutGenMask"); | ||||
REGISTER_OPTYPE_DECLARE(CONCAT, "Concat"); | REGISTER_OPTYPE_DECLARE(CONCAT, "Concat"); | ||||
REGISTER_OPTYPE_DECLARE(ROIPOOLING, "ROIPooling"); | REGISTER_OPTYPE_DECLARE(ROIPOOLING, "ROIPooling"); | ||||