/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_HYBRID_EXECUTOR_NODE_STATE_H_ #define GE_HYBRID_EXECUTOR_NODE_STATE_H_ #include #include #include #include "common/blocking_queue.h" #include "external/ge/ge_api_error_codes.h" #include "hybrid/model/node_item.h" #include "node_done_manager.h" namespace ge { namespace hybrid { class NodeTask; struct GraphExecutionContext; class SubgraphContext; class TaskContext; struct NodeState; struct FrameState; using NodeStatePtr = std::shared_ptr; using FrameStatePtr = std::shared_ptr; class ShapeFuture { public: ShapeFuture(NodeState *src_node, uint32_t src_index, SubgraphContext *subgraph_context); ~ShapeFuture() = default; Status Get(GeShape &ori_shape, GeShape &shape); Status GetTensorDesc(const GeTensorDesc **tensor_desc); private: NodeState *src_node_; uint32_t src_index_; SubgraphContext *subgraph_context_; }; struct ShapeInferenceState { explicit ShapeInferenceState(const NodeItem &node_item); void InitShapeState(); Status UpdateInputShape(int idx, const GeTensorDesc &tensor_desc); void UpdateInputShapeFuture(int idx, ShapeFuture &&future); Status AwaitShapesReady(const GraphExecutionContext &context); Status UpdateOutputDesc(); const vector &GetOutputTensorDesc() const; const NodeItem &node_item; private: Status UpdateInputForMerge(const GraphExecutionContext &context); friend struct NodeState; std::vector> shape_futures; // do not directly update op_desc, in case race condition across pipelines std::vector input_tensor_desc; std::vector output_tensor_desc; int num_pending_shapes_ = 0; std::condition_variable ready_cv_; std::mutex mu_; }; struct FrameState { public: FrameState(int64_t id) : frame_id_(id) {} ~FrameState() = default; int64_t frame_id_{0}; uint64_t active_count_{0}; uint64_t iteration_count_{0}; std::shared_ptr parent_frame_; }; // saving sth. dynamic during execution struct NodeState { public: NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); ~NodeState() = default; Status Init(int group, const shared_ptr &frame_state); OpDesc *GetOpDesc() const { return op_desc_.get(); } inline const NodeItem *GetNodeItem() const { return node_item_; } inline const string &GetName() const { return node_item_->NodeName(); } inline const string &GetType() const { return node_item_->NodeType(); } ShapeInferenceState &GetShapeInferenceState() { return shape_inference_state_; } Status UpdateOutputShapes(int index, const GeShape &shape, const GeShape &ori_shape); inline bool IsShapeDependence() const { return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; } void RunStreamActive(); void RunNextIteration(); void SavePersistTensor(int input_idx, const TensorValue &tensor); void UpdatePersistTensor(); Status NodeScheduled(const std::function &ready) const; void SetScheduleFuture(std::future &&future); Status WaitForScheduleDone(); void SetSwitchIndex(int index) { switch_index_ = index; } int GetSwitchIndex() const { return switch_index_; } void SetMergeIndex(int index) { merge_index_ = index; } int GetMergeIndex() const { return merge_index_; } int GetGroup() const { return group_; } const shared_ptr &GetKernelTask() const { return kernel_task_; } void SetKernelTask(const shared_ptr &kernel_task) { kernel_task_ = kernel_task; } Status WaitForPrepareDone(); void SetPrepareFuture(std::future &&prepare_future) { this->prepare_future_ = std::move(prepare_future); } Status AwaitInputTensors(GraphExecutionContext &context) const; void SetTaskContext(std::shared_ptr &task_context); std::shared_ptr GetTaskContext(); void SetSkipInferShape(bool skip_infershape) { skip_infershape_ = skip_infershape; } bool MaySkipShapeInference() const { return skip_infershape_; } private: bool IsScheduleReady() const; void SetDataSchedule(const NodeState &node_state, const std::function &ready); void SetCtrlSchedule(const NodeState &node_state, const std::function &ready); void ResetContext(uint64_t iteration); void ScheduleContext(const NodeState &node_state); void UpdatePersistTensor(int input_idx); const NodeItem *node_item_ = nullptr; std::shared_ptr kernel_task_ = nullptr; std::future prepare_future_; OpDescPtr op_desc_; ShapeInferenceState shape_inference_state_; SubgraphContext *subgraph_context_; std::shared_ptr task_context_ = nullptr; std::mutex mu_; std::future schedule_future_; std::shared_ptr frame_state_; std::map root_tensor_values_; uint64_t active_count_ = 0; uint64_t iteration_count_ = 0; uint32_t ctrl_scheduled_ = 0; uint32_t data_scheduled_ = 0; int merge_index_ = -1; // Use for Execute (Reset after Executed). int switch_index_ = -1; // Use for Schedule (Reset after Prepared). int group_ = -1; bool skip_infershape_ = false; }; } // namespace hybrid } // namespace ge #endif // GE_HYBRID_EXECUTOR_NODE_STATE_H_