Browse Source

Fix dynamic shape partition

pull/1782/head
zhangxiaokun 4 years ago
parent
commit
1571c9b731
16 changed files with 123 additions and 69 deletions
  1. +17
    -1
      ge/graph/partition/dynamic_shape_partition.cc
  2. +1
    -1
      ge/graph/partition/dynamic_shape_partition.h
  3. +16
    -18
      ge/graph/passes/mark_force_unknown_for_cond_pass.cc
  4. +6
    -0
      ge/graph/passes/mark_graph_unknown_status_pass.cc
  5. +9
    -1
      ge/graph/passes/next_iteration_pass.cc
  6. +49
    -8
      ge/hybrid/executor/node_state.cc
  7. +4
    -0
      ge/hybrid/executor/node_state.h
  8. +1
    -1
      ge/hybrid/executor/subgraph_context.cc
  9. +2
    -2
      ge/hybrid/executor/subgraph_context.h
  10. +3
    -15
      ge/hybrid/executor/subgraph_executor.cc
  11. +0
    -1
      ge/hybrid/executor/subgraph_executor.h
  12. +2
    -3
      ge/hybrid/model/node_item.cc
  13. +2
    -2
      ge/hybrid/model/node_item.h
  14. +3
    -3
      ge/hybrid/node_executor/aicore/aicore_op_task.cc
  15. +7
    -10
      ge/hybrid/node_executor/task_context.cc
  16. +1
    -3
      ge/hybrid/node_executor/task_context.h

+ 17
- 1
ge/graph/partition/dynamic_shape_partition.cc View File

@@ -364,6 +364,7 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) {
}

void DynamicShapePartitioner::MergeClustersControlFlow() {
std::unordered_set<ClusterPtr> all_merged_clusters;
for (const auto &item : control_clusters_) {
const auto &control_cluster = item.second;
auto rit = control_cluster.rbegin();
@@ -373,17 +374,27 @@ void DynamicShapePartitioner::MergeClustersControlFlow() {
}

const auto &cluster = *rit;
if (all_merged_clusters.count(cluster) > 0) {
continue;
}

bool is_unknown_cluster = cluster->IsUnknownShape();
for (++rit; rit != control_cluster.rend(); ++rit) {
const auto &cluster_from = *rit;
auto merged_clusters = cluster->MergeAllPathFrom(cluster_from);
GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(),
ToString(merged_clusters).c_str());
for (const auto &merged_cluster : merged_clusters) {
all_merged_clusters.emplace(merged_cluster);
for (const auto &node : merged_cluster->Nodes()) {
node_2_cluster_[node] = cluster;
}
}
}

if (!is_unknown_cluster && cluster->IsUnknownShape()) {
ordered_cluster_.push_back(cluster);
}
}
}

@@ -703,7 +714,12 @@ void Cluster::Merge(ClusterPtr other) {
if (other->min_ < min_) {
min_ = other->min_;
}
};

if (!IsUnknownShape() && other->IsUnknownShape()) {
type_ = UNKNOWN_SHAPE;
}
}

bool Cluster::TryMerge(ClusterPtr other) {
std::queue<ClusterPtr> forward_reached;
forward_reached.push(other);


+ 1
- 1
ge/graph/partition/dynamic_shape_partition.h View File

@@ -161,7 +161,7 @@ class DynamicShapePartitioner {
ge::ComputeGraphPtr root_graph_; // The original graph to partition
std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to
// V1 control flow cluster, need merge to one Graph.
std::unordered_map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_;
std::map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_;
// topological sorted clusters, this field will change with the splitting.
// When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters
// When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters


+ 16
- 18
ge/graph/passes/mark_force_unknown_for_cond_pass.cc View File

@@ -143,26 +143,24 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, s
continue;
}

if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) {
int64_t group_index = op_desc1->GetId();
GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index);
MarkForceUnknownShape(op_node1, true, group_index);
for (const auto &n : it1->second) {
MarkForceUnknownShape(n, true, group_index);
}
int64_t group_index = op_desc1->GetId();
GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index);
SetControlFlowGroup(op_node1, group_index);
for (const auto &n : it1->second) {
SetControlFlowGroup(n, group_index);
}

for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) {
const auto &op_node2 = it2->first;
const auto &op_desc2 = op_node2->GetOpDesc();
if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue;
}
for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) {
const auto &op_node2 = it2->first;
const auto &op_desc2 = op_node2->GetOpDesc();
if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue;
}

