Browse Source

revert MergeInputMemcpyPass

tags/v1.3.0
zhangxiaokun 4 years ago
parent
commit
1001a8a859
3 changed files with 1 additions and 115 deletions
  1. +1
    -1
      ge/graph/passes/mark_branch_force_unknown_pass.cc
  2. +0
    -99
      ge/graph/passes/merge_input_memcpy_pass.cc
  3. +0
    -15
      ge/graph/passes/merge_input_memcpy_pass.h

+ 1
- 1
ge/graph/passes/mark_branch_force_unknown_pass.cc View File

@@ -40,7 +40,7 @@ Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) {
for (const auto &node : graph->GetDirectNode()) {
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed.");
if ((node_type != MERGE) && (node_type != REFMERGE)) {
if (kMergeOpTypes.count(node_type) == 0) {
continue;
}



+ 0
- 99
ge/graph/passes/merge_input_memcpy_pass.cc View File

@@ -16,18 +16,11 @@

#include "graph/passes/merge_input_memcpy_pass.h"

#include <queue>

#include "common/ge/ge_util.h"
#include "ge/ge_api_types.h"
#include "graph/common/omg_util.h"

namespace ge {
namespace {
const std::set<std::string> kLoopMergeInputs{
ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION
};
}
Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) {
GELOGD("MergeInputMemcpyPass Enter");
std::unordered_map<NodePtr, std::vector<NodePtr>> switch_groups;
@@ -41,10 +34,8 @@ Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(node->GetOpDesc());
GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, node, node->GetOpDesc()->HasAttr(ATTR_INSERT_BY_MBATCH)),
"Merge add memcpy node failed.");
CollectSwitchGroup(node, switch_groups);
}

MarkUnknownForSwitch(switch_groups);
GELOGD("MergeInputMemcpyPass Leave");
return SUCCESS;
}
@@ -114,94 +105,4 @@ NodePtr MergeInputMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph

return graph->AddNode(op_desc);
}

///
/// @brief Mark force unknown shape for Switch node
/// @param [in] merge node
/// @param [out] switch_groups
/// @return
///
void MergeInputMemcpyPass::CollectSwitchGroup(const NodePtr &node,
std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups) {
const auto &op_desc = node->GetOpDesc();
for (const auto &in_anchor : node->GetAllInDataAnchors()) {
const auto &src_out_anchor = in_anchor->GetPeerOutAnchor();
if (src_out_anchor == nullptr) {
continue;
}

std::string node_type;
GetOriginalType(src_out_anchor->GetOwnerNode(), node_type);
if (kLoopMergeInputs.count(node_type) > 0) {
return;
}
}

// Switch --> {Switch --> Merge} --> Merge
std::queue<std::pair<NodePtr, uint32_t>> search_queue;
search_queue.push({node, 0});
std::vector<NodePtr> &switch_group = switch_groups[node];
while (!search_queue.empty()) {
const auto dst_node = search_queue.front().first;
const auto dst_span = search_queue.front().second;
search_queue.pop();

// Switch --> Identity --> Constant
for (const auto &in_ctrl_node : dst_node->GetInControlNodes()) {
if (in_ctrl_node->GetType() == IDENTITY) {
GELOGD("Travel node: %s, In control: %s, span is: %u",
dst_node->GetName().c_str(), in_ctrl_node->GetName().c_str(), dst_span);
search_queue.push({in_ctrl_node, dst_span});
}
}

for (const auto &in_data_node : dst_node->GetInDataNodes()) {
std::string node_type;
GetOriginalType(in_data_node, node_type);
GELOGD("Travel node: %s, %s node: %s, span is: %u",
dst_node->GetName().c_str(), node_type.c_str(), in_data_node->GetName().c_str(), dst_span);
if (node_type == SWITCH || node_type == REFSWITCH) {
if (dst_span > 0) {
search_queue.push({in_data_node, dst_span - 1});
} else {
switch_group.emplace_back(in_data_node);
}
} else if (node_type == MERGE || node_type == REFMERGE) {
search_queue.push({in_data_node, dst_span + 1});
} else {
search_queue.push({in_data_node, dst_span});
}
}
}

if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)) {
GELOGI("Mark [%s] as for unknown shape, switch groups: %zu", node->GetName().c_str(), switch_groups.size());
MarkForceUnknownShape(node, true);
for (const auto &n : switch_group) {
MarkForceUnknownShape(n, true);
}
}
}

void MergeInputMemcpyPass::MarkUnknownForSwitch(const std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups) {
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) {
return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE);
};

for (const auto &item : switch_groups) {
const auto &node = item.first;
if (node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)) {
continue;
}

const std::vector<NodePtr> &switch_group = item.second;
if (std::any_of(switch_group.begin(), switch_group.end(), callback)) {
GELOGI("Mark [%s] as force unknown shape, switch nodes: %zu", node->GetName().c_str(), switch_group.size());
MarkForceUnknownShape(node, true);
for (const auto &n : switch_group) {
MarkForceUnknownShape(n, true);
}
}
}
}
} // namespace ge

+ 0
- 15
ge/graph/passes/merge_input_memcpy_pass.h View File

@@ -44,21 +44,6 @@ class MergeInputMemcpyPass : public GraphPass {
///
NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name,
const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag);

///
/// @brief Mark force unknown shape for Switch node
/// @param [in] merge node
/// @param [out] switch_groups
/// @return
///
void CollectSwitchGroup(const NodePtr &node, std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups);

///
/// @brief Mark force unknown shape for Switch node
/// @param [in] switch_groups
/// @return
///
void MarkUnknownForSwitch(const std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_MERGE_ADD_INPUT_MEMCPY_PASS_H_

Loading…
Cancel
Save