Browse Source

Fix DSP for v1 control flow

pull/1702/head
zhangxiaokun chenyemeng 4 years ago
parent
commit
8ab99be735
14 changed files with 190 additions and 73 deletions
  1. +16
    -4
      ge/graph/common/omg_util.cc
  2. +2
    -1
      ge/graph/common/omg_util.h
  3. +0
    -3
      ge/graph/manager/graph_manager.cc
  4. +39
    -21
      ge/graph/partition/dynamic_shape_partition.cc
  5. +6
    -3
      ge/graph/partition/dynamic_shape_partition.h
  6. +46
    -17
      ge/graph/passes/mark_force_unknown_for_cond_pass.cc
  7. +4
    -3
      ge/graph/passes/merge_to_stream_merge_pass.cc
  8. +14
    -11
      ge/graph/passes/next_iteration_pass.cc
  9. +2
    -1
      ge/graph/passes/next_iteration_pass.h
  10. +8
    -5
      ge/graph/passes/switch_to_stream_switch_pass.cc
  11. +13
    -0
      ge/graph/preprocess/graph_preprocess.cc
  12. +1
    -0
      ge/graph/preprocess/graph_preprocess.h
  13. +37
    -4
      ge/hybrid/executor/node_state.cc
  14. +2
    -0
      ge/hybrid/executor/node_state.h

+ 16
- 4
ge/graph/common/omg_util.cc View File

@@ -272,20 +272,32 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) {
/// @brief Set Op _force_unknown_shape flag /// @brief Set Op _force_unknown_shape flag
/// @param [in] node /// @param [in] node
/// @param [in] force_unknown, set attribute if true /// @param [in] force_unknown, set attribute if true
/// @param [in] group_index, condition group index of node.
/// @return /// @return
/// ///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown) {
GE_RT_VOID_CHECK_NOTNULL(node);
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) {
if (!force_unknown) { if (!force_unknown) {
return; return;
} }


GELOGD("[%s] mark as force unknown shape node", node->GetName().c_str());
if (!AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) {
GE_RT_VOID_CHECK_NOTNULL(node);
const auto &op_desc = node->GetOpDesc();
GE_RT_VOID_CHECK_NOTNULL(op_desc);

// op_desc as AttrHolderAdapter valid, Set attribute always success, just log for check.
GELOGD("Mark [%s] as force unknown shape node, group index: %ld", node->GetName().c_str(), group_index);
if (!AttrUtils::SetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) {
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(),
node->GetName().c_str(), node->GetType().c_str()); node->GetName().c_str(), node->GetType().c_str());
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(),
node->GetName().c_str(), node->GetType().c_str()); node->GetName().c_str(), node->GetType().c_str());
} }

if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(),
node->GetName().c_str(), node->GetType().c_str());
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(),
node->GetName().c_str(), node->GetType().c_str());
}
} }
} // namespace ge } // namespace ge

+ 2
- 1
ge/graph/common/omg_util.h View File

@@ -129,9 +129,10 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc);
/// @brief Set Op _force_unknown_shape flag /// @brief Set Op _force_unknown_shape flag
/// @param [in] node /// @param [in] node
/// @param [in] force_unknown, set attribute if true /// @param [in] force_unknown, set attribute if true
/// @param [in] group_index, condition group index of node.
/// @return /// @return
/// ///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown);
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index);
} // namespace ge } // namespace ge


#endif // GE_GRAPH_COMMON_OMG_UTIL_H_ #endif // GE_GRAPH_COMMON_OMG_UTIL_H_

+ 0
- 3
ge/graph/manager/graph_manager.cc View File

