/** * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef INC_FRAMEWORK_PNE_MODEL_H_ #define INC_FRAMEWORK_PNE_MODEL_H_ #include #include #include #include "graph/compute_graph.h" #include "framework/common/debug/log.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_types.h" #include "framework/engine/dnnengine.h" #include "external/ge/ge_ir_build.h" namespace ge { const std::string PNE_ID_NPU = "NPU"; const std::string PNE_ID_CPU = "HOST_CPU"; struct ModelRelation; class PneModel { public: PneModel() = default; explicit PneModel(const ComputeGraphPtr &root_graph) : root_graph_(root_graph){}; virtual ~PneModel() = default; PneModel(const PneModel &other) = delete; PneModel &operator=(const PneModel &other) = delete; public: inline Status AddSubModel(const shared_ptr &submodel, std::string type = "") { if (submodel == nullptr) { GELOGE(INTERNAL_ERROR, "submodel is nullptr, type = %s", type.c_str()); return INTERNAL_ERROR; } submodel->SetModelType(type); if (!submodels_.emplace(submodel->GetModelName(), submodel).second) { GELOGE(INTERNAL_ERROR, "submodel already exist, name = %s, type = %s", submodel->GetModelName().c_str(), type.c_str()); return INTERNAL_ERROR; } return SUCCESS; } inline const std::shared_ptr GetSubmodel(const std::string &name) const { const auto &it = submodels_.find(name); if (it == submodels_.end()) { return nullptr; } return it->second; } inline const std::map> &GetSubmodels() const { return submodels_; } inline void SetModelType(const std::string &type) { model_type_ = type; } inline const std::string &GetModelType() const { return model_type_; } inline void SetModelName(const std::string &model_name) { model_name_ = model_name; } inline const std::string &GetModelName() const { return model_name_; } inline void SetRootGraph(const ComputeGraphPtr graph) { root_graph_ = graph; } inline const ComputeGraphPtr &GetRootGraph() const { return root_graph_; } inline void SetModelRelation(std::shared_ptr model_relation) { model_relation_ = std::move(model_relation); } inline const std::shared_ptr GetModelRelation() const { return model_relation_; } public: virtual Status SerializeModel(ModelBufferData &model_buff) = 0; virtual Status UnSerializeModel(const ModelBufferData &model_buff) = 0; virtual void SetModelId(const uint32_t model_id) { model_id_ = model_id; } virtual uint32_t GetModelId() const { return model_id_; } private: std::map> submodels_; std::shared_ptr model_relation_; ComputeGraphPtr root_graph_ = nullptr; std::string model_name_; std::string model_type_; uint32_t model_id_ = INVALID_MODEL_ID; }; using PneModelPtr = std::shared_ptr; } // namespace ge #endif // INC_FRAMEWORK_PNE_MODEL_H_