if (std::any_of(it2->second.begin(), it2->second.end(), callback)) {
MarkForceUnknownShape(op_node2, true, group_index);
for (const auto &n : it2->second) {
MarkForceUnknownShape(n, true, group_index);
}
if (std::any_of(it2->second.begin(), it2->second.end(), callback)) {
SetControlFlowGroup(op_node2, group_index);
for (const auto &n : it2->second) {
SetControlFlowGroup(n, group_index);
}
}
}


+ 6
- 0
ge/graph/passes/mark_graph_unknown_status_pass.cc View File

@@ -40,6 +40,12 @@ Status MarkGraphUnknownStatusPass::Run(ComputeGraphPtr graph) {
}
}

const auto &node = graph->GetParentNode();
if (!is_unknown_shape && node != nullptr && node->GetType() == PARTITIONEDCALL) {
GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape),
"[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str());
}

for (const auto &node : graph->GetDirectNode()) {
GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str());
(void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape);


+ 9
- 1
ge/graph/passes/next_iteration_pass.cc View File

@@ -284,13 +284,21 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
/// @return void
///
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) {
std::string node_type;
for (const auto &switch_node : loop_group.switch_nodes) {
SetControlFlowGroup(switch_node, group_index);
for (const auto &node : switch_node->GetOutDataNodes()) {
std::string node_type;
(void)GetOriginalType(node, node_type);
if (kExitOpTypes.count(node_type) > 0) {
SetControlFlowGroup(node, group_index);
} else {
// For: Switch -> Cast -> Exit
for (const auto &n : node->GetOutDataNodes()) {
(void)GetOriginalType(n, node_type);
if (kExitOpTypes.count(node_type) > 0) {
SetControlFlowGroup(n, group_index);
}
}
}
}
}


+ 49
- 8
ge/hybrid/executor/node_state.cc View File

@@ -19,8 +19,9 @@
#include "framework/common/debug/log.h"
#include "graph/compute_graph.h"
#include "graph/utils/tensor_utils.h"
#include "hybrid_execution_context.h"
#include "subgraph_context.h"
#include "hybrid/executor/hybrid_execution_context.h"
#include "hybrid/executor/subgraph_context.h"
#include "hybrid/node_executor/task_context.h"

#define INC_ITERATION_COUNT(iteration) \
do { \
@@ -258,6 +259,8 @@ ShapeFuture::ShapeFuture(NodeState *src_node,
NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context)
: node_item_(&node_item), shape_inference_state_(node_item), subgraph_context_(subgraph_context) {
this->op_desc_ = node_item.node->GetOpDesc();
auto unique_task_context = TaskContext::Create(this, subgraph_context_);
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release());
}

Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const {
@@ -314,15 +317,53 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() {
return task_context_;
}

void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) {
if (node_item_->root_data_.count(input_idx) > 0) {
GELOGD("[%s] Save Const input tensor: %d", GetName().c_str(), input_idx);
root_tensor_value_[input_idx] = tensor;
}

if (node_item_->enter_data_.count(input_idx) > 0) {
GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx);
root_tensor_value_[input_idx] = tensor;
}
}

void NodeState::UpdateRootTensor(int input_idx) {
const auto it = root_tensor_value_.find(input_idx);
if (it == root_tensor_value_.end()) {
GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx);
return;
}

auto tensor = task_context_->MutableInput(input_idx);
if (tensor == nullptr) {
GELOGW("[%s] Not found input tensor: %d", GetName().c_str(), input_idx);
return;
}

*tensor = it->second;
GELOGW("[%s] Update input tensor: %d", GetName().c_str(), input_idx);
}

void NodeState::ResetContext(uint64_t iteration) {
switch_index_ = -1;
subgraph_context_->ResetContext(node_item_->node);
if (iteration == 0) {
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
} else {
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size() + node_item_->enter_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size());
auto unique_task_context = TaskContext::Create(this, subgraph_context_);
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release());

data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
for (auto item : node_item_->root_data_) {
UpdateRootTensor(item.first);
}

if (iteration > 0) {
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size());
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size());
for (auto item : node_item_->enter_data_) {
UpdateRootTensor(item.first);
}
}

iteration_count_ = iteration;


+ 4
- 0
ge/hybrid/executor/node_state.h View File