@@ -65,7 +65,6 @@
#include "graph/passes/merge_pass.h" #include "graph/passes/merge_pass.h"
#include "graph/passes/merge_input_memcpy_pass.h" #include "graph/passes/merge_input_memcpy_pass.h"
#include "graph/passes/merge_to_stream_merge_pass.h" #include "graph/passes/merge_to_stream_merge_pass.h"
#include "graph/passes/mark_force_unknown_for_cond_pass.h"
#include "graph/passes/multi_batch_pass.h" #include "graph/passes/multi_batch_pass.h"
#include "graph/passes/next_iteration_pass.h" #include "graph/passes/next_iteration_pass.h"
#include "graph/passes/permute_pass.h" #include "graph/passes/permute_pass.h"
@@ -2582,8 +2581,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass));
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass));
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)); GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass));
auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass;
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkForceUnknownForCondPass", mark_force_unknown_pass));
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass))
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass))
GE_CHK_STATUS_RET( GE_CHK_STATUS_RET(


+ 39
- 21
ge/graph/partition/dynamic_shape_partition.cc View File

@@ -46,11 +46,6 @@
#define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) #define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__)


namespace ge { namespace ge {
namespace {
const std::set<std::string> kControlFlowOps{
STREAMACTIVE, STREAMSWITCH, STREAMMERGE, ENTER, REFENTER, LOOPCOND, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT
};
}
using Cluster = DynamicShapePartitioner::Cluster; using Cluster = DynamicShapePartitioner::Cluster;
using ClusterPtr = std::shared_ptr<Cluster>; using ClusterPtr = std::shared_ptr<Cluster>;


@@ -279,9 +274,17 @@ Status DynamicShapePartitioner::InitClusters() {
auto cluster = MakeShared<Cluster>(rank++, type, node, this); auto cluster = MakeShared<Cluster>(rank++, type, node, this);
REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster."); REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster.");
node_2_cluster_[node] = cluster; node_2_cluster_[node] = cluster;
if (cluster->IsUnknownShape() && !cluster->IsControlFlow()) {
if (cluster->IsUnknownShape()) {
ordered_cluster_.push_back(cluster); ordered_cluster_.push_back(cluster);
} }

int64_t group_index = -1;
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
GELOGD("[%s] is rts control flow Op, group index: %ld", node->GetName().c_str(), group_index);
auto &control_cluster = control_clusters_[group_index];
control_cluster.emplace_back(cluster);
}

// Already sorted topologically, so access to the parent cluster is safe // Already sorted topologically, so access to the parent cluster is safe
for (const auto &parent : node->GetInAllNodes()) { for (const auto &parent : node->GetInAllNodes()) {
cluster->AddInput(node_2_cluster_[parent]); cluster->AddInput(node_2_cluster_[parent]);
@@ -350,14 +353,38 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) {
} }
} }


void DynamicShapePartitioner::MergeClustersControlFlow() {
for (const auto &item : control_clusters_) {
const auto &control_cluster = item.second;
auto rit = control_cluster.rbegin();
if (rit == control_cluster.rend()) {
GELOGW("Invalid empty control flow cluster.");
continue;
}

const auto &cluster = *rit;
for (++rit; rit != control_cluster.rend(); ++rit) {
const auto &cluster_from = *rit;
auto merged_clusters = cluster->MergeAllPathFrom(cluster_from);
GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(),
ToString(merged_clusters).c_str());
for (const auto &merged_cluster : merged_clusters) {
for (const auto &node : merged_cluster->Nodes()) {
node_2_cluster_[node] = cluster;
}
}
}
}
}

