Browse Source

Support Enter feed Loop inside.

pull/1691/head^2
zhangxiaokun 4 years ago
parent
commit
52ba516f55
6 changed files with 19 additions and 0 deletions
  1. +3
    -0
      ge/graph/partition/dynamic_shape_partition.cc
  2. +4
    -0
      ge/graph/partition/dynamic_shape_partition.h
  3. +1
    -0
      ge/hybrid/executor/subgraph_executor.cc
  4. +5
    -0
      ge/hybrid/model/node_item.cc
  5. +1
    -0
      ge/hybrid/model/node_item.h
  6. +5
    -0
      ge/hybrid/node_executor/task_context.cc

+ 3
- 0
ge/graph/partition/dynamic_shape_partition.cc View File

@@ -387,6 +387,9 @@ void DynamicShapePartitioner::MergeClustersUnknownShape() {
if (!in_cluster->IsUnknownShape()) { if (!in_cluster->IsUnknownShape()) {
continue; continue;
} }
if (!cluster->IsAdjoinNodes(in_cluster)) {
continue;
}
auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); auto merged_clusters = cluster->MergeAllPathFrom(in_cluster);
GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(),
ToString(merged_clusters).c_str()); ToString(merged_clusters).c_str());


+ 4
- 0
ge/graph/partition/dynamic_shape_partition.h View File

@@ -80,6 +80,10 @@ class DynamicShapePartitioner {
Status BuildPartitionSubgraph(); Status BuildPartitionSubgraph();
// Clear resource and break circular dependency // Clear resource and break circular dependency
void Clear(); void Clear();
bool IsAdjoinNodes(const std::shared_ptr<Cluster> &other) const {
const auto &out_clusters = other->out_clusters_;
return std::find(out_clusters.begin(), out_clusters.end(), shared_from_this()) != out_clusters.end();
}


private: private:
static thread_local size_t unique_id_; static thread_local size_t unique_id_;


+ 1
- 0
ge/hybrid/executor/subgraph_executor.cc View File

@@ -537,6 +537,7 @@ Status SubgraphExecutor::LaunchTasks() {


Status SubgraphExecutor::ScheduleTasks(int group) { Status SubgraphExecutor::ScheduleTasks(int group) {
GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str());
subgraph_context_->SetGroup(group);
auto prepare_future = std::async(std::launch::async, [&]() -> Status { auto prepare_future = std::async(std::launch::async, [&]() -> Status {
GetContext().SetSessionId(context_->session_id); GetContext().SetSessionId(context_->session_id);
GetContext().SetContextId(context_->context_id); GetContext().SetContextId(context_->context_id);


+ 5
- 0
ge/hybrid/model/node_item.cc View File

@@ -401,6 +401,11 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) {
if (is_root_node_) { if (is_root_node_) {
node_item->root_data_.emplace(this); node_item->root_data_.emplace(this);
} }
// If Enter feed Not Merge, take as root Node.
if ((kEnterOpTypes.count(node_type) > 0) && (node_item->node_type != STREAMMERGE)) {
node_item->root_data_.emplace(this);
node_item->enter_inside_.emplace(anchor_index);
}
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());
} }




+ 1
- 0
ge/hybrid/model/node_item.h View File

@@ -142,6 +142,7 @@ struct NodeItem {
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to
std::set<int> enter_inside_; // Enter feed loop inside Node, Not cross Merge.


std::shared_ptr<NodeTask> kernel_task; std::shared_ptr<NodeTask> kernel_task;
std::unique_ptr<FusedSubgraph> fused_subgraph; std::unique_ptr<FusedSubgraph> fused_subgraph;


+ 5
- 0
ge/hybrid/node_executor/task_context.cc View File

@@ -489,6 +489,11 @@ void TaskContext::ReleaseInputsAndOutputs() {
} }


void TaskContext::ReleaseInput(int index) { void TaskContext::ReleaseInput(int index) {
if (node_item_->enter_inside_.count(index) > 0) {
GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index);
return;
}

auto input_tensor = MutableInput(index); auto input_tensor = MutableInput(index);
if (input_tensor != nullptr) { if (input_tensor != nullptr) {
input_tensor->Destroy(); input_tensor->Destroy();


Loading…
Cancel
Save