@@ -129,6 +129,8 @@ struct NodeState {
void RunStreamActive();
void RunNextIteration();

void SaveRootTensor(int input_idx, const TensorValue &tensor);

Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const;

void SetScheduleFuture(std::future<Status> &&future);
@@ -187,6 +189,7 @@ struct NodeState {
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready);
void ResetContext(uint64_t iteration);
void ScheduleContext(const NodeState &node_state);
void UpdateRootTensor(int input_idx);

const NodeItem *node_item_ = nullptr;
std::shared_ptr<NodeTask> kernel_task_ = nullptr;
@@ -199,6 +202,7 @@ struct NodeState {

std::future<Status> schedule_future_;
std::shared_ptr<FrameState> frame_state_;
std::map<int, TensorValue> root_tensor_value_;
uint64_t active_count_ = 0;
uint64_t iteration_count_ = 0;
uint32_t ctrl_scheduled_ = 0;


+ 1
- 1
ge/hybrid/executor/subgraph_context.cc View File

@@ -19,7 +19,7 @@

namespace ge {
namespace hybrid {
SubgraphContext::SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context)
SubgraphContext::SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context)
: graph_item_(graph_item), execution_context_(execution_context) {
}



+ 2
- 2
ge/hybrid/executor/subgraph_context.h View File

@@ -30,7 +30,7 @@ namespace ge {
namespace hybrid {
class SubgraphContext {
public:
explicit SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context);
explicit SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context);
~SubgraphContext();

Status Init();
@@ -54,7 +54,7 @@ class SubgraphContext {
FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock
friend class TaskContext;
const GraphItem *graph_item_;
const GraphExecutionContext *execution_context_;
GraphExecutionContext *execution_context_;
mmRWLock_t rw_lock_;
std::vector<TensorValue> all_inputs_;
std::vector<TensorValue> all_outputs_;


+ 3
- 15
ge/hybrid/executor/subgraph_executor.cc View File

@@ -175,16 +175,12 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue
GE_CHECK_NOTNULL(node_state);
node_state->SetKernelTask(node_item->kernel_task);

known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(known_shape_task_context_);
node_state->SetTaskContext(known_shape_task_context_);

std::function<void()> callback;
GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback));
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_, callback),
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, node_state->GetTaskContext(), *context_, callback),
"[%s] Failed to execute node [%s] for known subgraph.",
graph_item_->GetName().c_str(),
known_shape_task_context_->GetNodeName());
node_state->GetName().c_str());

GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str());
return SUCCESS;
@@ -271,16 +267,12 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) {
} else {
node_state->SetKernelTask(node_item.kernel_task);
}
auto unique_task_context = TaskContext::Create(node_state.get(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(unique_task_context);
const auto &task = node_state->GetKernelTask();
if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str());
REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str());
return INTERNAL_ERROR;
}
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);
GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state));
return AfterPrepared(p_node_state);
}
@@ -480,19 +472,15 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta
} else {
node_state.SetKernelTask(node_item.kernel_task);
}
auto unique_task_context = TaskContext::Create(&node_state, context_, subgraph_context_.get());
GE_CHECK_NOTNULL(unique_task_context);
const auto &task = node_state.GetKernelTask();
if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str());
REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str());
return INTERNAL_ERROR;
}
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state.SetTaskContext(shared_task_context);
GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context));
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start");
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*node_state.GetTaskContext())); // update op_desc before alloc ws
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end");
return SUCCESS;
}


+ 0
- 1
ge/hybrid/executor/subgraph_executor.h View File

@@ -125,7 +125,6 @@ class SubgraphExecutor {
ThreadPool pre_run_pool_;
BlockingQueue<NodeState *> ready_queue_;
std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_;
std::shared_ptr<TaskContext> known_shape_task_context_;

std::mutex mu_; // Guard for prepare_queues_.
std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_;


+ 2
- 3
ge/hybrid/model/node_item.cc View File

@@ -398,12 +398,11 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) {
data_send_.emplace(node_item);
node_item->data_recv_[this] = anchor_index;
if (is_root_node_) {
node_item->root_data_.emplace(this);
node_item->root_data_[anchor_index] = this;
}
// If Enter feed Not Merge, take as root Node.
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) {
node_item->enter_data_.emplace(this);
node_item->enter_inside_.emplace(anchor_index);
node_item->enter_data_[anchor_index] = this;
}
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());
}


+ 2
- 2
ge/hybrid/model/node_item.h View File