void DynamicShapePartitioner::MergeClustersUnknownShape() { void DynamicShapePartitioner::MergeClustersUnknownShape() {
// Merge unknown shape clusters // Merge unknown shape clusters
for (const auto &cluster : ordered_cluster_) { for (const auto &cluster : ordered_cluster_) {
if (cluster->IsIndependent() || cluster->IsControlFlow()) {
if (cluster->IsIndependent()) {
continue; continue;
} }
for (const auto &in_cluster : cluster->Inputs()) { for (const auto &in_cluster : cluster->Inputs()) {
if (!in_cluster->IsUnknownShape() || in_cluster->IsControlFlow()) {
if (!in_cluster->IsUnknownShape()) {
continue; continue;
} }
auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); auto merged_clusters = cluster->MergeAllPathFrom(in_cluster);
@@ -419,6 +446,7 @@ void DynamicShapePartitioner::MergeClustersInputData() {
} }


Status DynamicShapePartitioner::MergeClusters() { Status DynamicShapePartitioner::MergeClusters() {
MergeClustersControlFlow();
MergeClustersUnknownShape(); MergeClustersUnknownShape();
REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters."); REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters.");
MergeClustersKnownShape(); MergeClustersKnownShape();
@@ -608,13 +636,6 @@ bool Cluster::IsRefVariable() const {
return false; return false;
} }


bool Cluster::IsControlFlow() const {
const auto &op_desc = nodes_[0]->GetOpDesc();
bool is_ctrl_flow = kControlFlowOps.count(op_desc->GetType()) > 0 && op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE);
GELOGD("[%s] %s rts control flow Op ", op_desc->GetName().c_str(), is_ctrl_flow ? "Is" : "Not");
return is_ctrl_flow;
}

void Cluster::AddInput(ClusterPtr in) { void Cluster::AddInput(ClusterPtr in) {
if (std::find(in_clusters_.begin(), in_clusters_.end(), in) != in_clusters_.end()) return; if (std::find(in_clusters_.begin(), in_clusters_.end(), in) != in_clusters_.end()) return;
in_clusters_.insert(in_clusters_.end(), in); in_clusters_.insert(in_clusters_.end(), in);
@@ -694,10 +715,7 @@ std::vector<ClusterPtr> Cluster::MergeAllPathFrom(ClusterPtr other) {
if (other->IsIndependent()) { if (other->IsIndependent()) {
return path_clusters; return path_clusters;
} }
if (std::find(other->out_clusters_.begin(), other->out_clusters_.end(), shared_from_this()) ==
other->out_clusters_.end()) {
return path_clusters;
}

path_clusters.push_back(other); path_clusters.push_back(other);
forward_reached_queue.push(other); forward_reached_queue.push(other);
backward_reached_queue.push(shared_from_this()); backward_reached_queue.push(shared_from_this());
@@ -761,7 +779,7 @@ InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_->
OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); }; OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); };


Status Cluster::BuildFrame() { Status Cluster::BuildFrame() {
if ((IsUnknownShape() || IsKnownShape() || IsInputNode()) && !IsControlFlow()) {
if (IsUnknownShape() || IsKnownShape() || IsInputNode()) {
return BuildPartitionFrame(); return BuildPartitionFrame();
} else { } else {
auto node = nodes_.front(); auto node = nodes_.front();
@@ -896,7 +914,7 @@ Status Cluster::CombinePartitionFrame() {
} }


Status Cluster::BuildPartitionSubgraph() { Status Cluster::BuildPartitionSubgraph() {
if (IsData() || IsNetOutput() || IsIndependent() || IsControlFlow()) {
if (IsData() || IsNetOutput() || IsIndependent()) {
return SUCCESS; return SUCCESS;
} }
int64_t parent_node_index = 0; int64_t parent_node_index = 0;


+ 6
- 3
ge/graph/partition/dynamic_shape_partition.h View File

@@ -47,7 +47,6 @@ class DynamicShapePartitioner {
bool IsUnknownShape() const; bool IsUnknownShape() const;
bool IsIndependent() const; bool IsIndependent() const;
bool IsNetOutput() const; bool IsNetOutput() const;
bool IsControlFlow() const;
std::vector<std::shared_ptr<Cluster>> Inputs() const; std::vector<std::shared_ptr<Cluster>> Inputs() const;
std::vector<std::shared_ptr<Cluster>> Outputs() const; std::vector<std::shared_ptr<Cluster>> Outputs() const;
bool IsInputNode() const; bool IsInputNode() const;
@@ -126,13 +125,15 @@ class DynamicShapePartitioner {
// and there's only one path between the two clusters , merge the two clusters // and there's only one path between the two clusters , merge the two clusters
// 3) Iterate through the INPUT_DATA clusters, merge all INPUT_DATA // 3) Iterate through the INPUT_DATA clusters, merge all INPUT_DATA
Status MergeClusters(); Status MergeClusters();
// Merge clusters step0
void MergeClustersControlFlow();
// Merge clusters step1 // Merge clusters step1
void MergeClustersUnknownShape(); void MergeClustersUnknownShape();
// Merge clusters step2 // Merge clusters step2
void MergeClustersKnownShape(); void MergeClustersKnownShape();
// Merge clusters step3 // Merge clusters step3
void MergeClustersInputData(); void MergeClustersInputData();
// Topological sort clusters after merge unknow shape clusters.
// Topological sort clusters after merge unknown shape clusters.
Status TopologicalSortClusters(); Status TopologicalSortClusters();
// Deduplicate merged clusters // Deduplicate merged clusters
void PruneUniqueClusters(); void PruneUniqueClusters();
@@ -140,7 +141,7 @@ class DynamicShapePartitioner {
Status BuildPartitionFrame(); Status BuildPartitionFrame();
// Establish connection between corresponding partitioned of clusters // Establish connection between corresponding partitioned of clusters
Status CombinePartitionFrame(); Status CombinePartitionFrame();
// Convert the nodes in cluster into a complete ComputeGraoh
// Convert the nodes in cluster into a complete ComputeGraph
Status BuildPartitionSubgraph(); Status BuildPartitionSubgraph();
// Clear resource and break circular dependency // Clear resource and break circular dependency
void ClearResource(); void ClearResource();
@@ -155,6 +156,8 @@ class DynamicShapePartitioner {
Status CtrlEdgeTransfer(); Status CtrlEdgeTransfer();
ge::ComputeGraphPtr root_graph_; // The original graph to partition ge::ComputeGraphPtr root_graph_; // The original graph to partition
std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to
// V1 control flow cluster, need merge to one Graph.
std::unordered_map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_;
// topological sorted clusters, this field will change with the splitting. // topological sorted clusters, this field will change with the splitting.
// When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters
// When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters


+ 46
- 17
ge/graph/passes/mark_force_unknown_for_cond_pass.cc View File

@@ -18,20 +18,25 @@


#include <queue> #include <queue>


#include "graph/utils/node_utils.h"
#include "graph/common/omg_util.h" #include "graph/common/omg_util.h"


namespace ge { namespace ge {
namespace { namespace {
const std::set<std::string> kMergeOpTypes{ MERGE, REFMERGE };
inline bool IsMergeInLoop(const NodePtr &node) {
const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION };


const std::set<std::string> kSwitchOpTypes{ SWITCH, REFSWITCH };
std::string node_type;
(void)GetOriginalType(node, node_type);
return kLoopMergeInputs.count(node_type) > 0;
}


const std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION };
inline bool IsSwitchInLoop(const NodePtr &node) {
const static std::set<std::string> kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND };


inline bool IsMergeInLoop(const NodePtr &node) {
std::string node_type; std::string node_type;
(void)GetOriginalType(node, node_type); (void)GetOriginalType(node, node_type);
return kLoopMergeInputs.count(node_type) > 0;
return kLoopSwitchInputs.count(node_type) > 0;
} }
} }


@@ -103,7 +108,13 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std:
if (dst_span > 0) { if (dst_span > 0) {
search_queue.push({in_node, dst_span - 1}); search_queue.push({in_node, dst_span - 1});
} else { } else {
switch_group.emplace_back(in_node);
const auto &all_in_nodes = in_node->GetInDataNodes();
if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) {
GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(),
in_node->GetName().c_str());
} else {
switch_group.emplace_back(in_node);
}
} }
} else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node.
search_queue.push({in_node, dst_span + 1}); search_queue.push({in_node, dst_span + 1});
@@ -121,19 +132,37 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std:
/// ///
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) {
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) {
return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE);
return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP);
}; };


for (const auto &group : switch_groups) {
const auto &node = group.first;
const auto &switch_group = group.second;
const auto &op_desc = node->GetOpDesc();
if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE) ||
std::any_of(switch_group.begin(), switch_group.end(), callback)) {
GELOGI("Mark [%s] as force unknown shape", node->GetName().c_str());
MarkForceUnknownShape(node, true);
for (const auto &n : switch_group) {
MarkForceUnknownShape(n, true);
for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) {
const auto &op_node1 = it1->first;
const auto &op_desc1 = op_node1->GetOpDesc();
if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue;
}

if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) {
int64_t group_index = op_desc1->GetId();
GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index);
MarkForceUnknownShape(op_node1, true, group_index);
for (const auto &n : it1->second) {
MarkForceUnknownShape(n, true, group_index);
}

for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) {
const auto &op_node2 = it2->first;
const auto &op_desc2 = op_node2->GetOpDesc();
if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue;
}

if (std::any_of(it2->second.begin(), it2->second.end(), callback)) {
MarkForceUnknownShape(op_node2, true, group_index);
for (const auto &n : it2->second) {
MarkForceUnknownShape(n, true, group_index);
}
}
} }
} }
} }


