|
@@ -28,10 +28,6 @@ using std::vector; |
|
|
|
|
|
|
|
|
namespace ge { |
|
|
namespace ge { |
|
|
namespace { |
|
|
namespace { |
|
|
<<<<<<< Updated upstream |
|
|
|
|
|
const size_t kGenMaskInputIndex = 1; |
|
|
|
|
|
======= |
|
|
|
|
|
>>>>>>> Stashed changes |
|
|
|
|
|
const size_t kDefaultMaxParallelNum = 1; |
|
|
const size_t kDefaultMaxParallelNum = 1; |
|
|
} // namespace |
|
|
} // namespace |
|
|
|
|
|
|
|
@@ -108,10 +104,6 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
auto in_data_nodes = node->GetInDataNodes(); |
|
|
auto in_data_nodes = node->GetInDataNodes(); |
|
|
<<<<<<< Updated upstream |
|
|
|
|
|
if (in_data_nodes.size() > kGenMaskInputIndex) { |
|
|
|
|
|
NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); |
|
|
|
|
|
======= |
|
|
|
|
|
for (auto &in_data_node : in_data_nodes) { |
|
|
for (auto &in_data_node : in_data_nodes) { |
|
|
// node gen_mask is located at different place in the fused node |
|
|
// node gen_mask is located at different place in the fused node |
|
|
if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) { |
|
|
if (in_data_node->GetName().find(DROPOUTGENMASK) == in_data_node->GetName().npos) { |
|
@@ -119,7 +111,6 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vector<Node |
|
|
} |
|
|
} |
|
|
NodePtr &gen_mask = in_data_node; |
|
|
NodePtr &gen_mask = in_data_node; |
|
|
|
|
|
|
|
|
>>>>>>> Stashed changes |
|
|
|
|
|
if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { |
|
|
if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|