Browse Source

add new node pass

pull/1554/head
unknown 4 years ago
parent
commit
bf88d8c57e
2 changed files with 3 additions and 1 deletions
  1. +1
    -0
      ge/common/types.cc
  2. +2
    -1
      ge/graph/passes/link_gen_mask_nodes_pass.cc

+ 1
- 0
ge/common/types.cc View File

@@ -92,6 +92,7 @@ REGISTER_OPTYPE_DEFINE(DROPOUTGENMASK, "DropOutGenMask");
REGISTER_OPTYPE_DEFINE(DROPOUTDOMASK, "DropOutDoMask");
REGISTER_OPTYPE_DEFINE(DROPOUTDOMASKV3, "DropOutDoMaskV3");
REGISTER_OPTYPE_DEFINE(DROPOUTDOMASKV3D, "DropOutDoMaskV3D");
REGISTER_OPTYPE_DEFINE(SOFTMAXV2WITHDROPOUTDOMASKV3D, "SoftmaxV2WithDropOutDoMaskV3D");
REGISTER_OPTYPE_DEFINE(CONCAT, "Concat");
REGISTER_OPTYPE_DEFINE(ROIPOOLING, "ROIPooling");
REGISTER_OPTYPE_DEFINE(PROPOSAL, "Proposal");


+ 2
- 1
ge/graph/passes/link_gen_mask_nodes_pass.cc View File

@@ -96,7 +96,8 @@ bool LinkGenMaskNodesPass::AreAllInputsConst(const NodePtr &node) const {
void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<NodePtr> &gen_mask_nodes) const {
set<NodePtr> nodes_set;
for (const NodePtr &node : graph->GetDirectNode()) {
if (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 && node->GetType() != DROPOUTDOMASKV3D) {
if (node->GetType() != DROPOUTDOMASK && node->GetType() != DROPOUTDOMASKV3 &&
node->GetType() != DROPOUTDOMASKV3D && node->GetType() != SOFTMAXV2WITHDROPOUTDOMASKV3D) {
continue;
}



Loading…
Cancel
Save