+ 4
- 3
ge/graph/passes/merge_to_stream_merge_pass.cc View File

@@ -84,8 +84,9 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons
GE_CHK_BOOL_EXEC(node != nullptr, GE_CHK_BOOL_EXEC(node != nullptr,
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid");
return FAILED, "Param of pre node is null."); return FAILED, "Param of pre node is null.");
bool force_unknown = node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE);
MarkForceUnknownShape(node, force_unknown);
int64_t group_index = -1;
bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
MarkForceUnknownShape(node, force_unknown, group_index);
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) {
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
@@ -102,7 +103,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons
GELOGE(FAILED, "SetActiveLabelList for node %s failed.", active_node->GetName().c_str()); GELOGE(FAILED, "SetActiveLabelList for node %s failed.", active_node->GetName().c_str());
return FAILED; return FAILED;
} }
MarkForceUnknownShape(active_node, force_unknown);
MarkForceUnknownShape(active_node, force_unknown, group_index);
} }


return SUCCESS; return SUCCESS;


+ 14
- 11
ge/graph/passes/next_iteration_pass.cc View File

@@ -18,6 +18,7 @@


#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "graph/common/omg_util.h" #include "graph/common/omg_util.h"
#include "graph/utils/node_utils.h"


using std::string; using std::string;


