@@ -262,7 +262,6 @@ set(COMPILER_SRC_LIST | |||||
"common/dump/dump_op.cc" | "common/dump/dump_op.cc" | ||||
"common/ge/op_tiling_manager.cc" | "common/ge/op_tiling_manager.cc" | ||||
"common/ge/plugin_manager.cc" | "common/ge/plugin_manager.cc" | ||||
"common/helper/model_cache_helper.cc" | |||||
"common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
"engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
"ge_local_engine/engine/host_cpu_engine.cc" | "ge_local_engine/engine/host_cpu_engine.cc" | ||||
@@ -300,7 +299,6 @@ set(COMPILER_SRC_LIST | |||||
"graph/manager/graph_var_manager.cc" | "graph/manager/graph_var_manager.cc" | ||||
"graph/manager/host_mem_allocator.cc" | "graph/manager/host_mem_allocator.cc" | ||||
"graph/manager/host_mem_manager.cc" | "graph/manager/host_mem_manager.cc" | ||||
"graph/manager/model_manager/event_manager.cc" | |||||
"graph/manager/rdma_pool_allocator.cc" | "graph/manager/rdma_pool_allocator.cc" | ||||
"graph/manager/session_scope_mem_allocator.cc" | "graph/manager/session_scope_mem_allocator.cc" | ||||
"graph/manager/trans_var_data_utils.cc" | "graph/manager/trans_var_data_utils.cc" | ||||
@@ -1,123 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ | |||||
#define GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ | |||||
#include <nlohmann/json.hpp> | |||||
#include <set> | |||||
#include <string> | |||||
#include "external/ge/ge_api_error_codes.h" | |||||
#include "graph/compute_graph.h" | |||||
#include "graph/manager/graph_var_manager.h" | |||||
#include "common/model/ge_model.h" | |||||
namespace ge { | |||||
using Json = nlohmann::json; | |||||
struct CacheInfo { | |||||
size_t node_num; | |||||
size_t edge_num; | |||||
size_t graph_hash; | |||||
map<std::string, size_t> nodes_hash; | |||||
CacheInfo() : node_num(0), edge_num(0), graph_hash(0) {} | |||||
}; | |||||
class ModelCacheHelper { | |||||
public: | |||||
ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph); | |||||
~ModelCacheHelper(); | |||||
Status SaveCacheInfoToCache () const; | |||||
Status SaveVarManagerToCache(bool before_build) const; | |||||
Status SaveOmModelToCache(const GeModelPtr &ge_model) const; | |||||
bool IsModelCacheHit() const; | |||||
Status RecoverVarManagerFromCache() const; | |||||
Status LoadOmModelFromCache(GeModelPtr &ge_model) const; | |||||
Status RefreshComputeGraph(const ComputeGraphPtr &compute_graph); | |||||
Status ClearCache(uint32_t graph_id) const; | |||||
private: | |||||
Status GetComputeGraphHash(size_t &hash) const; | |||||
Status GetNodesHash(map<std::string, size_t> &hash_map) const; | |||||
Status GetCacheInfo(CacheInfo &cache_info) const; | |||||
Status RecoverMemResource(const Json &json) const; | |||||
Status RecoverAllocatedGraphId(const Json &json) const; | |||||
Status RecoverChangedGraphId(const Json &json) const; | |||||
Status RecoverVarAddrAndTensorDesc(const Json &json) const; | |||||
Status RecoverBroadcastInfo(const Json &json) const; | |||||
Status RecoverTransRoads(const Json &json) const; | |||||
static Status GetNodesNeedRecompile(ComputeGraphPtr &graph, vector<NodePtr> &nodes); | |||||
static Status RecompileNodes(GeModelPtr &ge_model); | |||||
bool IsNodeHashSameAsCache(const map<std::string, size_t> &hash_map) const; | |||||
bool IsMemResourceSameAsCache(Json &json) const; | |||||
bool IsChangedGraphIdSameAsCache(Json &json) const; | |||||
bool IsAllocatedGraphIdSameAsCache(Json &json) const; | |||||
bool IsCurVarTensorDescSameAsCache(Json &json) const; | |||||
bool IsVarAddrMgrMapSameAsCache(Json &json) const; | |||||
bool IsBroadcastInfoSameAsCache(Json &json) const; | |||||
bool IsTransRoadsSameAsCache(Json &json) const; | |||||
bool IsVarManagerSameAsCache(Json &json) const; | |||||
bool IsVarManagerParamSameAsCache(Json &json) const; | |||||
Status SaveJsonToFile(const string &file_name, const Json &json) const; | |||||
Status LoadJsonFromFile(const string &file_name, Json &json) const; | |||||
Status GetNodesHashMapJson(Json &json) const; | |||||
Status GetMemResourceMap(Json &json) const; | |||||
Status GetVarAddrMgrMapJson(Json &json) const; | |||||
Status GetCurVarTensorDescMapJson(Json &json) const; | |||||
Status GetTransRoadsJson(Json &json) const; | |||||
Status GetChangedGraphIdJson(Json &json) const; | |||||
Status GetAllocatedGraphIdJson(Json &json) const; | |||||
Status GetBroadcastInfoJson(Json &json) const; | |||||
Status GetVarResourceJson(Json &json) const; | |||||
Status GetVarManagerJson(Json &json) const; | |||||
static Status TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json); | |||||
static Status JsonToTensorDesc(const Json &json, GeTensorDesc &ge_tensor_desc); | |||||
static Status ParseMemResourceFromJson(const Json &json, map<rtMemType_t, int64_t> &mem_resource); | |||||
static Status ParseVarAddrMgrMapFromJson(const Json &json, | |||||
std::vector<std::pair<std::string, VarAddrMgr>> &var_addr_mgr_vector, | |||||
std::set<uint64_t> &var_offset_set); | |||||
static Status ParseCurVarTensorDescMapFromJson( | |||||
const Json &json, std::unordered_map<std::string, ge::GeTensorDesc> &cur_var_tensor_desc_map); | |||||
static Status ParseTransRoadsFromJson(const Json &json, | |||||
std::unordered_map<std::string, std::vector<TransNodeInfo>> &trans_roads); | |||||
static Status ParseChangedGraphIdFromJson(const Json &json, | |||||
std::map<std::string, uint32_t> &changed_graph_id); | |||||
static Status ParseAllocatedGraphIdFromJson(const Json &json, | |||||
std::map<std::string, uint32_t> &allocated_graph_id); | |||||
static Status ParseBroadcastInfoFromJson(const Json &json, | |||||
std::unordered_map<std::string, VarBroadCastInfo> &var_broadcast_info); | |||||
static Status GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc, string &var_name); | |||||
uint64_t session_id_; | |||||
uint32_t graph_id_; | |||||
string cache_path_; | |||||
ComputeGraphPtr compute_graph_; | |||||
std::set<string> var_names_; | |||||
bool is_cache_path_valid_for_output; | |||||
static map<uint32_t, uint32_t> graph_id_run_times_; | |||||
}; | |||||
using ModelCacheHelperPtr = std::shared_ptr<ModelCacheHelper>; | |||||
} // namespace ge | |||||
#endif // GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ |
@@ -27,6 +27,7 @@ | |||||
#include "graph/load/graph_loader.h" | #include "graph/load/graph_loader.h" | ||||
#include "graph/load/model_manager/model_manager.h" | #include "graph/load/model_manager/model_manager.h" | ||||
#include "graph/manager/graph_mem_manager.h" | #include "graph/manager/graph_mem_manager.h" | ||||
#include "graph/manager/graph_var_manager.h" | |||||
#include "single_op/single_op_manager.h" | #include "single_op/single_op_manager.h" | ||||
#include "graph/load/model_manager/davinci_model.h" | #include "graph/load/model_manager/davinci_model.h" | ||||
#include "opskernel_manager/ops_kernel_builder_manager.h" | #include "opskernel_manager/ops_kernel_builder_manager.h" | ||||
@@ -30,6 +30,7 @@ | |||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
#include "graph/manager/graph_manager.h" | #include "graph/manager/graph_manager.h" | ||||
#include "graph/manager/graph_var_manager.h" | |||||
#include "graph/manager/util/rt_context_util.h" | #include "graph/manager/util/rt_context_util.h" | ||||
#include "graph/operator_factory_impl.h" | #include "graph/operator_factory_impl.h" | ||||
#include "graph/opsproto_manager.h" | #include "graph/opsproto_manager.h" | ||||
@@ -248,7 +248,6 @@ Status GraphManager::Finalize() { | |||||
Analyzer::GetInstance()->DestroyGraphJsonObject(session_id, graph_id); | Analyzer::GetInstance()->DestroyGraphJsonObject(session_id, graph_id); | ||||
} | } | ||||
graph_map_.clear(); | graph_map_.clear(); | ||||
cache_helper_map_.clear(); | |||||
graph_count_.clear(); | graph_count_.clear(); | ||||
// graph context | // graph context | ||||
@@ -1005,13 +1004,6 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<Ge | |||||
} | } | ||||
} | } | ||||
ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther); | |||||
// when set incre build, save om model and var manager | |||||
GeModelPtr ge_model = nullptr; | |||||
auto save_ret = SaveCacheAfterBuild(graph_node->GetGraphId(), compute_graph, ge_model); | |||||
if (save_ret != SUCCESS) { | |||||
GELOGW("Fail to save cache."); | |||||
} | |||||
GEEVENT("[GEPERFTRACE] GE PreRun End"); | GEEVENT("[GEPERFTRACE] GE PreRun End"); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -1063,18 +1055,15 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: | |||||
graph_node->GetGraphId()); | graph_node->GetGraphId()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
GeModelPtr ge_model = nullptr; | |||||
// check need incre build. | |||||
ret = IncreBuild(graph_node, ge_model); | |||||
ret = PreRun(graph_node, inputs, ge_root_model, session_id); | |||||
// release rts generate context | |||||
RtContextUtil::GetInstance().DestroyRtContexts(session_id, graph_node->GetGraphId()); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
ret = PreRun(graph_node, inputs, ge_root_model, session_id); | |||||
// release rts generate context | |||||
RtContextUtil::GetInstance().DestroyRtContexts(session_id, graph_node->GetGraphId()); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Call][PreRun] Failed, graph_id:%u, session_id:%lu.", graph_node->GetGraphId(), session_id); | |||||
return ret; | |||||
} | |||||
GELOGE(ret, "[Call][PreRun] Failed, graph_id:%u, session_id:%lu.", graph_node->GetGraphId(), session_id); | |||||
return ret; | |||||
} | } | ||||
ret = LoadGraph(ge_root_model, graph_node); | ret = LoadGraph(ge_root_model, graph_node); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "[Load][Graph] Failed, graph_id:%u.", graph_node->GetGraphId()); | GELOGE(ret, "[Load][Graph] Failed, graph_id:%u.", graph_node->GetGraphId()); | ||||
@@ -1104,91 +1093,6 @@ Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphN | |||||
return executor_->LoadGraph(ge_root_model, graph_node); | return executor_->LoadGraph(ge_root_model, graph_node); | ||||
} | } | ||||
Status GraphManager::LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, | |||||
GeModelPtr &ge_model) { | |||||
auto graph_id = graph_node->GetGraphId(); | |||||
auto ret = cache_helper->LoadOmModelFromCache(ge_model); | |||||
if (ret != SUCCESS) { | |||||
GELOGW("Fail to load om model from cache."); | |||||
if (cache_helper->ClearCache(graph_id) != SUCCESS) { | |||||
GELOGW("Fail to clear cache of graph %u.", graph_id); | |||||
} | |||||
return FAILED; | |||||
} | |||||
ret = cache_helper->RecoverVarManagerFromCache(); | |||||
if (ret != SUCCESS) { | |||||
GELOGW("Fail to recover VarManager from cache."); | |||||
if (cache_helper->ClearCache(graph_id) != SUCCESS) { | |||||
GELOGW("Fail to clear cache of graph %u.", graph_id); | |||||
} | |||||
return FAILED; | |||||
} | |||||
ComputeGraphPtr compute_graph_in_model = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | |||||
if (compute_graph_in_model == nullptr) { | |||||
GELOGW("Error occurred when get compute graph from om, abandon."); | |||||
return FAILED; | |||||
} else { | |||||
graph_node->SetComputeGraph(compute_graph_in_model); | |||||
graph_node->SetGeModel(ge_model); | |||||
GELOGI("Load model and graph form cache om file."); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status GraphManager::SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper) { | |||||
auto ret = cache_helper->SaveCacheInfoToCache(); | |||||
if (ret != SUCCESS) { | |||||
GELOGW("Fail to save cache info of graph[%d] to cache.", graph_id); | |||||
return FAILED; | |||||
} | |||||
ret = cache_helper->SaveVarManagerToCache(true); | |||||
if (ret != SUCCESS) { | |||||
GELOGW("Fail to save var manager to cache."); | |||||
cache_helper->ClearCache(graph_id); | |||||
return FAILED; | |||||
} | |||||
GELOGI("Cache files have been saved."); | |||||
return SUCCESS; | |||||
} | |||||
Status GraphManager::SaveCacheAfterBuild(uint32_t graph_id, ge::ComputeGraphPtr graph, GeModelPtr &ge_model) { | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
if ((instance_ptr == nullptr) || !instance_ptr->InitFlag()) { | |||||
GELOGW("GELib not initialized."); | |||||
return FAILED; | |||||
} | |||||
if (instance_ptr->IsIncreBuild()) { | |||||
std::lock_guard<std::mutex> lock(member_mutex_); | |||||
auto iter = cache_helper_map_.find(graph_id); | |||||
if (iter == cache_helper_map_.end()) { | |||||
GELOGW("Can not find ModelCacheHelper of graph[%u]", graph_id); | |||||
return FAILED; | |||||
} else { | |||||
ModelCacheHelperPtr cache_helper = iter->second; | |||||
auto ret = cache_helper->RefreshComputeGraph(graph); | |||||
if (ret != SUCCESS) { | |||||
cache_helper->ClearCache(graph_id); | |||||
GELOGW("Fail to refresh cache helper's compute graph"); | |||||
return FAILED; | |||||
} | |||||
ret = cache_helper->SaveVarManagerToCache(false); | |||||
if (ret != SUCCESS) { | |||||
cache_helper->ClearCache(graph_id); | |||||
GELOGW("Fail to save VarManager to cache"); | |||||
return FAILED; | |||||
} | |||||
ret = cache_helper->SaveOmModelToCache(ge_model); | |||||
if (ret != SUCCESS) { | |||||
cache_helper->ClearCache(graph_id); | |||||
GELOGW("Fail to save om model to cache"); | |||||
return FAILED; | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, | Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, | ||||
const std::vector<GeTensor> &inputs, std::vector<GeTensor> &outputs) { | const std::vector<GeTensor> &inputs, std::vector<GeTensor> &outputs) { | ||||
GE_CHECK_NOTNULL(executor_); | GE_CHECK_NOTNULL(executor_); | ||||
@@ -1239,8 +1143,6 @@ Status GraphManager::RunGraphWithStreamAsync(const GraphId &graph_id, rtStream_t | |||||
graph_node->SetIsSpecificStream(true); | graph_node->SetIsSpecificStream(true); | ||||
ComputeGraphPtr compute_graph_tmp = GraphUtils::GetComputeGraph(*(graph_node->GetGraph())); | ComputeGraphPtr compute_graph_tmp = GraphUtils::GetComputeGraph(*(graph_node->GetGraph())); | ||||
// when set incre build, add cache helper map | |||||
AddModelCacheHelperToMap(graph_id, session_id, compute_graph_tmp); | |||||
if (options_.local_fmk_op_flag) { | if (options_.local_fmk_op_flag) { | ||||
GetCompilerStages(graph_id).optimizer.TranFrameOp(compute_graph_tmp); | GetCompilerStages(graph_id).optimizer.TranFrameOp(compute_graph_tmp); | ||||
} | } | ||||
@@ -1299,9 +1201,6 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vector<GeTenso | |||||
"compute_graph_tmp is NULL, graph id = %u.", graph_id); | "compute_graph_tmp is NULL, graph id = %u.", graph_id); | ||||
return GE_GRAPH_GRAPH_NODE_NULL;)) | return GE_GRAPH_GRAPH_NODE_NULL;)) | ||||
// when set incre build, add cache helper map | |||||
AddModelCacheHelperToMap(graph_id, session_id, compute_graph_tmp); | |||||
if (options_.local_fmk_op_flag) { | if (options_.local_fmk_op_flag) { | ||||
GetCompilerStages(graph_id).optimizer.TranFrameOp(compute_graph_tmp); | GetCompilerStages(graph_id).optimizer.TranFrameOp(compute_graph_tmp); | ||||
} | } | ||||
@@ -1518,16 +1417,6 @@ Status GraphManager::SaveParams(ge::GeModel &model, const std::string &type, con | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void GraphManager::RemoveModelCacheHelper(const GraphId &graph_id) { | |||||
std::lock_guard<std::mutex> lock(member_mutex_); | |||||
auto iter = cache_helper_map_.find(graph_id); | |||||
if (iter != cache_helper_map_.end()) { | |||||
cache_helper_map_.erase(iter); | |||||
} else { | |||||
GELOGW("[GraphManager] cache helper does not exist, graph_id = %u", graph_id); | |||||
} | |||||
} | |||||
bool GraphManager::CheckModelLoad(const GeRootModelPtr &ge_root_model, bool load_flag) { | bool GraphManager::CheckModelLoad(const GeRootModelPtr &ge_root_model, bool load_flag) { | ||||
return ((ge_root_model != nullptr) && (ge_root_model->GetModelId() != INVALID_MODEL_ID) && load_flag); | return ((ge_root_model != nullptr) && (ge_root_model->GetModelId() != INVALID_MODEL_ID) && load_flag); | ||||
} | } | ||||
@@ -1555,7 +1444,6 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { | |||||
var_acc_ctrl_.RemoveGraph(graph_id); | var_acc_ctrl_.RemoveGraph(graph_id); | ||||
RemoveGraphNode(graph_id); | RemoveGraphNode(graph_id); | ||||
RemoveModelCacheHelper(graph_id); | |||||
auto ge_root_model = graph_node->GetGeRootModel(); | auto ge_root_model = graph_node->GetGeRootModel(); | ||||
if (CheckModelLoad(ge_root_model, graph_node->GetLoadFlag())) { | if (CheckModelLoad(ge_root_model, graph_node->GetLoadFlag())) { | ||||
@@ -2727,61 +2615,6 @@ Status GraphManager::RunGraphAsync(const GraphId &graph_id, const std::vector<ge | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void GraphManager::AddModelCacheHelperToMap(const GraphId &graph_id, uint64_t session_id, | |||||
ComputeGraphPtr &compute_graph) { | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
if (instance_ptr != nullptr && instance_ptr->IsIncreBuild()) { | |||||
std::lock_guard<std::mutex> lock(member_mutex_); | |||||
auto iter = cache_helper_map_.find(graph_id); | |||||
if (iter == cache_helper_map_.end()) { | |||||
ModelCacheHelperPtr cache_helper = MakeShared<ge::ModelCacheHelper>(session_id, graph_id, compute_graph); | |||||
if (cache_helper != nullptr) { | |||||
cache_helper_map_.emplace(std::make_pair(graph_id, cache_helper)); | |||||
} else { | |||||
GELOGW("Cache helper make shared failed, graph_id = %u.", graph_id); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
ModelCacheHelperPtr GraphManager::FindModelCacheHelper(GraphId graph_id) { | |||||
std::lock_guard<std::mutex> lock(member_mutex_); | |||||
auto iter = cache_helper_map_.find(graph_id); | |||||
if (iter != cache_helper_map_.end()) { | |||||
return iter->second; | |||||
} | |||||
return nullptr; | |||||
} | |||||
Status GraphManager::IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model) { | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
if (instance_ptr == nullptr || !instance_ptr->IsIncreBuild()) { | |||||
return FAILED; | |||||
} | |||||
const uint32_t graph_id = graph_node->GetGraphId(); | |||||
ModelCacheHelperPtr cache_helper = FindModelCacheHelper(graph_id); | |||||
if (cache_helper == nullptr) { | |||||
GELOGW("Can not find ModelCacheHelper of graph[%u]", graph_id); | |||||
return FAILED; | |||||
} | |||||
if (cache_helper->IsModelCacheHit()) { | |||||
GEEVENT("Model cache hit."); | |||||
Status ret = LoadFromCache(graph_node, cache_helper, ge_model); | |||||
if (ret == SUCCESS) { | |||||
return SUCCESS; | |||||
} else { | |||||
GELOGW("Error occurred when load from cache, abandon."); | |||||
} | |||||
} else { | |||||
GEEVENT("Model cache miss."); | |||||
} | |||||
if (SaveCacheBeforeBuild(graph_node->GetGraphId(), cache_helper) != SUCCESS) { | |||||
GELOGW("Error occurred when save cache."); | |||||
} | |||||
return FAILED; | |||||
} | |||||
Status GraphManager::CheckIncreBuildAndPreRun(const PreRunArgs &args, | Status GraphManager::CheckIncreBuildAndPreRun(const PreRunArgs &args, | ||||
GraphNodePtr &graph_node, GeRootModelPtr &ge_root_model) { | GraphNodePtr &graph_node, GeRootModelPtr &ge_root_model) { | ||||
if (!IsGraphNeedBuild(graph_node)) { | if (!IsGraphNeedBuild(graph_node)) { | ||||
@@ -2796,20 +2629,18 @@ Status GraphManager::CheckIncreBuildAndPreRun(const PreRunArgs &args, | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
// check need incre build. | // check need incre build. | ||||
GeModelPtr ge_model = nullptr; | |||||
if (IncreBuild(graph_node, ge_model) != SUCCESS) { | |||||
std::vector<GeTensor> ge_inputs; | |||||
for (const auto &item: args.input_tensor) { | |||||
ge_inputs.emplace_back(TensorAdapter::AsGeTensor(item)); | |||||
} | |||||
Status ret = PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); | |||||
// release rts generate context | |||||
RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); | |||||
if (ret != SUCCESS) { | |||||
ReturnError(args.callback, ret, "PreRun Failed."); | |||||
return ret; | |||||
} | |||||
std::vector<GeTensor> ge_inputs; | |||||
for (const auto &item: args.input_tensor) { | |||||
ge_inputs.emplace_back(TensorAdapter::AsGeTensor(item)); | |||||
} | } | ||||
Status ret = PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); | |||||
// release rts generate context | |||||
RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); | |||||
if (ret != SUCCESS) { | |||||
ReturnError(args.callback, ret, "PreRun Failed."); | |||||
return ret; | |||||
} | |||||
graph_node->SetBuildFlag(true); | graph_node->SetBuildFlag(true); | ||||
var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); | var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -2878,10 +2709,6 @@ void GraphManager::PreRunThread() { | |||||
graph_node->Unlock(); | graph_node->Unlock(); | ||||
return; | return; | ||||
} | } | ||||
// when set incre build, save cache helper. | |||||
AddModelCacheHelperToMap(args.graph_id, args.session_id, compute_graph_tmp); | |||||
std::vector<GeModelPtr> ge_models; | |||||
if (options_.local_fmk_op_flag) { | if (options_.local_fmk_op_flag) { | ||||
GetCompilerStages(graph_node->GetGraphId()).optimizer.TranFrameOp(compute_graph_tmp); | GetCompilerStages(graph_node->GetGraphId()).optimizer.TranFrameOp(compute_graph_tmp); | ||||
@@ -27,7 +27,6 @@ | |||||
#include "common/blocking_queue.h" | #include "common/blocking_queue.h" | ||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "common/helper/model_cache_helper.h" | |||||
#include "external/graph/types.h" | #include "external/graph/types.h" | ||||
#include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
#include "graph/build/graph_builder.h" | #include "graph/build/graph_builder.h" | ||||
@@ -339,14 +338,6 @@ class GraphManager { | |||||
bool IsGraphNeedBuild(const GraphNodePtr &graph_node); | bool IsGraphNeedBuild(const GraphNodePtr &graph_node); | ||||
Status LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, GeModelPtr &ge_model); | |||||
Status SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper); | |||||
Status SaveCacheAfterBuild(uint32_t graph_id, ComputeGraphPtr graph, GeModelPtr &ge_model); | |||||
void AddModelCacheHelperToMap(const GraphId &graph_id, uint64_t session_id, ComputeGraphPtr &compute_graph); | |||||
Status IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model); | |||||
void RemoveModelCacheHelper(const GraphId &graph_id); | |||||
ModelCacheHelperPtr FindModelCacheHelper(GraphId graph_id); | |||||
void SetRunContext(const GraphNodePtr &graph_node); | void SetRunContext(const GraphNodePtr &graph_node); | ||||
void PushGraph(const RunArgs &args); | void PushGraph(const RunArgs &args); | ||||
@@ -411,7 +402,6 @@ class GraphManager { | |||||
std::thread prerun_thread_; | std::thread prerun_thread_; | ||||
ComputeGraphPtr compute_graph_; | ComputeGraphPtr compute_graph_; | ||||
std::map<GraphId, GraphNodePtr> graph_map_; | std::map<GraphId, GraphNodePtr> graph_map_; | ||||
std::map<GraphId, ModelCacheHelperPtr> cache_helper_map_; | |||||
// summary and checkpoint callback function list for ME, key is summary or checkpoint | // summary and checkpoint callback function list for ME, key is summary or checkpoint | ||||
std::map<std::string, std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)>> me_callback_map_; | std::map<std::string, std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)>> me_callback_map_; | ||||
@@ -70,45 +70,9 @@ void GraphNode::IncreaseLoadCount() { | |||||
++load_count_; | ++load_count_; | ||||
} | } | ||||
SubGraphInfo::SubGraphInfo() : subgraph_ptr_(nullptr), ge_model_ptr_(nullptr), malloc_flag_(false) {} | |||||
SubGraphInfo::SubGraphInfo() : subgraph_ptr_(nullptr), ge_model_ptr_(nullptr) {} | |||||
SubGraphInfo::~SubGraphInfo() { | SubGraphInfo::~SubGraphInfo() { | ||||
if (malloc_flag_) { | |||||
for (auto &buffer_addr : buffer_addr_) { | |||||
if (buffer_addr == nullptr) { | |||||
continue; | |||||
} | |||||
rtError_t rt_ret; | |||||
rt_ret = rtFreeHost(buffer_addr); | |||||
buffer_addr = nullptr; | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(rt_ret, "[Call][RtFreeHost] subgraph free buffer failed, modelId = %u", | |||||
model_id_info_.model_id); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
Status SubGraphInfo::FreeInOutBuffer() { | |||||
if (malloc_flag_) { | |||||
for (auto iter = buffer_addr_.begin(); iter != buffer_addr_.end(); ++iter) { | |||||
rtError_t rt_ret; | |||||
rt_ret = rtFreeHost(*iter); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
REPORT_CALL_ERROR("E19999", "Call rtFreeHost fail, ret:%d", rt_ret); | |||||
GELOGE(rt_ret, "[Call][RtFreeHost] subgraph free buffer failed, modelId = %u", model_id_info_.model_id); | |||||
buffer_addr_.erase(buffer_addr_.begin(), iter); | |||||
return GE_GRAPH_FREE_FAILED; | |||||
} | |||||
} | |||||
buffer_addr_.clear(); | |||||
malloc_flag_ = false; | |||||
return SUCCESS; | |||||
} else { | |||||
GELOGI("[GraphManager] not malloc buffer, modelId = %u", model_id_info_.model_id); | |||||
return SUCCESS; | |||||
} | |||||
} | } | ||||
GraphModelListener::GraphModelListener(std::mutex &mutex, std::condition_variable &cond) | GraphModelListener::GraphModelListener(std::mutex &mutex, std::condition_variable &cond) | ||||
@@ -86,8 +86,6 @@ class SubGraphInfo { | |||||
void SetGeModelPtr(const GeModelPtr &ge_model_ptr) { ge_model_ptr_ = ge_model_ptr; } | void SetGeModelPtr(const GeModelPtr &ge_model_ptr) { ge_model_ptr_ = ge_model_ptr; } | ||||
bool GeModelIsValid() const { return ge_model_ptr_ != nullptr; } | bool GeModelIsValid() const { return ge_model_ptr_ != nullptr; } | ||||
Status FreeInOutBuffer(); | |||||
void SetOutputContext(const std::string &output) { output_names_ = output; } | void SetOutputContext(const std::string &output) { output_names_ = output; } | ||||
std::string GetOutputContext() const { return output_names_; } | std::string GetOutputContext() const { return output_names_; } | ||||
@@ -429,10 +429,6 @@ ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTenso | |||||
return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type); | return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type); | ||||
} | } | ||||
void VarManager::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) { | |||||
var_resource_->GetAllVarAddrMgr(var_addr_mgr_map); | |||||
} | |||||
int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { | int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { | ||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
MemResource *mem_resource = nullptr; | MemResource *mem_resource = nullptr; | ||||
@@ -453,36 +449,6 @@ int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { | |||||
return mem_resource->GetVarMemSize(); | return mem_resource->GetVarMemSize(); | ||||
} | } | ||||
Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) { | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
MemResource *mem_resource = nullptr; | |||||
auto iter = mem_resource_map_.find(memory_type); | |||||
if (iter == mem_resource_map_.end()) { | |||||
mem_resource = MemResource::BuildMemResourceFromType(memory_type); | |||||
if (mem_resource == nullptr) { | |||||
REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu", | |||||
memory_type, session_id_); | |||||
GELOGE(ge::INTERNAL_ERROR, "[Alloc][MemResource] failed, memory_type:%u, session_id:%lu", | |||||
memory_type, session_id_); | |||||
return ge::INTERNAL_ERROR; | |||||
} else { | |||||
mem_resource_map_[memory_type] = mem_resource; | |||||
} | |||||
} else { | |||||
mem_resource = iter->second; | |||||
} | |||||
if (mem_resource == nullptr) { | |||||
REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu", | |||||
memory_type, session_id_); | |||||
GELOGE(ge::INTERNAL_ERROR, "[Check][Param] MemResource is invalid, memory_type:%u, session_id:%lu", | |||||
memory_type, session_id_); | |||||
return FAILED; | |||||
} | |||||
mem_resource->UpdateVarMemSize(mem_size); | |||||
return SUCCESS; | |||||
} | |||||
ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, | ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, | ||||
rtMemType_t memory_type) { | rtMemType_t memory_type) { | ||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
@@ -638,16 +604,6 @@ ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastIn | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
ge::Status VarManager::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) { | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
if (var_resource_ == nullptr) { | |||||
GELOGW("VarManager has not been init."); | |||||
return ge::INTERNAL_ERROR; | |||||
} | |||||
return var_resource_->GetBroadCastInfo(graph_id, var_name, broad_cast_info); | |||||
} | |||||
ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) { | ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) { | ||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str()); | GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str()); | ||||
@@ -223,14 +223,10 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, | ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, | ||||
rtMemType_t &memory_type); | rtMemType_t &memory_type); | ||||
void GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map); | |||||
ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); | ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); | ||||
ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | ||||
ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | |||||
ge::Status GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc); | ge::Status GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc); | ||||
ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); | ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); | ||||
@@ -273,8 +269,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
int64_t GetVarMemSize(rtMemType_t memory_type); | int64_t GetVarMemSize(rtMemType_t memory_type); | ||||
Status UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size); | |||||
bool IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc); | bool IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc); | ||||
bool IsVarExist(const std::string &var_name); | bool IsVarExist(const std::string &var_name); | ||||
@@ -1,83 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/manager/model_manager/event_manager.h" | |||||
#define RETURN_IF_COND_NOT_MET(condition, ...) \ | |||||
do { \ | |||||
if (!(condition)) { \ | |||||
GELOGE(FAILED, __VA_ARGS__); \ | |||||
return; \ | |||||
} \ | |||||
} while (0); | |||||
namespace ge { | |||||
Status EventManager::Init(size_t event_num) { | |||||
if (this->inited_) { | |||||
return SUCCESS; | |||||
} | |||||
rtEvent_t event = nullptr; | |||||
current_idx_ = 0; | |||||
for (size_t i = 0; i < event_num; ++i) { | |||||
GE_CHK_RT_RET(rtEventCreate(&event)); | |||||
this->event_list_.push_back(event); | |||||
} | |||||
this->inited_ = true; | |||||
return SUCCESS; | |||||
} | |||||
void EventManager::Release() noexcept { | |||||
for (size_t i = 0; i < this->event_list_.size(); ++i) { | |||||
rtError_t rt_ret = rtEventDestroy(this->event_list_[i]); | |||||
RETURN_IF_COND_NOT_MET(rt_ret == RT_ERROR_NONE, "[Destroy][Event] failed, idx is %zu, ret is 0x%x.", i, rt_ret); | |||||
} | |||||
this->event_list_.clear(); | |||||
this->inited_ = false; | |||||
} | |||||
Status EventManager::EventRecord(size_t event_idx, rtStream_t stream) { | |||||
GE_CHK_BOOL_RET_STATUS_NOLOG(this->inited_, INTERNAL_ERROR); | |||||
GE_CHK_BOOL_RET_STATUS_NOLOG(event_idx < this->event_list_.size(), PARAM_INVALID); | |||||
GE_CHK_RT_RET(rtEventRecord(this->event_list_[event_idx], stream)); | |||||
current_idx_ = static_cast<uint32_t>(event_idx); | |||||
return SUCCESS; | |||||
} | |||||
Status EventManager::EventElapsedTime(size_t start_event_idx, size_t stop_event_idx, float &time) { | |||||
GE_CHK_BOOL_RET_STATUS_NOLOG(this->inited_, INTERNAL_ERROR); | |||||
GE_CHK_BOOL_RET_STATUS_NOLOG(start_event_idx < this->event_list_.size() && | |||||
stop_event_idx < this->event_list_.size() && start_event_idx <= stop_event_idx, | |||||
PARAM_INVALID); | |||||
GE_CHK_RT_RET(rtEventElapsedTime(&time, this->event_list_[start_event_idx], this->event_list_[stop_event_idx])); | |||||
return SUCCESS; | |||||
} | |||||
Status EventManager::GetEvent(uint32_t index, rtEvent_t &event) { | |||||
GE_CHK_BOOL_RET_STATUS_NOLOG(index < this->event_list_.size(), PARAM_INVALID); | |||||
event = this->event_list_[index]; | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge |
@@ -1,98 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_MANAGER_MODEL_MANAGER_EVENT_MANAGER_H_ | |||||
#define GE_GRAPH_MANAGER_MODEL_MANAGER_EVENT_MANAGER_H_ | |||||
#include <vector> | |||||
#include "framework/common/fmk_error_codes.h" | |||||
#include "framework/common/fmk_types.h" | |||||
#include "framework/common/util.h" | |||||
#include "runtime/event.h" | |||||
namespace ge { | |||||
class EventManager { | |||||
public: | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief constructor | |||||
/// | |||||
EventManager() : inited_(false), current_idx_(0) {} | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief destructor | |||||
/// | |||||
~EventManager() { this->Release(); } | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief init and create event list | |||||
/// @param [in] event_num event number created | |||||
/// @return exec result | |||||
/// | |||||
Status Init(size_t event_num); | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief event record | |||||
/// @param [in] event_idx event index | |||||
/// @param [in] stream related stream | |||||
/// @return exec result | |||||
/// | |||||
Status EventRecord(size_t event_idx, rtStream_t stream); | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief time between start and end in ms | |||||
/// @param [in] start_event_idx start event index | |||||
/// @param [in] stop_event_idx stop event index | |||||
/// @param [out] time | |||||
/// @return exec result | |||||
/// | |||||
Status EventElapsedTime(size_t start_event_idx, size_t stop_event_idx, float &time); | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief current event index | |||||
/// @return | |||||
/// | |||||
uint32_t CurrentIdx() const { return current_idx_; } | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief get event at specific loc | |||||
/// @param [in] index event index | |||||
/// @return | |||||
/// | |||||
Status GetEvent(uint32_t index, rtEvent_t &event); | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief release event list | |||||
/// @param [in] | |||||
/// @return | |||||
/// | |||||
void Release() noexcept; | |||||
private: | |||||
std::vector<rtEvent_t> event_list_; | |||||
bool inited_; | |||||
uint32_t current_idx_; | |||||
}; // EventManager | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_MANAGER_MODEL_MANAGER_EVENT_MANAGER_H_ |
@@ -24,7 +24,6 @@ | |||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
#include "runtime/context.h" | #include "runtime/context.h" | ||||
#include "graph/manager/graph_var_manager.h" | |||||
namespace ge { | namespace ge { | ||||
class TransVarDataUtils { | class TransVarDataUtils { | ||||
@@ -24,7 +24,6 @@ | |||||
#include "framework/common/util.h" | #include "framework/common/util.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "graph/manager/graph_var_manager.h" | |||||
#include "graph/passes/pass_utils.h" | #include "graph/passes/pass_utils.h" | ||||
#include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
@@ -28,7 +28,6 @@ | |||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/utils/anchor_utils.h" | #include "graph/utils/anchor_utils.h" | ||||
#include "graph/manager/graph_var_manager.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "framework/common/ge_types.h" | #include "framework/common/ge_types.h" | ||||
@@ -63,13 +62,7 @@ class GE_FUNC_VISIBILITY GELib { | |||||
bool InitFlag() const { return init_flag_; } | bool InitFlag() const { return init_flag_; } | ||||
// get TrainMode flag | // get TrainMode flag | ||||
bool isTrainMode() { return is_train_mode_; } | |||||
// get incre build flag | |||||
bool IsIncreBuild() const { return is_incre_build_; } | |||||
// get incre build cache path | |||||
const std::string &GetIncreBuildCachePath() const { return incre_build_cache_path_; } | |||||
bool IsTrainMode() { return is_train_mode_; } | |||||
void InitProfiling(Options &options); | void InitProfiling(Options &options); | ||||
void ShutDownProfiling(); | void ShutDownProfiling(); | ||||
@@ -100,8 +93,6 @@ class GE_FUNC_VISIBILITY GELib { | |||||
bool is_system_inited = false; | bool is_system_inited = false; | ||||
bool is_shutdown = false; | bool is_shutdown = false; | ||||
bool is_use_hcom = false; | bool is_use_hcom = false; | ||||
bool is_incre_build_ = false; | |||||
std::string incre_build_cache_path_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -140,7 +140,6 @@ set(COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" | "${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" | ||||
"${GE_CODE_DIR}/ge/graph/build/graph_builder.cc" | "${GE_CODE_DIR}/ge/graph/build/graph_builder.cc" | ||||
"${GE_CODE_DIR}/ge/graph/partition/graph_partition.cc" | "${GE_CODE_DIR}/ge/graph/partition/graph_partition.cc" | ||||
"${GE_CODE_DIR}/ge/common/helper/model_cache_helper.cc" | |||||
"${GE_CODE_DIR}/ge/ir_build/ge_ir_build.cc" | "${GE_CODE_DIR}/ge/ir_build/ge_ir_build.cc" | ||||
"${GE_CODE_DIR}/ge/ir_build/attr_options/utils.cc" | "${GE_CODE_DIR}/ge/ir_build/attr_options/utils.cc" | ||||
"${GE_CODE_DIR}/ge/ir_build/attr_options/keep_dtype_option.cc" | "${GE_CODE_DIR}/ge/ir_build/attr_options/keep_dtype_option.cc" | ||||
@@ -248,7 +247,6 @@ set(GRAPH_DAVINCI_MODEL_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/super_kernel/super_kernel.cc" | "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/super_kernel/super_kernel.cc" | ||||
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/super_kernel/super_kernel_factory.cc" | "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/super_kernel/super_kernel_factory.cc" | ||||
"${GE_CODE_DIR}/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc" | "${GE_CODE_DIR}/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc" | ||||
"${GE_CODE_DIR}/ge/graph/manager/model_manager/event_manager.cc" | |||||
) | ) | ||||
set(GRAPH_EXECUTE_COMMON_SRC_FILES | set(GRAPH_EXECUTE_COMMON_SRC_FILES | ||||
@@ -520,13 +518,9 @@ set(COMMON_TEST_FILES | |||||
set(DISTINCT_GRAPH_LOAD_TEST_FILES | set(DISTINCT_GRAPH_LOAD_TEST_FILES | ||||
"graph/load/data_dumper_unittest.cc" | "graph/load/data_dumper_unittest.cc" | ||||
#"graph/load/new_model_manager_data_inputer_unittest.cc" | |||||
#"graph/load/new_model_manager_davinci_model_unittest.cc" | |||||
"graph/load/model_manager_unittest.cc" | "graph/load/model_manager_unittest.cc" | ||||
"graph/load/new_model_manager_model_manager_aicpu_unittest.cc" | "graph/load/new_model_manager_model_manager_aicpu_unittest.cc" | ||||
"graph/load/end_graph_task_unittest.cc" | "graph/load/end_graph_task_unittest.cc" | ||||
"graph/load/new_model_manager_event_manager_unittest.cc" | |||||
#"graph/load/output_net_output_unittest.cc" | |||||
"graph/load/davinci_model_unittest.cc" | "graph/load/davinci_model_unittest.cc" | ||||
"graph/load/tbe_handle_store_unittest.cc" | "graph/load/tbe_handle_store_unittest.cc" | ||||
"graph/load/hccl_task_info_unittest.cc" | "graph/load/hccl_task_info_unittest.cc" | ||||
@@ -536,7 +530,6 @@ set(DISTINCT_GRAPH_LOAD_TEST_FILES | |||||
"graph/load/memcpy_addr_async_task_info_unittest.cc" | "graph/load/memcpy_addr_async_task_info_unittest.cc" | ||||
"graph/load/memcpy_async_task_info_unittest.cc" | "graph/load/memcpy_async_task_info_unittest.cc" | ||||
"graph/load/cpu_queue_schedule_unittest.cc" | "graph/load/cpu_queue_schedule_unittest.cc" | ||||
#"graph/graph_load_unittest.cc" | |||||
"graph/ge_executor_unittest.cc" | "graph/ge_executor_unittest.cc" | ||||
"graph/load/model_helper_unittest.cc" | "graph/load/model_helper_unittest.cc" | ||||
"graph/load/model_utils_unittest.cc" | "graph/load/model_utils_unittest.cc" | ||||
@@ -20,6 +20,7 @@ | |||||
#define private public | #define private public | ||||
#include "graph/execute/model_executor.h" | #include "graph/execute/model_executor.h" | ||||
#include "graph/manager/graph_manager.h" | #include "graph/manager/graph_manager.h" | ||||
#include "graph/manager/graph_var_manager.h" | |||||
#include "graph/load/model_manager/model_manager.h" | #include "graph/load/model_manager/model_manager.h" | ||||
#include "graph/load/model_manager/davinci_model.h" | #include "graph/load/model_manager/davinci_model.h" | ||||
@@ -1,93 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <memory> | |||||
#include <mutex> | |||||
#include <thread> | |||||
#include <vector> | |||||
#include "common/debug/log.h" | |||||
#include "common/helper/model_helper.h" | |||||
#include "common/op/ge_op_utils.h" | |||||
#include "common/types.h" | |||||
#include "graph/op_desc.h" | |||||
#include "graph/types.h" | |||||
#include "graph/utils/attr_utils.h" | |||||
#include "graph/utils/op_desc_utils.h" | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/load/graph_loader.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
#include "graph/load/model_manager/model_manager.h" | |||||
#include "graph/manager/graph_manager_utils.h" | |||||
#include "common/model/ge_model.h" | |||||
#undef private | |||||
#undef protected | |||||
using namespace testing; | |||||
namespace ge { | |||||
class UtestGraphGraphLoad : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
TEST_F(UtestGraphGraphLoad, load_graph_param_invalid1) { | |||||
std::shared_ptr<GraphModelListener> graph_run_listener = nullptr; | |||||
SubGraphInfo sub_graph1; | |||||
ge::SubGraphInfoPtr sub_graph_ptr1 = std::make_shared<SubGraphInfo>(sub_graph1); | |||||
ModelIdInfo model_Id_info; | |||||
model_Id_info.model_id = 1; | |||||
GeModelPtr ge_model_ptr = std::make_shared<GeModel>(); | |||||
sub_graph_ptr1->SetGeModelPtr(ge_model_ptr); | |||||
std::vector<bool> input_flag; | |||||
input_flag.push_back(false); | |||||
sub_graph_ptr1->SetInputFlag(input_flag); | |||||
ge::GraphLoader graph_load; | |||||
EXPECT_EQ(GE_GRAPH_PARAM_NULLPTR, graph_load.LoadGraph(sub_graph_ptr1->ge_model_ptr_, graph_run_listener, model_Id_info)); | |||||
sub_graph_ptr1->SetModelIdInfo(model_Id_info); | |||||
} | |||||
TEST_F(UtestGraphGraphLoad, load_graph_param_invalid2) { | |||||
std::mutex sync_run_mutex; | |||||
std::condition_variable condition; | |||||
std::shared_ptr<GraphModelListener> listener = std::make_shared<GraphModelListener>(sync_run_mutex, condition); | |||||
SubGraphInfo sub_graph1; | |||||
ge::SubGraphInfoPtr sub_graph_ptr1 = std::make_shared<SubGraphInfo>(sub_graph1); | |||||
ModelIdInfo model_Id_info; | |||||
model_Id_info.model_id = 1; | |||||
GeModelPtr ge_model_ptr = std::make_shared<GeModel>(); | |||||
sub_graph_ptr1->SetGeModelPtr(ge_model_ptr); | |||||
std::vector<bool> input_flag; | |||||
input_flag.push_back(false); | |||||
sub_graph_ptr1->SetInputFlag(input_flag); | |||||
ge::GraphLoader graph_load; | |||||
EXPECT_EQ(GE_GRAPH_PARAM_NULLPTR, graph_load.LoadGraph(sub_graph_ptr1->ge_model_ptr_, listener, model_Id_info)); | |||||
sub_graph_ptr1->SetModelIdInfo(model_Id_info); | |||||
} | |||||
} // namespace ge |
@@ -1,64 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include "graph/load/model_manager/data_inputer.h" | |||||
#include "common/debug/log.h" | |||||
#include "common/debug/memory_dumper.h" | |||||
#include "common/types.h" | |||||
#include "new_op_test_utils.h" | |||||
using namespace std; | |||||
using namespace testing; | |||||
namespace ge { | |||||
class UtestModelManagerDataInputer : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
/// InputDataWrapper | |||||
/// constructor | |||||
/// GetInput | |||||
TEST_F(UtestModelManagerDataInputer, inputdatawrapper_construct) { | |||||
InputDataWrapper *input_data_wrapper = new InputDataWrapper(); | |||||
input_data_wrapper->GetInput(); | |||||
delete input_data_wrapper; | |||||
} | |||||
/// InputDataWrapper | |||||
/// Init func with correct input | |||||
TEST_F(UtestModelManagerDataInputer, success_inputdatawrapper_init) { | |||||
InputDataWrapper *input_data_wrapper = new InputDataWrapper(); | |||||
ge::InputData input_data; | |||||
ge::OutputData output_data; | |||||
Status ret = input_data_wrapper->Init(input_data, output_data); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
delete input_data_wrapper; | |||||
input_data_wrapper = NULL; | |||||
} | |||||
} // namespace ge |
@@ -1,117 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include "common/debug/log.h" | |||||
#include "common/debug/memory_dumper.h" | |||||
#include "common/types.h" | |||||
#define private public | |||||
#include "graph/manager/model_manager/event_manager.h" | |||||
#undef private | |||||
using namespace ge; | |||||
using namespace std; | |||||
using namespace testing; | |||||
class UtestModelManagerEventManager : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
// test repeat initialize | |||||
TEST_F(UtestModelManagerEventManager, repeat_initialization) { | |||||
ge::EventManager event_manager; | |||||
size_t event_num = 1; | |||||
event_manager.Init(event_num); | |||||
Status ret = event_manager.Init(event_num); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
TEST_F(UtestModelManagerEventManager, call_event_record_normal) { | |||||
ge::EventManager event_manager; | |||||
size_t event_num = 1; | |||||
Status ret = event_manager.Init(event_num); | |||||
EXPECT_EQ(SUCCESS, ret); | |||||
EXPECT_NE(event_manager.event_list_.size(), 0); | |||||
ret = event_manager.EventRecord(0, NULL); | |||||
EXPECT_EQ(SUCCESS, ret); | |||||
} | |||||
// test load EventRecore when uninited | |||||
TEST_F(UtestModelManagerEventManager, call_event_record_while_uninited) { | |||||
ge::EventManager event_manager; | |||||
Status ret = event_manager.EventRecord(1, NULL); | |||||
EXPECT_EQ(ge::INTERNAL_ERROR, ret); | |||||
} | |||||
// test with invalid param when load EventRecord | |||||
TEST_F(UtestModelManagerEventManager, call_event_record_with_invalid_param) { | |||||
ge::EventManager event_manager; | |||||
Status ret = event_manager.Init(1); | |||||
EXPECT_EQ(SUCCESS, ret); | |||||
ret = event_manager.EventRecord(1, NULL); | |||||
EXPECT_EQ(ge::PARAM_INVALID, ret); | |||||
} | |||||
// test load EventElapsedTime when uninited | |||||
TEST_F(UtestModelManagerEventManager, call_event_elapsed_time_while_uninited) { | |||||
ge::EventManager event_manager; | |||||
float time = .0f; | |||||
Status ret = event_manager.EventElapsedTime(1, 2, time); | |||||
EXPECT_EQ(ge::INTERNAL_ERROR, ret); | |||||
} | |||||
// test with invalid param when load EventElapsedTime | |||||
TEST_F(UtestModelManagerEventManager, call_event_elapsed_time_with_invalid_param) { | |||||
ge::EventManager *event_manager = new ge::EventManager; | |||||
size_t event_num = 2; | |||||
Status ret = event_manager->Init(event_num); | |||||
EXPECT_EQ(SUCCESS, ret); | |||||
float time = .0f; | |||||
// normal load | |||||
ret = event_manager->EventElapsedTime(0, 1, time); | |||||
EXPECT_EQ(SUCCESS, ret); | |||||
// startevent_idx overstep boundary | |||||
ret = event_manager->EventElapsedTime(2, 1, time); | |||||
EXPECT_EQ(ge::PARAM_INVALID, ret); | |||||
// stopevent_idx overstep boundary | |||||
ret = event_manager->EventElapsedTime(1, 2, time); | |||||
EXPECT_EQ(ge::PARAM_INVALID, ret); | |||||
// startevent_idx > stopevent_idx | |||||
ret = event_manager->EventElapsedTime(1, 0, time); | |||||
EXPECT_EQ(ge::PARAM_INVALID, ret); | |||||
delete event_manager; | |||||
} | |||||
TEST_F(UtestModelManagerEventManager, call_get_event) { | |||||
ge::EventManager event_manager; | |||||
size_t event_num = 1; | |||||
event_manager.Init(event_num); | |||||
rtEvent_t event = nullptr; | |||||
Status ret = event_manager.GetEvent(2, event); | |||||
EXPECT_EQ(ge::PARAM_INVALID, ret); | |||||
ret = event_manager.GetEvent(0, event); | |||||
EXPECT_EQ(SUCCESS, ret); | |||||
} |
@@ -1,115 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include "common/debug/log.h" | |||||
#include "common/debug/memory_dumper.h" | |||||
#include "common/types.h" | |||||
#include "new_op_test_utils.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "graph/utils/attr_utils.h" | |||||
#include "graph/detail/model_serialize_imp.h" | |||||
#include "proto/ge_ir.pb.h" | |||||
#define private public | |||||
#define protected public | |||||
#include "graph/compute_graph.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph/model_serialize.h" | |||||
#include "graph/load/model_manager/davinci_model.h" | |||||
#include "common/properties_manager.h" | |||||
#include "common/op/ge_op_utils.h" | |||||
#include <cce/taskdown_api.h> | |||||
#include "runtime/dev.h" | |||||
#include "runtime/kernel.h" | |||||
#include "cce/fwk_adpt_struct.h" | |||||
#undef private | |||||
#undef protected | |||||
using namespace std; | |||||
using namespace testing; | |||||
namespace ge { | |||||
class UtestModelManagerTaskBuilder : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
/// data weight | |||||
/// | | | | | |||||
/// |-conv-| | | | |||||
/// | | | | |||||
/// conv2d | | |||||
/// | | | |||||
/// |-resApply | |||||
void BuildGraph(ComputeGraphPtr graph) { | |||||
OpDescPtr data = std::make_shared<OpDesc>("DATA1", "data"); | |||||
OpDescPtr weight = std::make_shared<OpDesc>("WEIGHT", "weight"); | |||||
OpDescPtr conv_op = std::make_shared<OpDesc>("conv", "conv"); | |||||
OpDescPtr conv_2D = std::make_shared<OpDesc>("conv_2D", "conv2d"); | |||||
OpDescPtr res_apply_op = std::make_shared<OpDesc>("res_apply_op", "resapply"); | |||||
// add descriptor | |||||
vector<int64_t> dim(4, 4); | |||||
GeShape shape(dim); | |||||
GeTensorDesc out_desc(shape); | |||||
int32_t blockSize = 4096; | |||||
ge::TensorUtils::SetDataOffset(out_desc, blockSize * 1); | |||||
data->AddOutputDesc(out_desc); | |||||
ge::TensorUtils::SetDataOffset(out_desc, blockSize * 2); | |||||
weight->AddOutputDesc(out_desc); | |||||
ge::TensorUtils::SetDataOffset(out_desc, blockSize * 1); | |||||
conv_op->AddInputDesc(out_desc); | |||||
ge::TensorUtils::SetDataOffset(out_desc, blockSize * 2); | |||||
conv_op->AddInputDesc(out_desc); | |||||
ge::TensorUtils::SetDataOffset(out_desc, blockSize * 3); | |||||
conv_op->AddOutputDesc(out_desc); | |||||
ge::TensorUtils::SetDataOffset(out_desc, blockSize * 3); | |||||
conv_2D->AddInputDesc(out_desc); | |||||
ge::TensorUtils::SetDataOffset(out_desc, blockSize * 2); | |||||
conv_2D->AddInputDesc(out_desc); | |||||
ge::TensorUtils::SetDataOffset(out_desc, blockSize * 4); | |||||
conv_2D->AddOutputDesc(out_desc); | |||||
ge::TensorUtils::SetDataOffset(out_desc, blockSize * 4); | |||||
res_apply_op->AddInputDesc(out_desc); | |||||
ge::TensorUtils::SetDataOffset(out_desc, blockSize * 1); | |||||
res_apply_op->AddInputDesc(out_desc); | |||||
ge::TensorUtils::SetDataOffset(out_desc, blockSize * 5); | |||||
res_apply_op->AddOutputDesc(out_desc); | |||||
NodePtr data_node = graph->AddNode(data); | |||||
NodePtr weigth_node = graph->AddNode(weight); | |||||
NodePtr conv_node = graph->AddNode(conv_op); | |||||
NodePtr conv_2D_node = graph->AddNode(conv_2D); | |||||
NodePtr res_node = graph->AddNode(res_apply_op); | |||||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(weigth_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv_2D_node->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(weigth_node->GetOutDataAnchor(0), conv_2D_node->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(conv_2D_node->GetOutDataAnchor(0), res_node->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(weigth_node->GetOutDataAnchor(0), res_node->GetInDataAnchor(1)); | |||||
return; | |||||
} | |||||
}; | |||||
} // namespace ge |
@@ -1,300 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <memory> | |||||
#include "securec.h" | |||||
#define protected public | |||||
#define private public | |||||
#include "common/debug/memory_dumper.h" | |||||
#include "common/op/ge_op_utils.h" | |||||
#include "graph/load/model_manager/davinci_model.h" | |||||
#include "graph/load/model_manager/model_utils.h" | |||||
#include "graph/manager/graph_var_manager.h" | |||||
#include "new_op_test_utils.h" | |||||
#include "proto/om.pb.h" | |||||
using namespace std; | |||||
namespace ge { | |||||
class UtestNetOutput : public testing::Test { | |||||
protected: | |||||
void TearDown() {} | |||||
shared_ptr<OmeTestOpDescBuilder> GenOpdef(OpDescPtr &op_desc, int flag) { | |||||
shared_ptr<OmeTestOpDescBuilder> builder = make_shared<OmeTestOpDescBuilder>(op_desc); | |||||
builder->SetStreamId(0); | |||||
builder->AddInput(1); | |||||
builder->SetType("NetOutput"); | |||||
if (flag == 1) { | |||||
auto input_desc_1 = builder->AddInputDesc({1, 1, 10, 10}, FORMAT_NCHW, DT_FLOAT16); | |||||
} | |||||
auto input_desc_1 = builder->AddInputDesc({1, 1, 10, 10}, FORMAT_NCHW, DT_FLOAT16); | |||||
if (flag == 2) { | |||||
auto input_desc_2 = builder->AddInputDesc({1, 1, 10, 10}, FORMAT_NCHW, DT_FLOAT16); | |||||
} | |||||
if (flag == 3) { | |||||
builder->AddInput(10); | |||||
} | |||||
return builder; | |||||
} | |||||
shared_ptr<OmeTestOpDescBuilder> GenOpdef2(OpDescPtr &op_desc) { | |||||
shared_ptr<OmeTestOpDescBuilder> builder = make_shared<OmeTestOpDescBuilder>(op_desc); | |||||
builder->SetStreamId(0); | |||||
builder->SetType("NetOutput"); | |||||
builder->AddInput(10); | |||||
auto input_desc_1 = builder->AddInputDesc({64, 32, 5, 5}, FORMAT_FRACTAL_Z, DT_FLOAT); | |||||
builder->AddInput(1000000); | |||||
auto input_desc_2 = builder->AddInputDesc({1, 10, 10, 1}, FORMAT_NHWC, DT_FLOAT); | |||||
builder->AddOutput(2000000); | |||||
auto output_desc_1 = builder->AddOutputDesc({64, 32, 5, 5}, FORMAT_NCHW, DT_FLOAT); | |||||
builder->AddOutput(2100000); | |||||
output_desc_1 = builder->AddOutputDesc({1, 10, 10, 1}, FORMAT_NHWC, DT_FLOAT); | |||||
return builder; | |||||
} | |||||
public: | |||||
shared_ptr<DavinciModel> dav_model_; | |||||
}; | |||||
TEST_F(UtestNetOutput, test_get_input_size) { | |||||
shared_ptr<OpDesc> custom_op_desc = make_shared<OpDesc>(); | |||||
OmeTestOpDescBuilder builder(custom_op_desc); | |||||
builder.SetName("netoutput"); | |||||
builder.SetStreamId(0); | |||||
builder.SetType("NetOutput"); | |||||
auto input_desc_1 = builder.AddInputDesc({1, 1, 1, 1}, FORMAT_FRACTAL_Z, DT_FLOAT); | |||||
builder.AddInput(1); | |||||
auto output_desc = builder.AddOutputDesc({1, 1, 1, 1}, FORMAT_NCHW, DT_FLOAT); | |||||
builder.AddOutput(1); | |||||
builder.Finish(); | |||||
vector<int64_t> v_output_size = ModelUtils::GetInputSize(custom_op_desc); | |||||
EXPECT_EQ(v_output_size.size(), 1); | |||||
} | |||||
// test ModelUtils::IsOutput | |||||
TEST_F(UtestNetOutput, success_is_output) { | |||||
ModelUtils *model_utils = new ModelUtils(); | |||||
std::shared_ptr<OpDesc> op_desc = std::make_shared<OpDesc>(); | |||||
OmeTestOpDescBuilder builder(op_desc); | |||||
builder.SetType("NetOutput"); | |||||
vector<GeTensorDescPtr> outputs_desc; | |||||
std::shared_ptr<GeTensorDesc> desc = std::make_shared<GeTensorDesc>(); | |||||
outputs_desc.push_back(desc); | |||||
op_desc->outputs_desc_ = outputs_desc; | |||||
bool ret = model_utils->IsOutput(op_desc); | |||||
EXPECT_EQ(false, ret); | |||||
delete model_utils; | |||||
} | |||||
// test ModelUtils::IsOutput | |||||
TEST_F(UtestNetOutput, true_is_output) { | |||||
ModelUtils *model_utils = new ModelUtils(); | |||||
std::shared_ptr<OpDesc> op_desc = std::make_shared<OpDesc>(); | |||||
OmeTestOpDescBuilder builder(op_desc); | |||||
builder.SetType("NetOutput"); | |||||
vector<GeTensorDescPtr> outputs_desc; | |||||
std::shared_ptr<GeTensorDesc> desc = std::make_shared<GeTensorDesc>(); | |||||
outputs_desc.push_back(desc); | |||||
op_desc->outputs_desc_ = outputs_desc; | |||||
ge::TensorUtils::SetOutputTensor(*(outputs_desc[0].get()), true); | |||||
bool ret = model_utils->IsOutput(op_desc); | |||||
EXPECT_EQ(true, ret); | |||||
delete model_utils; | |||||
} | |||||
// test ModelUtils::IsInputTensorNeedTrans | |||||
TEST_F(UtestNetOutput, success_is_output_tensor_need_trans) { | |||||
ModelUtils *model_utils = new ModelUtils(); | |||||
std::shared_ptr<OpDesc> op_desc = std::make_shared<OpDesc>(); | |||||
OmeTestOpDescBuilder builder(op_desc); | |||||
builder.SetType("NetOutput"); | |||||
size_t tensor_index = 1; | |||||
vector<GeTensorDescPtr> outputs_desc; | |||||
std::shared_ptr<GeTensorDesc> desc = std::make_shared<GeTensorDesc>(); | |||||
outputs_desc.push_back(desc); | |||||
op_desc->outputs_desc_ = outputs_desc; | |||||
op_desc->inputs_desc_ = outputs_desc; | |||||
bool ret = model_utils->IsInputTensorNeedTrans(op_desc, tensor_index); | |||||
EXPECT_EQ(false, ret); | |||||
delete model_utils; | |||||
} | |||||
// test ModelUtils::GetOutputSize | |||||
TEST_F(UtestNetOutput, success_get_output_size) { | |||||
vector<int64_t> v_output_size; | |||||
ModelUtils *model_utils = new ModelUtils(); | |||||
std::shared_ptr<OpDesc> op_desc = std::make_shared<OpDesc>(); | |||||
vector<GeTensorDescPtr> outputs_desc; | |||||
std::shared_ptr<GeTensorDesc> desc = std::make_shared<GeTensorDesc>(); | |||||
outputs_desc.push_back(desc); | |||||
op_desc->outputs_desc_ = outputs_desc; | |||||
EXPECT_EQ(v_output_size, model_utils->GetOutputSize(op_desc)); | |||||
vector<int64_t> output = {1}; | |||||
op_desc->SetOutputOffset(output); | |||||
uint32_t tensor_size = 0; | |||||
v_output_size.push_back(tensor_size); | |||||
EXPECT_EQ(v_output_size, model_utils->GetOutputSize(op_desc)); | |||||
delete model_utils; | |||||
} | |||||
// test ModelUtils::GetWorkspaceSize | |||||
TEST_F(UtestNetOutput, success_get_workspace_size) { | |||||
vector<int64_t> v_workspace_size; | |||||
ModelUtils *model_utils = new ModelUtils(); | |||||
std::shared_ptr<OpDesc> op_desc = std::make_shared<OpDesc>(); | |||||
vector<int64_t> workspace = {1}; | |||||
op_desc->SetWorkspace(workspace); | |||||
EXPECT_EQ(v_workspace_size, model_utils->GetWorkspaceSize(op_desc)); | |||||
op_desc->SetWorkspaceBytes(workspace); | |||||
v_workspace_size.push_back(1); | |||||
EXPECT_EQ(v_workspace_size, model_utils->GetWorkspaceSize(op_desc)); | |||||
delete model_utils; | |||||
} | |||||
// test ModelUtils::GetWeightSize | |||||
TEST_F(UtestNetOutput, success_get_weight_size) { | |||||
vector<int64_t> v_weight_size; | |||||
ModelUtils *model_utils = new ModelUtils(); | |||||
std::shared_ptr<OpDesc> op_desc = std::make_shared<OpDesc>(); | |||||
op_desc->SetType("Const"); | |||||
EXPECT_EQ(v_weight_size, model_utils->GetWeightSize(op_desc)); | |||||
op_desc->SetType("NetOutput"); | |||||
vector<GeTensorDescPtr> inputs_desc; | |||||
std::shared_ptr<GeTensorDesc> desc = std::make_shared<GeTensorDesc>(); | |||||
inputs_desc.push_back(desc); | |||||
op_desc->inputs_desc_ = inputs_desc; | |||||
vector<bool> is_input_const = {true}; | |||||
op_desc->SetIsInputConst(is_input_const); | |||||
v_weight_size.push_back(0); | |||||
EXPECT_EQ(v_weight_size, model_utils->GetWeightSize(op_desc)); | |||||
delete model_utils; | |||||
} | |||||
// test ModelUtils::GetWeights | |||||
TEST_F(UtestNetOutput, success_get_weights) { | |||||
vector<ConstGeTensorPtr> v_weights; | |||||
ModelUtils *model_utils = new ModelUtils(); | |||||
std::shared_ptr<OpDesc> op_desc = std::make_shared<OpDesc>(); | |||||
op_desc->SetType("Const"); | |||||
EXPECT_EQ(v_weights, model_utils->GetWeights(op_desc)); | |||||
op_desc->SetType("NetOutput"); | |||||
vector<GeTensorDescPtr> inputs_desc; | |||||
std::shared_ptr<GeTensorDesc> desc = std::make_shared<GeTensorDesc>(); | |||||
inputs_desc.push_back(desc); | |||||
op_desc->inputs_desc_ = inputs_desc; | |||||
vector<bool> is_input_const = {true}; | |||||
op_desc->SetIsInputConst(is_input_const); | |||||
GeTensorDesc tensor_desc; | |||||
EXPECT_EQ(v_weights, model_utils->GetWeights(op_desc)); | |||||
delete model_utils; | |||||
} | |||||
// test ModelUtils::GetInputDescs | |||||
TEST_F(UtestNetOutput, success_get_input_descs) { | |||||
vector<::opTensor_t> v_input_descs; | |||||
vector<::tagCcAICPUTensor> ret; | |||||
ModelUtils *model_utils = new ModelUtils(); | |||||
std::shared_ptr<OpDesc> op_desc = std::make_shared<OpDesc>(); | |||||
ret = model_utils->GetInputDescs(op_desc); | |||||
EXPECT_EQ(v_input_descs.size(), ret.size()); | |||||
vector<GeTensorDescPtr> inputs_desc; | |||||
std::shared_ptr<GeTensorDesc> desc = std::make_shared<GeTensorDesc>(); | |||||
inputs_desc.push_back(desc); | |||||
op_desc->inputs_desc_ = inputs_desc; | |||||
vector<bool> is_input_const = {false}; | |||||
op_desc->SetIsInputConst(is_input_const); | |||||
opTensor_t tmp; | |||||
tmp.format = OP_TENSOR_FORMAT_NC1HWC0; | |||||
tmp.dim_cnt = 0; | |||||
tmp.data_type = OP_DATA_FLOAT; | |||||
v_input_descs.push_back(tmp); | |||||
ret = model_utils->GetInputDescs(op_desc); | |||||
EXPECT_EQ(v_input_descs.size(), ret.size()); | |||||
delete model_utils; | |||||
} | |||||
// test ModelUtils::GetOutputDescs | |||||
TEST_F(UtestNetOutput, success_get_output_descs) { | |||||
vector<::opTensor_t> v_output_descs; | |||||
vector<::tagCcAICPUTensor> ret; | |||||
ModelUtils *model_utils = new ModelUtils(); | |||||
std::shared_ptr<OpDesc> op_desc = std::make_shared<OpDesc>(); | |||||
ret = model_utils->GetOutputDescs(op_desc); | |||||
EXPECT_EQ(v_output_descs.size(), ret.size()); | |||||
vector<GeTensorDescPtr> outputs_desc; | |||||
std::shared_ptr<GeTensorDesc> desc = std::make_shared<GeTensorDesc>(); | |||||
outputs_desc.push_back(desc); | |||||
op_desc->outputs_desc_ = outputs_desc; | |||||
opTensor_t tmp; | |||||
tmp.format = OP_TENSOR_FORMAT_NC1HWC0; | |||||
tmp.dim_cnt = 0; | |||||
tmp.data_type = OP_DATA_FLOAT; | |||||
v_output_descs.push_back(tmp); | |||||
ret = model_utils->GetOutputDescs(op_desc); | |||||
EXPECT_EQ(v_output_descs.size(), ret.size()); | |||||
delete model_utils; | |||||
} | |||||
// test Output::GetOutputData | |||||
TEST_F(UtestNetOutput, success_get_output_data) { | |||||
Output *output = new Output(nullptr, nullptr); | |||||
output->v_input_data_addr_.push_back((void *)1); | |||||
output->v_input_size_.push_back(1); | |||||
output->input_num_ = 1; | |||||
vector<void *> v_data_addr; | |||||
vector<int64_t> v_data_size; | |||||
output->GetOutputData(v_data_addr, v_data_size); | |||||
EXPECT_EQ(output->v_input_data_addr_, v_data_addr); | |||||
EXPECT_EQ(output->v_input_size_, v_data_size); | |||||
delete output; | |||||
} | |||||
} // namespace ge |
@@ -30,9 +30,6 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/manager/graph_manager.h" | #include "graph/manager/graph_manager.h" | ||||
#define const | |||||
#include "common/helper/model_cache_helper.h" | |||||
#undef const | |||||
#include "init/gelib.h" | #include "init/gelib.h" | ||||
#include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||