Browse Source

fix gen_mask control-edges bug

tags/v1.3.0
dingshihao2 4 years ago
parent
commit
6cdea5a3d9
1 changed files with 13 additions and 0 deletions
  1. +13
    -0
      ge/graph/passes/link_gen_mask_nodes_pass.cc

+ 13
- 0
ge/graph/passes/link_gen_mask_nodes_pass.cc View File

@@ -28,7 +28,10 @@ using std::vector;

namespace ge {
namespace {
<<<<<<< Updated upstream
const size_t kGenMaskInputIndex = 1;
=======
>>>>>>> Stashed changes
const size_t kDefaultMaxParallelNum = 1;
} // namespace

@@ -105,8 +108,18 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node
}

auto in_data_nodes = node->GetInDataNodes();
<<<<<<< Updated upstream
if (in_data_nodes.size() > kGenMaskInputIndex) {
NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex);
=======
for (auto &in_data_node : in_data_nodes) {
// node gen_mask is located at different place in the fused node
if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) {
continue;
}
NodePtr &gen_mask = in_data_node;

>>>>>>> Stashed changes
if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) {
continue;
}


Loading…
Cancel
Save