@@ -203,6 +204,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
for (const auto &loop_cond_iter : loop_group_map_) { for (const auto &loop_cond_iter : loop_group_map_) {
const LoopCondGroup &loop_group = *loop_cond_iter.second; const LoopCondGroup &loop_group = *loop_cond_iter.second;
const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName();
const int64_t group_index = loop_group.loop_cond->GetOpDesc()->GetId();
GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str());


// Create Active node, Enter->Active->Merge, NextIteration->Active->Merge // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge
@@ -223,7 +225,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
enter_active->GetName().c_str()); enter_active->GetName().c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape);
MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape, group_index);
} }


for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { for (const auto &pair : loop_cond_iter.second->merge_next_pairs) {
@@ -253,8 +255,8 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }


MarkForceUnknownShape(next_node, loop_group.is_unknown_shape);
MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape);
MarkForceUnknownShape(next_node, loop_group.is_unknown_shape, group_index);
MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape, group_index);
} }


if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) ||
@@ -263,10 +265,10 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }


MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape);
MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape);
MarkForceUnknownShape(next_active, loop_group.is_unknown_shape);
HandleSwitchExitNodes(loop_group);
MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape, group_index);
MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape, group_index);
MarkForceUnknownShape(next_active, loop_group.is_unknown_shape, group_index);
HandleSwitchExitNodes(loop_group, group_index);
} }


return SUCCESS; return SUCCESS;
@@ -275,20 +277,21 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
/// ///
/// @brief Mark force unknown for Exit node /// @brief Mark force unknown for Exit node
/// @param [in] group of LoopCond /// @param [in] group of LoopCond
/// @param [in] index of LoopCond Node
/// @return void /// @return void
/// ///
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group) {
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) {
if (!loop_group.is_unknown_shape) { if (!loop_group.is_unknown_shape) {
return; return;
} }


for (const auto &switch_node : loop_group.switch_nodes) { for (const auto &switch_node : loop_group.switch_nodes) {
MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape);
MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape, group_index);
for (const auto &node : switch_node->GetOutDataNodes()) { for (const auto &node : switch_node->GetOutDataNodes()) {
std::string node_type; std::string node_type;
(void)GetOriginalType(node, node_type); (void)GetOriginalType(node, node_type);
if (node_type == EXIT || node_type == REFEXIT) {
MarkForceUnknownShape(node, loop_group.is_unknown_shape);
if (kExitOpTypes.count(node_type) > 0) {
MarkForceUnknownShape(node, loop_group.is_unknown_shape, group_index);
} }
} }
} }


+ 2
- 1
ge/graph/passes/next_iteration_pass.h View File

@@ -96,9 +96,10 @@ class NextIterationPass : public GraphPass {
/// ///
/// @brief Mark force unknown for Exit node /// @brief Mark force unknown for Exit node
/// @param [in] group of LoopCond /// @param [in] group of LoopCond
/// @param [in] index of LoopCond Node
/// @return void /// @return void
/// ///
void HandleSwitchExitNodes(const LoopCondGroup &loop_group);
void HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index);


// map<frame_name, LoopCondGroup> // map<frame_name, LoopCondGroup>
std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_; std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_;


+ 8
- 5
ge/graph/passes/switch_to_stream_switch_pass.cc View File

@@ -369,7 +369,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr &
GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)),
"StreamSwitch node add cond edge failed."); "StreamSwitch node add cond edge failed.");


MarkForceUnknownShape(stream_switch, switch_node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE));
int64_t group_index = -1;
bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
MarkForceUnknownShape(stream_switch, force_unknown, group_index);
return stream_switch; return stream_switch;
} }


