|
|
@@ -16,18 +16,11 @@ |
|
|
|
|
|
|
|
#include "graph/passes/merge_input_memcpy_pass.h" |
|
|
|
|
|
|
|
#include <queue> |
|
|
|
|
|
|
|
#include "common/ge/ge_util.h" |
|
|
|
#include "ge/ge_api_types.h" |
|
|
|
#include "graph/common/omg_util.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
namespace { |
|
|
|
const std::set<std::string> kLoopMergeInputs{ |
|
|
|
ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION |
|
|
|
}; |
|
|
|
} |
|
|
|
Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { |
|
|
|
GELOGD("MergeInputMemcpyPass Enter"); |
|
|
|
std::unordered_map<NodePtr, std::vector<NodePtr>> switch_groups; |
|
|
@@ -41,10 +34,8 @@ Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { |
|
|
|
GE_CHECK_NOTNULL(node->GetOpDesc()); |
|
|
|
GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, node, node->GetOpDesc()->HasAttr(ATTR_INSERT_BY_MBATCH)), |
|
|
|
"Merge add memcpy node failed."); |
|
|
|
CollectSwitchGroup(node, switch_groups); |
|
|
|
} |
|
|
|
|
|
|
|
MarkUnknownForSwitch(switch_groups); |
|
|
|
GELOGD("MergeInputMemcpyPass Leave"); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
@@ -114,94 +105,4 @@ NodePtr MergeInputMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph |
|
|
|
|
|
|
|
return graph->AddNode(op_desc); |
|
|
|
} |
|
|
|
|
|
|
|
/// |
|
|
|
/// @brief Mark force unknown shape for Switch node |
|
|
|
/// @param [in] merge node |
|
|
|
/// @param [out] switch_groups |
|
|
|
/// @return |
|
|
|
/// |
|
|
|
void MergeInputMemcpyPass::CollectSwitchGroup(const NodePtr &node, |
|
|
|
std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups) { |
|
|
|
const auto &op_desc = node->GetOpDesc(); |
|
|
|
for (const auto &in_anchor : node->GetAllInDataAnchors()) { |
|
|
|
const auto &src_out_anchor = in_anchor->GetPeerOutAnchor(); |
|
|
|
if (src_out_anchor == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
std::string node_type; |
|
|
|
GetOriginalType(src_out_anchor->GetOwnerNode(), node_type); |
|
|
|
if (kLoopMergeInputs.count(node_type) > 0) { |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Switch --> {Switch --> Merge} --> Merge |
|
|
|
std::queue<std::pair<NodePtr, uint32_t>> search_queue; |
|
|
|
search_queue.push({node, 0}); |
|
|
|
std::vector<NodePtr> &switch_group = switch_groups[node]; |
|
|
|
while (!search_queue.empty()) { |
|
|
|
const auto dst_node = search_queue.front().first; |
|
|
|
const auto dst_span = search_queue.front().second; |
|
|
|
search_queue.pop(); |
|
|
|
|
|
|
|
// Switch --> Identity --> Constant |
|
|
|
for (const auto &in_ctrl_node : dst_node->GetInControlNodes()) { |
|
|
|
if (in_ctrl_node->GetType() == IDENTITY) { |
|
|
|
GELOGD("Travel node: %s, In control: %s, span is: %u", |
|
|
|
dst_node->GetName().c_str(), in_ctrl_node->GetName().c_str(), dst_span); |
|
|
|
search_queue.push({in_ctrl_node, dst_span}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &in_data_node : dst_node->GetInDataNodes()) { |
|
|
|
std::string node_type; |
|
|
|
GetOriginalType(in_data_node, node_type); |
|
|
|
GELOGD("Travel node: %s, %s node: %s, span is: %u", |
|
|
|
dst_node->GetName().c_str(), node_type.c_str(), in_data_node->GetName().c_str(), dst_span); |
|
|
|
if (node_type == SWITCH || node_type == REFSWITCH) { |
|
|
|
if (dst_span > 0) { |
|
|
|
search_queue.push({in_data_node, dst_span - 1}); |
|
|
|
} else { |
|
|
|
switch_group.emplace_back(in_data_node); |
|
|
|
} |
|
|
|
} else if (node_type == MERGE || node_type == REFMERGE) { |
|
|
|
search_queue.push({in_data_node, dst_span + 1}); |
|
|
|
} else { |
|
|
|
search_queue.push({in_data_node, dst_span}); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)) { |
|
|
|
GELOGI("Mark [%s] as for unknown shape, switch groups: %zu", node->GetName().c_str(), switch_groups.size()); |
|
|
|
MarkForceUnknownShape(node, true); |
|
|
|
for (const auto &n : switch_group) { |
|
|
|
MarkForceUnknownShape(n, true); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void MergeInputMemcpyPass::MarkUnknownForSwitch(const std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups) { |
|
|
|
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { |
|
|
|
return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); |
|
|
|
}; |
|
|
|
|
|
|
|
for (const auto &item : switch_groups) { |
|
|
|
const auto &node = item.first; |
|
|
|
if (node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
const std::vector<NodePtr> &switch_group = item.second; |
|
|
|
if (std::any_of(switch_group.begin(), switch_group.end(), callback)) { |
|
|
|
GELOGI("Mark [%s] as force unknown shape, switch nodes: %zu", node->GetName().c_str(), switch_group.size()); |
|
|
|
MarkForceUnknownShape(node, true); |
|
|
|
for (const auto &n : switch_group) { |
|
|
|
MarkForceUnknownShape(n, true); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace ge |