Browse Source

Add Init to NodeState

tags/v1.3.0
zhangxiaokun 4 years ago
parent
commit
ab65075326
4 changed files with 30 additions and 15 deletions
  1. +9
    -0
      ge/hybrid/executor/node_state.cc
  2. +2
    -8
      ge/hybrid/executor/node_state.h
  3. +18
    -7
      ge/hybrid/executor/subgraph_context.cc
  4. +1
    -0
      ge/hybrid/executor/subgraph_context.h

+ 9
- 0
ge/hybrid/executor/node_state.cc View File

@@ -259,8 +259,16 @@ ShapeFuture::ShapeFuture(NodeState *src_node,
NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context) NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context)
: node_item_(&node_item), shape_inference_state_(node_item), subgraph_context_(subgraph_context) { : node_item_(&node_item), shape_inference_state_(node_item), subgraph_context_(subgraph_context) {
this->op_desc_ = node_item.node->GetOpDesc(); this->op_desc_ = node_item.node->GetOpDesc();
}

Status NodeState::Init(int group, const shared_ptr<FrameState> &frame_state) {
GE_CHECK_NOTNULL(frame_state);
group_ = group;
frame_state_ = frame_state;
auto unique_task_context = TaskContext::Create(this, subgraph_context_); auto unique_task_context = TaskContext::Create(this, subgraph_context_);
GE_CHECK_NOTNULL(unique_task_context);
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release());
return SUCCESS;
} }


Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const {
@@ -350,6 +358,7 @@ void NodeState::ResetContext(uint64_t iteration) {
switch_index_ = -1; switch_index_ = -1;
subgraph_context_->ResetContext(node_item_->node); subgraph_context_->ResetContext(node_item_->node);
auto unique_task_context = TaskContext::Create(this, subgraph_context_); auto unique_task_context = TaskContext::Create(this, subgraph_context_);
GE_CHECK_NOTNULL_JUST_RETURN(unique_task_context);
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release());


data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());


+ 2
- 8
ge/hybrid/executor/node_state.h View File

@@ -100,6 +100,8 @@ struct NodeState {
NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context);
~NodeState() = default; ~NodeState() = default;


Status Init(int group, const shared_ptr<FrameState> &frame_state);

OpDesc *GetOpDesc() const { OpDesc *GetOpDesc() const {
return op_desc_.get(); return op_desc_.get();
} }
@@ -152,18 +154,10 @@ struct NodeState {
return merge_index_; return merge_index_;
} }


void SetGroup(int group) {
group_ = group;
}

int GetGroup() const { int GetGroup() const {
return group_; return group_;
} }


void SetFrameState(const shared_ptr<FrameState> &frame_state) {
frame_state_ = frame_state;
}

const shared_ptr<NodeTask> &GetKernelTask() const { const shared_ptr<NodeTask> &GetKernelTask() const {
return kernel_task_; return kernel_task_;
} }


+ 18
- 7
ge/hybrid/executor/subgraph_context.cc View File

@@ -79,20 +79,31 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) {
return nullptr; return nullptr;
} }


return CreateNodeState(node_item);
}

NodeStatePtr SubgraphContext::CreateNodeState(const NodeItem *node_item) {
GELOGD("[%s] lock for write", node_item->NodeName().c_str()); GELOGD("[%s] lock for write", node_item->NodeName().c_str());
if (mmRWLockWRLock(&rw_lock_) != EN_OK) { if (mmRWLockWRLock(&rw_lock_) != EN_OK) {
REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str());
GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str());
return nullptr; return nullptr;
} }

auto &node_state = node_states_[node_item]; auto &node_state = node_states_[node_item];
if (node_state == nullptr) {
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState");
node_state.reset(new(std::nothrow)NodeState(*node_item, this));
node_state->SetFrameState(GetOrCreateFrameState(*node_item));
node_state->SetGroup(group_);
(void)guard;
}
do {
if (node_state == nullptr) {
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState");
node_state.reset(new(std::nothrow)NodeState(*node_item, this));
if (node_state == nullptr || node_state->Init(group_, GetOrCreateFrameState(*node_item)) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Create][NodeState] failed for[%s].", node_item->NodeName().c_str());
REPORT_CALL_ERROR("E19999", "Create NodeState failed for %s.", node_item->NodeName().c_str());
break;
}
(void)guard;
}
} while (0);

GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); GELOGD("[%s] unlock for write", node_item->NodeName().c_str());
if (mmWRLockUnLock(&rw_lock_) != EN_OK) { if (mmWRLockUnLock(&rw_lock_) != EN_OK) {
REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str());


+ 1
- 0
ge/hybrid/executor/subgraph_context.h View File

@@ -51,6 +51,7 @@ class SubgraphContext {
void NodeDone(const NodePtr &node); void NodeDone(const NodePtr &node);


private: private:
NodeStatePtr CreateNodeState(const NodeItem *node_item);
FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock
friend class TaskContext; friend class TaskContext;
const GraphItem *graph_item_; const GraphItem *graph_item_;


Loading…
Cancel
Save