@@ -488,11 +490,12 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
return FAILED; return FAILED;
} }


std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) {
return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE);
int64_t group_index = -1;
std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) {
return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
}; };
bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback);
MarkForceUnknownShape(active_node, is_unknown_shape);
MarkForceUnknownShape(active_node, is_unknown_shape, group_index);


const std::string &cond_group = cond_node->GetName(); const std::string &cond_group = cond_node->GetName();
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) {
@@ -522,7 +525,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)),
"Cast add data edge failed."); "Cast add data edge failed.");


MarkForceUnknownShape(stream_switch, is_unknown_shape);
MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index);
for (const NodePtr &node : switch_list) { for (const NodePtr &node : switch_list) {
GE_IF_BOOL_EXEC(node != stream_switch, { GE_IF_BOOL_EXEC(node != stream_switch, {
GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)),


+ 13
- 0
ge/graph/preprocess/graph_preprocess.cc View File

@@ -74,6 +74,7 @@
#include "graph/passes/unused_const_pass.h" #include "graph/passes/unused_const_pass.h"
#include "graph/passes/var_is_initialized_op_pass.h" #include "graph/passes/var_is_initialized_op_pass.h"
#include "graph/passes/variable_prepare_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h"
#include "graph/passes/mark_force_unknown_for_cond_pass.h"
#include "graph/preprocess/insert_op/util_insert_aipp_op.h" #include "graph/preprocess/insert_op/util_insert_aipp_op.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"
#include "inc/pass_manager.h" #include "inc/pass_manager.h"
@@ -1675,6 +1676,7 @@ Status GraphPrepare::PrepareDynShape(const GraphNodePtr &graph_node, const std::
PP_RUN_AND_DUMP("InsertAipp", TryDoAipp); PP_RUN_AND_DUMP("InsertAipp", TryDoAipp);
PP_RUN_AND_DUMP("ProcessBeforeInfershape", ProcessBeforeInfershape); PP_RUN_AND_DUMP("ProcessBeforeInfershape", ProcessBeforeInfershape);
PP_RUN_AND_DUMP("InferFormatAndShape", FormatAndShapeProcess); PP_RUN_AND_DUMP("InferFormatAndShape", FormatAndShapeProcess);
PP_RUN_AND_DUMP("CtrlFlowPreProcess", CtrlFlowPreProcess);
PP_RUN_AND_DUMP("GetDynamicOutputShape", multibatch::GetDynamicOutputShape, compute_graph_); PP_RUN_AND_DUMP("GetDynamicOutputShape", multibatch::GetDynamicOutputShape, compute_graph_);
PP_RUN_AND_DUMP("ProcessAippStage2", InsertNewOpUtil::Instance().UpdateDataNodeByAipp, compute_graph_); PP_RUN_AND_DUMP("ProcessAippStage2", InsertNewOpUtil::Instance().UpdateDataNodeByAipp, compute_graph_);
PP_RUN("SaveOriginalGraphToOmModel", SaveOriginalGraphToOmModel); PP_RUN("SaveOriginalGraphToOmModel", SaveOriginalGraphToOmModel);
@@ -1683,6 +1685,17 @@ Status GraphPrepare::PrepareDynShape(const GraphNodePtr &graph_node, const std::
return SUCCESS; return SUCCESS;
} }


Status GraphPrepare::CtrlFlowPreProcess() {
PassManager graph_pass;

// After InferShape Mark v1 control flow for unknown shape.
auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass;
GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::MarkForceUnknownForCondPass", mark_force_unknown_pass));

GE_CHK_STATUS_RET(graph_pass.Run(compute_graph_));
return SUCCESS;
}

Status GraphPrepare::RecordAIPPInfo(ge::ComputeGraphPtr &compute_graph) { Status GraphPrepare::RecordAIPPInfo(ge::ComputeGraphPtr &compute_graph) {
PP_RUN("RecordAIPPInfo", InsertNewOpUtil::Instance().RecordAIPPInfoToData, compute_graph_); PP_RUN("RecordAIPPInfo", InsertNewOpUtil::Instance().RecordAIPPInfoToData, compute_graph_);
return SUCCESS; return SUCCESS;


+ 1
- 0
ge/graph/preprocess/graph_preprocess.h View File

@@ -79,6 +79,7 @@ class GraphPrepare {
Status ProcessNetOutput(); Status ProcessNetOutput();
Status ProcessBeforeInfershape(); Status ProcessBeforeInfershape();
Status UpdateInputOutputByOptions(); Status UpdateInputOutputByOptions();
Status CtrlFlowPreProcess();


bool IsTansDataOpData(const ge::NodePtr &var_node); bool IsTansDataOpData(const ge::NodePtr &var_node);




+ 37
- 4
ge/hybrid/executor/node_state.cc View File

@@ -104,11 +104,47 @@ void ShapeInferenceState::UpdateInputShapeFuture(int idx, ShapeFuture &&future)
} }
} }


Status ShapeInferenceState::UpdateInputForMerge(const GraphExecutionContext &context) {
int merge_index = -1;
const auto &guard = node_item.MutexGuard("UpdateInputForMerge");
if (!AttrUtils::GetInt(node_item.op_desc, ATTR_NAME_MERGE_INPUT_INDEX, merge_index)) {
GELOGE(FAILED, "[%s] Get attr %s failed", node_item.NodeName().c_str(), ATTR_NAME_MERGE_INPUT_INDEX.c_str());
return FAILED;
}

if (merge_index < 0 || static_cast<size_t>(merge_index) >= input_tensor_desc.size()) {
GELOGE(FAILED, "[%s] merge index: %d invalid, should in range[0, %zu)",
node_item.NodeName().c_str(), merge_index, input_tensor_desc.size());
return FAILED;
}

auto dst_tensor_desc = node_item.MutableInputDesc(merge_index);
GE_CHECK_NOTNULL(dst_tensor_desc);

int64_t tensor_size = -1;
auto &tensor_desc = input_tensor_desc[merge_index];
(void)TensorUtils::GetSize(tensor_desc, tensor_size);

dst_tensor_desc->SetShape(tensor_desc.MutableShape());
dst_tensor_desc->SetOriginShape(tensor_desc.GetOriginShape());
(void)TensorUtils::SetSize(*dst_tensor_desc, tensor_size);
(void)guard;
GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], tensor size = %ld",
node_item.NodeName().c_str(), merge_index, dst_tensor_desc->GetShape().ToString().c_str(),
dst_tensor_desc->GetOriginShape().ToString().c_str(), tensor_size);

return SUCCESS;
}

Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &context) { Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &context) {
if (!node_item.is_dynamic) { if (!node_item.is_dynamic) {
return SUCCESS; return SUCCESS;
} }
std::unique_lock<std::mutex> lk(mu_); std::unique_lock<std::mutex> lk(mu_);
if (node_item.IsMergeOp()) {
return UpdateInputForMerge(context);
}

if (num_pending_shapes_ > 0) { if (num_pending_shapes_ > 0) {
GELOGD("[%s] Await pending shape or shape future start.", node_item.NodeName().c_str()); GELOGD("[%s] Await pending shape or shape future start.", node_item.NodeName().c_str());
int try_count = 0; int try_count = 0;
@@ -169,7 +205,7 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex


int64_t tensor_size = -1; int64_t tensor_size = -1;
(void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); (void) TensorUtils::GetSize(*src_tensor_desc, tensor_size);
GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], index = %zu",
GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], tensor size = %ld",
node_item.NodeName().c_str(), node_item.NodeName().c_str(),
idx, idx,
src_tensor_desc->GetShape().ToString().c_str(), src_tensor_desc->GetShape().ToString().c_str(),
@@ -283,11 +319,8 @@ void NodeState::ResetContext(int group) {
} }


switch_index_ = -1; switch_index_ = -1;
const auto &guard = node_item_->MutexGuard("ResetContext");
shape_inference_state_.InitShapeState();
subgraph_context_->ResetContext(node_item_->node); subgraph_context_->ResetContext(node_item_->node);
GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_); GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_);
(void)guard;
} }


void NodeState::ResetSchedule() { void NodeState::ResetSchedule() {


+ 2
- 0
ge/hybrid/executor/node_state.h View File

@@ -67,6 +67,8 @@ struct ShapeInferenceState {
const NodeItem &node_item; const NodeItem &node_item;


private: private:
Status UpdateInputForMerge(const GraphExecutionContext &context);

friend struct NodeState; friend struct NodeState;
std::vector<std::pair<int, ShapeFuture>> shape_futures; std::vector<std::pair<int, ShapeFuture>> shape_futures;
// do not directly update op_desc, in case race condition across pipelines // do not directly update op_desc, in case race condition across pipelines


Loading…
Cancel
Save