@@ -364,6 +364,7 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) { | |||||
} | } | ||||
void DynamicShapePartitioner::MergeClustersControlFlow() { | void DynamicShapePartitioner::MergeClustersControlFlow() { | ||||
std::unordered_set<ClusterPtr> all_merged_clusters; | |||||
for (const auto &item : control_clusters_) { | for (const auto &item : control_clusters_) { | ||||
const auto &control_cluster = item.second; | const auto &control_cluster = item.second; | ||||
auto rit = control_cluster.rbegin(); | auto rit = control_cluster.rbegin(); | ||||
@@ -373,17 +374,27 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { | |||||
} | } | ||||
const auto &cluster = *rit; | const auto &cluster = *rit; | ||||
if (all_merged_clusters.count(cluster) > 0) { | |||||
continue; | |||||
} | |||||
bool is_unknown_cluster = cluster->IsUnknownShape(); | |||||
for (++rit; rit != control_cluster.rend(); ++rit) { | for (++rit; rit != control_cluster.rend(); ++rit) { | ||||
const auto &cluster_from = *rit; | const auto &cluster_from = *rit; | ||||
auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); | auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); | ||||
GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), | GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), | ||||
ToString(merged_clusters).c_str()); | ToString(merged_clusters).c_str()); | ||||
for (const auto &merged_cluster : merged_clusters) { | for (const auto &merged_cluster : merged_clusters) { | ||||
all_merged_clusters.emplace(merged_cluster); | |||||
for (const auto &node : merged_cluster->Nodes()) { | for (const auto &node : merged_cluster->Nodes()) { | ||||
node_2_cluster_[node] = cluster; | node_2_cluster_[node] = cluster; | ||||
} | } | ||||
} | } | ||||
} | } | ||||
if (!is_unknown_cluster && cluster->IsUnknownShape()) { | |||||
ordered_cluster_.push_back(cluster); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -703,7 +714,12 @@ void Cluster::Merge(ClusterPtr other) { | |||||
if (other->min_ < min_) { | if (other->min_ < min_) { | ||||
min_ = other->min_; | min_ = other->min_; | ||||
} | } | ||||
}; | |||||
if (!IsUnknownShape() && other->IsUnknownShape()) { | |||||
type_ = UNKNOWN_SHAPE; | |||||
} | |||||
} | |||||
bool Cluster::TryMerge(ClusterPtr other) { | bool Cluster::TryMerge(ClusterPtr other) { | ||||
std::queue<ClusterPtr> forward_reached; | std::queue<ClusterPtr> forward_reached; | ||||
forward_reached.push(other); | forward_reached.push(other); | ||||
@@ -161,7 +161,7 @@ class DynamicShapePartitioner { | |||||
ge::ComputeGraphPtr root_graph_; // The original graph to partition | ge::ComputeGraphPtr root_graph_; // The original graph to partition | ||||
std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | ||||
// V1 control flow cluster, need merge to one Graph. | // V1 control flow cluster, need merge to one Graph. | ||||
std::unordered_map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_; | |||||
std::map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_; | |||||
// topological sorted clusters, this field will change with the splitting. | // topological sorted clusters, this field will change with the splitting. | ||||
// When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters | // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters | ||||
// When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters | // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters | ||||
@@ -143,26 +143,24 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, s | |||||
continue; | continue; | ||||
} | } | ||||
if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) { | |||||
int64_t group_index = op_desc1->GetId(); | |||||
GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); | |||||
MarkForceUnknownShape(op_node1, true, group_index); | |||||
for (const auto &n : it1->second) { | |||||
MarkForceUnknownShape(n, true, group_index); | |||||
} | |||||
int64_t group_index = op_desc1->GetId(); | |||||
GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); | |||||
SetControlFlowGroup(op_node1, group_index); | |||||
for (const auto &n : it1->second) { | |||||
SetControlFlowGroup(n, group_index); | |||||
} | |||||
for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { | |||||
const auto &op_node2 = it2->first; | |||||
const auto &op_desc2 = op_node2->GetOpDesc(); | |||||
if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||||
continue; | |||||
} | |||||
for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { | |||||
const auto &op_node2 = it2->first; | |||||
const auto &op_desc2 = op_node2->GetOpDesc(); | |||||
if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||||
continue; | |||||
} | |||||
if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { | |||||
MarkForceUnknownShape(op_node2, true, group_index); | |||||
for (const auto &n : it2->second) { | |||||
MarkForceUnknownShape(n, true, group_index); | |||||
} | |||||
if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { | |||||
SetControlFlowGroup(op_node2, group_index); | |||||
for (const auto &n : it2->second) { | |||||
SetControlFlowGroup(n, group_index); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -40,6 +40,12 @@ Status MarkGraphUnknownStatusPass::Run(ComputeGraphPtr graph) { | |||||
} | } | ||||
} | } | ||||
const auto &node = graph->GetParentNode(); | |||||
if (!is_unknown_shape && node != nullptr && node->GetType() == PARTITIONEDCALL) { | |||||
GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), | |||||
"[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str()); | |||||
} | |||||
for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str()); | GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str()); | ||||
(void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape); | (void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape); | ||||
@@ -284,13 +284,21 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
/// @return void | /// @return void | ||||
/// | /// | ||||
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { | void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { | ||||
std::string node_type; | |||||
for (const auto &switch_node : loop_group.switch_nodes) { | for (const auto &switch_node : loop_group.switch_nodes) { | ||||
SetControlFlowGroup(switch_node, group_index); | SetControlFlowGroup(switch_node, group_index); | ||||
for (const auto &node : switch_node->GetOutDataNodes()) { | for (const auto &node : switch_node->GetOutDataNodes()) { | ||||
std::string node_type; | |||||
(void)GetOriginalType(node, node_type); | (void)GetOriginalType(node, node_type); | ||||
if (kExitOpTypes.count(node_type) > 0) { | if (kExitOpTypes.count(node_type) > 0) { | ||||
SetControlFlowGroup(node, group_index); | SetControlFlowGroup(node, group_index); | ||||
} else { | |||||
// For: Switch -> Cast -> Exit | |||||
for (const auto &n : node->GetOutDataNodes()) { | |||||
(void)GetOriginalType(n, node_type); | |||||
if (kExitOpTypes.count(node_type) > 0) { | |||||
SetControlFlowGroup(n, group_index); | |||||
} | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -19,8 +19,9 @@ | |||||
#include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "hybrid_execution_context.h" | |||||
#include "subgraph_context.h" | |||||
#include "hybrid/executor/hybrid_execution_context.h" | |||||
#include "hybrid/executor/subgraph_context.h" | |||||
#include "hybrid/node_executor/task_context.h" | |||||
#define INC_ITERATION_COUNT(iteration) \ | #define INC_ITERATION_COUNT(iteration) \ | ||||
do { \ | do { \ | ||||
@@ -258,6 +259,8 @@ ShapeFuture::ShapeFuture(NodeState *src_node, | |||||
NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context) | NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context) | ||||
: node_item_(&node_item), shape_inference_state_(node_item), subgraph_context_(subgraph_context) { | : node_item_(&node_item), shape_inference_state_(node_item), subgraph_context_(subgraph_context) { | ||||
this->op_desc_ = node_item.node->GetOpDesc(); | this->op_desc_ = node_item.node->GetOpDesc(); | ||||
auto unique_task_context = TaskContext::Create(this, subgraph_context_); | |||||
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
} | } | ||||
Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { | Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { | ||||
@@ -314,15 +317,53 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||||
return task_context_; | return task_context_; | ||||
} | } | ||||
void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) { | |||||
if (node_item_->root_data_.count(input_idx) > 0) { | |||||
GELOGD("[%s] Save Const input tensor: %d", GetName().c_str(), input_idx); | |||||
root_tensor_value_[input_idx] = tensor; | |||||
} | |||||
if (node_item_->enter_data_.count(input_idx) > 0) { | |||||
GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); | |||||
root_tensor_value_[input_idx] = tensor; | |||||
} | |||||
} | |||||
void NodeState::UpdateRootTensor(int input_idx) { | |||||
const auto it = root_tensor_value_.find(input_idx); | |||||
if (it == root_tensor_value_.end()) { | |||||
GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx); | |||||
return; | |||||
} | |||||
auto tensor = task_context_->MutableInput(input_idx); | |||||
if (tensor == nullptr) { | |||||
GELOGW("[%s] Not found input tensor: %d", GetName().c_str(), input_idx); | |||||
return; | |||||
} | |||||
*tensor = it->second; | |||||
GELOGW("[%s] Update input tensor: %d", GetName().c_str(), input_idx); | |||||
} | |||||
void NodeState::ResetContext(uint64_t iteration) { | void NodeState::ResetContext(uint64_t iteration) { | ||||
switch_index_ = -1; | switch_index_ = -1; | ||||
subgraph_context_->ResetContext(node_item_->node); | subgraph_context_->ResetContext(node_item_->node); | ||||
if (iteration == 0) { | |||||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||||
} else { | |||||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size() + node_item_->enter_data_.size()); | |||||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size()); | |||||
auto unique_task_context = TaskContext::Create(this, subgraph_context_); | |||||
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||||
for (auto item : node_item_->root_data_) { | |||||
UpdateRootTensor(item.first); | |||||
} | |||||
if (iteration > 0) { | |||||
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | |||||
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | |||||
for (auto item : node_item_->enter_data_) { | |||||
UpdateRootTensor(item.first); | |||||
} | |||||
} | } | ||||
iteration_count_ = iteration; | iteration_count_ = iteration; | ||||
@@ -129,6 +129,8 @@ struct NodeState { | |||||
void RunStreamActive(); | void RunStreamActive(); | ||||
void RunNextIteration(); | void RunNextIteration(); | ||||
void SaveRootTensor(int input_idx, const TensorValue &tensor); | |||||
Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | ||||
void SetScheduleFuture(std::future<Status> &&future); | void SetScheduleFuture(std::future<Status> &&future); | ||||
@@ -187,6 +189,7 @@ struct NodeState { | |||||
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | ||||
void ResetContext(uint64_t iteration); | void ResetContext(uint64_t iteration); | ||||
void ScheduleContext(const NodeState &node_state); | void ScheduleContext(const NodeState &node_state); | ||||
void UpdateRootTensor(int input_idx); | |||||
const NodeItem *node_item_ = nullptr; | const NodeItem *node_item_ = nullptr; | ||||
std::shared_ptr<NodeTask> kernel_task_ = nullptr; | std::shared_ptr<NodeTask> kernel_task_ = nullptr; | ||||
@@ -199,6 +202,7 @@ struct NodeState { | |||||
std::future<Status> schedule_future_; | std::future<Status> schedule_future_; | ||||
std::shared_ptr<FrameState> frame_state_; | std::shared_ptr<FrameState> frame_state_; | ||||
std::map<int, TensorValue> root_tensor_value_; | |||||
uint64_t active_count_ = 0; | uint64_t active_count_ = 0; | ||||
uint64_t iteration_count_ = 0; | uint64_t iteration_count_ = 0; | ||||
uint32_t ctrl_scheduled_ = 0; | uint32_t ctrl_scheduled_ = 0; | ||||
@@ -19,7 +19,7 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
SubgraphContext::SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context) | |||||
SubgraphContext::SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context) | |||||
: graph_item_(graph_item), execution_context_(execution_context) { | : graph_item_(graph_item), execution_context_(execution_context) { | ||||
} | } | ||||
@@ -30,7 +30,7 @@ namespace ge { | |||||
namespace hybrid { | namespace hybrid { | ||||
class SubgraphContext { | class SubgraphContext { | ||||
public: | public: | ||||
explicit SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context); | |||||
explicit SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context); | |||||
~SubgraphContext(); | ~SubgraphContext(); | ||||
Status Init(); | Status Init(); | ||||
@@ -54,7 +54,7 @@ class SubgraphContext { | |||||
FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock | FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock | ||||
friend class TaskContext; | friend class TaskContext; | ||||
const GraphItem *graph_item_; | const GraphItem *graph_item_; | ||||
const GraphExecutionContext *execution_context_; | |||||
GraphExecutionContext *execution_context_; | |||||
mmRWLock_t rw_lock_; | mmRWLock_t rw_lock_; | ||||
std::vector<TensorValue> all_inputs_; | std::vector<TensorValue> all_inputs_; | ||||
std::vector<TensorValue> all_outputs_; | std::vector<TensorValue> all_outputs_; | ||||
@@ -175,16 +175,12 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue | |||||
GE_CHECK_NOTNULL(node_state); | GE_CHECK_NOTNULL(node_state); | ||||
node_state->SetKernelTask(node_item->kernel_task); | node_state->SetKernelTask(node_item->kernel_task); | ||||
known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | |||||
GE_CHECK_NOTNULL(known_shape_task_context_); | |||||
node_state->SetTaskContext(known_shape_task_context_); | |||||
std::function<void()> callback; | std::function<void()> callback; | ||||
GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); | GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); | ||||
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_, callback), | |||||
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, node_state->GetTaskContext(), *context_, callback), | |||||
"[%s] Failed to execute node [%s] for known subgraph.", | "[%s] Failed to execute node [%s] for known subgraph.", | ||||
graph_item_->GetName().c_str(), | graph_item_->GetName().c_str(), | ||||
known_shape_task_context_->GetNodeName()); | |||||
node_state->GetName().c_str()); | |||||
GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str()); | GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str()); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -271,16 +267,12 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) { | |||||
} else { | } else { | ||||
node_state->SetKernelTask(node_item.kernel_task); | node_state->SetKernelTask(node_item.kernel_task); | ||||
} | } | ||||
auto unique_task_context = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | |||||
GE_CHECK_NOTNULL(unique_task_context); | |||||
const auto &task = node_state->GetKernelTask(); | const auto &task = node_state->GetKernelTask(); | ||||
if (task == nullptr) { | if (task == nullptr) { | ||||
GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str()); | GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str()); | ||||
REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str()); | REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state->SetTaskContext(shared_task_context); | |||||
GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state)); | GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state)); | ||||
return AfterPrepared(p_node_state); | return AfterPrepared(p_node_state); | ||||
} | } | ||||
@@ -480,19 +472,15 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta | |||||
} else { | } else { | ||||
node_state.SetKernelTask(node_item.kernel_task); | node_state.SetKernelTask(node_item.kernel_task); | ||||
} | } | ||||
auto unique_task_context = TaskContext::Create(&node_state, context_, subgraph_context_.get()); | |||||
GE_CHECK_NOTNULL(unique_task_context); | |||||
const auto &task = node_state.GetKernelTask(); | const auto &task = node_state.GetKernelTask(); | ||||
if (task == nullptr) { | if (task == nullptr) { | ||||
GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str()); | GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str()); | ||||
REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str()); | REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
node_state.SetTaskContext(shared_task_context); | |||||
GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context)); | GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context)); | ||||
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start"); | RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start"); | ||||
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws | |||||
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*node_state.GetTaskContext())); // update op_desc before alloc ws | |||||
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end"); | RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end"); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -125,7 +125,6 @@ class SubgraphExecutor { | |||||
ThreadPool pre_run_pool_; | ThreadPool pre_run_pool_; | ||||
BlockingQueue<NodeState *> ready_queue_; | BlockingQueue<NodeState *> ready_queue_; | ||||
std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_; | std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_; | ||||
std::shared_ptr<TaskContext> known_shape_task_context_; | |||||
std::mutex mu_; // Guard for prepare_queues_. | std::mutex mu_; // Guard for prepare_queues_. | ||||
std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_; | std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_; | ||||
@@ -398,12 +398,11 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
data_send_.emplace(node_item); | data_send_.emplace(node_item); | ||||
node_item->data_recv_[this] = anchor_index; | node_item->data_recv_[this] = anchor_index; | ||||
if (is_root_node_) { | if (is_root_node_) { | ||||
node_item->root_data_.emplace(this); | |||||
node_item->root_data_[anchor_index] = this; | |||||
} | } | ||||
// If Enter feed Not Merge, take as root Node. | // If Enter feed Not Merge, take as root Node. | ||||
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | ||||
node_item->enter_data_.emplace(this); | |||||
node_item->enter_inside_.emplace(anchor_index); | |||||
node_item->enter_data_[anchor_index] = this; | |||||
} | } | ||||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
} | } | ||||
@@ -148,9 +148,9 @@ struct NodeItem { | |||||
int64_t frame_index_ = -1; | int64_t frame_index_ = -1; | ||||
int64_t parent_frame_ = -1; | int64_t parent_frame_ = -1; | ||||
std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | ||||
std::set<const NodeItem *> root_data_; // Recv data from root node | |||||
std::map<int, const NodeItem *> root_data_; // Recv data from root node | |||||
std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | ||||
std::set<const NodeItem *> enter_data_; // Recv data from Enter node | |||||
std::map<int, const NodeItem *> enter_data_; // Recv data from Enter node | |||||
std::set<const NodeItem *> data_send_; // Send data notify to | std::set<const NodeItem *> data_send_; // Send data notify to | ||||
std::map<const NodeItem *, int> data_recv_; // Recv data notify from | std::map<const NodeItem *, int> data_recv_; // Recv data notify from | ||||
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | ||||
@@ -306,7 +306,7 @@ Status AiCoreOpTask::InitWithKernelDefWithHandle(const OpDesc &op_desc, const do | |||||
} | } | ||||
Status AiCoreOpTask::InitWithTaskDef(const OpDesc &op_desc, const domi::TaskDef &task_def) { | Status AiCoreOpTask::InitWithTaskDef(const OpDesc &op_desc, const domi::TaskDef &task_def) { | ||||
auto rt_ret = ValidateTaskDef(task_def); | auto rt_ret = ValidateTaskDef(task_def); | ||||
if (rt_ret != SUCCESS) { | if (rt_ret != SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "op:%s(op_type:%s) failed to validate task def:%s", | REPORT_CALL_ERROR("E19999", "op:%s(op_type:%s) failed to validate task def:%s", | ||||
@@ -315,7 +315,7 @@ Status AiCoreOpTask::InitWithTaskDef(const OpDesc &op_desc, const domi::TaskDef | |||||
op_desc.GetName().c_str(), op_desc.GetType().c_str(), task_def.DebugString().c_str()); | op_desc.GetName().c_str(), op_desc.GetType().c_str(), task_def.DebugString().c_str()); | ||||
return rt_ret; | return rt_ret; | ||||
} | } | ||||
if (task_def.type() != RT_MODEL_TASK_ALL_KERNEL) { | if (task_def.type() != RT_MODEL_TASK_ALL_KERNEL) { | ||||
GE_CHK_STATUS_RET(InitWithKernelDef(op_desc, task_def)); | GE_CHK_STATUS_RET(InitWithKernelDef(op_desc, task_def)); | ||||
} else { | } else { | ||||
@@ -474,7 +474,7 @@ Status AiCoreOpTask::UpdateArgs(TaskContext &task_context) { | |||||
if (task_context.IsTraceEnabled()) { | if (task_context.IsTraceEnabled()) { | ||||
for (int i = 0; i < index; ++i) { | for (int i = 0; i < index; ++i) { | ||||
GELOGD("[%s] Arg[%d] = %lu", stub_name_.c_str(), i, arg_base_[i]); | |||||
GELOGD("[%s] Arg[%d] = %p", stub_name_.c_str(), i, arg_base_[i]); | |||||
} | } | ||||
} | } | ||||
@@ -52,9 +52,7 @@ void TaskContext::ReleaseWorkspace() { | |||||
} | } | ||||
} | } | ||||
std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | |||||
GraphExecutionContext *execution_context, | |||||
SubgraphContext *subgraph_context) { | |||||
std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, SubgraphContext *subgraph_context) { | |||||
const NodeItem &node_item = *node_state->GetNodeItem(); | const NodeItem &node_item = *node_state->GetNodeItem(); | ||||
GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.", | GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.", | ||||
node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
@@ -75,7 +73,7 @@ std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | |||||
} | } | ||||
auto task_context = std::unique_ptr<TaskContext>( | auto task_context = std::unique_ptr<TaskContext>( | ||||
new(std::nothrow)TaskContext(execution_context, node_state, subgraph_context)); | |||||
new(std::nothrow)TaskContext(subgraph_context->execution_context_, node_state, subgraph_context)); | |||||
if (task_context == nullptr) { | if (task_context == nullptr) { | ||||
REPORT_CALL_ERROR("E19999", "Create TaskContext failed for [%s].", node_item.NodeName().c_str()); | REPORT_CALL_ERROR("E19999", "Create TaskContext failed for [%s].", node_item.NodeName().c_str()); | ||||
GELOGE(MEMALLOC_FAILED, "[Create][TaskContext] failed for [%s].", node_item.NodeName().c_str()); | GELOGE(MEMALLOC_FAILED, "[Create][TaskContext] failed for [%s].", node_item.NodeName().c_str()); | ||||
@@ -85,7 +83,7 @@ std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | |||||
task_context->node_item_ = &node_item; | task_context->node_item_ = &node_item; | ||||
task_context->inputs_start_ = subgraph_context->all_inputs_.data() + node_item.input_start; | task_context->inputs_start_ = subgraph_context->all_inputs_.data() + node_item.input_start; | ||||
task_context->outputs_start_ = subgraph_context->all_outputs_.data() + node_item.output_start; | task_context->outputs_start_ = subgraph_context->all_outputs_.data() + node_item.output_start; | ||||
task_context->iteration_ = execution_context->iteration; | |||||
task_context->iteration_ = subgraph_context->execution_context_->iteration; | |||||
return task_context; | return task_context; | ||||
} | } | ||||
@@ -460,6 +458,10 @@ Status TaskContext::PropagateOutputs() { | |||||
subgraph_context_->all_inputs_[input_offset].SetName( | subgraph_context_->all_inputs_[input_offset].SetName( | ||||
node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); | node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); | ||||
} | } | ||||
auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); | |||||
GE_CHECK_NOTNULL(dst_node_state); | |||||
dst_node_state->SaveRootTensor(dst_input_idx, *tensor); | |||||
} | } | ||||
} | } | ||||
(void)guard; | (void)guard; | ||||
@@ -489,11 +491,6 @@ void TaskContext::ReleaseInputsAndOutputs() { | |||||
} | } | ||||
void TaskContext::ReleaseInput(int index) { | void TaskContext::ReleaseInput(int index) { | ||||
if (node_item_->enter_inside_.count(index) > 0) { | |||||
GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index); | |||||
return; | |||||
} | |||||
auto input_tensor = MutableInput(index); | auto input_tensor = MutableInput(index); | ||||
if (input_tensor != nullptr) { | if (input_tensor != nullptr) { | ||||
input_tensor->Destroy(); | input_tensor->Destroy(); | ||||
@@ -36,9 +36,7 @@ class SubgraphContext; | |||||
class TaskContext { | class TaskContext { | ||||
public: | public: | ||||
static std::unique_ptr<TaskContext> Create(NodeState *node_state, | |||||
GraphExecutionContext *execution_context, | |||||
SubgraphContext *subgraph_context); | |||||
static std::unique_ptr<TaskContext> Create(NodeState *node_state, SubgraphContext *subgraph_context); | |||||
~TaskContext(); | ~TaskContext(); | ||||