@@ -148,9 +148,9 @@ struct NodeItem {
int64_t frame_index_ = -1;
int64_t parent_frame_ = -1;
std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node
std::set<const NodeItem *> root_data_; // Recv data from root node
std::map<int, const NodeItem *> root_data_; // Recv data from root node
std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node
std::set<const NodeItem *> enter_data_; // Recv data from Enter node
std::map<int, const NodeItem *> enter_data_; // Recv data from Enter node
std::set<const NodeItem *> data_send_; // Send data notify to
std::map<const NodeItem *, int> data_recv_; // Recv data notify from
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to


+ 3
- 3
ge/hybrid/node_executor/aicore/aicore_op_task.cc View File

@@ -306,7 +306,7 @@ Status AiCoreOpTask::InitWithKernelDefWithHandle(const OpDesc &op_desc, const do
}

Status AiCoreOpTask::InitWithTaskDef(const OpDesc &op_desc, const domi::TaskDef &task_def) {
auto rt_ret = ValidateTaskDef(task_def);
if (rt_ret != SUCCESS) {
REPORT_CALL_ERROR("E19999", "op:%s(op_type:%s) failed to validate task def:%s",
@@ -315,7 +315,7 @@ Status AiCoreOpTask::InitWithTaskDef(const OpDesc &op_desc, const domi::TaskDef
op_desc.GetName().c_str(), op_desc.GetType().c_str(), task_def.DebugString().c_str());
return rt_ret;
}
if (task_def.type() != RT_MODEL_TASK_ALL_KERNEL) {
GE_CHK_STATUS_RET(InitWithKernelDef(op_desc, task_def));
} else {
@@ -474,7 +474,7 @@ Status AiCoreOpTask::UpdateArgs(TaskContext &task_context) {

if (task_context.IsTraceEnabled()) {
for (int i = 0; i < index; ++i) {
GELOGD("[%s] Arg[%d] = %lu", stub_name_.c_str(), i, arg_base_[i]);
GELOGD("[%s] Arg[%d] = %p", stub_name_.c_str(), i, arg_base_[i]);
}
}



+ 7
- 10
ge/hybrid/node_executor/task_context.cc View File

@@ -52,9 +52,7 @@ void TaskContext::ReleaseWorkspace() {
}
}

std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state,
GraphExecutionContext *execution_context,
SubgraphContext *subgraph_context) {
std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, SubgraphContext *subgraph_context) {
const NodeItem &node_item = *node_state->GetNodeItem();
GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.",
node_item.NodeName().c_str(),
@@ -75,7 +73,7 @@ std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state,
}

auto task_context = std::unique_ptr<TaskContext>(
new(std::nothrow)TaskContext(execution_context, node_state, subgraph_context));
new(std::nothrow)TaskContext(subgraph_context->execution_context_, node_state, subgraph_context));
if (task_context == nullptr) {
REPORT_CALL_ERROR("E19999", "Create TaskContext failed for [%s].", node_item.NodeName().c_str());
GELOGE(MEMALLOC_FAILED, "[Create][TaskContext] failed for [%s].", node_item.NodeName().c_str());
@@ -85,7 +83,7 @@ std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state,
task_context->node_item_ = &node_item;
task_context->inputs_start_ = subgraph_context->all_inputs_.data() + node_item.input_start;
task_context->outputs_start_ = subgraph_context->all_outputs_.data() + node_item.output_start;
task_context->iteration_ = execution_context->iteration;
task_context->iteration_ = subgraph_context->execution_context_->iteration;
return task_context;
}

@@ -460,6 +458,10 @@ Status TaskContext::PropagateOutputs() {
subgraph_context_->all_inputs_[input_offset].SetName(
node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx));
}

auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item);
GE_CHECK_NOTNULL(dst_node_state);
dst_node_state->SaveRootTensor(dst_input_idx, *tensor);
}
}
(void)guard;
@@ -489,11 +491,6 @@ void TaskContext::ReleaseInputsAndOutputs() {
}

void TaskContext::ReleaseInput(int index) {
if (node_item_->enter_inside_.count(index) > 0) {
GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index);
return;
}

auto input_tensor = MutableInput(index);
if (input_tensor != nullptr) {
input_tensor->Destroy();


+ 1
- 3
ge/hybrid/node_executor/task_context.h View File

@@ -36,9 +36,7 @@ class SubgraphContext;

class TaskContext {
public:
static std::unique_ptr<TaskContext> Create(NodeState *node_state,
GraphExecutionContext *execution_context,
SubgraphContext *subgraph_context);
static std::unique_ptr<TaskContext> Create(NodeState *node_state, SubgraphContext *subgraph_context);

~TaskContext();



Loading…
Cancel
Save