Browse Source

fix gen_mask control-edges bug

tags/v1.3.0
dingshihao2 4 years ago
parent
commit
1a492ec892
1 changed files with 0 additions and 9 deletions
  1. +0
    -9
      ge/graph/passes/link_gen_mask_nodes_pass.cc

+ 0
- 9
ge/graph/passes/link_gen_mask_nodes_pass.cc View File

@@ -28,10 +28,6 @@ using std::vector;


namespace ge { namespace ge {
namespace { namespace {
<<<<<<< Updated upstream
const size_t kGenMaskInputIndex = 1;
=======
>>>>>>> Stashed changes
const size_t kDefaultMaxParallelNum = 1; const size_t kDefaultMaxParallelNum = 1;
} // namespace } // namespace


@@ -108,10 +104,6 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node
} }


auto in_data_nodes = node->GetInDataNodes(); auto in_data_nodes = node->GetInDataNodes();
<<<<<<< Updated upstream
if (in_data_nodes.size() > kGenMaskInputIndex) {
NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex);
=======
for (auto &in_data_node : in_data_nodes) { for (auto &in_data_node : in_data_nodes) {
// node gen_mask is located at different place in the fused node // node gen_mask is located at different place in the fused node
if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) { if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) {
@@ -119,7 +111,6 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node
} }
NodePtr &gen_mask = in_data_node; NodePtr &gen_mask = in_data_node;


>>>>>>> Stashed changes
if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) {
continue; continue;
} }


Loading…
Cancel
Save