/** * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_SUBSCRIBERS_SCHEDULER_H_ #define AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_SUBSCRIBERS_SCHEDULER_H_ #include #include #include "executor_subscriber_guarder.h" #include "built_in_subscriber_definitions.h" #include "global_profiling.h" #include "framework/common/ge_visibility.h" namespace gert { class ModelV2Executor; class VISIBILITY_EXPORT ExecutorSubscribersScheduler { public: static void OnExecuteEvent(ExecutorSubscribersScheduler *ins, ExecutorEvent event, const void *node, KernelStatus result); ExecutorSubscribersScheduler() : executor_(nullptr), enabled_(false), built_in_subscribers_ptr_(), subscribers_(), subscriber_wrapper_({reinterpret_cast<::SubscriberFunc>(ExecutorSubscribersScheduler::OnExecuteEvent), this}) {} #ifdef ONLY_COMPILE_OPEN_SRC ~ExecutorSubscribersScheduler(); #endif void Init(ModelV2Executor *executor); ExecutorSubscribersScheduler(const ExecutorSubscribersScheduler &) = delete; ExecutorSubscribersScheduler &operator=(const ExecutorSubscribersScheduler &) = delete; ExecutorSubscriber &GetSubscriber() { if (subscribers_.size() == 1UL) { return subscribers_[0].GetSubscriber(); } else { return subscriber_wrapper_; } } ModelV2Executor *GetModelV2Executor() { return executor_; } const ModelV2Executor *GetModelV2Executor() const { return executor_; } /** * 设置订阅者,订阅者需要实现一个static方法,原型为: * ```c++ * static void OnExecuteEvent(T *void_arg, ExecutorEvent event, const void *node, KernelStatus result); * ``` * * 默认情况下,subscribers处于disable状态,在添加首个subscriber时,自动将状态切换到enable状态。 * * @tparam T 订阅者类型 * @tparam Args 订阅者初始化参数类型 * @param args 订阅者初始化参数 * @return 添加的subscriber指针,注意subscriber所有权归`ExecutorSubscribersScheduler`所有,外部使用者不可以释放此指针 */ template T *AddSubscriber(Args... args) { auto ins = new (std::nothrow) T(args...); if (ins == nullptr) { return nullptr; } constexpr size_t kInitSubscriberSize = 1UL; // profiler exists when ess init if (subscribers_.size() == kInitSubscriberSize) { enabled_ = true; } subscribers_.emplace_back(reinterpret_cast<::SubscriberFunc>(T::OnExecuteEvent), ins, ObjectDeleter); return ins; } /** * 添加一个内置的subscriber * 内置subscriber较少,当前没有使用注册机制,后续如果需要扩展,那么可以考虑通过注册机制自动注册。 * 为了易用性,在本类提供了获取内置subscriber的指针的接口。而自注册的subscriber将丢失此能力。 * @param subscriber_type */ void AddBuiltIn(BuiltInSubscriberType subscriber_type, uint64_t enable_flag); void RemoveSubscriber(void *subscriber_ptr) { for (auto iter = subscribers_.begin(); iter != subscribers_.end(); ++iter) { if (iter->GetSubscriber().arg == subscriber_ptr) { subscribers_.erase(iter); break; } } for (auto &built_in_subscriber : built_in_subscribers_ptr_) { if (built_in_subscriber == subscriber_ptr) { built_in_subscriber = nullptr; } } if (subscribers_.empty()) { enabled_ = false; } } template inline T *MutableBuiltInSubscriber(const BuiltInSubscriberType type) { return static_cast(built_in_subscribers_ptr_[static_cast(type)]); } template inline const T *GetBuiltInSubscriber(const BuiltInSubscriberType type) { return static_cast(built_in_subscribers_ptr_[static_cast(type)]); } bool IsEnable() const { return enabled_ || GlobalProfilingWrapper::GetInstance()->GetEnableFlags(); } void SetEnable(bool enable_flag) { enabled_ = enable_flag; } void Clear() { subscribers_.clear(); for (auto &built_in_subscriber : built_in_subscribers_ptr_) { built_in_subscriber = nullptr; } enabled_ = false; } size_t GetSize() const { return subscribers_.size(); } private: ModelV2Executor *executor_{nullptr}; bool enabled_{false}; std::array(BuiltInSubscriberType::kNum)> built_in_subscribers_ptr_; std::vector subscribers_; ExecutorSubscriber subscriber_wrapper_; }; } // namespace gert #endif // AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_SUBSCRIBERS_SCHEDULER_H_