/** * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef AIR_CXX_RUNTIME_V2_CORE_MODEL_V_2_EXECUTOR_H_ #define AIR_CXX_RUNTIME_V2_CORE_MODEL_V_2_EXECUTOR_H_ #include #include "graph/compute_graph.h" #include "graph/ge_error_codes.h" #include "model_desc.h" #include "runtime/stream.h" #include "exe_graph/runtime/tensor.h" namespace gert { enum SubExeGraphType { kInitExeGraph, kMainExeGraph, kDeInitExeGraph, kSubExeGraphTypeEnd }; static constexpr char *kSubExeGraphTypeStrs[kSubExeGraphTypeEnd] = {(char *)"Init", (char *)"Main", (char *)"DeInit"}; inline const char *GetSubExeGraphTypeStr(SubExeGraphType type) { return kSubExeGraphTypeStrs[type]; } class ResourceGuard { public: void *ResetExecutionData(std::unique_ptr execution_data); void ResetAnyValue(std::unique_ptr any_values, size_t count); void PushNode(void *node); void PushWatcher(void *watcher); void *ResetNodesArray(std::unique_ptr nodes_array); void *ResetStartNodesArray(std::unique_ptr start_nodes_array); void *ResetNodesIndgreeArray(std::unique_ptr nodes_indgree_array); void *ResetNodesWaitIndgreeArray(std::unique_ptr nodes_indgree_array); void *ResetInputsArray(std::unique_ptr inputs_array); void *ResetOutputsArray(std::unique_ptr outputs_array); void *ResetWatchersArray(std::unique_ptr watchers_array); void *ResetReadyQueue(void *ready_queue); void *ResetBuffer(std::unique_ptr buffer); void *ResetComputeNodeInfo(std::unique_ptr compute_node_info); void *ResetKernelExtendInfo(std::unique_ptr kernel_extend_info); void *ResetModelDesc(std::unique_ptr model_desc); ~ResourceGuard(); private: std::unique_ptr execution_data_holder_; size_t any_values_num_; std::unique_ptr any_values_guard_; std::vector> nodes_guarder_; std::vector> watchers_guarder_; std::unique_ptr continuous_buffer_guarder_; std::unique_ptr buffer_guarder_; std::unique_ptr compute_node_info_guarder_; std::unique_ptr kernel_extend_info_guarder_; std::unique_ptr model_desc_guarder_; std::unique_ptr nodes_array_guarder_; std::unique_ptr start_nodes_array_guarder_; std::unique_ptr nodes_indgree_array_guarder_; std::unique_ptr nodes_wait_indgree_array_guarder_; std::unique_ptr inputs_array_guarder_; std::unique_ptr outputs_array_guarder_; std::unique_ptr watchers_array_guarder_; std::unique_ptr ready_queue_guarder_{nullptr, nullptr}; }; struct ModelExecuteArg { rtStream_t stream; }; static_assert(std::is_standard_layout::value, "The class ModelExecuteArg must be a POD"); class ExeGraphExecutor { public: // todo unload时释放anyvalue资源 ge::graphStatus Load() { return ge::GRAPH_SUCCESS; } ge::graphStatus UnLoad() { return ge::GRAPH_SUCCESS; } /** * 设置图执行的输入/输出,需要注意的是,使用者需要自己保证inputs/outputs刷新完全!!! */ ge::graphStatus SpecifyInputs(void **inputs, size_t start, size_t num); ge::graphStatus SpecifyOutputs(void **outputs, size_t num); ge::graphStatus Execute(); const void *GetExecutionData() const { return execution_data_; } ResourceGuard &GetResourceGuard(); void *SetExecutionData(std::unique_ptr execution_data); private: friend class ModelV2ExecutorTestHelper; void *execution_data_; ResourceGuard resource_guard_; }; class ModelV2Executor { public: static std::unique_ptr Create(const ge::ComputeGraphPtr &root_graph); ge::graphStatus Load(); ge::graphStatus Execute(const ModelExecuteArg &arg, Tensor **inputs, size_t input_num, Tensor **outputs, size_t output_num); ge::graphStatus ExecuteSync(Tensor **inputs, size_t input_num, Tensor **outputs, size_t output_num); ge::graphStatus UnLoad(); const ModelDesc &GetModelDesc() const; void SetModelDesc(ModelDesc *model_desc); ModelV2Executor(const ModelV2Executor &) = delete; ModelV2Executor(ModelV2Executor &&) = delete; ModelV2Executor &operator=(const ModelV2Executor &) = delete; ModelV2Executor &operator=(ModelV2Executor &&) = delete; private: friend class ModelV2ExecutorBuilder; friend class ModelV2ExecutorTestHelper; ModelV2Executor() = default; private: std::array graphs_; ResourceGuard resource_guard_; ModelDesc *model_desc_ = nullptr; rtStream_t default_stream_ = nullptr; }; } // namespace gert #endif // AIR_CXX_RUNTIME_V2_CORE_MODEL_V_2_EXECUTOR_H_