From bdee8d1e058bd29ec778b90cc8fd7a3da8675e0d Mon Sep 17 00:00:00 2001 From: wjm Date: Sat, 12 Jun 2021 05:43:45 +0800 Subject: [PATCH 01/33] fix --- .../format_transfers/format_transfer_fracz_hwcn.cc | 16 +++++----- ge/common/helper/model_cache_helper.cc | 7 +++++ ge/hybrid/node_executor/hccl/hccl_node_executor.cc | 34 ++++++++++++---------- 3 files changed, 33 insertions(+), 24 deletions(-) mode change 100755 => 100644 ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc b/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc old mode 100755 new mode 100644 index abe6263b..ed3a062c --- a/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc +++ b/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc @@ -17,6 +17,7 @@ #include "common/formats/format_transfers/format_transfer_fracz_hwcn.h" #include + #include #include "common/formats/utils/formats_definitions.h" @@ -35,8 +36,8 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { auto dst_shape = args.dst_shape; if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_HWCN) { std::string error = "Dose not support trans format from " + - FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + - FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); + FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + + FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); return ACL_ERROR_GE_FORMAT_INVALID; } @@ -52,15 +53,13 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { if (!CheckShapeValid(src_shape, kFracZDimsNum)) { GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, src shape %s", ShapeToString(src_shape).c_str()); - REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", - ShapeToString(src_shape).c_str()); + REPORT_CALL_ERROR("E19999", "Src shape %s check invalid", ShapeToString(src_shape).c_str()); return ACL_ERROR_GE_SHAPE_INVALID; } if (!CheckShapeValid(dst_shape, kHwcnDimsNum)) { GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Value is invalid, dst shape %s", ShapeToString(dst_shape).c_str()); - REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", - ShapeToString(dst_shape).c_str()); + REPORT_CALL_ERROR("E19999", "Dst shape %s check invalid", ShapeToString(dst_shape).c_str()); return ACL_ERROR_GE_SHAPE_INVALID; } int64_t c0 = GetCubeSizeByDataType(args.src_data_type); @@ -71,9 +70,8 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { int64_t n0 = Ceil(dst_shape.at(kHwcnN), static_cast(kNiSize)); if (src_shape.at(kFracZHWC1) != dst_shape.at(kHwcnH) * dst_shape.at(kHwcnW) * c1 || src_shape.at(kFracZC0) != c0 || src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { - std::string error = "Failed to check relationship between src shape" + - FmtToStr(ShapeToString(src_shape)) + " and dst shape" + - FmtToStr(ShapeToString(dst_shape)); + std::string error = "Failed to check relationship between src shape" + FmtToStr(ShapeToString(src_shape)) + + " and dst shape" + FmtToStr(ShapeToString(dst_shape)); GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); return ACL_ERROR_GE_SHAPE_INVALID; } diff --git a/ge/common/helper/model_cache_helper.cc b/ge/common/helper/model_cache_helper.cc index 9cd88ef1..0e6c6329 100755 --- a/ge/common/helper/model_cache_helper.cc +++ b/ge/common/helper/model_cache_helper.cc @@ -1679,6 +1679,13 @@ Status ModelCacheHelper::LoadOmModelFromCache(GeModelPtr &ge_model) const { GELOGW("LoadOmModelFromCache: Load model from file failed. ret = %u", ret); return ret; } + std::function callback = [&]() { + if (model_data.model_data != nullptr) { + delete[] reinterpret_cast(model_data.model_data); + model_data.model_data = nullptr; + } + }; + GE_MAKE_GUARD(release, callback); ModelHelper model_helper; ret = model_helper.LoadModel(model_data); diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc index 31f2c7a1..6be9849c 100644 --- a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc +++ b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc @@ -15,15 +15,16 @@ */ #include "hybrid/node_executor/hccl/hccl_node_executor.h" + #include "common/ge/plugin_manager.h" #include "common/math/math_util.h" #include "external/graph/attr_value.h" +#include "external/graph/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/manager/util/hcom_util.h" #include "graph/utils/type_utils.h" -#include "external/graph/types.h" -#include "hybrid/executor/hybrid_execution_context.h" #include "hccl/hcom.h" +#include "hybrid/executor/hybrid_execution_context.h" #include "runtime/event.h" namespace ge { @@ -267,14 +268,16 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector do } Status BuildAllToAllVparams(TaskContext &context, HcomAllToAllVParams ¶ms) { - void **input_addrs[kAllToAllVInputNums] = {¶ms.sendbuf, ¶ms.sendcounts, ¶ms.sdispls, - ¶ms.recvcounts, ¶ms.rdispls}; + void **input_addrs[kAllToAllVInputNums] = {¶ms.sendbuf, ¶ms.sendcounts, ¶ms.sdispls, ¶ms.recvcounts, + ¶ms.rdispls}; for (size_t i = 0; i < kAllToAllVInputNums; ++i) { auto addr = context.MutableInput(i); GE_CHECK_NOTNULL(addr); @@ -383,13 +386,14 @@ Status BuildAllToAllVparams(TaskContext &context, HcomAllToAllVParams ¶ms) { } params.sendtype = iter->second; params.recvtype = iter->second; + params.group = nullptr; return SUCCESS; } Status BuildGatherAllToAllParams(TaskContext &context, HcomGatherAllToAllVParams ¶ms) { - void **input_addrs[kGatherAllToAllVInputNums] = {¶ms.addrInfo, ¶ms.addrInfoCountPerRank, - ¶ms.recvcounts, ¶ms.rdispls}; + void **input_addrs[kGatherAllToAllVInputNums] = {¶ms.addrInfo, ¶ms.addrInfoCountPerRank, ¶ms.recvcounts, + ¶ms.rdispls}; for (size_t i = 0; i < kGatherAllToAllVInputNums; ++i) { auto addr = context.MutableInput(i); GE_CHECK_NOTNULL(addr); @@ -418,8 +422,9 @@ Status BuildGatherAllToAllParams(TaskContext &context, HcomGatherAllToAllVParams params.recvtype = iter->second; int64_t addr_len = 0; - (void) ge::AttrUtils::GetInt(op_desc, "addr_length", addr_len); + (void)ge::AttrUtils::GetInt(op_desc, "addr_length", addr_len); params.addrLength = static_cast(addr_len); + params.group = nullptr; return SUCCESS; } @@ -428,7 +433,7 @@ Status AllToAllNodeTask::ExecuteAsync(TaskContext &context, std::functionGetNodeName()); p_ctx->SetStatus(FAILED); @@ -460,7 +465,6 @@ Status AllToAllNodeTask::ExecuteAsync(TaskContext &context, std::function Date: Tue, 13 Jul 2021 15:32:29 +0800 Subject: [PATCH 02/33] fix static check warning --- ge/graph/common/omg_util.h | 9 ++++----- tests/ut/ge/graph/optimize/graph_optimize_unittest.cc | 11 ++++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ge/graph/common/omg_util.h b/ge/graph/common/omg_util.h index d55cc7c8..83057dfb 100644 --- a/ge/graph/common/omg_util.h +++ b/ge/graph/common/omg_util.h @@ -27,11 +27,10 @@ #include "graph/node.h" namespace ge { -namespace { -const int64_t kBufferPoolMemAlignSize = 512; -const uint32_t kBufferPoolNodeOutIndex = 0; -const uint32_t kEventReuseThreshold = 65500; -} // namespace +static constexpr int64_t kBufferPoolMemAlignSize = 512; +static constexpr uint32_t kBufferPoolNodeOutIndex = 0; +static constexpr uint32_t kEventReuseThreshold = 65500; + /// /// @brief get the Original Type of FrameworkOp /// @param [in] node diff --git a/tests/ut/ge/graph/optimize/graph_optimize_unittest.cc b/tests/ut/ge/graph/optimize/graph_optimize_unittest.cc index 5468ec97..7f26aa8c 100644 --- a/tests/ut/ge/graph/optimize/graph_optimize_unittest.cc +++ b/tests/ut/ge/graph/optimize/graph_optimize_unittest.cc @@ -32,14 +32,14 @@ using namespace ge; namespace { const char *const kVectorCore = "VectorCore"; const char *const kAicoreEngine = "AIcoreEngine"; -string CreateEngineConfigJson() { +void CreateEngineConfigJson(string &dir_path, string &file_path) { GELOGI("Begin to create engine config json file."); string base_path = PluginManager::GetPath(); GELOGI("Base path is %s.", base_path.c_str()); - string dir_path = base_path.substr(0, base_path.rfind('/') + 1) + "plugin/nnengine/ge_config"; + dir_path = base_path.substr(0, base_path.rfind('/') + 1) + "plugin/nnengine/ge_config"; string cmd = "mkdir -p " + dir_path; system(cmd.c_str()); - string file_path = dir_path + "/engine_conf.json"; + file_path = dir_path + "/engine_conf.json"; GELOGI("Begin to write into the config file: %s.", file_path.c_str()); ofstream ofs(file_path, ios::out); EXPECT_EQ(!ofs, false); @@ -56,7 +56,6 @@ string CreateEngineConfigJson() { "}"; ofs.close(); GELOGI("Json config file %s has been written.", file_path.c_str()); - return file_path; } void DeleteFile(const string &file_name) { @@ -69,14 +68,16 @@ void DeleteFile(const string &file_name) { class UtestGraphOptimizeTest : public testing::Test { protected: void SetUp() { - config_file_ = CreateEngineConfigJson(); + CreateEngineConfigJson(config_dir_, config_file_); } void TearDown() { DeleteFile(config_file_); + DeleteFile(config_dir_); } private: + string config_dir_; string config_file_; }; From 4f049ae644495582cb5707cb16c572dc9dc74f5a Mon Sep 17 00:00:00 2001 From: zhou_chao1993 Date: Tue, 13 Jul 2021 17:13:38 +0800 Subject: [PATCH 03/33] modify ge common so --- .clang-format | 2 +- ge/CMakeLists.txt | 13 -- ge/client/ge_api.cc | 2 +- ge/common/CMakeLists.txt | 15 ++- ge/common/auth/file_saver.cc | 22 ++-- ge/{graph => }/common/bcast.cc | 4 +- ge/{graph => }/common/bcast.h | 6 +- ge/common/context/ctx.cc | 2 +- ge/common/cust_aicpu_kernel_store.h | 2 +- ge/common/debug/memory_dumper.cc | 11 +- ge/common/dump/dump_manager.cc | 19 ++- ge/common/dump/dump_properties.cc | 74 +++++------ ge/common/fmk_error_codes.cc | 7 +- .../format_transfers/format_transfer_transpose.h | 1 - ge/common/formats/formats.cc | 16 +-- ge/common/formats/utils/formats_trans_utils.cc | 5 +- ge/common/fp16_t.cc | 74 ++++++++--- ge/common/ge/datatype_util.h | 2 +- ge/common/ge/tbe_plugin_manager.cc | 6 +- ge/{graph => }/common/ge_call_wrapper.h | 0 ge/common/ge_format_util.cc | 4 +- ge/common/helper/model_cache_helper.h | 2 +- ge/common/helper/model_helper.cc | 77 +++++------ ge/common/helper/om_file_helper.cc | 40 +++--- ge/common/kernel_store.h | 2 +- ge/{graph => }/common/local_context.cc | 2 +- ge/{graph => }/common/local_context.h | 0 ge/common/math/fp16_math.cc | 28 ++-- ge/{ => common}/model/ge_model.cc | 2 +- ge/{ => common}/model/ge_model.h | 12 +- ge/{ => common}/model/ge_root_model.cc | 3 +- ge/{ => common}/model/ge_root_model.h | 2 +- ge/common/model_parser/model_parser.cc | 12 +- ge/common/model_saver.cc | 3 +- ge/{graph => }/common/omg_util.cc | 22 ++-- ge/{graph => }/common/omg_util.h | 0 ge/common/op/attr_value_util.cc | 141 ++++++++++----------- ge/common/op/ge_op_utils.cc | 41 ++---- ge/common/profiling/ge_profiling.cc | 2 +- ge/common/profiling/profiling_manager.cc | 65 ++++------ ge/common/profiling/profiling_manager.h | 2 +- ge/common/properties_manager.cc | 15 +-- ge/common/tbe_kernel_store.h | 2 +- ge/common/thread_pool.cc | 4 +- ge/common/thread_pool.h | 2 +- ge/{graph => }/common/transop_util.cc | 4 +- ge/{graph => }/common/transop_util.h | 2 +- ge/common/util.cc | 34 ++--- ge/executor/CMakeLists.txt | 14 +- ge/executor/module.mk | 6 +- ge/ge_inference.mk | 8 +- ge/ge_runner.mk | 8 +- ge/generator/ge_generator.cc | 2 +- ge/graph/build/graph_builder.cc | 4 +- ge/graph/build/graph_builder.h | 2 +- ge/graph/build/logical_stream_allocator.cc | 2 +- ge/graph/build/memory/block_mem_assigner.cc | 2 +- ge/graph/build/memory/buffer_pool_mem_assigner.cc | 2 +- ge/graph/build/memory/graph_mem_assigner.cc | 2 +- ge/graph/build/memory/var_mem_assign_util.cc | 2 +- ge/graph/build/model_builder.cc | 6 +- ge/graph/build/model_builder.h | 2 +- ge/graph/build/run_context.cc | 2 +- ge/graph/build/stream_allocator.cc | 2 +- ge/graph/build/task_generator.cc | 2 +- ge/graph/execute/model_executor.cc | 4 +- ge/graph/load/model_manager/davinci_model.cc | 6 +- ge/graph/load/model_manager/davinci_model.h | 2 +- ge/graph/load/model_manager/model_manager.cc | 4 +- ge/graph/load/model_manager/model_manager.h | 2 +- ge/graph/manager/graph_manager.cc | 10 +- ge/graph/manager/graph_manager.h | 2 +- ge/graph/manager/graph_manager_utils.h | 6 +- ge/graph/optimize/graph_optimize.cc | 2 +- ge/graph/optimize/mem_rw_conflict_optimize.cc | 2 +- ge/graph/partition/dynamic_shape_partition.cc | 2 +- ge/graph/partition/graph_partition.cc | 2 +- ge/graph/partition/graph_partition.h | 2 + ge/graph/passes/atomic_addr_clean_pass.cc | 2 +- ge/graph/passes/attach_stream_label_pass.cc | 2 +- ge/graph/passes/buffer_pool_memory_pass.cc | 2 +- ge/graph/passes/cast_remove_pass.cc | 2 +- ge/graph/passes/cast_translate_pass.cc | 2 +- ge/graph/passes/compile_nodes_pass.cc | 2 +- ge/graph/passes/control_trigger_pass.cc | 2 +- ge/graph/passes/dimension_adjust_pass.h | 2 +- ge/graph/passes/flow_ctrl_pass.cc | 2 +- ge/graph/passes/get_original_format_pass.cc | 2 +- ge/graph/passes/guarantee_const_pass.cc | 2 +- ge/graph/passes/hccl_tailing_optimization_pass.cc | 2 +- ge/graph/passes/identity_pass.cc | 2 +- ge/graph/passes/infershape_pass.cc | 2 +- ge/graph/passes/iterator_op_pass.cc | 2 +- .../passes/mark_force_unknown_for_cond_pass.cc | 2 +- ge/graph/passes/mark_node_unknown_shape_pass.cc | 2 +- ge/graph/passes/merge_input_memcpy_pass.cc | 2 +- ge/graph/passes/merge_pass.cc | 2 +- ge/graph/passes/merge_to_stream_merge_pass.cc | 2 +- ge/graph/passes/multi_batch_clone_pass.cc | 4 +- ge/graph/passes/multi_batch_pass.cc | 2 +- ge/graph/passes/net_output_pass.cc | 2 +- ge/graph/passes/next_iteration_pass.cc | 2 +- ge/graph/passes/pass_manager.cc | 2 +- ge/graph/passes/pass_utils.cc | 2 +- ge/graph/passes/permute_pass.cc | 2 +- ge/graph/passes/placeholder_with_default_pass.cc | 2 +- ge/graph/passes/prevent_gradient_pass.cc | 2 +- ge/graph/passes/print_op_pass.h | 2 +- ge/graph/passes/ref_identity_delete_op_pass.cc | 2 +- ge/graph/passes/replace_transshape_pass.cc | 2 +- ge/graph/passes/snapshot_pass.cc | 2 +- ge/graph/passes/stop_gradient_pass.h | 2 +- ge/graph/passes/switch_dead_branch_elimination.cc | 2 +- ge/graph/passes/switch_to_stream_switch_pass.cc | 2 +- ge/graph/passes/transop_breadth_fusion_pass.cc | 2 +- ge/graph/passes/transop_depth_fusion_pass.cc | 2 +- .../passes/transop_nearby_allreduce_fusion_pass.cc | 2 +- .../passes/transop_symmetry_elimination_pass.cc | 2 +- .../passes/transop_without_reshape_fusion_pass.cc | 2 +- ge/graph/passes/variable_op_pass.h | 2 +- ge/graph/passes/variable_prepare_op_pass.cc | 2 +- ge/graph/preprocess/graph_preprocess.cc | 6 +- ge/graph/preprocess/insert_op/ge_aipp_op.cc | 2 +- ge/graph/preprocess/multi_batch_copy_graph.cc | 4 +- ge/graph/preprocess/multi_batch_options.cc | 4 +- ge/host_kernels/add_kernel.cc | 2 +- ge/host_kernels/broadcast_args_kernel.cc | 2 +- ge/host_kernels/broadcast_gradient_args_kernel.cc | 2 +- ge/host_kernels/cast_kernel.cc | 2 +- ge/host_kernels/floormod_kernel.cc | 2 +- ge/host_kernels/greater_kernel.cc | 2 +- ge/host_kernels/maximum_kernel.cc | 2 +- ge/host_kernels/mul_kernel.cc | 2 +- ge/host_kernels/permute_kernel.cc | 2 +- ge/host_kernels/sub_kernel.cc | 2 +- ge/host_kernels/transdata_kernel.cc | 2 +- ge/hybrid/hybrid_davinci_model.h | 2 +- ge/hybrid/model/hybrid_model.h | 2 +- ge/hybrid/model/hybrid_model_builder.cc | 2 +- ge/hybrid/model/hybrid_model_builder.h | 2 +- ge/init/gelib.cc | 2 +- ge/ir_build/attr_options/utils.cc | 2 +- ge/ir_build/ge_ir_build.cc | 2 +- ge/session/inner_session.cc | 2 +- inc/framework/common/helper/model_helper.h | 20 ++- tests/ut/ge/CMakeLists.txt | 13 +- tests/ut/ge/common/fp16_unittest.cc | 56 ++++++++ tests/ut/ge/graph/build/model_builder_unittest.cc | 2 +- tests/ut/ge/graph/graph_load_unittest.cc | 2 +- tests/ut/ge/graph/load/model_helper_unittest.cc | 2 +- .../ut/ge/graph/manager/graph_manager_unittest.cc | 10 +- .../partition/dynamic_shape_partition_unittest.cc | 2 +- .../mark_node_unknown_shape_pass_unittest.cc | 2 +- .../passes/multi_batch_clone_pass_unittest.cc | 2 +- .../subgraph_const_migration_pass_unittest.cc | 2 +- tests/ut/ge/graph/transop_util_unittest.cc | 2 +- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 4 +- .../hybrid/model/hybrid_model_builder_unittest.cc | 2 +- .../ge_local/ge_local_node_executor_unittest.cc | 2 +- .../host_cpu/host_cpu_node_task_unittest.cc | 2 +- .../node_executor/rts/rts_node_task_unittest.cc | 2 +- 161 files changed, 592 insertions(+), 599 deletions(-) rename ge/{graph => }/common/bcast.cc (98%) rename ge/{graph => }/common/bcast.h (98%) rename ge/{graph => }/common/ge_call_wrapper.h (100%) rename ge/{graph => }/common/local_context.cc (97%) rename ge/{graph => }/common/local_context.h (100%) rename ge/{ => common}/model/ge_model.cc (99%) rename ge/{ => common}/model/ge_model.h (90%) rename ge/{ => common}/model/ge_root_model.cc (95%) rename ge/{ => common}/model/ge_root_model.h (98%) rename ge/{graph => }/common/omg_util.cc (95%) rename ge/{graph => }/common/omg_util.h (100%) rename ge/{graph => }/common/transop_util.cc (97%) rename ge/{graph => }/common/transop_util.h (95%) create mode 100644 tests/ut/ge/common/fp16_unittest.cc diff --git a/.clang-format b/.clang-format index 6faea40d..dd8abe32 100644 --- a/.clang-format +++ b/.clang-format @@ -11,7 +11,7 @@ AlignTrailingComments: true AllowAllParametersOfDeclarationOnNextLine: true AllowShortBlocksOnASingleLine: false AllowShortCaseLabelsOnASingleLine: false -AllowShortFunctionsOnASingleLine: All +AllowShortFunctionsOnASingleLine: Empty AllowShortIfStatementsOnASingleLine: true AllowShortLoopsOnASingleLine: true AlwaysBreakAfterDefinitionReturnType: None diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index dead6aa5..cb4c84b1 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -112,7 +112,6 @@ set(EXECUTOR_SRC_LIST "analyzer/analyzer.cc" "common/dump/dump_manager.cc" "common/dump/dump_op.cc" - "common/dump/dump_properties.cc" "common/dump/exception_dumper.cc" "common/dump/opdebug_register.cc" "common/formats/format_transfers/format_transfer_transpose.cc" @@ -126,9 +125,6 @@ set(EXECUTOR_SRC_LIST "executor/ge_executor.cc" "ge_local_engine/engine/host_cpu_engine.cc" "graph/build/memory/var_mem_assign_util.cc" - "graph/common/bcast.cc" - "graph/common/local_context.cc" - "graph/common/omg_util.cc" "graph/execute/graph_execute.cc" "graph/execute/model_executor.cc" "graph/load/graph_loader.cc" @@ -255,8 +251,6 @@ set(EXECUTOR_SRC_LIST "hybrid/node_executor/rts/rts_task_factory.cc" "hybrid/node_executor/task_context.cc" "init/gelib.cc" - "model/ge_model.cc" - "model/ge_root_model.cc" "opskernel_manager/ops_kernel_builder_manager.cc" "opskernel_manager/ops_kernel_manager.cc" "single_op/single_op.cc" @@ -274,7 +268,6 @@ set(EXECUTOR_SRC_LIST ################################################################## set(COMPILER_SRC_LIST "analyzer/analyzer.cc" - "common/dump/dump_manager.cc" "common/dump/dump_op.cc" "common/dump/dump_properties.cc" "common/formats/format_transfers/datatype_transfer.cc" @@ -322,10 +315,6 @@ set(COMPILER_SRC_LIST "graph/build/stream_allocator.cc" "graph/build/stream_graph_optimizer.cc" "graph/build/task_generator.cc" - "graph/common/bcast.cc" - "graph/common/local_context.cc" - "graph/common/omg_util.cc" - "graph/common/transop_util.cc" "graph/label/case_label_maker.cc" "graph/label/if_label_maker.cc" "graph/label/label_maker.cc" @@ -508,8 +497,6 @@ set(COMPILER_SRC_LIST "ir_build/attr_options/weight_compress_option.cc" "ir_build/ge_ir_build.cc" "ir_build/option_utils.cc" - "model/ge_model.cc" - "model/ge_root_model.cc" "opskernel_manager/ops_kernel_builder_manager.cc" "opskernel_manager/ops_kernel_manager.cc" ) diff --git a/ge/client/ge_api.cc b/ge/client/ge_api.cc index 3cf7c3c4..e4a016b3 100644 --- a/ge/client/ge_api.cc +++ b/ge/client/ge_api.cc @@ -29,7 +29,7 @@ #include "graph/opsproto_manager.h" #include "graph/utils/type_utils.h" #include "graph/manager/util/rt_context_util.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "register/op_registry.h" #include "common/ge/tbe_plugin_manager.h" #include "common/util/error_manager/error_manager.h" diff --git a/ge/common/CMakeLists.txt b/ge/common/CMakeLists.txt index 1872b4c2..0d41b86f 100755 --- a/ge/common/CMakeLists.txt +++ b/ge/common/CMakeLists.txt @@ -2,16 +2,23 @@ set(SRC_LIST "context/ctx.cc" "model_saver.cc" "ge/datatype_util.cc" + "ge/plugin_manager.cc" + "ge/op_tiling_manager.cc" "helper/om_file_helper.cc" "helper/model_helper.cc" - "../model/ge_model.cc" - "../model/ge_root_model.cc" + "model/ge_model.cc" + "model/ge_root_model.cc" + "bcast.cc" + "local_context.cc" + "omg_util.cc" + "transop_util.cc" "auth/file_saver.cc" "fp16_t.cc" "math/fp16_math.cc" "debug/memory_dumper.cc" "formats/utils/formats_trans_utils.cc" "dump/dump_properties.cc" + "dump/dump_manager.cc" "formats/format_transfers/datatype_transfer.cc" "formats/format_transfers/format_transfer_transpose.cc" "formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" @@ -63,7 +70,7 @@ target_compile_definitions(ge_common PRIVATE ) target_compile_options(ge_common PRIVATE - -fvisibility=hidden + -fvisibility=default -O2 -Werror -Wno-deprecated-declarations @@ -183,7 +190,7 @@ target_compile_definitions(ge_common PRIVATE ) target_compile_options(ge_common PRIVATE - -fvisibility=hidden + -fvisibility=default -O2 -Werror -Wno-deprecated-declarations diff --git a/ge/common/auth/file_saver.cc b/ge/common/auth/file_saver.cc index 57ab901b..d6f24497 100755 --- a/ge/common/auth/file_saver.cc +++ b/ge/common/auth/file_saver.cc @@ -238,7 +238,7 @@ Status FileSaver::SaveToBuffWithFileHeader(const ModelFileHeader &file_header, return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::CheckPath(const std::string &file_path) { +Status FileSaver::CheckPath(const std::string &file_path) { // Determine file path length if (file_path.size() >= MMPA_MAX_PATH) { GELOGE(FAILED, "[Check][FilePath]Failed, file path's length:%zu > mmpa_max_path:%d", @@ -271,8 +271,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::CheckPath(con return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status -FileSaver::SaveToFile(const string &file_path, const ge::ModelData &model, const ModelFileHeader *model_file_header) { +Status FileSaver::SaveToFile(const string &file_path, const ge::ModelData &model, + const ModelFileHeader *model_file_header) { if (file_path.empty() || model.model_data == nullptr || model.model_len == 0) { GELOGE(FAILED, "[Save][File]Incorrect input param, " "file_path is empty or model_data is nullptr or model_len is 0"); @@ -301,19 +301,18 @@ FileSaver::SaveToFile(const string &file_path, const ge::ModelData &model, const return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status -FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, ModelPartitionTable &model_partition_table, - const std::vector &partition_datas) { +Status FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, + ModelPartitionTable &model_partition_table, + const std::vector &partition_datas) { const Status ret = SaveWithFileHeader(file_path, file_header, model_partition_table, partition_datas); GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, FAILED, "save file failed, file_path:%s, file header len:%u.", file_path.c_str(), file_header.length); return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status -FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, - vector &model_partition_tables, - const vector> &all_partition_datas) { +Status FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, + vector &model_partition_tables, + const vector> &all_partition_datas) { const Status ret = SaveWithFileHeader(file_path, file_header, model_partition_tables, all_partition_datas); GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, FAILED, "save file failed, file_path:%s, file header len:%u.", file_path.c_str(), file_header.length); @@ -372,8 +371,7 @@ Status FileSaver::SaveWithFileHeader(const std::string &file_path, const ModelFi return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(const string &file_path, const void *data, - int len) { +Status FileSaver::SaveToFile(const string &file_path, const void *data, int len) { if (data == nullptr || len <= 0) { GELOGE(FAILED, "[Check][Param]Failed, model_data is null or the " "length[%d] is less than 1.", len); diff --git a/ge/graph/common/bcast.cc b/ge/common/bcast.cc similarity index 98% rename from ge/graph/common/bcast.cc rename to ge/common/bcast.cc index fcc8f9a1..a4e8d1a1 100644 --- a/ge/graph/common/bcast.cc +++ b/ge/common/bcast.cc @@ -14,12 +14,12 @@ * limitations under the License. */ -#include "graph/common/bcast.h" +#include "common/bcast.h" #include #include "common/math_util.h" -#include "framework/common/util.h" +#include "common/util.h" using domi::Status; diff --git a/ge/graph/common/bcast.h b/ge/common/bcast.h similarity index 98% rename from ge/graph/common/bcast.h rename to ge/common/bcast.h index 184751fe..a8399896 100644 --- a/ge/graph/common/bcast.h +++ b/ge/common/bcast.h @@ -21,11 +21,11 @@ #include #include -#include "framework/common/debug/log.h" -#include "framework/common/types.h" +#include "common/debug/log.h" +#include "common/types.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "external/graph/attr_value.h" +#include "graph/attr_value.h" #include "graph/ge_tensor.h" #include "graph/utils/tensor_adapter.h" diff --git a/ge/common/context/ctx.cc b/ge/common/context/ctx.cc index 9fe2f8c7..8e138ade 100755 --- a/ge/common/context/ctx.cc +++ b/ge/common/context/ctx.cc @@ -18,7 +18,7 @@ using ge::OmgContext; namespace domi { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OmgContext &GetContext() { +OmgContext &GetContext() { static OmgContext context; return context; } diff --git a/ge/common/cust_aicpu_kernel_store.h b/ge/common/cust_aicpu_kernel_store.h index 033a636b..38124587 100755 --- a/ge/common/cust_aicpu_kernel_store.h +++ b/ge/common/cust_aicpu_kernel_store.h @@ -21,7 +21,7 @@ namespace ge { -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY CustAICPUKernelStore : public KernelStore { +class CustAICPUKernelStore : public KernelStore { public: CustAICPUKernelStore(); ~CustAICPUKernelStore() {} diff --git a/ge/common/debug/memory_dumper.cc b/ge/common/debug/memory_dumper.cc index 78ef2daa..f4a49440 100644 --- a/ge/common/debug/memory_dumper.cc +++ b/ge/common/debug/memory_dumper.cc @@ -30,13 +30,12 @@ const int kInvalidFd = (-1); } // namespace namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY MemoryDumper::MemoryDumper() : fd_(kInvalidFd) {} +MemoryDumper::MemoryDumper() : fd_(kInvalidFd) {} -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY MemoryDumper::~MemoryDumper() { Close(); } +MemoryDumper::~MemoryDumper() { Close(); } // Dump the data to the file -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::DumpToFile(const char *filename, void *data, - int64_t len) { +Status MemoryDumper::DumpToFile(const char *filename, void *data, int64_t len) { #ifdef FMK_SUPPORT_DUMP GE_CHECK_NOTNULL(filename); GE_CHECK_NOTNULL(data); @@ -81,7 +80,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::DumpToFile } // Open file -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::Open(const char *filename) { +Status MemoryDumper::Open(const char *filename) { GE_CHK_BOOL_RET_STATUS(filename != nullptr, FAILED, "Incorrect parameter. filename is nullptr"); // Try to remove file first for reduce the close time by overwriting way @@ -104,7 +103,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::Open(const } // Dump the data to file -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::Dump(void *data, uint32_t len) const { +Status MemoryDumper::Dump(void *data, uint32_t len) const { GE_CHK_BOOL_RET_STATUS(data != nullptr, FAILED, "Incorrect parameter. data is nullptr"); #ifdef FMK_SUPPORT_DUMP diff --git a/ge/common/dump/dump_manager.cc b/ge/common/dump/dump_manager.cc index ebe16fed..da8160ff 100644 --- a/ge/common/dump/dump_manager.cc +++ b/ge/common/dump/dump_manager.cc @@ -15,6 +15,7 @@ */ #include "common/dump/dump_manager.h" + #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" @@ -26,7 +27,7 @@ const uint64_t kInferSessionId = 0; const uint32_t kAllOverflow = 3; } // namespace namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpManager &DumpManager::GetInstance() { +DumpManager &DumpManager::GetInstance() { static DumpManager instance; return instance; } @@ -74,7 +75,7 @@ void DumpManager::SetDumpList(const DumpConfig &dump_config, DumpProperties &dum Status DumpManager::SetNormalDumpConf(const DumpConfig &dump_config, DumpProperties &dump_properties) { if (dump_config.dump_status == kDumpOn) { - GELOGI("Only do normal dump process, dump status is %s.", dump_config.dump_status.c_str()); + GELOGI("Only do normal dump process, dump status is %s", dump_config.dump_status.c_str()); dump_properties.SetDumpStatus(dump_config.dump_status); std::string dump_op_switch = dump_config.dump_op_switch; dump_properties.SetDumpOpSwitch(dump_op_switch); @@ -104,8 +105,8 @@ Status DumpManager::SetNormalDumpConf(const DumpConfig &dump_config, DumpPropert Status DumpManager::SetDumpPath(const DumpConfig &dump_config, DumpProperties &dump_properties) { std::string dump_path = dump_config.dump_path; if (dump_path.empty()) { - GELOGE(PARAM_INVALID, "[Check][DumpPath]It is empty"); - REPORT_INNER_ERROR("E19999", "Dump path check is empty"); + GELOGE(PARAM_INVALID, "[Check][DumpPath]It is empty."); + REPORT_INNER_ERROR("E19999", "Dump path check is empty."); return PARAM_INVALID; } if (dump_path[dump_path.size() - 1] != '/') { @@ -117,7 +118,7 @@ Status DumpManager::SetDumpPath(const DumpConfig &dump_config, DumpProperties &d return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpManager::SetDumpConf(const DumpConfig &dump_config) { +Status DumpManager::SetDumpConf(const DumpConfig &dump_config) { DumpProperties dump_properties; if (!NeedDoDump(dump_config, dump_properties)) { GELOGD("No need do dump process."); @@ -131,8 +132,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpManager::SetDumpConf return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const DumpProperties &DumpManager::GetDumpProperties( - uint64_t session_id) { +const DumpProperties &DumpManager::GetDumpProperties(uint64_t session_id) { std::lock_guard lock(mutex_); auto iter = dump_properties_map_.find(session_id); if (iter != dump_properties_map_.end()) { @@ -142,13 +142,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const DumpProperties &DumpManag return default_properties; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpManager::AddDumpProperties( - uint64_t session_id, const DumpProperties &dump_properties) { +void DumpManager::AddDumpProperties(uint64_t session_id, const DumpProperties &dump_properties) { std::lock_guard lock(mutex_); dump_properties_map_.emplace(session_id, dump_properties); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpManager::RemoveDumpProperties(uint64_t session_id) { +void DumpManager::RemoveDumpProperties(uint64_t session_id) { std::lock_guard lock(mutex_); auto iter = dump_properties_map_.find(session_id); if (iter != dump_properties_map_.end()) { diff --git a/ge/common/dump/dump_properties.cc b/ge/common/dump/dump_properties.cc index 099920e7..3bed76d9 100644 --- a/ge/common/dump/dump_properties.cc +++ b/ge/common/dump/dump_properties.cc @@ -38,9 +38,7 @@ const uint32_t kAtomicOverflow = (0x1 << 1); const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow); } // namespace namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::Split(const std::string &s, - std::vector &result, - const char *delchar) { +void DumpProperties::Split(const std::string &s, std::vector &result, const char *delchar) { if (s.empty()) { return; } @@ -68,7 +66,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::Split(cons delete[] buffer; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckDumpStep(const std::string &dump_step) { +Status DumpProperties::CheckDumpStep(const std::string &dump_step) { std::string modified_dum_step = dump_step + "|"; std::smatch result; std::vector match_vecs; @@ -126,7 +124,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckDum return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckDumpMode(const std::string &dump_mode) { +Status DumpProperties::CheckDumpMode(const std::string &dump_mode) { const std::set dump_mode_list = {"input", "output", "all"}; std::set::iterator iter; @@ -143,7 +141,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckDum return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckDumpPath(const std::string &input) { +Status DumpProperties::CheckDumpPath(const std::string &input) { if (mmIsDir(input.c_str()) != EN_OK) { REPORT_INPUT_ERROR("E10001", std::vector({"parameter", "value", "reason"}), std::vector({ @@ -175,7 +173,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckDum return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckEnableDump(const std::string &input) { +Status DumpProperties::CheckEnableDump(const std::string &input) { std::set enable_dump_option_list = {"1", "0"}; auto it = enable_dump_option_list.find(input); if (it == enable_dump_option_list.end()) { @@ -191,17 +189,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckEna return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties::DumpProperties(const DumpProperties &other) { +DumpProperties::DumpProperties(const DumpProperties &other) { CopyFrom(other); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties &DumpProperties::operator=( - const DumpProperties &other) { +DumpProperties &DumpProperties::operator=(const DumpProperties &other) { CopyFrom(other); return *this; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::SetDumpOptions() { +Status DumpProperties::SetDumpOptions() { if (enable_dump_ == kEnableFlag) { std::string dump_step; if (GetContext().GetOption(OPTION_EXEC_DUMP_STEP, dump_step) == GRAPH_SUCCESS && !dump_step.empty()) { @@ -220,7 +217,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::SetDumpO return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::InitByOptions() { +Status DumpProperties::InitByOptions() { enable_dump_.clear(); enable_dump_debug_.clear(); dump_path_.clear(); @@ -281,8 +278,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::InitByOp } // The following is the new dump scenario of the fusion operator -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::AddPropertyValue( - const std::string &model, const std::set &layers) { +void DumpProperties::AddPropertyValue(const std::string &model, const std::set &layers) { for (const std::string &layer : layers) { GELOGI("This model %s config to dump layer %s", model.c_str(), layer.c_str()); } @@ -290,18 +286,18 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::AddPropert model_dump_properties_map_[model] = layers; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::DeletePropertyValue(const std::string &model) { +void DumpProperties::DeletePropertyValue(const std::string &model) { auto iter = model_dump_properties_map_.find(model); if (iter != model_dump_properties_map_.end()) { model_dump_properties_map_.erase(iter); } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::ClearDumpPropertyValue() { +void DumpProperties::ClearDumpPropertyValue() { model_dump_properties_map_.clear(); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::ClearDumpInfo() { +void DumpProperties::ClearDumpInfo() { enable_dump_.clear(); enable_dump_debug_.clear(); dump_path_.clear(); @@ -314,7 +310,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::ClearDumpI op_debug_mode_ = 0; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpProperties::GetAllDumpModel() const { +std::set DumpProperties::GetAllDumpModel() const { std::set model_list; for (auto &iter : model_dump_properties_map_) { model_list.insert(iter.first); @@ -323,8 +319,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpPrope return model_list; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpProperties::GetPropertyValue( - const std::string &model) const { +std::set DumpProperties::GetPropertyValue(const std::string &model) const { auto iter = model_dump_properties_map_.find(model); if (iter != model_dump_properties_map_.end()) { return iter->second; @@ -332,8 +327,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpPrope return {}; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsLayerNeedDump( - const std::string &model, const std::string &om_name, const std::string &op_name) const { +bool DumpProperties::IsLayerNeedDump(const std::string &model, const std::string &om_name, + const std::string &op_name) const { // if dump all GELOGD("model name is %s om name is %s op is %s in layer need dump", model.c_str(), om_name.c_str(), op_name.c_str()); if (model_dump_properties_map_.find(DUMP_ALL_MODEL) != model_dump_properties_map_.end()) { @@ -353,67 +348,66 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsLayerNee return model_iter->second.find(op_name) != model_iter->second.end(); } - GELOGD("Model %s is not seated to be dump.", model.c_str()); + GELOGD("Model %s is not seated to be dump", model.c_str()); return false; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpPath(const std::string &path) { +void DumpProperties::SetDumpPath(const std::string &path) { dump_path_ = path; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpPath() const { +const std::string &DumpProperties::GetDumpPath() const { return dump_path_; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpStep(const std::string &step) { +void DumpProperties::SetDumpStep(const std::string &step) { dump_step_ = step; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpStep() const { +const std::string &DumpProperties::GetDumpStep() const { return dump_step_; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpMode(const std::string &mode) { +void DumpProperties::SetDumpMode(const std::string &mode) { dump_mode_ = mode; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpMode() const { +const std::string &DumpProperties::GetDumpMode() const { return dump_mode_; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpStatus(const std::string &status) { +void DumpProperties::SetDumpStatus(const std::string &status) { dump_status_ = status; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpStatus() const { +const std::string &DumpProperties::GetDumpStatus() const { return dump_status_; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitInferOpDebug() { +void DumpProperties::InitInferOpDebug() { is_infer_op_debug_ = true; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetOpDebugMode(const uint32_t &op_debug_mode) { +void DumpProperties::SetOpDebugMode(const uint32_t &op_debug_mode) { op_debug_mode_ = op_debug_mode; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpOpSwitch( - const std::string &dump_op_switch) { +void DumpProperties::SetDumpOpSwitch(const std::string &dump_op_switch) { dump_op_switch_ = dump_op_switch; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperties::GetDumpOpSwitch() const { +const std::string &DumpProperties::GetDumpOpSwitch() const { return dump_op_switch_; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsSingleOpNeedDump() const { +bool DumpProperties::IsSingleOpNeedDump() const { if (dump_op_switch_ == kDumpStatusOpen) { return true; } return false; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsDumpOpen() const { +bool DumpProperties::IsDumpOpen() const { if (enable_dump_ == kEnableFlag || dump_status_ == kDumpStatusOpen) { return true; } @@ -441,7 +435,7 @@ Status DumpProperties::SetDumpDebugOptions() { if (enable_dump_debug_ == kEnableFlag) { std::string dump_debug_mode; if (GetContext().GetOption(OPTION_EXEC_DUMP_DEBUG_MODE, dump_debug_mode) == GRAPH_SUCCESS) { - GELOGD("Get ge.exec.dumpDebugMode %s successfully", dump_debug_mode.c_str()); + GELOGD("Get ge.exec.dumpDebugMode %s successfully.", dump_debug_mode.c_str()); } else { GELOGW("ge.exec.dumpDebugMode is not set."); return SUCCESS; @@ -469,7 +463,7 @@ Status DumpProperties::SetDumpDebugOptions() { return PARAM_INVALID; } } else { - GELOGI("ge.exec.enableDumpDebug is false or is not set."); + GELOGI("ge.exec.enableDumpDebug is false or is not set"); } return SUCCESS; } diff --git a/ge/common/fmk_error_codes.cc b/ge/common/fmk_error_codes.cc index ddb8089d..180af0e2 100755 --- a/ge/common/fmk_error_codes.cc +++ b/ge/common/fmk_error_codes.cc @@ -17,19 +17,18 @@ #include "framework/common/fmk_error_codes.h" namespace domi { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY StatusFactory *StatusFactory::Instance() { +StatusFactory *StatusFactory::Instance() { static StatusFactory instance; return &instance; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void StatusFactory::RegisterErrorNo(uint32_t err, - const std::string &desc) { +void StatusFactory::RegisterErrorNo(uint32_t err, const std::string &desc) { if (err_desc_.find(err) != err_desc_.end()) { return; } err_desc_[err] = desc; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string StatusFactory::GetErrDesc(uint32_t err) { +std::string StatusFactory::GetErrDesc(uint32_t err) { auto iter_find = err_desc_.find(err); if (iter_find == err_desc_.end()) { return ""; diff --git a/ge/common/formats/format_transfers/format_transfer_transpose.h b/ge/common/formats/format_transfers/format_transfer_transpose.h index 7fa19ff0..b608777c 100755 --- a/ge/common/formats/format_transfers/format_transfer_transpose.h +++ b/ge/common/formats/format_transfers/format_transfer_transpose.h @@ -33,7 +33,6 @@ Status TransposeWithShapeCheck(const uint8_t *src, const std::vector &s Status GetPermByForamt(Format src_format, Format dst_format, std::vector &perm); - class FormatTransferTranspose : public FormatTransfer { public: Status TransFormat(const TransArgs &args, TransResult &result) override; diff --git a/ge/common/formats/formats.cc b/ge/common/formats/formats.cc index 9e97a4d2..5a454d60 100755 --- a/ge/common/formats/formats.cc +++ b/ge/common/formats/formats.cc @@ -17,6 +17,7 @@ #include "common/formats/formats.h" #include + #include #include #include @@ -32,7 +33,7 @@ namespace ge { namespace formats { -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArgs &args, TransResult &result) { +Status TransFormat(const TransArgs &args, TransResult &result) { auto transfer = BuildFormatTransfer(args); if (transfer == nullptr) { std::string error = "Failed to trans data from format " + @@ -56,11 +57,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArg return transfer->TransFormat(args, result); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransShape(Format src_format, - const std::vector &src_shape, - DataType data_type, - Format dst_format, - std::vector &dst_shape) { +Status TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, + std::vector &dst_shape) { formats::TransArgs args; args.src_format = src_format; args.dst_format = dst_format; @@ -76,7 +74,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransShape(Format src_form return transfer->TransShape(src_format, src_shape, data_type, dst_format, dst_shape); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastArgs &args, TransResult &result) { +Status TransDataType(const CastArgs &args, TransResult &result) { auto transfer = BuildDataTypeTransfer(args); if (transfer == nullptr) { std::string error = "Failed to trans data from datatype " + @@ -95,11 +93,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastAr return transfer->TransDataType(args, result); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool IsTransFormatSupport(const TransArgs &args) { +bool IsTransFormatSupport(const TransArgs &args) { return FormatTransferExists(args); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool IsTransDataTypeSupport(const CastArgs &args) { +bool IsTransDataTypeSupport(const CastArgs &args) { return DataTypeTransferExists(args); } } // namespace formats diff --git a/ge/common/formats/utils/formats_trans_utils.cc b/ge/common/formats/utils/formats_trans_utils.cc index db1812d0..63ad424f 100755 --- a/ge/common/formats/utils/formats_trans_utils.cc +++ b/ge/common/formats/utils/formats_trans_utils.cc @@ -41,15 +41,14 @@ int64_t GetCubeSizeByDataType(DataType data_type) { } } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string ShapeToString(const GeShape &shape) { +std::string ShapeToString(const GeShape &shape) { return ShapeToString(shape.GetDims()); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string ShapeToString(const std::vector &shape) { +std::string ShapeToString(const std::vector &shape) { return JoinToString(shape); } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string RangeToString(const std::vector> &ranges) { bool first = true; std::stringstream ss; diff --git a/ge/common/fp16_t.cc b/ge/common/fp16_t.cc index 2f94323d..adb55dfb 100755 --- a/ge/common/fp16_t.cc +++ b/ge/common/fp16_t.cc @@ -1180,20 +1180,40 @@ fp16_t &fp16_t::operator=(const double &d_val) { } // convert -fp16_t::operator float() const { return Fp16ToFloat(val); } -fp16_t::operator double() const { return Fp16ToDouble(val); } -fp16_t::operator int8_t() const { return Fp16ToInt8(val); } -fp16_t::operator uint8_t() const { return Fp16ToUInt8(val); } -fp16_t::operator int16_t() const { return Fp16ToInt16(val); } -fp16_t::operator uint16_t() const { return Fp16ToUInt16(val); } -fp16_t::operator int32_t() const { return Fp16ToInt32(val); } -fp16_t::operator uint32_t() const { return Fp16ToUInt32(val); } +fp16_t::operator float() const { + return Fp16ToFloat(val); +} +fp16_t::operator double() const { + return Fp16ToDouble(val); +} +fp16_t::operator int8_t() const { + return Fp16ToInt8(val); +} +fp16_t::operator uint8_t() const { + return Fp16ToUInt8(val); +} +fp16_t::operator int16_t() const { + return Fp16ToInt16(val); +} +fp16_t::operator uint16_t() const { + return Fp16ToUInt16(val); +} +fp16_t::operator int32_t() const { + return Fp16ToInt32(val); +} +fp16_t::operator uint32_t() const { + return Fp16ToUInt32(val); +} // Cannot be used, just in order to solve the compile error -fp16_t::operator int64_t() const { return 0; } +fp16_t::operator int64_t() const { + return 0; +} // Cannot be used, just in order to solve the compile error -fp16_t::operator uint64_t() const { return 0; } +fp16_t::operator uint64_t() const { + return 0; +} -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int fp16_t::IsInf() { +int fp16_t::IsInf() { if ((val & kFp16AbsMax) == kFp16ExpMask) { if (val & kFp16SignMask) { return -1; @@ -1205,12 +1225,28 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int fp16_t::IsInf() { } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY float fp16_t::ToFloat() const { return Fp16ToFloat(val); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY double fp16_t::ToDouble() const { return Fp16ToDouble(val); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int8_t fp16_t::ToInt8() const { return Fp16ToInt8(val); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint8_t fp16_t::ToUInt8() const { return Fp16ToUInt8(val); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int16_t fp16_t::ToInt16() const { return Fp16ToInt16(val); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint16_t fp16_t::ToUInt16() const { return Fp16ToUInt16(val); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int32_t fp16_t::ToInt32() const { return Fp16ToInt32(val); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t fp16_t::ToUInt32() const { return Fp16ToUInt32(val); } +float fp16_t::ToFloat() const { + return Fp16ToFloat(val); +} +double fp16_t::ToDouble() const { + return Fp16ToDouble(val); +} +int8_t fp16_t::ToInt8() const { + return Fp16ToInt8(val); +} +uint8_t fp16_t::ToUInt8() const { + return Fp16ToUInt8(val); +} +int16_t fp16_t::ToInt16() const { + return Fp16ToInt16(val); +} +uint16_t fp16_t::ToUInt16() const { + return Fp16ToUInt16(val); +} +int32_t fp16_t::ToInt32() const { + return Fp16ToInt32(val); +} +uint32_t fp16_t::ToUInt32() const { + return Fp16ToUInt32(val); +} } // namespace ge diff --git a/ge/common/ge/datatype_util.h b/ge/common/ge/datatype_util.h index c3b41b81..82c8d259 100644 --- a/ge/common/ge/datatype_util.h +++ b/ge/common/ge/datatype_util.h @@ -42,7 +42,7 @@ static std::map CONST_OPDATA_TYPE_SIZE_MAP = { {ge::DT_UINT8, kGeSizeUint8}, {ge::DT_UINT16, kGeSizeUint16}, {ge::DT_UINT32, kGeSizeUint32}, {ge::DT_UINT64, kGeSizeUint64}, {ge::DT_DOUBLE, kGeSizeDouble}, {ge::DT_BOOL, kGeSizeBool}}; -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY DataTypeUtil { +class DataTypeUtil { public: static bool DataTypeTranslatable(const ge::DataType &src_out_data_type, const ge::DataType &dst_in_data_type); static const std::vector &GetTranslatableDataTypesBySrc(const ge::DataType &src_out_data_type); diff --git a/ge/common/ge/tbe_plugin_manager.cc b/ge/common/ge/tbe_plugin_manager.cc index 70c1ab94..3680a8bb 100755 --- a/ge/common/ge/tbe_plugin_manager.cc +++ b/ge/common/ge/tbe_plugin_manager.cc @@ -42,7 +42,7 @@ const int kBaseInt = 10; std::map TBEPluginManager::options_ = {}; // Get Singleton Instance -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY TBEPluginManager &TBEPluginManager::Instance() { +TBEPluginManager &TBEPluginManager::Instance() { static TBEPluginManager instance_ptr_; return instance_ptr_; } @@ -61,7 +61,7 @@ Status TBEPluginManager::ClearHandles_() { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status TBEPluginManager::Finalize() { +Status TBEPluginManager::Finalize() { Status ret = ClearHandles_(); return ret; } @@ -207,7 +207,6 @@ void TBEPluginManager::LoadCustomOpLib() { } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::LoadPluginSo(const std::map &options) { vector file_list; string caffe_parser_path; @@ -246,7 +245,6 @@ void TBEPluginManager::LoadPluginSo(const std::map &options) { } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::InitPreparation(const std::map &options) { options_.insert(options.begin(), options.end()); // Load TBE plugin diff --git a/ge/graph/common/ge_call_wrapper.h b/ge/common/ge_call_wrapper.h similarity index 100% rename from ge/graph/common/ge_call_wrapper.h rename to ge/common/ge_call_wrapper.h diff --git a/ge/common/ge_format_util.cc b/ge/common/ge_format_util.cc index f3dee571..0ffa686f 100755 --- a/ge/common/ge_format_util.cc +++ b/ge/common/ge_format_util.cc @@ -18,9 +18,7 @@ #include "common/formats/formats.h" namespace ge { -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status GeFormatUtil::TransShape(const TensorDesc &src_desc, - Format dst_format, - std::vector &dst_shape) { +Status GeFormatUtil::TransShape(const TensorDesc &src_desc, Format dst_format, std::vector &dst_shape) { return formats::TransShape(src_desc.GetFormat(), src_desc.GetShape().GetDims(), src_desc.GetDataType(), dst_format, dst_shape); } diff --git a/ge/common/helper/model_cache_helper.h b/ge/common/helper/model_cache_helper.h index 13253cbe..f0831075 100755 --- a/ge/common/helper/model_cache_helper.h +++ b/ge/common/helper/model_cache_helper.h @@ -24,7 +24,7 @@ #include "external/ge/ge_api_error_codes.h" #include "graph/compute_graph.h" #include "graph/manager/graph_var_manager.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" namespace ge { using Json = nlohmann::json; diff --git a/ge/common/helper/model_helper.cc b/ge/common/helper/model_helper.cc index 4e760a4a..2608b1e1 100644 --- a/ge/common/helper/model_helper.cc +++ b/ge/common/helper/model_helper.cc @@ -33,7 +33,7 @@ const uint32_t kStatiOmFileModelNum = 1; namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); } +ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); } Status ModelHelper::SaveModelPartition(std::shared_ptr &om_file_save_helper, ModelPartitionType type, const uint8_t *data, size_t size, size_t model_index) { @@ -108,8 +108,8 @@ Status ModelHelper::SaveSizeToModelDef(const GeModelPtr &ge_model) { return SUCCESS; } -Status ModelHelper::SaveModelDef(std::shared_ptr &om_file_save_helper, - const GeModelPtr &ge_model, ge::Buffer &model_buffer, size_t model_index) { +Status ModelHelper::SaveModelDef(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + ge::Buffer &model_buffer, size_t model_index) { ModelPtr model_tmp = ge::MakeShared(ge_model->GetName(), ge_model->GetPlatformVersion()); if (model_tmp == nullptr) { GELOGE(FAILED, "[Creat][Model]Failed, Model %s Ptr", ge_model->GetName().c_str()); @@ -143,8 +143,8 @@ Status ModelHelper::SaveModelDef(std::shared_ptr &om_file_save return SUCCESS; } -Status ModelHelper::SaveModelWeights(std::shared_ptr &om_file_save_helper, - const GeModelPtr &ge_model, size_t model_index) { +Status ModelHelper::SaveModelWeights(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + size_t model_index) { auto ge_model_weight = ge_model->GetWeight(); GELOGD("WEIGHTS_DATA size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); // weight is not necessary @@ -187,8 +187,8 @@ Status ModelHelper::SaveModelCustAICPU(std::shared_ptr &om_fil return SUCCESS; } -Status ModelHelper::SaveModelTaskDef(std::shared_ptr &om_file_save_helper, - const GeModelPtr &ge_model, ge::Buffer &task_buffer, size_t model_index) { +Status ModelHelper::SaveModelTaskDef(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + ge::Buffer &task_buffer, size_t model_index) { std::shared_ptr model_task_def = ge_model->GetModelTaskDefPtr(); if (model_task_def == nullptr) { GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Creat][ModelTaskDef]Failed, it is nullptr, " @@ -231,8 +231,8 @@ Status ModelHelper::SaveModelTaskDef(std::shared_ptr &om_file_ return SUCCESS; } -Status ModelHelper::SaveModelHeader(std::shared_ptr &om_file_save_helper, - const GeModelPtr &ge_model, size_t model_num) { +Status ModelHelper::SaveModelHeader(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + size_t model_num) { // Save target/version to model_header ModelFileHeader &model_header = om_file_save_helper->GetModelFileHeader(); model_header.platform_type = ge_model->GetPlatformType(); @@ -246,8 +246,10 @@ Status ModelHelper::SaveModelHeader(std::shared_ptr &om_file_s if (err != EOK) { GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Save][Model]Failed while allocating memory for platform_version %s, model %s, " - "errno %d", platform_version.c_str(), ge_model->GetName().c_str(), err); - REPORT_CALL_ERROR("E19999", "ModelHelper save model %s failed while " + "errno %d", + platform_version.c_str(), ge_model->GetName().c_str(), err); + REPORT_CALL_ERROR("E19999", + "ModelHelper save model %s failed while " "allocating memory for platform_version %s, errno %d", ge_model->GetName().c_str(), platform_version.c_str(), err); return ACL_ERROR_GE_MEMORY_ALLOCATION; @@ -271,9 +273,9 @@ Status ModelHelper::SaveModelHeader(std::shared_ptr &om_file_s return SUCCESS; } -Status ModelHelper::SaveAllModelPartiton(std::shared_ptr& om_file_save_helper, - const GeModelPtr &ge_model, ge::Buffer &model_buffer, - ge::Buffer &task_buffer, size_t model_index) { +Status ModelHelper::SaveAllModelPartiton(std::shared_ptr &om_file_save_helper, + const GeModelPtr &ge_model, ge::Buffer &model_buffer, ge::Buffer &task_buffer, + size_t model_index) { if (SaveModelDef(om_file_save_helper, ge_model, model_buffer, model_index) != SUCCESS) { GELOGE(FAILED, "[Save][ModelDef]Failed, model %s, model index %zu", ge_model->GetName().c_str(), model_index); @@ -316,10 +318,8 @@ Status ModelHelper::SaveAllModelPartiton(std::shared_ptr& om_f return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmModel(const GeModelPtr &ge_model, - const SaveParam &save_param, - const std::string &output_file, - ModelBufferData& model) { +Status ModelHelper::SaveToOmModel(const GeModelPtr &ge_model, const SaveParam &save_param, + const std::string &output_file, ModelBufferData &model) { if (output_file.empty()) { GELOGE(FAILED, "[Save][Model]GraphBuilder SaveModel received invalid file name prefix, " "model %s", ge_model->GetName().c_str()); @@ -367,13 +367,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmRootModel( - const GeRootModelPtr &ge_root_model, - const SaveParam &save_param, - const std::string &output_file, - ModelBufferData& model, - bool is_unknown_shape) { - +Status ModelHelper::SaveToOmRootModel(const GeRootModelPtr &ge_root_model, const SaveParam &save_param, + const std::string &output_file, ModelBufferData &model, bool is_unknown_shape) { GE_CHECK_NOTNULL(ge_root_model); GE_IF_BOOL_EXEC(ge_root_model == nullptr, GELOGE(FAILED, "[Check][GERootModel]Ge_root_model is nullptr"); @@ -466,8 +461,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmRoo return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status -ModelHelper::SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::string &output_file) { +Status ModelHelper::SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::string &output_file) { if (output_file.empty()) { GELOGE(FAILED, "[Save][Model]Received invalid file name prefix, output_file %s", output_file.c_str()); REPORT_INNER_ERROR("E19999", "Save model received invalid file name prefix, output_file %s", output_file.c_str()); @@ -545,7 +539,7 @@ ModelHelper::SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::strin return (ret == SUCCESS ? SUCCESS : FAILED); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(const ge::ModelData &model_data) { +Status ModelHelper::LoadModel(const ge::ModelData &model_data) { if (model_data.model_data == nullptr || model_data.model_len == 0) { GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "[Load][Model]Model_data is nullptr or model_data_size is 0"); @@ -597,7 +591,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(c return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadRootModel(const ge::ModelData &model_data) { +Status ModelHelper::LoadRootModel(const ge::ModelData &model_data) { if (model_data.model_data == nullptr || model_data.model_len == 0) { GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "[Load][RootModel] " "Model_data is nullptr or model data is empty."); @@ -783,7 +777,6 @@ Status ModelHelper::LoadModelData(OmFileLoadHelper &om_load_helper, GeModelPtr & return SUCCESS; } - Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper) { ModelPartition partition; if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition) != SUCCESS) { @@ -814,7 +807,7 @@ Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper, GeModelPtr &cu return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper) { +Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper) { ModelPartition task_partition; if (om_load_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition) != SUCCESS) { GELOGE(FAILED, "[Get][ModelTaskPartition]Failed, task_partition size:%u", task_partition.size); @@ -838,9 +831,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(Om return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper, - GeModelPtr &cur_model, - size_t mode_index) { +Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index) { ModelPartition task_partition; if (om_load_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition, mode_index) != SUCCESS) { GELOGE(FAILED, "Get task model partition failed."); @@ -915,8 +906,8 @@ Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper) { return SUCCESS; } -Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper, - GeModelPtr &cur_model, size_t mode_index) { +Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, + size_t mode_index) { // Load cust aicpu kernels ModelPartition partition_kernel_def; CustAICPUKernelStore kernel_store; @@ -933,7 +924,7 @@ Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper, return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeModel() { +GeModelPtr ModelHelper::GetGeModel() { if (model_ != nullptr) { return model_; } @@ -946,7 +937,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeMo return out_model; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeRootModelPtr ModelHelper::GetGeRootModel() { +GeRootModelPtr ModelHelper::GetGeRootModel() { if (root_model_ != nullptr) { return root_model_; } @@ -959,7 +950,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeRootModelPtr ModelHelper::Get return out_model; } - Status ModelHelper::ReleaseLocalModelData() noexcept { Status result = SUCCESS; if (model_addr_tmp_ != nullptr) { @@ -976,8 +966,7 @@ Status ModelHelper::ReleaseLocalModelData() noexcept { return result; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::GetBaseNameFromFileName( - const string &file_name, string &base_name) { +Status ModelHelper::GetBaseNameFromFileName(const string &file_name, string &base_name) { GELOGD("Get base_name from file, file_name:%s", file_name.c_str()); GE_CHK_BOOL_EXEC_WARN(!file_name.empty(), return FAILED, "File path may not valid, check params --output"); size_t start_position = 0; @@ -992,8 +981,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::GetBaseName return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::GetModelNameFromMergedGraphName( - const string &graph_name, string &model_name) { +Status ModelHelper::GetModelNameFromMergedGraphName(const string &graph_name, string &model_name) { GELOGD("Get model_name from graph_name, graph_name:%s", graph_name.c_str()); // this can only be used after merged graph(graph name will be append with "_x", x is index); GE_CHK_BOOL_EXEC_WARN(!graph_name.empty(), return FAILED, "File path may not valid, check params --output"); @@ -1035,8 +1023,7 @@ Status ModelTool::GetModelInfoFromOm(const char *model_file, ge::proto::ModelDef ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, {"om", model_file, "invalid om file, can't be parsed"}); GELOGE(ACL_ERROR_GE_PARAM_INVALID, - "[Parse][ModelContent]Failed because of invalid om file %s, please check om param", - model_file); + "[Parse][ModelContent]Failed because of invalid om file %s, please check om param", model_file); return ret; } diff --git a/ge/common/helper/om_file_helper.cc b/ge/common/helper/om_file_helper.cc index cd13c5d8..a42316ff 100644 --- a/ge/common/helper/om_file_helper.cc +++ b/ge/common/helper/om_file_helper.cc @@ -18,10 +18,11 @@ #include #include -#include "common/math/math_util.h" + #include "common/auth/file_saver.h" -#include "framework/common/debug/log.h" +#include "common/math/math_util.h" #include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/common/util.h" @@ -32,7 +33,7 @@ const int32_t kOptionalNum = 2; } namespace ge { // For Load -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(const ge::ModelData &model) { +Status OmFileLoadHelper::Init(const ge::ModelData &model) { if (CheckModelValid(model) != SUCCESS) { return FAILED; } @@ -42,8 +43,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(c return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(uint8_t *model_data, - const uint32_t model_data_size) { +Status OmFileLoadHelper::Init(uint8_t *model_data, const uint32_t model_data_size) { Status status = LoadModelPartitionTable(model_data, model_data_size); if (status != SUCCESS) { return status; @@ -52,9 +52,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(u return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(uint8_t *model_data, - uint32_t model_data_size, - uint32_t model_num) { +Status OmFileLoadHelper::Init(uint8_t *model_data, uint32_t model_data_size, uint32_t model_num) { Status status = LoadModelPartitionTable(model_data, model_data_size, model_num); if (status != SUCCESS) { return status; @@ -64,8 +62,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(u } // Use both -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, - ModelPartition &partition) { +Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, ModelPartition &partition) { if (!is_inited_) { GELOGE(PARAM_INVALID, "OmFileLoadHelper has not been initialized!"); return PARAM_INVALID; @@ -90,9 +87,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetMod return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, - ModelPartition &partition, - size_t model_index) { +Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, ModelPartition &partition, size_t model_index) { if (!is_inited_) { GELOGE(PARAM_INVALID, "OmFileLoadHelper has not been initialized!"); return PARAM_INVALID; @@ -248,12 +243,11 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, uint32_t m return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::vector - &OmFileSaveHelper::GetModelPartitions() const { +const std::vector &OmFileSaveHelper::GetModelPartitions() const { return context_.partition_datas_; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSaveHelper::GetPartitionTable() { +ModelPartitionTable *OmFileSaveHelper::GetPartitionTable() { auto partition_size = static_cast(context_.partition_datas_.size()); // Build ModelPartitionTable, flex array context_.partition_table_.clear(); @@ -272,8 +266,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSave return partition_table; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSaveHelper::GetPartitionTable( - size_t cur_ctx_index) { +ModelPartitionTable *OmFileSaveHelper::GetPartitionTable(size_t cur_ctx_index) { auto &cur_ctx = model_contexts_[cur_ctx_index]; auto partition_size = static_cast(cur_ctx.partition_datas_.size()); // Build ModelPartitionTable, flex array @@ -293,8 +286,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSave return partition_table; } - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileSaveHelper::AddPartition(ModelPartition &partition) { +Status OmFileSaveHelper::AddPartition(ModelPartition &partition) { if (ge::CheckUint32AddOverflow(context_.model_data_len_, partition.size) != SUCCESS) { GELOGE(FAILED, "UINT32 %u and %u addition can result in overflow!", context_.model_data_len_, partition.size); return FAILED; @@ -379,8 +371,8 @@ Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferDat #endif } -Status OmFileSaveHelper::SaveRootModel(const SaveParam &save_param, const char *output_file, - ModelBufferData &model, bool is_offline) { +Status OmFileSaveHelper::SaveRootModel(const SaveParam &save_param, const char *output_file, ModelBufferData &model, + bool is_offline) { (void)save_param.cert_file; (void)save_param.ek_file; (void)save_param.encode_mode; @@ -409,8 +401,8 @@ Status OmFileSaveHelper::SaveRootModel(const SaveParam &save_param, const char * model_header_.length += size_of_table + cur_model_data_len; model_partition_tabels.push_back(tmp_table); all_model_partitions.push_back(cur_ctx.partition_datas_); - GELOGD("sizeof(ModelPartitionTable):%u, cur_model_data_len:%u, cur_context_index:%zu", - size_of_table, cur_model_data_len, ctx_index); + GELOGD("sizeof(ModelPartitionTable):%u, cur_model_data_len:%u, cur_context_index:%zu", size_of_table, + cur_model_data_len, ctx_index); } Status ret; if (is_offline) { diff --git a/ge/common/kernel_store.h b/ge/common/kernel_store.h index b3f4a62e..e7b867a3 100755 --- a/ge/common/kernel_store.h +++ b/ge/common/kernel_store.h @@ -48,7 +48,7 @@ struct KernelStoreItemHead { uint32_t bin_len; }; -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY KernelStore { +class KernelStore { public: KernelStore() = default; virtual ~KernelStore() = default; diff --git a/ge/graph/common/local_context.cc b/ge/common/local_context.cc similarity index 97% rename from ge/graph/common/local_context.cc rename to ge/common/local_context.cc index bd747021..e31f2342 100644 --- a/ge/graph/common/local_context.cc +++ b/ge/common/local_context.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "framework/common/debug/ge_log.h" diff --git a/ge/graph/common/local_context.h b/ge/common/local_context.h similarity index 100% rename from ge/graph/common/local_context.h rename to ge/common/local_context.h diff --git a/ge/common/math/fp16_math.cc b/ge/common/math/fp16_math.cc index 6a9c2fb3..c2dfeb61 100755 --- a/ge/common/math/fp16_math.cc +++ b/ge/common/math/fp16_math.cc @@ -18,7 +18,7 @@ #include "external/register/register_types.h" namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t sqrt(fp16_t fp) { +fp16_t sqrt(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -29,7 +29,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t sqrt(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t rsqrt(fp16_t fp) { +fp16_t rsqrt(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -40,7 +40,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t rsqrt(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t rcp(fp16_t fp) { +fp16_t rcp(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -51,7 +51,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t rcp(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t exp(fp16_t fp) { +fp16_t exp(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -63,7 +63,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t exp(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t pow2(fp16_t fp) { +fp16_t pow2(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -75,7 +75,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t pow2(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t pow10(fp16_t fp) { +fp16_t pow10(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -87,7 +87,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t pow10(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t ln(fp16_t fp) { +fp16_t ln(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -99,7 +99,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t ln(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t log2(fp16_t fp) { +fp16_t log2(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -111,7 +111,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t log2(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t log10(fp16_t fp) { +fp16_t log10(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -123,7 +123,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t log10(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t cos(fp16_t fp) { +fp16_t cos(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -135,7 +135,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t cos(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t sin(fp16_t fp) { +fp16_t sin(fp16_t fp) { fp16_t ret; // Convert half precision float number to double double dVal = fp; @@ -147,13 +147,13 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t sin(fp16_t fp) { return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t abs(fp16_t fp) { +fp16_t abs(fp16_t fp) { fp16_t ret; ret.val = (fp.val & kFp16AbsMax); return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t max(fp16_t fp1, fp16_t fp2) { +fp16_t max(fp16_t fp1, fp16_t fp2) { if (fp1 >= fp2) { return fp1; } else { @@ -161,7 +161,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t max(fp16_t fp1, fp16_t f } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY fp16_t min(fp16_t fp1, fp16_t fp2) { +fp16_t min(fp16_t fp1, fp16_t fp2) { if (fp1 <= fp2) { return fp1; } else { diff --git a/ge/model/ge_model.cc b/ge/common/model/ge_model.cc similarity index 99% rename from ge/model/ge_model.cc rename to ge/common/model/ge_model.cc index 1bf35afc..7fc58b6d 100755 --- a/ge/model/ge_model.cc +++ b/ge/common/model/ge_model.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "model/ge_model.h" +#include "common/model/ge_model.h" #include #include "framework/common/debug/log.h" #include "graph/debug/ge_attr_define.h" diff --git a/ge/model/ge_model.h b/ge/common/model/ge_model.h similarity index 90% rename from ge/model/ge_model.h rename to ge/common/model/ge_model.h index 6356c621..0e791746 100755 --- a/ge/model/ge_model.h +++ b/ge/common/model/ge_model.h @@ -31,7 +31,7 @@ namespace ge { const uint32_t INVALID_MODEL_ID = 0xFFFFFFFFUL; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder { +class GeModel : public AttrHolder { public: GeModel(); ~GeModel() = default; @@ -82,13 +82,13 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder private: void Init(); - ProtoAttrMapHelper attrs_; + ProtoAttrMapHelper attrs_; /*lint !e148*/ Graph graph_; - std::shared_ptr task_; - TBEKernelStore tbe_kernal_store_; - CustAICPUKernelStore cust_aicpu_kernal_store_; - Buffer weights_buffer_; + std::shared_ptr task_; /*lint !e148*/ + TBEKernelStore tbe_kernal_store_; /*lint !e148*/ + CustAICPUKernelStore cust_aicpu_kernal_store_; /*lint !e148*/ + Buffer weights_buffer_; /*lint !e148*/ std::string name_; uint32_t version_ = {0}; diff --git a/ge/model/ge_root_model.cc b/ge/common/model/ge_root_model.cc similarity index 95% rename from ge/model/ge_root_model.cc rename to ge/common/model/ge_root_model.cc index b6a1e175..3fe10991 100644 --- a/ge/model/ge_root_model.cc +++ b/ge/common/model/ge_root_model.cc @@ -14,8 +14,9 @@ * limitations under the License. */ -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" #include "graph/debug/ge_attr_define.h" + namespace ge { void GeRootModel::SetSubgraphInstanceNameToModel(string instance_name, GeModelPtr ge_model) { subgraph_instance_name_to_model_.insert(std::pair(instance_name, ge_model)); diff --git a/ge/model/ge_root_model.h b/ge/common/model/ge_root_model.h similarity index 98% rename from ge/model/ge_root_model.h rename to ge/common/model/ge_root_model.h index 9e8e116e..e9ba3da6 100755 --- a/ge/model/ge_root_model.h +++ b/ge/common/model/ge_root_model.h @@ -15,7 +15,7 @@ */ #include #include "graph/compute_graph.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" #ifndef GE_MODEL_GE_ROOT_MODEL_H_ #define GE_MODEL_GE_ROOT_MODEL_H_ diff --git a/ge/common/model_parser/model_parser.cc b/ge/common/model_parser/model_parser.cc index 7447cdf8..5d1869be 100644 --- a/ge/common/model_parser/model_parser.cc +++ b/ge/common/model_parser/model_parser.cc @@ -23,12 +23,10 @@ #include "framework/common/helper/model_helper.h" namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelParserBase::ModelParserBase() {} -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelParserBase::~ModelParserBase() {} +ModelParserBase::ModelParserBase() {} +ModelParserBase::~ModelParserBase() {} -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::LoadFromFile(const char *model_path, - int32_t priority, - ge::ModelData &model_data) { +Status ModelParserBase::LoadFromFile(const char *model_path, int32_t priority, ge::ModelData &model_data) { std::string real_path = RealPath(model_path); if (real_path.empty()) { GELOGE(ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID, "[Check][Param]Model file path %s is invalid", @@ -81,9 +79,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::LoadFro return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::ParseModelContent(const ge::ModelData &model, - uint8_t *&model_data, - uint32_t &model_len) { +Status ModelParserBase::ParseModelContent(const ge::ModelData &model, uint8_t *&model_data, uint32_t &model_len) { // Parameter validity check GE_CHECK_NOTNULL(model.model_data); diff --git a/ge/common/model_saver.cc b/ge/common/model_saver.cc index 24e837f7..56045030 100755 --- a/ge/common/model_saver.cc +++ b/ge/common/model_saver.cc @@ -29,8 +29,7 @@ namespace ge { const uint32_t kInteval = 2; -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFile(const char *file_path, - const Json &model) { +Status ModelSaver::SaveJsonToFile(const char *file_path, const Json &model) { Status ret = SUCCESS; if (file_path == nullptr || SUCCESS != CheckPath(file_path)) { GELOGE(FAILED, "[Check][OutputFile]Failed, file %s", file_path); diff --git a/ge/graph/common/omg_util.cc b/ge/common/omg_util.cc similarity index 95% rename from ge/graph/common/omg_util.cc rename to ge/common/omg_util.cc index b2017e4d..31e4270a 100644 --- a/ge/graph/common/omg_util.cc +++ b/ge/common/omg_util.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" @@ -59,8 +59,8 @@ Status SetStreamLabel(const ge::NodePtr &node, const std::string &label) { if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_STREAM_LABEL, label)) { REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_STREAM_LABEL.c_str(), node->GetName().c_str(), node->GetType().c_str()); - GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_STREAM_LABEL.c_str(), - node->GetName().c_str(), node->GetType().c_str()); + GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_STREAM_LABEL.c_str(), node->GetName().c_str(), + node->GetType().c_str()); return FAILED; } @@ -100,8 +100,8 @@ Status SetActiveLabelList(const ge::NodePtr &node, const std::vectorGetName().c_str(), node->GetType().c_str()); - GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_ACTIVE_LABEL_LIST.c_str(), - node->GetName().c_str(), node->GetType().c_str()); + GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_ACTIVE_LABEL_LIST.c_str(), node->GetName().c_str(), + node->GetType().c_str()); return FAILED; } @@ -163,8 +163,8 @@ Status SetOriginalNodeName(const ge::NodePtr &node, const std::string &orig_name if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_ORIG_NODE_NAME, orig_name)) { REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_ORIG_NODE_NAME.c_str(), node->GetName().c_str(), node->GetType().c_str()); - GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_ORIG_NODE_NAME.c_str(), - node->GetName().c_str(), node->GetType().c_str()); + GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_ORIG_NODE_NAME.c_str(), node->GetName().c_str(), + node->GetType().c_str()); return FAILED; } @@ -207,8 +207,8 @@ Status SetNextIteration(const NodePtr &node, const NodePtr &next) { if (!AttrUtils::SetStr(op_desc, ATTR_NAME_NEXT_ITERATION, name)) { REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str()); - GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); + GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), op_desc->GetName().c_str(), + op_desc->GetType().c_str()); return FAILED; } return SUCCESS; @@ -290,8 +290,8 @@ void SetControlFlowGroup(const NodePtr &node, int64_t group) { if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group)) { REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), node->GetName().c_str(), node->GetType().c_str()); - GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), - node->GetName().c_str(), node->GetType().c_str()); + GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), node->GetName().c_str(), + node->GetType().c_str()); } } } // namespace ge diff --git a/ge/graph/common/omg_util.h b/ge/common/omg_util.h similarity index 100% rename from ge/graph/common/omg_util.h rename to ge/common/omg_util.h diff --git a/ge/common/op/attr_value_util.cc b/ge/common/op/attr_value_util.cc index 8be0ecd1..fd5b842a 100644 --- a/ge/common/op/attr_value_util.cc +++ b/ge/common/op/attr_value_util.cc @@ -77,37 +77,33 @@ DEFINE_SET_ATTR_VALUE_LIST(const std::string &, s); } \ } while (0); -#define DEFINE_ADD_ATTR_VALUE(KEY_TYPE, VALUE_TYPE) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpAttr(KEY_TYPE map_key, VALUE_TYPE value, OpDef *op_def) { \ - GE_CHECK_NOTNULL_JUST_RETURN(op_def); \ - auto attr = op_def->mutable_attr(); \ - ADD_TO_ATTR_MAP(map_key, value, attr) \ - } \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpAttr(KEY_TYPE map_key, VALUE_TYPE value, \ - AttrDefMap *attr_map) { \ - ADD_TO_ATTR_MAP(map_key, value, attr_map) \ - } \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddModelAttr(KEY_TYPE map_key, VALUE_TYPE value, \ - ModelDef *model_def) { \ - GE_CHECK_NOTNULL_JUST_RETURN(model_def); \ - auto attr = model_def->mutable_attr(); \ - ADD_TO_ATTR_MAP(map_key, value, attr) \ +#define DEFINE_ADD_ATTR_VALUE(KEY_TYPE, VALUE_TYPE) \ + void AddOpAttr(KEY_TYPE map_key, VALUE_TYPE value, OpDef *op_def) { \ + GE_CHECK_NOTNULL_JUST_RETURN(op_def); \ + auto attr = op_def->mutable_attr(); \ + ADD_TO_ATTR_MAP(map_key, value, attr) \ + } \ + void AddOpAttr(KEY_TYPE map_key, VALUE_TYPE value, AttrDefMap *attr_map) { \ + ADD_TO_ATTR_MAP(map_key, value, attr_map) \ + } \ + void AddModelAttr(KEY_TYPE map_key, VALUE_TYPE value, ModelDef *model_def) { \ + GE_CHECK_NOTNULL_JUST_RETURN(model_def); \ + auto attr = model_def->mutable_attr(); \ + ADD_TO_ATTR_MAP(map_key, value, attr) \ } -#define DEFINE_ADD_ATTR_VALUE_LIST(KEY_TYPE, VALUE_TYPE) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpAttrList(KEY_TYPE map_key, VALUE_TYPE value, \ - OpDef *op_def) { \ - GE_CHECK_NOTNULL_JUST_RETURN(op_def); \ - auto attr = op_def->mutable_attr(); \ - ADD_TO_ATTR_MAP_LIST(map_key, value, attr) \ - } \ - FMK_FUNC_DEV_VISIBILITY void AddOpAttrList(KEY_TYPE map_key, VALUE_TYPE value, AttrDefMap *attr_map) { \ - ADD_TO_ATTR_MAP_LIST(map_key, value, attr_map) \ - } \ - FMK_FUNC_DEV_VISIBILITY void AddModelAttrList(KEY_TYPE map_key, VALUE_TYPE value, ModelDef *model_def) { \ - GE_CHECK_NOTNULL_JUST_RETURN(model_def); \ - auto attr = model_def->mutable_attr(); \ - ADD_TO_ATTR_MAP_LIST(map_key, value, attr) \ +#define DEFINE_ADD_ATTR_VALUE_LIST(KEY_TYPE, VALUE_TYPE) \ + void AddOpAttrList(KEY_TYPE map_key, VALUE_TYPE value, OpDef *op_def) { \ + GE_CHECK_NOTNULL_JUST_RETURN(op_def); \ + auto attr = op_def->mutable_attr(); \ + ADD_TO_ATTR_MAP_LIST(map_key, value, attr) \ + } \ + void AddOpAttrList(KEY_TYPE map_key, VALUE_TYPE value, AttrDefMap *attr_map) { \ + ADD_TO_ATTR_MAP_LIST(map_key, value, attr_map)} FMK_FUNC_DEV_VISIBILITY void \ + AddModelAttrList(KEY_TYPE map_key, VALUE_TYPE value, ModelDef *model_def) { \ + GE_CHECK_NOTNULL_JUST_RETURN(model_def); \ + auto attr = model_def->mutable_attr(); \ + ADD_TO_ATTR_MAP_LIST(map_key, value, attr) \ } DEFINE_ADD_ATTR_VALUE(const std::string &, const std::string &); @@ -127,46 +123,42 @@ DEFINE_ADD_ATTR_VALUE_LIST(const std::string &, const bool); DEFINE_ADD_ATTR_VALUE_LIST(const std::string &, const int64_t); DEFINE_ADD_ATTR_VALUE_LIST(const std::string &, const std::string &); -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpAttr(const std::string &map_key, AttrDef &attr, - OpDef *op_def) { +void AddOpAttr(const std::string &map_key, AttrDef &attr, OpDef *op_def) { GE_CHECK_NOTNULL_JUST_RETURN(op_def); GE_CHECK_NOTNULL_JUST_RETURN(op_def->mutable_attr()); (void)op_def->mutable_attr()->insert(AttrDefPair(map_key, attr)); } -#define DEFINE_GET_ATTR_VALUE(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetAttrDefValue(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, \ - const AttrDefMap &attr) { \ - auto it = attr.find(map_key); \ - if (it != attr.end()) { \ - *value = it->second.FIELD(); \ - return true; \ - } \ - return false; \ +#define DEFINE_GET_ATTR_VALUE(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \ + bool GetAttrDefValue(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, const AttrDefMap &attr) { \ + auto it = attr.find(map_key); \ + if (it != attr.end()) { \ + *value = it->second.FIELD(); \ + return true; \ + } \ + return false; \ } -#define DEFINE_GET_ATTR_POINT_REF(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetAttrDefValue(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE *&value, \ - AttrDefMap *attr) { \ - GE_RT_FALSE_CHECK_NOTNULL(attr); \ - auto it = attr->find(map_key); \ - if (it != attr->end()) { \ - value = it->second.mutable_##FIELD(); \ - return true; \ - } \ - return false; \ +#define DEFINE_GET_ATTR_POINT_REF(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \ + bool GetAttrDefValue(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE *&value, AttrDefMap *attr) { \ + GE_RT_FALSE_CHECK_NOTNULL(attr); \ + auto it = attr->find(map_key); \ + if (it != attr->end()) { \ + value = it->second.mutable_##FIELD(); \ + return true; \ + } \ + return false; \ } -#define DEFINE_GET_ATTR_CONST_POINT_REF(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetAttrDefValue( \ - ARG_TYPE_KEY map_key, const ARG_TYPE_VALUE *&value, const AttrDefMap &attr) { \ - auto it = attr.find(map_key); \ - if (it == attr.end()) { \ - return false; \ - } \ - \ - value = &(it->second.FIELD()); \ - return true; \ +#define DEFINE_GET_ATTR_CONST_POINT_REF(ARG_TYPE_KEY, ARG_TYPE_VALUE, FIELD) \ + bool GetAttrDefValue(ARG_TYPE_KEY map_key, const ARG_TYPE_VALUE *&value, const AttrDefMap &attr) { \ + auto it = attr.find(map_key); \ + if (it == attr.end()) { \ + return false; \ + } \ + \ + value = &(it->second.FIELD()); \ + return true; \ } #define DEFINE_GET_BYTES_ATTR_VALUE(ARG_TYPE_KEY, ARG_TYPE_VALUE) \ @@ -216,16 +208,14 @@ DEFINE_GET_ATTR_CONST_POINT_REF(const std::string &, NamedAttrs, func); DEFINE_GET_BYTES_ATTR_VALUE(const std::string &, std::string *); -#define DEFINE_GET_OP_ATTR(ARG_TYPE_KEY, ARG_TYPE_VALUE) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetOpAttr(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, \ - const OpDef *op_def) { \ - GE_RT_FALSE_CHECK_NOTNULL(op_def); \ - return GetAttrDefValue(map_key, value, op_def->attr()); \ - } \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetModelAttr(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, \ - const ModelDef *model_def) { \ - GE_RT_FALSE_CHECK_NOTNULL(model_def); \ - return GetAttrDefValue(map_key, value, model_def->attr()); \ +#define DEFINE_GET_OP_ATTR(ARG_TYPE_KEY, ARG_TYPE_VALUE) \ + bool GetOpAttr(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, const OpDef *op_def) { \ + GE_RT_FALSE_CHECK_NOTNULL(op_def); \ + return GetAttrDefValue(map_key, value, op_def->attr()); \ + } \ + bool GetModelAttr(ARG_TYPE_KEY map_key, ARG_TYPE_VALUE value, const ModelDef *model_def) { \ + GE_RT_FALSE_CHECK_NOTNULL(model_def); \ + return GetAttrDefValue(map_key, value, model_def->attr()); \ } DEFINE_GET_OP_ATTR(const std::string &, std::string *); @@ -238,8 +228,7 @@ DEFINE_GET_OP_ATTR(const std::string &, bool *); DEFINE_GET_OP_ATTR(const std::string &, AttrDef_ListValue *); #define DEFINE_GET_BT_ATTR(ARG_TYPE_KEY, ARG_TYPE_VALUE) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool GetBytesAttr(ARG_TYPE_KEY key, ARG_TYPE_VALUE value, \ - const OpDef *op_def) { \ + bool GetBytesAttr(ARG_TYPE_KEY key, ARG_TYPE_VALUE value, const OpDef *op_def) { \ GE_RT_FALSE_CHECK_NOTNULL(op_def); \ return GetBytesValue(key, value, op_def->attr()); \ } \ @@ -250,7 +239,7 @@ DEFINE_GET_OP_ATTR(const std::string &, AttrDef_ListValue *); DEFINE_GET_BT_ATTR(const std::string &, std::string *); -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool HasOpAttr(const OpDef *op_def, const std::string &attr_name) { +bool HasOpAttr(const OpDef *op_def, const std::string &attr_name) { if (op_def == nullptr) { return false; } @@ -263,8 +252,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool HasOpAttr(const OpDef *op_ return false; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddModelAttr(const std::string &map_key, const void *value, - size_t size, ModelDef *model_def) { +void AddModelAttr(const std::string &map_key, const void *value, size_t size, ModelDef *model_def) { if (model_def == nullptr) { return; } @@ -280,8 +268,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddModelAttr(const std::st } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void AddOpBytesAttr(const std::string &key, const void *value, - size_t size, OpDef *op_def) { +void AddOpBytesAttr(const std::string &key, const void *value, size_t size, OpDef *op_def) { if (op_def == nullptr) { return; } diff --git a/ge/common/op/ge_op_utils.cc b/ge/common/op/ge_op_utils.cc index 99b5733c..429ce909 100644 --- a/ge/common/op/ge_op_utils.cc +++ b/ge/common/op/ge_op_utils.cc @@ -115,8 +115,7 @@ const int NORMAL_TENSOR_SIZE = 4; #define AIPP_CONVERT_LIST_FLOAT(KEY, REQUIRED) AIPP_CONVERT_LIST_FORMAT(KEY, float, REQUIRED, GeAttrValue::FLOAT) -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status -OpUtils::ConvertAippParams(const GeAttrValue::NAMED_ATTRS &aipp_attr, domi::AippOpParams *aipp_params) { +Status OpUtils::ConvertAippParams(const GeAttrValue::NAMED_ATTRS &aipp_attr, domi::AippOpParams *aipp_params) { GE_CHECK_NOTNULL(aipp_params); AIPP_CONVERT_FORMAT_EX(aipp_mode, domi::AippOpParams::AippMode, int32_t, GeAttrValue::INT); AIPP_CONVERT_INT(related_input_rank); @@ -178,8 +177,7 @@ OpUtils::ConvertAippParams(const GeAttrValue::NAMED_ATTRS &aipp_attr, domi::Aipp return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OpUtils::TransferDim(const std::vector &dim, - std::vector &dim_vector) { +Status OpUtils::TransferDim(const std::vector &dim, std::vector &dim_vector) { size_t input_shape_size = dim.size(); std::list new_dim_list; for (auto dim_temp : dim) { @@ -301,9 +299,9 @@ Status OpUtils::SetOutputSliceDataByDataType(void *data, int64_t data_size, cons return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OpUtils::SetOutputSliceData( - void *data, int64_t data_size, int32_t data_type, std::vector &input_dims, std::vector &begin, - std::vector &output_dims, GeTensor *output, std::vector &stride) { +Status OpUtils::SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector &input_dims, + std::vector &begin, std::vector &output_dims, GeTensor *output, + std::vector &stride) { if (data == nullptr || output == nullptr) { GELOGE(PARAM_INVALID, "[Check][Param]Input param is nullptr"); REPORT_INNER_ERROR("E19999", "Input param is nullptr"); @@ -352,9 +350,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OpUtils::SetOutputSliceD return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void OpUtils::TransDataHWCK2KCHW(const void *input, int64_t h, - int64_t w, int64_t c, int64_t k, - void **output) { +void OpUtils::TransDataHWCK2KCHW(const void *input, int64_t h, int64_t w, int64_t c, int64_t k, void **output) { if (input == nullptr) { return; } @@ -386,9 +382,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void OpUtils::TransDataHWCK2KCH *output = buf; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void OpUtils::TransDataKCHW2HWCK(const void *input, int64_t k, - int64_t c, int64_t h, int64_t w, - void *output) { +void OpUtils::TransDataKCHW2HWCK(const void *input, int64_t k, int64_t c, int64_t h, int64_t w, void *output) { if ((input == nullptr) || (output == nullptr)) { GELOGD("%s[%d]: input param is nullptr.", __FILE__, __LINE__); return; @@ -417,31 +411,22 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void OpUtils::TransDataKCHW2HWC vector OpUtils::GetWeights(const ge::Node &node) { return OpDescUtils::GetWeights(node); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY vector OpUtils::GetWeights(ge::ConstNodePtr node) { - return OpDescUtils::GetWeights(node); -} +vector OpUtils::GetWeights(ge::ConstNodePtr node) { return OpDescUtils::GetWeights(node); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY vector OpUtils::MutableWeights(const ge::Node &node) { - return OpDescUtils::MutableWeights(node); -} +vector OpUtils::MutableWeights(const ge::Node &node) { return OpDescUtils::MutableWeights(node); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY vector OpUtils::MutableWeights(const ge::NodePtr node) { - return OpDescUtils::MutableWeights(node); -} +vector OpUtils::MutableWeights(const ge::NodePtr node) { return OpDescUtils::MutableWeights(node); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OpUtils::SetWeights(ge::Node &node, - const vector &weights) { +Status OpUtils::SetWeights(ge::Node &node, const vector &weights) { return OpDescUtils::SetWeights(node, weights); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OpUtils::SetWeights(ge::NodePtr node, - const vector &weights) { +Status OpUtils::SetWeights(ge::NodePtr node, const vector &weights) { return OpDescUtils::SetWeights(node, weights); } // The caller guarantees that the input sensor is constant -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status -OpUtils::GetShapeDataFromConstTensor(const ConstGeTensorPtr &tensor, DataType type, std::vector &dims) { +Status OpUtils::GetShapeDataFromConstTensor(const ConstGeTensorPtr &tensor, DataType type, std::vector &dims) { if (tensor == nullptr) { GELOGE(PARAM_INVALID, "[Check][Param]Input tensor is nullptr"); REPORT_INNER_ERROR("E19999", "Input tensor is nullptr"); diff --git a/ge/common/profiling/ge_profiling.cc b/ge/common/profiling/ge_profiling.cc index fcd01a12..a5857b35 100644 --- a/ge/common/profiling/ge_profiling.cc +++ b/ge/common/profiling/ge_profiling.cc @@ -23,7 +23,7 @@ #include "graph/ge_context.h" #include "init/gelib.h" #include "framework/common/ge_inner_error_codes.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" #include "framework/omg/omg_inner_types.h" namespace { diff --git a/ge/common/profiling/profiling_manager.cc b/ge/common/profiling/profiling_manager.cc index 0464491d..e8f41cc4 100644 --- a/ge/common/profiling/profiling_manager.cc +++ b/ge/common/profiling/profiling_manager.cc @@ -77,12 +77,12 @@ ProfilingManager::ProfilingManager() ProfilingManager::~ProfilingManager() {} -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager &ProfilingManager::Instance() { +ProfilingManager &ProfilingManager::Instance() { static ProfilingManager profiling_manager; return profiling_manager; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::Init(const Options &options) { +ge::Status ProfilingManager::Init(const Options &options) { #ifdef DAVINCI_SUPPORT_PROFILING vector().swap(device_id_); subscribe_count_ = 0; @@ -221,7 +221,7 @@ ge::Status ProfilingManager::ParseOptions(const std::string &options) { return ge::SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProfiling() { +void ProfilingManager::StopProfiling() { #ifdef DAVINCI_SUPPORT_PROFILING uint64_t module = GetProfilingModule(); // The following if case will not be executed in normal case, inc case of ProfStopProfiling is abnormal @@ -259,8 +259,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProf #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingOpInputOutInfo( - const TaskDescInfo &task, Json &task_json) { +void ProfilingManager::ProfilingOpInputOutInfo(const TaskDescInfo &task, Json &task_json) { #ifdef DAVINCI_SUPPORT_PROFILING for (size_t i = 0; i < task.input_format.size(); i++) { Json tmp_input; @@ -286,8 +285,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingTaskDescInfo( - uint32_t model_id, const std::vector &task_desc_info, const int32_t &device_id) { +void ProfilingManager::ProfilingTaskDescInfo(uint32_t model_id, const std::vector &task_desc_info, + const int32_t &device_id) { #ifdef DAVINCI_SUPPORT_PROFILING for (const auto &task : task_desc_info) { Json task_info; @@ -324,8 +323,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfileStepInfo( - uint64_t index_id, uint64_t model_id, uint16_t tag_id, rtStream_t stream, int32_t device_id) { +Status ProfilingManager::ProfileStepInfo(uint64_t index_id, uint64_t model_id, uint16_t tag_id, rtStream_t stream, + int32_t device_id) { #ifdef DAVINCI_SUPPORT_PROFILING if (!is_load_profiling_ && subscribe_count_ == 0) { GELOGD("Profiling is not turned on, no need to profile step info."); @@ -385,8 +384,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::Profil return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportData( - const int32_t &device_id, const string &data, const string &tag_name) { +void ProfilingManager::ReportData(const int32_t &device_id, const string &data, const string &tag_name) { #ifdef DAVINCI_SUPPORT_PROFILING ReporterData reporter_data{}; int ret = -1; @@ -426,8 +424,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportDa #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportProfilingData( - uint32_t model_id, const std::vector &task_desc_info) { +void ProfilingManager::ReportProfilingData(uint32_t model_id, const std::vector &task_desc_info) { #ifdef DAVINCI_SUPPORT_PROFILING int32_t logic_device_id = 0; rtError_t rt_ret = rtGetDevice(&logic_device_id); @@ -443,7 +440,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportPr #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t ProfilingManager::GetProfilingModule() { +uint64_t ProfilingManager::GetProfilingModule() { uint64_t module = PROF_MODEL_EXECUTE_MASK | PROF_RUNTIME_API_MASK | PROF_RUNTIME_TRACE_MASK | @@ -485,8 +482,7 @@ void ProfilingManager::UpdateSubscribeDeviceModuleMap(std::string prof_type, uin #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfModelSubscribe( - uint64_t module, void *model) { +Status ProfilingManager::ProfModelSubscribe(uint64_t module, void *model) { #ifdef DAVINCI_SUPPORT_PROFILING std::lock_guard lock(mutex_); uint64_t model_load_mask = module & PROF_MODEL_LOAD_MASK; @@ -526,8 +522,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfModelUnsubscribe( - void *model) { +Status ProfilingManager::ProfModelUnsubscribe(void *model) { #ifdef DAVINCI_SUPPORT_PROFILING std::lock_guard lock(mutex_); if (subscribe_count_ == 0) { @@ -568,7 +563,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfInit(uint64_t module) { +Status ProfilingManager::ProfInit(uint64_t module) { #ifdef DAVINCI_SUPPORT_PROFILING std::lock_guard lock(mutex_); uint64_t model_load_mask = module & PROF_MODEL_LOAD_MASK; @@ -602,7 +597,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfIn return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfFinalize() { +Status ProfilingManager::ProfFinalize() { #ifdef DAVINCI_SUPPORT_PROFILING std::lock_guard lock(mutex_); is_load_profiling_ = false; @@ -697,8 +692,8 @@ Status ProfilingManager::ProfParseDeviceId(const std::map &config_para, - int32_t &device_num, vector &device_list) { +Status ProfilingManager::ProfParseParam(const std::map &config_para, int32_t &device_num, + vector &device_list) { #ifdef DAVINCI_SUPPORT_PROFILING // device num auto iter = config_para.find(kConfigNumsdev); @@ -747,8 +742,7 @@ Status ProfilingManager::ProfParseParam(const std::map return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfStartProfiling( - uint64_t module, const std::map &config_para) { +Status ProfilingManager::ProfStartProfiling(uint64_t module, const std::map &config_para) { #ifdef DAVINCI_SUPPORT_PROFILING std::lock_guard lock(mutex_); uint64_t training_trace_mask = module & PROF_TRAINING_TRACE_MASK; @@ -803,8 +797,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfStopProfiling(uint64_t module, - const std::map &config_para) { +Status ProfilingManager::ProfStopProfiling(uint64_t module, const std::map &config_para) { #ifdef DAVINCI_SUPPORT_PROFILING std::lock_guard lock(mutex_); int32_t device_num = 0; @@ -855,8 +848,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::UpdateDeviceIdModuleMap(string prof_type, - uint64_t module, const vector &device_list) { +void ProfilingManager::UpdateDeviceIdModuleMap(string prof_type, uint64_t module, const vector &device_list) { #ifdef DAVINCI_SUPPORT_PROFILING if (prof_type == kProfStart) { for (uint32_t i = 0; i < device_list.size(); i++) { @@ -886,7 +878,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::UpdateDe #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ProfilingManager::ProfilingModelExecuteOn() const { +bool ProfilingManager::ProfilingModelExecuteOn() const { int32_t logic_device_id = 0; rtError_t rt_ret = rtGetDevice(&logic_device_id); if (rt_ret != RT_ERROR_NONE) { @@ -904,7 +896,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ProfilingManager::Profilin return execute_model_prof_on; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::PluginInit() { +Status ProfilingManager::PluginInit() { if (prof_cb_.msprofReporterCallback == nullptr) { GELOGE(ge::PARAM_INVALID, "[Check][Param]MsprofReporterCallback callback is nullptr"); REPORT_INNER_ERROR("E19999", "MsprofReporterCallback callback is nullptr"); @@ -933,7 +925,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::Plugin return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::PluginUnInit() const { +void ProfilingManager::PluginUnInit() const { #ifdef DAVINCI_SUPPORT_PROFILING if (prof_cb_.msprofReporterCallback == nullptr) { GELOGE(ge::PARAM_INVALID, "[Check][Param]MsprofReporterCallback callback is nullptr"); @@ -950,8 +942,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::PluginUn #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::CallMsprofReport( - ReporterData &reporter_data) const { +Status ProfilingManager::CallMsprofReport(ReporterData &reporter_data) const { if (prof_cb_.msprofReporterCallback == nullptr) { GELOGE(ge::PARAM_INVALID, "[Check][Param]MsprofReporterCallback callback is nullptr"); REPORT_INNER_ERROR("E19999", "MsprofReporterCallback callback is nullptr"); @@ -1007,14 +998,12 @@ void ProfilingManager::GetOpOutputInfo(const OpDescPtr &op, TaskDescInfo &task_d task_desc_info.output_data_type = output_data_type.empty() ? data_type_default : output_data_type; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetOpInputOutputInfo( - const OpDescPtr &op, TaskDescInfo &task_desc_info) const { +void ProfilingManager::GetOpInputOutputInfo(const OpDescPtr &op, TaskDescInfo &task_desc_info) const { GetOpInputInfo(op, task_desc_info); GetOpOutputInfo(op, task_desc_info); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpPoint( - std::string &fp_point, std::string &bp_point) { +void ProfilingManager::GetFpBpPoint(std::string &fp_point, std::string &bp_point) { // Env or options mode, fp_point_/bp_point_ have initiliazed on profiling init if (!fp_point_.empty() && !bp_point_.empty()) { fp_point = fp_point_; @@ -1025,7 +1014,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpP } // ProfApi mode and training trace is set // Parse options first - char env_profiling_options[MSPROF_OPTIONS_DEF_LEN_MAX] = { 0x00 }; + char env_profiling_options[MSPROF_OPTIONS_DEF_LEN_MAX] = {0x00}; bool is_profiling_valid = false; std::string profiling_options; if (ge::GetContext().GetOption(OPTION_EXEC_PROFILING_OPTIONS, profiling_options) == SUCCESS && diff --git a/ge/common/profiling/profiling_manager.h b/ge/common/profiling/profiling_manager.h index e5137562..86371d51 100755 --- a/ge/common/profiling/profiling_manager.h +++ b/ge/common/profiling/profiling_manager.h @@ -73,7 +73,7 @@ struct MsprofCallback { MsprofReporterCallback msprofReporterCallback; }; -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { +class ProfilingManager { public: ProfilingManager(); virtual ~ProfilingManager(); diff --git a/ge/common/properties_manager.cc b/ge/common/properties_manager.cc index 0c5ef1fe..aeabb008 100644 --- a/ge/common/properties_manager.cc +++ b/ge/common/properties_manager.cc @@ -35,13 +35,13 @@ PropertiesManager::PropertiesManager() : is_inited_(false), delimiter("=") {} PropertiesManager::~PropertiesManager() {} // singleton -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY PropertiesManager &PropertiesManager::Instance() { +PropertiesManager &PropertiesManager::Instance() { static PropertiesManager instance; return instance; } // Initialize property configuration -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool PropertiesManager::Init(const std::string &file_path) { +bool PropertiesManager::Init(const std::string &file_path) { std::lock_guard lock(mutex_); if (is_inited_) { GELOGW("Already inited, will be initialized again"); @@ -139,8 +139,7 @@ std::string PropertiesManager::Trim(const std::string &str) { } // Get property value, if not found, return "" -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetPropertyValue( - const std::string &map_key) { +std::string PropertiesManager::GetPropertyValue(const std::string &map_key) { std::lock_guard lock(mutex_); auto iter = properties_map_.find(map_key); if (properties_map_.end() != iter) { @@ -151,21 +150,19 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager:: } // Set property value -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetPropertyValue(const std::string &map_key, - const std::string &value) { +void PropertiesManager::SetPropertyValue(const std::string &map_key, const std::string &value) { std::lock_guard lock(mutex_); properties_map_[map_key] = value; } // return properties_map_ -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map -PropertiesManager::GetPropertyMap() { +std::map PropertiesManager::GetPropertyMap() { std::lock_guard lock(mutex_); return properties_map_; } // Set separator -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetPropertyDelimiter(const std::string &de) { +void PropertiesManager::SetPropertyDelimiter(const std::string &de) { std::lock_guard lock(mutex_); delimiter = de; } diff --git a/ge/common/tbe_kernel_store.h b/ge/common/tbe_kernel_store.h index 6304af50..1492bdd9 100755 --- a/ge/common/tbe_kernel_store.h +++ b/ge/common/tbe_kernel_store.h @@ -21,7 +21,7 @@ namespace ge { -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY TBEKernelStore : public KernelStore { +class TBEKernelStore : public KernelStore { public: TBEKernelStore(); ~TBEKernelStore() {} diff --git a/ge/common/thread_pool.cc b/ge/common/thread_pool.cc index f9b7bb99..56f8ee60 100644 --- a/ge/common/thread_pool.cc +++ b/ge/common/thread_pool.cc @@ -26,7 +26,7 @@ #include "external/register/register_types.h" namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::ThreadPool(uint32_t size) : is_stoped_(false) { +ThreadPool::ThreadPool(uint32_t size) : is_stoped_(false) { idle_thrd_num_ = size < 1 ? 1 : size; for (uint32_t i = 0; i < idle_thrd_num_; ++i) { @@ -34,7 +34,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::ThreadPool(uint32_t } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::~ThreadPool() { +ThreadPool::~ThreadPool() { is_stoped_.store(true); { std::unique_lock lock{m_lock_}; diff --git a/ge/common/thread_pool.h b/ge/common/thread_pool.h index 7e52edcc..777a3c9b 100755 --- a/ge/common/thread_pool.h +++ b/ge/common/thread_pool.h @@ -37,7 +37,7 @@ namespace ge { using ThreadTask = std::function; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ThreadPool { +class ThreadPool { public: explicit ThreadPool(uint32_t size = 4); ~ThreadPool(); diff --git a/ge/graph/common/transop_util.cc b/ge/common/transop_util.cc similarity index 97% rename from ge/graph/common/transop_util.cc rename to ge/common/transop_util.cc index 871ecdb1..914e80aa 100755 --- a/ge/graph/common/transop_util.cc +++ b/ge/common/transop_util.cc @@ -14,9 +14,9 @@ * limitations under the License. */ -#include "graph/common/transop_util.h" +#include "common/transop_util.h" -#include "framework/common/types.h" +#include "common/types.h" #include "graph/utils/type_utils.h" #include "framework/common/debug/ge_log.h" diff --git a/ge/graph/common/transop_util.h b/ge/common/transop_util.h similarity index 95% rename from ge/graph/common/transop_util.h rename to ge/common/transop_util.h index 883ae41b..57e4adad 100644 --- a/ge/graph/common/transop_util.h +++ b/ge/common/transop_util.h @@ -23,7 +23,7 @@ #include "graph/node.h" namespace ge { -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY TransOpUtil { +class TransOpUtil { public: static bool IsTransOp(const NodePtr &node); diff --git a/ge/common/util.cc b/ge/common/util.cc index dfb5bac4..6d77dbc8 100644 --- a/ge/common/util.cc +++ b/ge/common/util.cc @@ -70,7 +70,7 @@ static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Messag return proto->ParseFromCodedStream(&coded_stream); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) { +bool ReadProtoFromArray(const void *data, int size, Message *proto) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto == nullptr || data == nullptr || size == 0), return false, "incorrect parameter. proto is nullptr || data is nullptr || size is 0"); @@ -112,8 +112,7 @@ long GetFileLength(const std::string &input_file) { * @return false fail * @return true success */ -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, - int &length) { +bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr), return false, "incorrect parameter. file is nullptr"); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((buffer == nullptr), return false, "incorrect parameter. buffer is nullptr"); @@ -141,8 +140,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co return true; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, - std::vector &buffer) { +bool ReadBytesFromBinaryFile(const char *file_name, std::vector &buffer) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr), return false, "incorrect parameter. file path is null"); std::string real_path = RealPath(file_name); @@ -177,7 +175,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co * @return -1 fail * @return 0 success */ -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std::string &directory_path) { +int CreateDirectory(const std::string &directory_path) { GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); auto dir_path_len = directory_path.length(); if (dir_path_len >= MMPA_MAX_PATH) { @@ -219,7 +217,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: return 0; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string CurrentTimeInStr() { +std::string CurrentTimeInStr() { std::time_t now = std::time(nullptr); std::tm *ptm = std::localtime(&now); if (ptm == nullptr) { @@ -235,8 +233,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string CurrentTimeInStr() return std::string(buffer); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file, - google::protobuf::Message *message) { +bool ReadProtoFromText(const char *file, google::protobuf::Message *message) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || message == nullptr), return false, "incorrect parameter. nullptr == file || nullptr == message"); @@ -266,8 +263,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size, - google::protobuf::Message *message) { +bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((data == nullptr || message == nullptr), return false, "incorrect parameter. data is nullptr || message is nullptr"); std::string str(data, static_cast(size)); @@ -281,7 +277,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const cha return ret; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() { +uint64_t GetCurrentTimestamp() { mmTimeval tv{}; int ret = mmGetTimeOfDay(&tv, nullptr); GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed, ret:%d, errmsg:%s", ret, strerror(errno)); @@ -289,7 +285,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() return static_cast(total_use_time); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t GetCurrentSecondTimestap() { +uint32_t GetCurrentSecondTimestap() { mmTimeval tv{}; int ret = mmGetTimeOfDay(&tv, nullptr); GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed, ret:%d, errmsg:%s", ret, strerror(errno)); @@ -297,7 +293,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t GetCurrentSecondTimest return static_cast(total_use_time); } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInt64MulOverflow(int64_t a, int64_t b) { +bool CheckInt64MulOverflow(int64_t a, int64_t b) { if (a > 0) { if (b > 0) { if (a > (INT64_MAX / b)) { @@ -322,7 +318,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInt64MulOverflow(int6 return true; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) { +std::string RealPath(const char *path) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(path) >= MMPA_MAX_PATH, ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, @@ -349,8 +345,7 @@ void PathValidErrReport(const std::string &file_path, const std::string &atc_par } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path, - const std::string &atc_param) { +bool CheckInputPathValid(const std::string &file_path, const std::string &atc_param) { // The specified path is empty std::map args_map; if (file_path.empty()) { @@ -395,8 +390,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const return true; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const std::string &file_path, - const std::string &atc_param) { +bool CheckOutputPathValid(const std::string &file_path, const std::string &atc_param) { // The specified path is empty if (file_path.empty()) { if (!atc_param.empty()) { @@ -552,7 +546,7 @@ FMK_FUNC_HOST_VISIBILITY bool IsValidFile(const char *file_path) { return true; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status CheckPath(const char *path, size_t length) { +Status CheckPath(const char *path, size_t length) { if (path == nullptr) { GELOGE(PARAM_INVALID, "[Check][Param]Config path is invalid"); REPORT_CALL_ERROR("E19999", "Config path is invalid"); diff --git a/ge/executor/CMakeLists.txt b/ge/executor/CMakeLists.txt index 44ba3131..54cb7639 100755 --- a/ge/executor/CMakeLists.txt +++ b/ge/executor/CMakeLists.txt @@ -1,13 +1,9 @@ set(SRC_LIST "ge_executor.cc" "../common/profiling/profiling_manager.cc" - "../common/ge/plugin_manager.cc" - "../common/ge/op_tiling_manager.cc" - "../common/dump/dump_properties.cc" - "../common/dump/exception_dumper.cc" - "../common/dump/dump_manager.cc" "../common/dump/dump_op.cc" "../common/dump/opdebug_register.cc" + "../common/dump/exception_dumper.cc" "../common/profiling/ge_profiling.cc" "../graph/load/graph_loader.cc" "../graph/execute/graph_execute.cc" @@ -22,8 +18,6 @@ set(SRC_LIST "../graph/manager/rdma_pool_allocator.cc" "../graph/manager/host_mem_allocator.cc" "../hybrid/node_executor/aicpu/aicpu_ext_info.cc" - "../model/ge_model.cc" - "../model/ge_root_model.cc" "../graph/load/model_manager/davinci_model.cc" "../graph/load/model_manager/model_manager.cc" "../graph/load/model_manager/tbe_handle_store.cc" @@ -55,7 +49,6 @@ set(SRC_LIST "../graph/load/model_manager/task_info/model_exit_task_info.cc" "../graph/load/model_manager/task_info/super_kernel/super_kernel_factory.cc" "../graph/load/model_manager/task_info/super_kernel/super_kernel.cc" - "../graph/common/local_context.cc" "../opskernel_manager/ops_kernel_builder_manager.cc" "../single_op/single_op_manager.cc" "../single_op/single_op_model.cc" @@ -102,7 +95,6 @@ set(SRC_LIST "../hybrid/node_executor/task_context.cc" "../hybrid/hybrid_davinci_model.cc" "../ge_local_engine/engine/host_cpu_engine.cc" - "../graph/common/omg_util.cc" "../graph/manager/host_mem_manager.cc" "../graph/build/memory/var_mem_assign_util.cc" "../host_kernels/transpose_kernel.cc" @@ -144,10 +136,6 @@ set(SRC_LIST "../host_kernels/transdata_kernel.cc" "../host_kernels/unpack_kernel.cc" "../graph/passes/pass_utils.cc" - "../graph/common/bcast.cc" - "../common/fp16_t.cc" - "../common/formats/format_transfers/format_transfer_transpose.cc" - "../common/formats/utils/formats_trans_utils.cc" ) ######## libge_executor.a ######## diff --git a/ge/executor/module.mk b/ge/executor/module.mk index 7a7e2b51..430efa75 100644 --- a/ge/executor/module.mk +++ b/ge/executor/module.mk @@ -63,7 +63,7 @@ local_ge_executor_src_files := \ ../single_op/task/aicpu_task_builder.cc \ ../single_op/task/aicpu_kernel_task_builder.cc \ ../hybrid/node_executor/aicpu/aicpu_ext_info.cc \ - ../graph/common/local_context.cc \ + ../common/local_context.cc \ ../hybrid/common/tensor_value.cc \ ../hybrid/common/npu_memory_allocator.cc \ ../hybrid/executor/rt_callback_manager.cc \ @@ -102,7 +102,7 @@ local_ge_executor_src_files := \ ../hybrid/node_executor/task_context.cc \ ../hybrid/hybrid_davinci_model.cc \ ../ge_local_engine/engine/host_cpu_engine.cc \ - ../graph/common/omg_util.cc \ + ../common/omg_util.cc \ ../graph/manager/host_mem_manager.cc \ ../graph/build/memory/var_mem_assign_util.cc \ ../host_kernels/transpose_kernel.cc \ @@ -144,7 +144,7 @@ local_ge_executor_src_files := \ ../host_kernels/transdata_kernel.cc \ ../host_kernels/unpack_kernel.cc \ ../graph/passes/pass_utils.cc \ - ../graph/common/bcast.cc \ + ../common/bcast.cc \ ../common/fp16_t.cc \ ../common/formats/format_transfers/format_transfer_transpose.cc \ ../common/formats/utils/formats_trans_utils.cc \ diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index a56eaadf..3fd8be1a 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -80,7 +80,7 @@ ANALYZER_SRC_FILES:= \ OMG_HOST_SRC_FILES := \ model/ge_model.cc \ model/ge_root_model.cc \ - graph/common/transop_util.cc \ + common/transop_util.cc \ graph/passes/pass_manager.cc \ graph/passes/resource_pair_add_control_pass.cc \ graph/passes/resource_pair_remove_control_pass.cc \ @@ -115,9 +115,9 @@ OMG_HOST_SRC_FILES := \ graph/passes/mark_graph_unknown_status_pass.cc \ graph/passes/mark_node_unknown_shape_pass.cc \ graph/passes/mark_agnostic_pass.cc \ - graph/common/omg_util.cc \ - graph/common/bcast.cc \ - graph/common/local_context.cc \ + common/omg_util.cc \ + common/bcast.cc \ + common/local_context.cc \ graph/passes/dimension_compute_pass.cc \ graph/passes/dimension_adjust_pass.cc \ graph/passes/get_original_format_pass.cc \ diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 8ca8572c..d6462542 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -43,10 +43,10 @@ LIBGE_LOCAL_SRC_FILES := \ graph/build/stream_allocator.cc \ graph/build/stream_graph_optimizer.cc \ graph/build/task_generator.cc \ - graph/common/bcast.cc \ - graph/common/local_context.cc \ - graph/common/omg_util.cc \ - graph/common/transop_util.cc \ + common/bcast.cc \ + common/local_context.cc \ + common/omg_util.cc \ + common/transop_util.cc \ graph/execute/graph_execute.cc \ graph/label/case_label_maker.cc \ graph/label/if_label_maker.cc \ diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 07355ab5..45eaed59 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -36,7 +36,7 @@ #include "graph/utils/graph_utils.h" #include "graph/utils/type_utils.h" #include "init/gelib.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" #include "analyzer/analyzer.h" using std::map; diff --git a/ge/graph/build/graph_builder.cc b/ge/graph/build/graph_builder.cc index 96dea02e..e1398d1f 100644 --- a/ge/graph/build/graph_builder.cc +++ b/ge/graph/build/graph_builder.cc @@ -21,14 +21,14 @@ #include "graph/build/logical_stream_allocator.h" #include "graph/build/run_context.h" #include "graph/build/stream_graph_optimizer.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "graph/ge_context.h" #include "graph/manager/graph_var_manager.h" #include "graph/passes/mark_same_addr_pass.h" #include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" #include "init/gelib.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" #include "graph/ge_context.h" #include "opskernel_manager/ops_kernel_builder_manager.h" #include "graph/utils/op_desc_utils.h" diff --git a/ge/graph/build/graph_builder.h b/ge/graph/build/graph_builder.h index c4b16814..6ed14dae 100644 --- a/ge/graph/build/graph_builder.h +++ b/ge/graph/build/graph_builder.h @@ -38,7 +38,7 @@ #include "graph/partition/graph_partition.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" namespace ge { class GraphBuilder { diff --git a/ge/graph/build/logical_stream_allocator.cc b/ge/graph/build/logical_stream_allocator.cc index 58763aa9..3d6ca74a 100644 --- a/ge/graph/build/logical_stream_allocator.cc +++ b/ge/graph/build/logical_stream_allocator.cc @@ -22,7 +22,7 @@ #include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" using std::map; using std::set; diff --git a/ge/graph/build/memory/block_mem_assigner.cc b/ge/graph/build/memory/block_mem_assigner.cc index 159e68a7..7d0db676 100755 --- a/ge/graph/build/memory/block_mem_assigner.cc +++ b/ge/graph/build/memory/block_mem_assigner.cc @@ -34,7 +34,7 @@ #include "graph/debug/ge_attr_define.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/optimize/common/params.h" #include "framework/omg/omg_inner_types.h" #include "runtime/mem.h" diff --git a/ge/graph/build/memory/buffer_pool_mem_assigner.cc b/ge/graph/build/memory/buffer_pool_mem_assigner.cc index d66fe038..ca197e02 100644 --- a/ge/graph/build/memory/buffer_pool_mem_assigner.cc +++ b/ge/graph/build/memory/buffer_pool_mem_assigner.cc @@ -15,7 +15,7 @@ */ #include "graph/build/memory/buffer_pool_mem_assigner.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/tensor_utils.h" #include "framework/common/util.h" #include "graph/compute_graph.h" diff --git a/ge/graph/build/memory/graph_mem_assigner.cc b/ge/graph/build/memory/graph_mem_assigner.cc index e086940a..f8878383 100755 --- a/ge/graph/build/memory/graph_mem_assigner.cc +++ b/ge/graph/build/memory/graph_mem_assigner.cc @@ -24,7 +24,7 @@ #include "graph/build/memory/hybrid_mem_assigner.h" #include "graph/build/memory/var_mem_assign_util.h" #include "graph/build/memory/block_mem_assigner.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_attr_value.h" #include "graph/manager/graph_var_manager.h" diff --git a/ge/graph/build/memory/var_mem_assign_util.cc b/ge/graph/build/memory/var_mem_assign_util.cc index adddf6bd..dc7c3b01 100755 --- a/ge/graph/build/memory/var_mem_assign_util.cc +++ b/ge/graph/build/memory/var_mem_assign_util.cc @@ -18,7 +18,7 @@ #include #include "framework/common/types.h" #include "framework/common/debug/ge_log.h" -#include "graph/common/transop_util.h" +#include "common/transop_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_var_manager.h" diff --git a/ge/graph/build/model_builder.cc b/ge/graph/build/model_builder.cc index 2816f170..897be1f8 100755 --- a/ge/graph/build/model_builder.cc +++ b/ge/graph/build/model_builder.cc @@ -25,9 +25,9 @@ #include "external/graph/attr_value.h" #include "graph/buffer.h" #include "graph/build/stream_allocator.h" -#include "graph/common/omg_util.h" -#include "graph/common/ge_call_wrapper.h" -#include "graph/common/local_context.h" +#include "common/omg_util.h" +#include "common/ge_call_wrapper.h" +#include "common/local_context.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_attr_value.h" #include "graph/ge_context.h" diff --git a/ge/graph/build/model_builder.h b/ge/graph/build/model_builder.h index 151e6006..d87976dd 100644 --- a/ge/graph/build/model_builder.h +++ b/ge/graph/build/model_builder.h @@ -32,7 +32,7 @@ #include "graph/manager/graph_manager_utils.h" #include "graph/model.h" #include "graph/node.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" #include "framework/omg/omg_inner_types.h" namespace ge { diff --git a/ge/graph/build/run_context.cc b/ge/graph/build/run_context.cc index e7f07c0a..e629bddc 100644 --- a/ge/graph/build/run_context.cc +++ b/ge/graph/build/run_context.cc @@ -18,7 +18,7 @@ #include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { RunContextUtil::~RunContextUtil() { DestroyRtModelResources(); } diff --git a/ge/graph/build/stream_allocator.cc b/ge/graph/build/stream_allocator.cc index bc34a228..987a77f7 100644 --- a/ge/graph/build/stream_allocator.cc +++ b/ge/graph/build/stream_allocator.cc @@ -22,7 +22,7 @@ #include "framework/common/fmk_error_codes.h" #include "framework/common/types.h" #include "graph/build/logical_stream_allocator.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/build/task_generator.cc b/ge/graph/build/task_generator.cc index 67289f73..7bb2e2f6 100755 --- a/ge/graph/build/task_generator.cc +++ b/ge/graph/build/task_generator.cc @@ -29,7 +29,7 @@ #include "graph/utils/node_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "init/gelib.h" #include "graph/ge_local_context.h" #include "external/ge/ge_api_types.h" diff --git a/ge/graph/execute/model_executor.cc b/ge/graph/execute/model_executor.cc index d1683f1d..993ba8c3 100644 --- a/ge/graph/execute/model_executor.cc +++ b/ge/graph/execute/model_executor.cc @@ -18,8 +18,8 @@ #include "graph/ge_context.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/ge_call_wrapper.h" -#include "graph/common/local_context.h" +#include "common/ge_call_wrapper.h" +#include "common/local_context.h" #include "graph/manager/graph_var_manager.h" #include "graph/utils/tensor_adapter.h" #include "graph/load/graph_loader.h" diff --git a/ge/graph/load/model_manager/davinci_model.cc b/ge/graph/load/model_manager/davinci_model.cc index 1d6f7aff..aba06173 100755 --- a/ge/graph/load/model_manager/davinci_model.cc +++ b/ge/graph/load/model_manager/davinci_model.cc @@ -32,7 +32,7 @@ #include "common/thread_pool.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" @@ -57,9 +57,9 @@ #include "runtime/rt_model.h" #include "runtime/stream.h" #include "securec.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "common/formats/utils/formats_trans_utils.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/build/memory/block_mem_assigner.h" #include "graph/manager/session_scope_mem_allocator.h" #include "framework/omg/omg_inner_types.h" diff --git a/ge/graph/load/model_manager/davinci_model.h b/ge/graph/load/model_manager/davinci_model.h index 4ff36677..fe89f66f 100755 --- a/ge/graph/load/model_manager/davinci_model.h +++ b/ge/graph/load/model_manager/davinci_model.h @@ -49,7 +49,7 @@ #include "mmpa/mmpa_api.h" #include "proto/task.pb.h" #include "graph/load/model_manager/task_info/task_info.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" using std::mutex; using std::thread; diff --git a/ge/graph/load/model_manager/model_manager.cc b/ge/graph/load/model_manager/model_manager.cc index 5af503b2..d0d88e66 100755 --- a/ge/graph/load/model_manager/model_manager.cc +++ b/ge/graph/load/model_manager/model_manager.cc @@ -23,9 +23,9 @@ #include "common/dump/dump_manager.h" #include "framework/common/l2_cache_optimize.h" #include "common/profiling/profiling_manager.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "graph/load/model_manager/davinci_model.h" -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" #include "common/formats/utils/formats_trans_utils.h" namespace ge { diff --git a/ge/graph/load/model_manager/model_manager.h b/ge/graph/load/model_manager/model_manager.h index 63a03dd7..6389d6db 100755 --- a/ge/graph/load/model_manager/model_manager.h +++ b/ge/graph/load/model_manager/model_manager.h @@ -17,7 +17,7 @@ #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_MODEL_MANAGER_H_ #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_MODEL_MANAGER_H_ -#include +#include #include #include #include diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 84ed3ab0..7d72d85b 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -29,9 +29,9 @@ #include "common/dump/dump_manager.h" #include "ge_opt_info/ge_opt_info.h" #include "analyzer/analyzer.h" -#include "graph/common/ge_call_wrapper.h" -#include "graph/common/local_context.h" -#include "graph/common/transop_util.h" +#include "common/ge_call_wrapper.h" +#include "common/local_context.h" +#include "common/transop_util.h" #include "graph/ge_context.h" #include "graph/ge_global_options.h" #include "graph/manager/util/rt_context_util.h" @@ -103,8 +103,8 @@ #include "inc/pass_manager.h" #include "init/gelib.h" #include "ir_build/option_utils.h" -#include "graph/common/local_context.h" -#include "graph/common/omg_util.h" +#include "common/local_context.h" +#include "common/omg_util.h" #include "common/formats/utils/formats_trans_utils.h" #include "register/custom_pass_helper.h" #include "external/graph/types.h" diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index 763654bd..84d2b11e 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -38,7 +38,7 @@ #include "graph/partition/graph_partition.h" #include "graph/preprocess/graph_preprocess.h" #include "graph/tuning_utils.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" #include "common/executor.h" namespace ge { diff --git a/ge/graph/manager/graph_manager_utils.h b/ge/graph/manager/graph_manager_utils.h index 9cec6b6d..14eb67f2 100644 --- a/ge/graph/manager/graph_manager_utils.h +++ b/ge/graph/manager/graph_manager_utils.h @@ -33,11 +33,11 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/compute_graph.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "external/graph/graph.h" #include "graph/model.h" -#include "model/ge_model.h" -#include "model/ge_root_model.h" +#include "common/model/ge_model.h" +#include "common/model/ge_root_model.h" #include "external/register/register_fmk_types.h" #include "external/ge/ge_api_types.h" diff --git a/ge/graph/optimize/graph_optimize.cc b/ge/graph/optimize/graph_optimize.cc index 55f374eb..a321ed43 100644 --- a/ge/graph/optimize/graph_optimize.cc +++ b/ge/graph/optimize/graph_optimize.cc @@ -17,7 +17,7 @@ #include "graph/optimize/graph_optimize.h" #include "graph/ge_context.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/passes/dimension_adjust_pass.h" #include "inc/pass_manager.h" #include "init/gelib.h" diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc index 7e7ab908..2edb1828 100644 --- a/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -17,7 +17,7 @@ #include #include "common/ge/ge_util.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/optimize/graph_optimize.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/partition/dynamic_shape_partition.cc b/ge/graph/partition/dynamic_shape_partition.cc index 8fc19ff2..cd98b6c5 100755 --- a/ge/graph/partition/dynamic_shape_partition.cc +++ b/ge/graph/partition/dynamic_shape_partition.cc @@ -31,7 +31,7 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #define REQUIRE(cond, ...) \ do { \ diff --git a/ge/graph/partition/graph_partition.cc b/ge/graph/partition/graph_partition.cc index 6f221d97..86c9f1fd 100755 --- a/ge/graph/partition/graph_partition.cc +++ b/ge/graph/partition/graph_partition.cc @@ -28,7 +28,7 @@ #include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/manager/graph_manager_utils.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" diff --git a/ge/graph/partition/graph_partition.h b/ge/graph/partition/graph_partition.h index 560aa9e7..6c21fabe 100644 --- a/ge/graph/partition/graph_partition.h +++ b/ge/graph/partition/graph_partition.h @@ -70,6 +70,8 @@ class GraphPartitioner { // Return all subgraphs const Graph2SubGraphInfoList &GetSubGraphMap(); + const Graph2InputNodesSubGraphInfo &GetSubGraphInfoMap() {return graph_2_input_subgraph_; } + private: Status MergeSubGraph(ge::ComputeGraphPtr &output_merged_compute_graph, const ge::ComputeGraphPtr &original_compute_graph); diff --git a/ge/graph/passes/atomic_addr_clean_pass.cc b/ge/graph/passes/atomic_addr_clean_pass.cc index cc22d126..13700e2e 100755 --- a/ge/graph/passes/atomic_addr_clean_pass.cc +++ b/ge/graph/passes/atomic_addr_clean_pass.cc @@ -24,7 +24,7 @@ #include "framework/common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/node_utils.h" #include "init/gelib.h" diff --git a/ge/graph/passes/attach_stream_label_pass.cc b/ge/graph/passes/attach_stream_label_pass.cc index bcf86bc2..71d74500 100644 --- a/ge/graph/passes/attach_stream_label_pass.cc +++ b/ge/graph/passes/attach_stream_label_pass.cc @@ -16,7 +16,7 @@ #include "graph/passes/attach_stream_label_pass.h" #include "external/ge/ge_api_types.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" using std::string; diff --git a/ge/graph/passes/buffer_pool_memory_pass.cc b/ge/graph/passes/buffer_pool_memory_pass.cc index 8a64da59..deb25325 100644 --- a/ge/graph/passes/buffer_pool_memory_pass.cc +++ b/ge/graph/passes/buffer_pool_memory_pass.cc @@ -18,7 +18,7 @@ #include #include -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/node_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/op_desc_utils.h" diff --git a/ge/graph/passes/cast_remove_pass.cc b/ge/graph/passes/cast_remove_pass.cc index 564b311d..1e1f4eb4 100644 --- a/ge/graph/passes/cast_remove_pass.cc +++ b/ge/graph/passes/cast_remove_pass.cc @@ -18,7 +18,7 @@ #include #include #include "framework/common/debug/ge_log.h" -#include "graph/common/transop_util.h" +#include "common/transop_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" diff --git a/ge/graph/passes/cast_translate_pass.cc b/ge/graph/passes/cast_translate_pass.cc index d49424c8..704faeda 100644 --- a/ge/graph/passes/cast_translate_pass.cc +++ b/ge/graph/passes/cast_translate_pass.cc @@ -22,7 +22,7 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/passes/pass_utils.h" #include "graph/utils/node_utils.h" diff --git a/ge/graph/passes/compile_nodes_pass.cc b/ge/graph/passes/compile_nodes_pass.cc index 1e734178..c5976f11 100755 --- a/ge/graph/passes/compile_nodes_pass.cc +++ b/ge/graph/passes/compile_nodes_pass.cc @@ -22,7 +22,7 @@ #include "framework/common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "graph/op_desc.h" using domi::ImplyType; diff --git a/ge/graph/passes/control_trigger_pass.cc b/ge/graph/passes/control_trigger_pass.cc index 85505dc5..d81edefd 100644 --- a/ge/graph/passes/control_trigger_pass.cc +++ b/ge/graph/passes/control_trigger_pass.cc @@ -17,7 +17,7 @@ #include "graph/passes/control_trigger_pass.h" #include #include "common/ge/ge_util.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/type_utils.h" namespace ge { diff --git a/ge/graph/passes/dimension_adjust_pass.h b/ge/graph/passes/dimension_adjust_pass.h index a84f0d8d..cba283ed 100755 --- a/ge/graph/passes/dimension_adjust_pass.h +++ b/ge/graph/passes/dimension_adjust_pass.h @@ -21,7 +21,7 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/passes/base_pass.h" #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/passes/flow_ctrl_pass.cc b/ge/graph/passes/flow_ctrl_pass.cc index 87896dc3..e75a4592 100755 --- a/ge/graph/passes/flow_ctrl_pass.cc +++ b/ge/graph/passes/flow_ctrl_pass.cc @@ -22,7 +22,7 @@ #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "common/ge/ge_util.h" #include "graph/manager/graph_var_manager.h" #include "graph/passes/pass_utils.h" diff --git a/ge/graph/passes/get_original_format_pass.cc b/ge/graph/passes/get_original_format_pass.cc index 4b27dd0e..0da4c5cc 100644 --- a/ge/graph/passes/get_original_format_pass.cc +++ b/ge/graph/passes/get_original_format_pass.cc @@ -25,7 +25,7 @@ #include "framework/omg/omg_inner_types.h" #include "graph/utils/attr_utils.h" #include "graph/utils/op_desc_utils.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" using domi::DOMI_TENSOR_NCHW; using domi::DOMI_TENSOR_NHWC; diff --git a/ge/graph/passes/guarantee_const_pass.cc b/ge/graph/passes/guarantee_const_pass.cc index b1df73a9..06bc821c 100644 --- a/ge/graph/passes/guarantee_const_pass.cc +++ b/ge/graph/passes/guarantee_const_pass.cc @@ -21,7 +21,7 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/type_utils.h" diff --git a/ge/graph/passes/hccl_tailing_optimization_pass.cc b/ge/graph/passes/hccl_tailing_optimization_pass.cc index d952885d..fe606067 100644 --- a/ge/graph/passes/hccl_tailing_optimization_pass.cc +++ b/ge/graph/passes/hccl_tailing_optimization_pass.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "graph/passes/hccl_tailing_optimization_pass.h" -#include "graph/common/transop_util.h" +#include "common/transop_util.h" namespace ge { Status HcclTailingOptimizationPass::Run(ComputeGraphPtr graph) { diff --git a/ge/graph/passes/identity_pass.cc b/ge/graph/passes/identity_pass.cc index f0653983..0a346bb1 100755 --- a/ge/graph/passes/identity_pass.cc +++ b/ge/graph/passes/identity_pass.cc @@ -19,7 +19,7 @@ #include #include #include "framework/common/debug/ge_log.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/node_utils.h" #include "graph/utils/attr_utils.h" #include "graph/debug/ge_attr_define.h" diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index 60a2f09a..a5e64519 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -22,7 +22,7 @@ #include "graph/shape_refiner.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" diff --git a/ge/graph/passes/iterator_op_pass.cc b/ge/graph/passes/iterator_op_pass.cc index d1de809d..57416017 100644 --- a/ge/graph/passes/iterator_op_pass.cc +++ b/ge/graph/passes/iterator_op_pass.cc @@ -26,7 +26,7 @@ #include "common/ge/ge_util.h" #include "framework/common/debug/ge_log.h" #include "graph/anchor.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "external/graph/graph.h" #include "graph/node.h" #include "graph/passes/pass_utils.h" diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc index 67b6c617..3989e54f 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc @@ -17,7 +17,7 @@ #include "graph/passes/mark_force_unknown_for_cond_pass.h" #include "graph/utils/node_utils.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { namespace { diff --git a/ge/graph/passes/mark_node_unknown_shape_pass.cc b/ge/graph/passes/mark_node_unknown_shape_pass.cc index c040e846..eadd3ca7 100644 --- a/ge/graph/passes/mark_node_unknown_shape_pass.cc +++ b/ge/graph/passes/mark_node_unknown_shape_pass.cc @@ -17,7 +17,7 @@ #include "graph/passes/mark_node_unknown_shape_pass.h" #include "graph/utils/node_utils.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" namespace ge { namespace { diff --git a/ge/graph/passes/merge_input_memcpy_pass.cc b/ge/graph/passes/merge_input_memcpy_pass.cc index 044d4ad9..97a17d99 100644 --- a/ge/graph/passes/merge_input_memcpy_pass.cc +++ b/ge/graph/passes/merge_input_memcpy_pass.cc @@ -18,7 +18,7 @@ #include "common/ge/ge_util.h" #include "external/ge/ge_api_types.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { diff --git a/ge/graph/passes/merge_pass.cc b/ge/graph/passes/merge_pass.cc index fec9c6d0..2ddfcaab 100644 --- a/ge/graph/passes/merge_pass.cc +++ b/ge/graph/passes/merge_pass.cc @@ -22,7 +22,7 @@ #include "framework/common/debug/ge_log.h" #include "common/ge/ge_util.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/passes/pass_utils.h" diff --git a/ge/graph/passes/merge_to_stream_merge_pass.cc b/ge/graph/passes/merge_to_stream_merge_pass.cc index c58def59..e91410e1 100644 --- a/ge/graph/passes/merge_to_stream_merge_pass.cc +++ b/ge/graph/passes/merge_to_stream_merge_pass.cc @@ -17,7 +17,7 @@ #include "graph/passes/merge_to_stream_merge_pass.h" #include "common/ge/ge_util.h" #include "external/ge/ge_api_types.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { Status MergeToStreamMergePass::Run(ComputeGraphPtr graph) { diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index d36b4186..b25239b1 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -18,14 +18,14 @@ #include "common/formats/utils/formats_trans_utils.h" #include "common/ge/ge_util.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/preprocess/multi_batch_options.h" #include "graph/utils/node_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "register/op_registry.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { namespace { diff --git a/ge/graph/passes/multi_batch_pass.cc b/ge/graph/passes/multi_batch_pass.cc index 25d629fa..9fba362c 100644 --- a/ge/graph/passes/multi_batch_pass.cc +++ b/ge/graph/passes/multi_batch_pass.cc @@ -19,7 +19,7 @@ #include #include #include "common/ge/ge_util.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/type_utils.h" #include "common/formats/utils/formats_trans_utils.h" diff --git a/ge/graph/passes/net_output_pass.cc b/ge/graph/passes/net_output_pass.cc index 30455fa0..9aea4863 100644 --- a/ge/graph/passes/net_output_pass.cc +++ b/ge/graph/passes/net_output_pass.cc @@ -27,7 +27,7 @@ #include "framework/common/ge_inner_error_codes.h" #include "framework/omg/omg_inner_types.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/passes/pass_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" diff --git a/ge/graph/passes/next_iteration_pass.cc b/ge/graph/passes/next_iteration_pass.cc index 1c2d7218..af3e4d2d 100644 --- a/ge/graph/passes/next_iteration_pass.cc +++ b/ge/graph/passes/next_iteration_pass.cc @@ -17,7 +17,7 @@ #include "graph/passes/next_iteration_pass.h" #include "common/ge/ge_util.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/node_utils.h" using std::string; diff --git a/ge/graph/passes/pass_manager.cc b/ge/graph/passes/pass_manager.cc index 7c9aa414..afd2e4a7 100644 --- a/ge/graph/passes/pass_manager.cc +++ b/ge/graph/passes/pass_manager.cc @@ -19,7 +19,7 @@ #include "framework/common/types.h" #include "framework/common/util.h" #include "graph/utils/node_utils.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "framework/omg/omg_inner_types.h" namespace ge { diff --git a/ge/graph/passes/pass_utils.cc b/ge/graph/passes/pass_utils.cc index d5306f5f..0e056a0f 100644 --- a/ge/graph/passes/pass_utils.cc +++ b/ge/graph/passes/pass_utils.cc @@ -27,7 +27,7 @@ #include "common/ge/ge_util.h" #include "framework/common/op/ge_op_utils.h" #include "framework/common/types.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_tensor.h" #include "graph/manager/graph_var_manager.h" diff --git a/ge/graph/passes/permute_pass.cc b/ge/graph/passes/permute_pass.cc index 21222b2c..f3045b1a 100644 --- a/ge/graph/passes/permute_pass.cc +++ b/ge/graph/passes/permute_pass.cc @@ -24,7 +24,7 @@ #include "inc/kernel.h" #include "inc/kernel_factory.h" #include "framework/omg/omg_inner_types.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" using domi::DOMI_TENSOR_ND; using domi::DOMI_TENSOR_NHWC; diff --git a/ge/graph/passes/placeholder_with_default_pass.cc b/ge/graph/passes/placeholder_with_default_pass.cc index 893ee798..bc51b217 100644 --- a/ge/graph/passes/placeholder_with_default_pass.cc +++ b/ge/graph/passes/placeholder_with_default_pass.cc @@ -18,7 +18,7 @@ #include #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { Status PlaceholderWithDefaultPass::Run(NodePtr &node) { diff --git a/ge/graph/passes/prevent_gradient_pass.cc b/ge/graph/passes/prevent_gradient_pass.cc index c531fd2f..8b8b17bd 100644 --- a/ge/graph/passes/prevent_gradient_pass.cc +++ b/ge/graph/passes/prevent_gradient_pass.cc @@ -19,7 +19,7 @@ #include #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { Status PreventGradientPass::Run(NodePtr &node) { diff --git a/ge/graph/passes/print_op_pass.h b/ge/graph/passes/print_op_pass.h index 7ee19d5d..96501dc5 100755 --- a/ge/graph/passes/print_op_pass.h +++ b/ge/graph/passes/print_op_pass.h @@ -20,7 +20,7 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "external/graph/graph.h" #include "graph/passes/base_pass.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/passes/ref_identity_delete_op_pass.cc b/ge/graph/passes/ref_identity_delete_op_pass.cc index 7bc5804b..46bc7467 100644 --- a/ge/graph/passes/ref_identity_delete_op_pass.cc +++ b/ge/graph/passes/ref_identity_delete_op_pass.cc @@ -17,7 +17,7 @@ #include "graph/passes/ref_identity_delete_op_pass.h" #include #include -#include "graph/common/transop_util.h" +#include "common/transop_util.h" namespace ge { Status RefIdentityDeleteOpPass::Run(ComputeGraphPtr graph) { diff --git a/ge/graph/passes/replace_transshape_pass.cc b/ge/graph/passes/replace_transshape_pass.cc index c7844619..0e1701ab 100644 --- a/ge/graph/passes/replace_transshape_pass.cc +++ b/ge/graph/passes/replace_transshape_pass.cc @@ -21,7 +21,7 @@ #include "common/ge/ge_util.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/utils/graph_utils.h" namespace ge { diff --git a/ge/graph/passes/snapshot_pass.cc b/ge/graph/passes/snapshot_pass.cc index 95733e67..a6cd79a3 100644 --- a/ge/graph/passes/snapshot_pass.cc +++ b/ge/graph/passes/snapshot_pass.cc @@ -18,7 +18,7 @@ #include #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { Status SnapshotPass::Run(NodePtr &node) { diff --git a/ge/graph/passes/stop_gradient_pass.h b/ge/graph/passes/stop_gradient_pass.h index 5132b889..5f022200 100755 --- a/ge/graph/passes/stop_gradient_pass.h +++ b/ge/graph/passes/stop_gradient_pass.h @@ -20,7 +20,7 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/types.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/passes/base_pass.h" namespace ge { diff --git a/ge/graph/passes/switch_dead_branch_elimination.cc b/ge/graph/passes/switch_dead_branch_elimination.cc index 3c6c57d0..284111ba 100644 --- a/ge/graph/passes/switch_dead_branch_elimination.cc +++ b/ge/graph/passes/switch_dead_branch_elimination.cc @@ -19,7 +19,7 @@ #include #include #include "framework/common/debug/ge_log.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/passes/pass_utils.h" #include "graph/utils/graph_utils.h" diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index 7fecae31..acbf27e3 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -18,7 +18,7 @@ #include #include "common/ge/ge_util.h" #include "external/ge/ge_api_types.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/ge_context.h" #include "graph/utils/type_utils.h" diff --git a/ge/graph/passes/transop_breadth_fusion_pass.cc b/ge/graph/passes/transop_breadth_fusion_pass.cc index 5b8e1940..88db9501 100644 --- a/ge/graph/passes/transop_breadth_fusion_pass.cc +++ b/ge/graph/passes/transop_breadth_fusion_pass.cc @@ -20,7 +20,7 @@ #include #include "framework/common/types.h" -#include "graph/common/transop_util.h" +#include "common/transop_util.h" #include "graph/utils/node_utils.h" namespace ge { diff --git a/ge/graph/passes/transop_depth_fusion_pass.cc b/ge/graph/passes/transop_depth_fusion_pass.cc index 66ce346a..3ce54e50 100755 --- a/ge/graph/passes/transop_depth_fusion_pass.cc +++ b/ge/graph/passes/transop_depth_fusion_pass.cc @@ -23,7 +23,7 @@ #include "graph/ge_tensor.h" #include "graph/op_desc.h" #include "graph/utils/graph_utils.h" -#include "graph/common/transop_util.h" +#include "common/transop_util.h" #include "graph/utils/node_utils.h" namespace ge { diff --git a/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc b/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc index 483575a4..437926ef 100644 --- a/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc +++ b/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc @@ -19,7 +19,7 @@ #include "framework/common/debug/log.h" #include "framework/common/types.h" #include "graph/utils/graph_utils.h" -#include "graph/common/transop_util.h" +#include "common/transop_util.h" namespace ge { Status TransOpNearbyAllreduceFusionPass::Run(NodePtr &node) { diff --git a/ge/graph/passes/transop_symmetry_elimination_pass.cc b/ge/graph/passes/transop_symmetry_elimination_pass.cc index fe0e48f9..2bd00206 100644 --- a/ge/graph/passes/transop_symmetry_elimination_pass.cc +++ b/ge/graph/passes/transop_symmetry_elimination_pass.cc @@ -18,7 +18,7 @@ #include "common/formats/utils/formats_trans_utils.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" -#include "graph/common/transop_util.h" +#include "common/transop_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" diff --git a/ge/graph/passes/transop_without_reshape_fusion_pass.cc b/ge/graph/passes/transop_without_reshape_fusion_pass.cc index 10e619b9..58145fe7 100644 --- a/ge/graph/passes/transop_without_reshape_fusion_pass.cc +++ b/ge/graph/passes/transop_without_reshape_fusion_pass.cc @@ -22,7 +22,7 @@ #include "common/ge/ge_util.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" -#include "graph/common/transop_util.h" +#include "common/transop_util.h" #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_tensor.h" diff --git a/ge/graph/passes/variable_op_pass.h b/ge/graph/passes/variable_op_pass.h index d442fdf4..e314fd12 100755 --- a/ge/graph/passes/variable_op_pass.h +++ b/ge/graph/passes/variable_op_pass.h @@ -18,7 +18,7 @@ #define GE_GRAPH_PASSES_VARIABLE_OP_PASS_H_ #include #include -#include "graph/common/transop_util.h" +#include "common/transop_util.h" #include "external/graph/graph.h" #include "graph/manager/graph_var_manager.h" #include "graph/manager/util/variable_accelerate_ctrl.h" diff --git a/ge/graph/passes/variable_prepare_op_pass.cc b/ge/graph/passes/variable_prepare_op_pass.cc index 3bb9a2fa..288ff185 100644 --- a/ge/graph/passes/variable_prepare_op_pass.cc +++ b/ge/graph/passes/variable_prepare_op_pass.cc @@ -21,7 +21,7 @@ #include "common/ge/ge_util.h" #include "external/graph/graph.h" #include "framework/common/debug/ge_log.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/node.h" #include "graph/utils/tensor_utils.h" diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index 8d59d9f9..2efe623e 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -28,9 +28,9 @@ #include "common/math/math_util.h" #include "framework/common/op/ge_op_utils.h" #include "ir_build/option_utils.h" -#include "graph/common/ge_call_wrapper.h" -#include "graph/common/local_context.h" -#include "graph/common/transop_util.h" +#include "common/ge_call_wrapper.h" +#include "common/local_context.h" +#include "common/transop_util.h" #include "graph/ge_context.h" #include "graph/shape_refiner.h" #include "graph/manager/graph_var_manager.h" diff --git a/ge/graph/preprocess/insert_op/ge_aipp_op.cc b/ge/graph/preprocess/insert_op/ge_aipp_op.cc index 48bfa3e6..7a89a1f4 100755 --- a/ge/graph/preprocess/insert_op/ge_aipp_op.cc +++ b/ge/graph/preprocess/insert_op/ge_aipp_op.cc @@ -39,7 +39,7 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "proto/insert_op.pb.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #define SAVE_AIPP_ATTR(KEY, SAVE_TYPE) \ do { \ diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index fd3a4e91..d88cf6cd 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -38,8 +38,8 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "inc/pass_manager.h" -#include "graph/common/local_context.h" -#include "graph/common/omg_util.h" +#include "common/local_context.h" +#include "common/omg_util.h" using std::set; using std::string; diff --git a/ge/graph/preprocess/multi_batch_options.cc b/ge/graph/preprocess/multi_batch_options.cc index 21cbc0c2..9cda6194 100644 --- a/ge/graph/preprocess/multi_batch_options.cc +++ b/ge/graph/preprocess/multi_batch_options.cc @@ -25,11 +25,11 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/node_utils.h" #include "graph/ge_context.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "framework/common/types.h" #include "graph/compute_graph.h" #include "graph/utils/graph_utils.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { namespace multibatch { diff --git a/ge/host_kernels/add_kernel.cc b/ge/host_kernels/add_kernel.cc index 1c206018..eb0ea86d 100644 --- a/ge/host_kernels/add_kernel.cc +++ b/ge/host_kernels/add_kernel.cc @@ -19,7 +19,7 @@ #include #include "common/math/math_util.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/broadcast_args_kernel.cc b/ge/host_kernels/broadcast_args_kernel.cc index 796142f4..660717ad 100644 --- a/ge/host_kernels/broadcast_args_kernel.cc +++ b/ge/host_kernels/broadcast_args_kernel.cc @@ -22,7 +22,7 @@ #include "framework/common/types.h" #include "framework/common/util.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/broadcast_gradient_args_kernel.cc b/ge/host_kernels/broadcast_gradient_args_kernel.cc index 59993171..8b9e3fb5 100644 --- a/ge/host_kernels/broadcast_gradient_args_kernel.cc +++ b/ge/host_kernels/broadcast_gradient_args_kernel.cc @@ -22,7 +22,7 @@ #include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/cast_kernel.cc b/ge/host_kernels/cast_kernel.cc index 3f09974f..2d2f463c 100644 --- a/ge/host_kernels/cast_kernel.cc +++ b/ge/host_kernels/cast_kernel.cc @@ -28,7 +28,7 @@ #include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/floormod_kernel.cc b/ge/host_kernels/floormod_kernel.cc index bef6d014..1d101667 100644 --- a/ge/host_kernels/floormod_kernel.cc +++ b/ge/host_kernels/floormod_kernel.cc @@ -23,7 +23,7 @@ #include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/greater_kernel.cc b/ge/host_kernels/greater_kernel.cc index 3e62db04..0cc895c4 100644 --- a/ge/host_kernels/greater_kernel.cc +++ b/ge/host_kernels/greater_kernel.cc @@ -25,7 +25,7 @@ #include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/maximum_kernel.cc b/ge/host_kernels/maximum_kernel.cc index 314bc7be..0e28fcdc 100644 --- a/ge/host_kernels/maximum_kernel.cc +++ b/ge/host_kernels/maximum_kernel.cc @@ -25,7 +25,7 @@ #include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/mul_kernel.cc b/ge/host_kernels/mul_kernel.cc index e3657197..608f351d 100644 --- a/ge/host_kernels/mul_kernel.cc +++ b/ge/host_kernels/mul_kernel.cc @@ -25,7 +25,7 @@ #include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/permute_kernel.cc b/ge/host_kernels/permute_kernel.cc index 93d56415..9e9462b6 100755 --- a/ge/host_kernels/permute_kernel.cc +++ b/ge/host_kernels/permute_kernel.cc @@ -24,7 +24,7 @@ #include "framework/common/op/ge_op_utils.h" #include "framework/common/types.h" #include "framework/common/util.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" #include "common/formats/formats.h" diff --git a/ge/host_kernels/sub_kernel.cc b/ge/host_kernels/sub_kernel.cc index 84c334b0..0aebb946 100644 --- a/ge/host_kernels/sub_kernel.cc +++ b/ge/host_kernels/sub_kernel.cc @@ -23,7 +23,7 @@ #include "framework/common/debug/log.h" #include "common/math/math_util.h" #include "framework/common/op/ge_op_utils.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/host_kernels/transdata_kernel.cc b/ge/host_kernels/transdata_kernel.cc index a06db78b..7d44fdae 100644 --- a/ge/host_kernels/transdata_kernel.cc +++ b/ge/host_kernels/transdata_kernel.cc @@ -28,7 +28,7 @@ #include "framework/common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "graph/common/bcast.h" +#include "common/bcast.h" #include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" diff --git a/ge/hybrid/hybrid_davinci_model.h b/ge/hybrid/hybrid_davinci_model.h index 34503b01..abab74f6 100644 --- a/ge/hybrid/hybrid_davinci_model.h +++ b/ge/hybrid/hybrid_davinci_model.h @@ -20,7 +20,7 @@ #include #include "external/ge/ge_api_error_codes.h" #include "graph/load/model_manager/data_inputer.h" -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" namespace ge { namespace hybrid { diff --git a/ge/hybrid/model/hybrid_model.h b/ge/hybrid/model/hybrid_model.h index 77246e20..3cb936f6 100644 --- a/ge/hybrid/model/hybrid_model.h +++ b/ge/hybrid/model/hybrid_model.h @@ -27,7 +27,7 @@ #include "hybrid/common/tensor_value.h" #include "hybrid/model/node_item.h" #include "hybrid/model/graph_item.h" -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" namespace ge { namespace hybrid { diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index c722d269..44115240 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -21,7 +21,7 @@ #include "graph/ge_context.h" #include "graph/build/memory/var_mem_assign_util.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" #include "graph/load/model_manager/model_utils.h" #include "graph/load/model_manager/model_manager.h" #include "graph/manager/graph_var_manager.h" diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 05830e82..3592d3d2 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -25,7 +25,7 @@ #include "graph/node.h" #include "hybrid/model/hybrid_model.h" #include "hybrid/model/node_item.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" namespace ge { class VarManager; diff --git a/ge/init/gelib.cc b/ge/init/gelib.cc index 1a2f0d5b..2491715b 100644 --- a/ge/init/gelib.cc +++ b/ge/init/gelib.cc @@ -34,7 +34,7 @@ #include "analyzer/analyzer.h" #include "external/ge/ge_api_types.h" #include "ge_local_engine/engine/host_cpu_engine.h" -#include "graph/common/ge_call_wrapper.h" +#include "common/ge_call_wrapper.h" #include "graph/ge_context.h" #include "graph/ge_global_options.h" #include "graph/manager/graph_mem_manager.h" diff --git a/ge/ir_build/attr_options/utils.cc b/ge/ir_build/attr_options/utils.cc index 5398c220..23bb0b7b 100644 --- a/ge/ir_build/attr_options/utils.cc +++ b/ge/ir_build/attr_options/utils.cc @@ -17,7 +17,7 @@ #include #include "graph/debug/ge_attr_define.h" #include "framework/common/debug/ge_log.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { namespace { const std::string CFG_PRE_OPTYPE = "OpType::"; diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index e1cd5d29..cafc534d 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -33,7 +33,7 @@ #include "graph/ge_global_options.h" #include "init/gelib.h" #include "ir_build/option_utils.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" #include "graph/shape_refiner.h" #include "graph/opsproto_manager.h" #include "inc/pass_manager.h" diff --git a/ge/session/inner_session.cc b/ge/session/inner_session.cc index 1dcc2996..b9c44ef1 100755 --- a/ge/session/inner_session.cc +++ b/ge/session/inner_session.cc @@ -29,7 +29,7 @@ #include "graph/ge_context.h" #include "graph/ge_global_options.h" #include "graph/ge_local_context.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/manager/graph_var_manager.h" #include "graph/manager/graph_mem_manager.h" #include "graph/utils/tensor_adapter.h" diff --git a/inc/framework/common/helper/model_helper.h b/inc/framework/common/helper/model_helper.h index e25d5d6f..2a63291c 100644 --- a/inc/framework/common/helper/model_helper.h +++ b/inc/framework/common/helper/model_helper.h @@ -22,10 +22,10 @@ #include "common/fmk_types.h" #include "common/helper/om_file_helper.h" +#include "common/model/ge_model.h" +#include "common/model/ge_root_model.h" #include "common/types.h" #include "graph/model.h" -#include "model/ge_model.h" -#include "model/ge_root_model.h" namespace ge { class GE_FUNC_VISIBILITY ModelHelper { @@ -42,13 +42,21 @@ class GE_FUNC_VISIBILITY ModelHelper { Status LoadRootModel(const ge::ModelData &model_data); Status GetModelBufferData(ge::ModelBufferData &model); - const ModelFileHeader *GetFileHeader() const { return file_header_; } + const ModelFileHeader *GetFileHeader() const { + return file_header_; + } GeModelPtr GetGeModel(); GeRootModelPtr GetGeRootModel(); - void SetSaveMode(bool val) { is_offline_ = val; } - bool GetSaveMode(void) const { return is_offline_; } - bool GetModelType() const { return is_unknown_shape_model_; }; + void SetSaveMode(bool val) { + is_offline_ = val; + } + bool GetSaveMode(void) const { + return is_offline_; + } + bool GetModelType() const { + return is_unknown_shape_model_; + }; Status GetBaseNameFromFileName(const std::string &file_name, std::string &base_name); Status GetModelNameFromMergedGraphName(const std::string &graph_name, std::string &model_name); diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 42fa6128..856d9d43 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -104,8 +104,8 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/opskernel_manager/ops_kernel_manager.cc" "${GE_CODE_DIR}/ge/generator/ge_generator.cc" "${GE_CODE_DIR}/ge/generator/generator_api.cc" - "${GE_CODE_DIR}/ge/graph/common/omg_util.cc" - "${GE_CODE_DIR}/ge/graph/common/bcast.cc" + "${GE_CODE_DIR}/ge/common/omg_util.cc" + "${GE_CODE_DIR}/ge/common/bcast.cc" "${GE_CODE_DIR}/ge/common/util.cc" "${GE_CODE_DIR}/ge/common/ge/op_tiling_manager.cc" "${GE_CODE_DIR}/ge/init/gelib.cc" @@ -124,12 +124,12 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/common/dump/opdebug_register.cc" "${GE_CODE_DIR}/ge/common/dump/dump_op.cc" "${GE_CODE_DIR}/ge/common/helper/om_file_helper.cc" - "${GE_CODE_DIR}/ge/model/ge_root_model.cc" + "${GE_CODE_DIR}/ge/common/model/ge_root_model.cc" "${GE_CODE_DIR}/ge/common/model_parser/model_parser.cc" "${GE_CODE_DIR}/ge/common/dump/dump_server.cc" "${GE_CODE_DIR}/ge/graph/preprocess/multi_batch_copy_graph.cc" "${GE_CODE_DIR}/ge/graph/optimize/mem_rw_conflict_optimize.cc" - "${GE_CODE_DIR}/ge/model/ge_model.cc" + "${GE_CODE_DIR}/ge/common/model/ge_model.cc" "${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" "${GE_CODE_DIR}/ge/common/kernel_store.cc" "${GE_CODE_DIR}/ge/common/tbe_kernel_store.cc" @@ -169,10 +169,10 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/manager/graph_var_manager.cc" "${GE_CODE_DIR}/ge/analyzer/analyzer.cc" "${GE_CODE_DIR}/ge/common/thread_pool.cc" - "${GE_CODE_DIR}/ge/graph/common/transop_util.cc" + "${GE_CODE_DIR}/ge/common/transop_util.cc" "${GE_CODE_DIR}/ge/graph/manager/graph_manager_utils.cc" "${GE_CODE_DIR}/ge/graph/manager/trans_var_data_utils.cc" - "${GE_CODE_DIR}/ge/graph/common/local_context.cc" + "${GE_CODE_DIR}/ge/common/local_context.cc" "${GE_CODE_DIR}/ge/graph/manager/graph_caching_allocator.cc" "${GE_CODE_DIR}/ge/graph/manager/session_scope_mem_allocator.cc" "${GE_CODE_DIR}/ge/graph/manager/rdma_pool_allocator.cc" @@ -648,6 +648,7 @@ set(MULTI_PARTS_TEST_FILES "graph/transop_util_unittest.cc" "common/datatype_transfer_unittest.cc" "common/util_unittest.cc" + "common/fp16_unittest.cc" "common/dump_manager_unittest.cc" "common/dump_op_unittest.cc" "common/dump_properties_unittest.cc" diff --git a/tests/ut/ge/common/fp16_unittest.cc b/tests/ut/ge/common/fp16_unittest.cc new file mode 100644 index 00000000..a9590fe2 --- /dev/null +++ b/tests/ut/ge/common/fp16_unittest.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "common/fp16_t.h" + +namespace ge { +namespace formats { +class UtestFP16 : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +TEST_F(UtestFP16, fp16_to_other) { + fp16_t test; + float num = test.ToFloat(); + EXPECT_EQ(num, 0.0); + + double num2 = test.ToDouble(); + EXPECT_EQ(num2, 0); + + int16_t num3 = test.ToInt16(); + EXPECT_EQ(num3, 0); + + int32_t num4 = test.ToInt32(); + EXPECT_EQ(num4, 0); + + int8_t num5 = test.ToInt8(); + EXPECT_EQ(num5, 0); + + uint16_t num6 = test.ToUInt16(); + EXPECT_EQ(num6, 0); + + uint32_t num7 = test.ToUInt16(); + EXPECT_EQ(num7, 0); + + uint8_t num8 = test.ToUInt8(); + EXPECT_EQ(num8, 0); +} +} // namespace formats +} // namespace ge diff --git a/tests/ut/ge/graph/build/model_builder_unittest.cc b/tests/ut/ge/graph/build/model_builder_unittest.cc index d544e1a3..4f061e27 100644 --- a/tests/ut/ge/graph/build/model_builder_unittest.cc +++ b/tests/ut/ge/graph/build/model_builder_unittest.cc @@ -17,7 +17,7 @@ #include #include -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/anchor.h" #include "graph/attr_value.h" #include "graph/debug/ge_attr_define.h" diff --git a/tests/ut/ge/graph/graph_load_unittest.cc b/tests/ut/ge/graph/graph_load_unittest.cc index cbcefd03..93282a5e 100644 --- a/tests/ut/ge/graph/graph_load_unittest.cc +++ b/tests/ut/ge/graph/graph_load_unittest.cc @@ -36,7 +36,7 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/load/model_manager/model_manager.h" #include "graph/manager/graph_manager_utils.h" -#include "model/ge_model.h" +#include "common/model/ge_model.h" #undef private #undef protected diff --git a/tests/ut/ge/graph/load/model_helper_unittest.cc b/tests/ut/ge/graph/load/model_helper_unittest.cc index 8fd8f014..8af329ed 100644 --- a/tests/ut/ge/graph/load/model_helper_unittest.cc +++ b/tests/ut/ge/graph/load/model_helper_unittest.cc @@ -20,7 +20,7 @@ #include "framework/common/helper/model_helper.h" #include "framework/omg/model_tool.h" #include "framework/omg/ge_init.h" -#include "ge/model/ge_model.h" +#include "ge/common/model/ge_model.h" #undef private #undef protected diff --git a/tests/ut/ge/graph/manager/graph_manager_unittest.cc b/tests/ut/ge/graph/manager/graph_manager_unittest.cc index 9663e90f..518cfdcd 100644 --- a/tests/ut/ge/graph/manager/graph_manager_unittest.cc +++ b/tests/ut/ge/graph/manager/graph_manager_unittest.cc @@ -39,9 +39,9 @@ #include "common/thread_pool.h" #include "common/dump/dump_manager.h" #include "analyzer/analyzer.h" -#include "graph/common/ge_call_wrapper.h" -#include "graph/common/local_context.h" -#include "graph/common/transop_util.h" +#include "common/ge_call_wrapper.h" +#include "common/local_context.h" +#include "common/transop_util.h" #include "graph/ge_context.h" #include "graph/ge_global_options.h" #include "graph/manager/util/rt_context_util.h" @@ -108,8 +108,8 @@ #include "graph/utils/tensor_adapter.h" #include "inc/pass_manager.h" #include "ir_build/option_utils.h" -#include "graph/common/local_context.h" -#include "graph/common/omg_util.h" +#include "common/local_context.h" +#include "common/omg_util.h" #include "common/formats/utils/formats_trans_utils.h" #include "../passes/graph_builder_utils.h" #include "register/custom_pass_helper.h" diff --git a/tests/ut/ge/graph/partition/dynamic_shape_partition_unittest.cc b/tests/ut/ge/graph/partition/dynamic_shape_partition_unittest.cc index da1abd0f..1d19a8bd 100644 --- a/tests/ut/ge/graph/partition/dynamic_shape_partition_unittest.cc +++ b/tests/ut/ge/graph/partition/dynamic_shape_partition_unittest.cc @@ -24,7 +24,7 @@ #include "inc/framework/common/types.h" #include "utils/graph_utils.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" namespace ge { namespace { diff --git a/tests/ut/ge/graph/passes/mark_node_unknown_shape_pass_unittest.cc b/tests/ut/ge/graph/passes/mark_node_unknown_shape_pass_unittest.cc index c7d36582..7d4663b3 100644 --- a/tests/ut/ge/graph/passes/mark_node_unknown_shape_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/mark_node_unknown_shape_pass_unittest.cc @@ -24,7 +24,7 @@ #include "common/ge_inner_error_codes.h" #include "inc/pass_manager.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #undef private namespace ge { diff --git a/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc b/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc index c752cea4..9ec254d7 100644 --- a/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/multi_batch_clone_pass_unittest.cc @@ -22,7 +22,7 @@ #include "inc/pass_manager.h" #include "graph/utils/tensor_utils.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/passes/multi_batch_pass.h" #include "graph/preprocess/multi_batch_copy_graph.h" #include "graph/preprocess/insert_op/util_insert_aipp_op.h" diff --git a/tests/ut/ge/graph/passes/subgraph_const_migration_pass_unittest.cc b/tests/ut/ge/graph/passes/subgraph_const_migration_pass_unittest.cc index c633c0e1..6565295c 100644 --- a/tests/ut/ge/graph/passes/subgraph_const_migration_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/subgraph_const_migration_pass_unittest.cc @@ -20,7 +20,7 @@ #include #include "framework/omg/omg_inner_types.h" -#include "graph/common/local_context.h" +#include "common/local_context.h" #include "graph/passes/subgraph_const_migration_pass.h" #include "inc/pass_manager.h" #include "register/op_registry.h" diff --git a/tests/ut/ge/graph/transop_util_unittest.cc b/tests/ut/ge/graph/transop_util_unittest.cc index 9f645c22..02aa97bf 100644 --- a/tests/ut/ge/graph/transop_util_unittest.cc +++ b/tests/ut/ge/graph/transop_util_unittest.cc @@ -16,7 +16,7 @@ #include -#include "graph/common/transop_util.h" +#include "common/transop_util.h" #include "common/debug/log.h" #include "common/types.h" diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index b09211cb..782a06d6 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -25,8 +25,8 @@ #include "hybrid/model/hybrid_model_builder.h" #include "hybrid/model/hybrid_model.h" #include "hybrid/node_executor/node_executor.h" -#include "model/ge_model.h" -#include "model/ge_root_model.h" +#include "common/model/ge_model.h" +#include "common/model/ge_root_model.h" #include "hybrid/node_executor/aicore/aicore_op_task.h" #include "framework/common/taskdown_common.h" #include "framework/common/debug/log.h" diff --git a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc index 10f7c0fe..eb6030dc 100644 --- a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc +++ b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc @@ -29,7 +29,7 @@ #include "graph/utils/graph_utils.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_local_context.h" -#include "graph/common/omg_util.h" +#include "common/omg_util.h" using namespace std; using namespace testing; diff --git a/tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc index e4d211f9..53b28762 100644 --- a/tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc @@ -22,7 +22,7 @@ #define protected public #include "hybrid/executor/subgraph_context.h" #include "hybrid/node_executor/ge_local/ge_local_node_executor.h" -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" #undef protected #undef private diff --git a/tests/ut/ge/hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc b/tests/ut/ge/hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc index b113fa9b..bb134175 100644 --- a/tests/ut/ge/hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc @@ -22,7 +22,7 @@ #define protected public #include "hybrid/executor/subgraph_context.h" #include "hybrid/node_executor/host_cpu/host_cpu_node_executor.h" -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" #include "graph/passes/graph_builder_utils.h" #include "aicpu/common/aicpu_task_struct.h" #include "graph/manager/graph_mem_manager.h" diff --git a/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc b/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc index 109e5192..d21ae1e0 100644 --- a/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc @@ -22,7 +22,7 @@ #define protected public #include "hybrid/executor/subgraph_context.h" #include "hybrid/node_executor/rts/rts_node_executor.h" -#include "model/ge_root_model.h" +#include "common/model/ge_root_model.h" using namespace std; using namespace testing; From b8882fd650f8080026a31d4721fc7fbab75fc783 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Tue, 13 Jul 2021 20:36:08 +0800 Subject: [PATCH 04/33] Sort common\CMakeLists.txt --- ge/common/CMakeLists.txt | 102 +++++++++++++++++++++++------------------------ 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/ge/common/CMakeLists.txt b/ge/common/CMakeLists.txt index 0d41b86f..99d6ead3 100755 --- a/ge/common/CMakeLists.txt +++ b/ge/common/CMakeLists.txt @@ -1,55 +1,55 @@ set(SRC_LIST - "context/ctx.cc" - "model_saver.cc" - "ge/datatype_util.cc" - "ge/plugin_manager.cc" - "ge/op_tiling_manager.cc" - "helper/om_file_helper.cc" - "helper/model_helper.cc" - "model/ge_model.cc" - "model/ge_root_model.cc" - "bcast.cc" - "local_context.cc" - "omg_util.cc" - "transop_util.cc" - "auth/file_saver.cc" - "fp16_t.cc" - "math/fp16_math.cc" - "debug/memory_dumper.cc" - "formats/utils/formats_trans_utils.cc" - "dump/dump_properties.cc" - "dump/dump_manager.cc" - "formats/format_transfers/datatype_transfer.cc" - "formats/format_transfers/format_transfer_transpose.cc" - "formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" - "formats/format_transfers/format_transfer_fractal_z.cc" - "formats/format_transfers/format_transfer_fractal_nz.cc" - "formats/format_transfers/format_transfer_fractal_zz.cc" - "formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" - "formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" - "formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" - "formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" - "formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" - "formats/format_transfers/format_transfer_fracz_nchw.cc" - "formats/format_transfers/format_transfer_fracz_nhwc.cc" - "formats/format_transfers/format_transfer_fracz_hwcn.cc" - "formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" - "formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" - "formats/format_transfers/format_transfer_nchw_fz_c04.cc" - "formats/formats.cc" - "ge_format_util.cc" - "fmk_error_codes.cc" - "util.cc" - "properties_manager.cc" - "types.cc" - "model_parser/model_parser.cc" - "kernel_store.cc" - "tbe_kernel_store.cc" - "cust_aicpu_kernel_store.cc" - "op/attr_value_util.cc" - "op/ge_op_utils.cc" - "thread_pool.cc" - "ge/tbe_plugin_manager.cc" + "${GE_CODE_DIR}/ge/common/auth/file_saver.cc" + "${GE_CODE_DIR}/ge/common/bcast.cc" + "${GE_CODE_DIR}/ge/common/context/ctx.cc" + "${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc" + "${GE_CODE_DIR}/ge/common/debug/memory_dumper.cc" + "${GE_CODE_DIR}/ge/common/dump/dump_manager.cc" + "${GE_CODE_DIR}/ge/common/dump/dump_properties.cc" + "${GE_CODE_DIR}/ge/common/fmk_error_codes.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/datatype_transfer.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fractal_z.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" + "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_transpose.cc" + "${GE_CODE_DIR}/ge/common/formats/formats.cc" + "${GE_CODE_DIR}/ge/common/formats/utils/formats_trans_utils.cc" + "${GE_CODE_DIR}/ge/common/fp16_t.cc" + "${GE_CODE_DIR}/ge/common/ge/datatype_util.cc" + "${GE_CODE_DIR}/ge/common/ge/op_tiling_manager.cc" + "${GE_CODE_DIR}/ge/common/ge/plugin_manager.cc" + "${GE_CODE_DIR}/ge/common/ge/tbe_plugin_manager.cc" + "${GE_CODE_DIR}/ge/common/ge_format_util.cc" + "${GE_CODE_DIR}/ge/common/helper/model_helper.cc" + "${GE_CODE_DIR}/ge/common/helper/om_file_helper.cc" + "${GE_CODE_DIR}/ge/common/kernel_store.cc" + "${GE_CODE_DIR}/ge/common/local_context.cc" + "${GE_CODE_DIR}/ge/common/math/fp16_math.cc" + "${GE_CODE_DIR}/ge/common/model/ge_model.cc" + "${GE_CODE_DIR}/ge/common/model/ge_root_model.cc" + "${GE_CODE_DIR}/ge/common/model_parser/model_parser.cc" + "${GE_CODE_DIR}/ge/common/model_saver.cc" + "${GE_CODE_DIR}/ge/common/omg_util.cc" + "${GE_CODE_DIR}/ge/common/op/attr_value_util.cc" + "${GE_CODE_DIR}/ge/common/op/ge_op_utils.cc" + "${GE_CODE_DIR}/ge/common/properties_manager.cc" + "${GE_CODE_DIR}/ge/common/tbe_kernel_store.cc" + "${GE_CODE_DIR}/ge/common/thread_pool.cc" + "${GE_CODE_DIR}/ge/common/transop_util.cc" + "${GE_CODE_DIR}/ge/common/types.cc" + "${GE_CODE_DIR}/ge/common/util.cc" ) if (NOT ENABLE_D AND NOT ENABLE_ACL) From 3128929306fcc3983fa1f6e230814df2cf917c3d Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Wed, 14 Jul 2021 12:45:41 +0800 Subject: [PATCH 05/33] aicore_task_compiler.cc to runner --- ge/CMakeLists.txt | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index cb4c84b1..d1a0da0f 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -109,7 +109,6 @@ endif () ################################################################## set(EXECUTOR_SRC_LIST - "analyzer/analyzer.cc" "common/dump/dump_manager.cc" "common/dump/dump_op.cc" "common/dump/exception_dumper.cc" @@ -121,7 +120,6 @@ set(EXECUTOR_SRC_LIST "common/ge/plugin_manager.cc" "common/profiling/ge_profiling.cc" "common/profiling/profiling_manager.cc" - "engine_manager/dnnengine_manager.cc" "executor/ge_executor.cc" "ge_local_engine/engine/host_cpu_engine.cc" "graph/build/memory/var_mem_assign_util.cc" @@ -236,7 +234,6 @@ set(EXECUTOR_SRC_LIST "hybrid/node_executor/aicore/aicore_node_executor.cc" "hybrid/node_executor/aicore/aicore_op_task.cc" "hybrid/node_executor/aicore/aicore_task_builder.cc" - "hybrid/node_executor/aicore/aicore_task_compiler.cc" "hybrid/node_executor/aicpu/aicpu_ext_info.cc" "hybrid/node_executor/aicpu/aicpu_node_executor.cc" "hybrid/node_executor/compiledsubgraph/known_node_executor.cc" @@ -250,9 +247,7 @@ set(EXECUTOR_SRC_LIST "hybrid/node_executor/rts/rts_node_task.cc" "hybrid/node_executor/rts/rts_task_factory.cc" "hybrid/node_executor/task_context.cc" - "init/gelib.cc" "opskernel_manager/ops_kernel_builder_manager.cc" - "opskernel_manager/ops_kernel_manager.cc" "single_op/single_op.cc" "single_op/single_op_manager.cc" "single_op/single_op_model.cc" @@ -510,6 +505,7 @@ set(RUNNER_SRC_LIST "graph/manager/util/hcom_util.cc" "graph/load/model_manager/task_info/hccl_task_info.cc" "hybrid/node_executor/hccl/hccl_node_executor.cc" + "hybrid/node_executor/aicore/aicore_task_compiler.cc" ) if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) @@ -750,7 +746,6 @@ target_link_libraries(ge_executor PRIVATE $<$>:$> $<$>:$> json - ge_proto_client ascend_protobuf_static c_sec $<$>:-lrt> @@ -813,7 +808,6 @@ target_link_libraries(ge_executor_shared PRIVATE $<$>:$> -Wl,--no-as-needed ge_common - ge_proto_client runtime slog graph From 5bf4d4c4245b801b754efd46c3c865f8659260f8 Mon Sep 17 00:00:00 2001 From: lichun Date: Wed, 14 Jul 2021 15:38:58 +0800 Subject: [PATCH 06/33] assign graph memory max size and variable memory max size adaptively --- ge/graph/manager/graph_var_manager.cc | 47 ++++++++++++++++++++++++++----- ge/graph/manager/graph_var_manager.h | 3 ++ tests/depends/runtime/src/runtime_stub.cc | 6 ++++ 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/ge/graph/manager/graph_var_manager.cc b/ge/graph/manager/graph_var_manager.cc index 89a4e45b..5138a0f5 100755 --- a/ge/graph/manager/graph_var_manager.cc +++ b/ge/graph/manager/graph_var_manager.cc @@ -20,6 +20,7 @@ #include "graph/manager/graph_mem_manager.h" #include "graph/manager/trans_var_data_utils.h" #include "graph/utils/type_utils.h" +#include "graph/ge_context.h" using std::map; using std::string; @@ -767,24 +768,54 @@ Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &grap return var_resource_->GetChangedGraphId(var_name, graph_id); } +Status VarManager::GetTotalMemorySize(size_t &total_mem_size) { + rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", + GetContext().DeviceId(), rt_ret); + GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); + return RT_FAILED; + } + size_t free_mem = 0; + rt_ret = rtMemGetInfoEx(RT_MEMORYINFO_HBM, &free_mem, &total_mem_size); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtMemGetInfo failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtMemGetInfo] failed, ret:0x%X", rt_ret); + return RT_FAILED; + } + rt_ret = rtDeviceReset(GetContext().DeviceId()); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", + GetContext().DeviceId(), rt_ret); + GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); + return RT_FAILED; + } + return SUCCESS; +} + Status VarManager::SetMemoryMallocSize(const map &options) { + size_t total_mem_size = 0; + Status ret = VarManager::GetTotalMemorySize(total_mem_size); + if (ret != SUCCESS) { + return ret; + } + GEEVENT("Total memory size is %zu", total_mem_size); + + graph_mem_max_size_ = floor(total_mem_size * kGraphMemoryManagerMallocRatio); + var_mem_max_size_ = floor(total_mem_size * kVarMemoryManagerMallocRatio); + auto it = options.find(GRAPH_MEMORY_MAX_SIZE); - if (it == options.end()) { - graph_mem_max_size_ = kGraphMemoryManagerMallocMaxSize; - } else { + if (it != options.end()) { string graph_memory_manager_malloc_max_size = it->second; ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_); if (ret != SUCCESS) { GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); return ge::GE_GRAPH_OPTIONS_INVALID; } - GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_); } it = options.find(VARIABLE_MEMORY_MAX_SIZE); - if (it == options.end()) { - var_mem_max_size_ = kMemoryVarManagerMallocSize; - } else { + if (it != options.end()) { string memory_var_manager_malloc_size = it->second; ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_); if (ret != SUCCESS) { @@ -793,6 +824,8 @@ Status VarManager::SetMemoryMallocSize(const map &options) { } } + GEEVENT("The graph_mem_max_size is %zu and the var_mem_max_size is %zu", graph_mem_max_size_, var_mem_max_size_); + var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer; if (var_mem_logic_base_ > kMaxMemorySize) { REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid", diff --git a/ge/graph/manager/graph_var_manager.h b/ge/graph/manager/graph_var_manager.h index f2b68e79..a1b45959 100755 --- a/ge/graph/manager/graph_var_manager.h +++ b/ge/graph/manager/graph_var_manager.h @@ -43,6 +43,8 @@ const size_t kMaxMemorySize = 256UL * 1024UL * 1024UL * 1024UL; const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY"; const uint64_t kSessionMemAlignSize = 512; const size_t kSessionMemAlignUnit = 2; +const double kGraphMemoryManagerMallocRatio = 26.0 / 32.0; +const double kVarMemoryManagerMallocRatio = 5.0 / 32.0; enum MemStatus { NORMAL = 0, @@ -301,6 +303,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { mutable std::recursive_mutex mutex_; Status ParseMemoryMallocSize(std::string &memory_size, size_t &my_size); + Status GetTotalMemorySize(size_t &total_mem_size); }; class VarManagerPool { diff --git a/tests/depends/runtime/src/runtime_stub.cc b/tests/depends/runtime/src/runtime_stub.cc index 0c9e2c27..a8f7e59a 100644 --- a/tests/depends/runtime/src/runtime_stub.cc +++ b/tests/depends/runtime/src/runtime_stub.cc @@ -193,6 +193,12 @@ rtError_t rtMemGetInfo(size_t *free, size_t *total) { return RT_ERROR_NONE; } +rtError_t rtMemGetInfoEx(rtMemInfoType_t memInfoType, size_t *free, size_t *total) { + *free = 512UL * 1024UL * 1024UL; + *total = 1024UL * 1024UL * 1024UL; + return RT_ERROR_NONE; +} + rtError_t rtMemAllocManaged(void **ptr, uint64_t size, uint32_t flag) { return RT_ERROR_NONE; } rtError_t rtMemFreeManaged(void *ptr) { return RT_ERROR_NONE; } From e271d40b9ce1e673b2d550a70a27b2b328ce4433 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Wed, 14 Jul 2021 15:41:17 +0800 Subject: [PATCH 07/33] delete common file form compoler --- ge/CMakeLists.txt | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index d1a0da0f..f83d2607 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -109,15 +109,9 @@ endif () ################################################################## set(EXECUTOR_SRC_LIST - "common/dump/dump_manager.cc" "common/dump/dump_op.cc" "common/dump/exception_dumper.cc" "common/dump/opdebug_register.cc" - "common/formats/format_transfers/format_transfer_transpose.cc" - "common/formats/utils/formats_trans_utils.cc" - "common/fp16_t.cc" - "common/ge/op_tiling_manager.cc" - "common/ge/plugin_manager.cc" "common/profiling/ge_profiling.cc" "common/profiling/profiling_manager.cc" "executor/ge_executor.cc" @@ -264,29 +258,6 @@ set(EXECUTOR_SRC_LIST set(COMPILER_SRC_LIST "analyzer/analyzer.cc" "common/dump/dump_op.cc" - "common/dump/dump_properties.cc" - "common/formats/format_transfers/datatype_transfer.cc" - "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" - "common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" - "common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" - "common/formats/format_transfers/format_transfer_fractal_nz.cc" - "common/formats/format_transfers/format_transfer_fractal_z.cc" - "common/formats/format_transfers/format_transfer_fractal_zz.cc" - "common/formats/format_transfers/format_transfer_fracz_hwcn.cc" - "common/formats/format_transfers/format_transfer_fracz_nchw.cc" - "common/formats/format_transfers/format_transfer_fracz_nhwc.cc" - "common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" - "common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" - "common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" - "common/formats/format_transfers/format_transfer_nchw_fz_c04.cc" - "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" - "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" - "common/formats/format_transfers/format_transfer_transpose.cc" - "common/formats/formats.cc" - "common/formats/utils/formats_trans_utils.cc" - "common/fp16_t.cc" - "common/ge/op_tiling_manager.cc" - "common/ge/plugin_manager.cc" "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" "engine_manager/dnnengine_manager.cc" From fec3176277626c8e80c69a2d76ed36f08cd7ae43 Mon Sep 17 00:00:00 2001 From: lichun Date: Wed, 14 Jul 2021 16:31:09 +0800 Subject: [PATCH 08/33] assign graph memory max size and variable memory max size adaptively --- tests/ut/ge/CMakeLists.txt | 1 + .../ge/graph/manager/graph_var_manager_unittest.cc | 63 ++++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 tests/ut/ge/graph/manager/graph_var_manager_unittest.cc diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 856d9d43..773a2686 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -690,6 +690,7 @@ set(MULTI_PARTS_TEST_FILES "graph/manager/run_graph_unittest.cc" "graph/partition/dynamic_shape_partition_unittest.cc" "graph/manager/graph_manager_unittest.cc" + "graph/manager/graph_var_manager_unittest.cc" "graph/optimize/mem_rw_conflict_optimize_unittest.cc" "graph/optimize/graph_optimize_unittest.cc" "session/omg_omg_unittest.cc" diff --git a/tests/ut/ge/graph/manager/graph_var_manager_unittest.cc b/tests/ut/ge/graph/manager/graph_var_manager_unittest.cc new file mode 100644 index 00000000..3eda6c47 --- /dev/null +++ b/tests/ut/ge/graph/manager/graph_var_manager_unittest.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#define protected public +#define private public +#include "graph/manager/graph_var_manager.h" +#include "graph/ge_context.h" +#undef protected +#undef private + +namespace ge { +class UtestGraphVarManagerTest : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +TEST_F(UtestGraphVarManagerTest, test_get_total_memory_size) { + size_t total_mem_size = 0; + Status ret = VarManager::Instance(0)->GetTotalMemorySize(total_mem_size); + EXPECT_EQ(total_mem_size, 1024UL * 1024UL * 1024UL); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_no_related_option) { + const map options{}; + Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); + EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (26.0f / 32.0f))); + EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (5.0f / 32.0f))); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_graph_mem_max_size) { + const map options{{"ge.graphMemoryMaxSize", 1024UL * 1024UL * 1024UL / 2}}; + Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); + EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL / 2)); + EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (5.0f / 32.0f))); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_var_mem_max_size) { + const map options{{"ge.variableMemoryMaxSize", 1024UL * 1024UL * 1024UL / 2}}; + Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); + EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (26.0f / 32.0f))); + EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL / 2)); + EXPECT_EQ(ret, SUCCESS); +} +} // namespace ge From 3fd107039848df8f7e6a403beee4cf73ca30ee1c Mon Sep 17 00:00:00 2001 From: lichun Date: Wed, 14 Jul 2021 16:59:49 +0800 Subject: [PATCH 09/33] assign graph memory max size and variable memory max size adaptively --- tests/ut/ge/graph/manager/graph_var_manager_unittest.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ut/ge/graph/manager/graph_var_manager_unittest.cc b/tests/ut/ge/graph/manager/graph_var_manager_unittest.cc index 3eda6c47..c20e786d 100644 --- a/tests/ut/ge/graph/manager/graph_var_manager_unittest.cc +++ b/tests/ut/ge/graph/manager/graph_var_manager_unittest.cc @@ -46,7 +46,7 @@ TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_no_related_option) } TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_graph_mem_max_size) { - const map options{{"ge.graphMemoryMaxSize", 1024UL * 1024UL * 1024UL / 2}}; + const map options{{"ge.graphMemoryMaxSize", "536870912"}}; Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL / 2)); EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (5.0f / 32.0f))); @@ -54,7 +54,7 @@ TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_g } TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_var_mem_max_size) { - const map options{{"ge.variableMemoryMaxSize", 1024UL * 1024UL * 1024UL / 2}}; + const map options{{"ge.variableMemoryMaxSize", "536870912"}}; Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (26.0f / 32.0f))); EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL / 2)); From 191f381cc5251711c4f65ef11f7262f47e583068 Mon Sep 17 00:00:00 2001 From: wangxiaotian22 Date: Tue, 13 Jul 2021 11:49:50 +0800 Subject: [PATCH 10/33] runtime api transfer --- .../model_manager/task_info/kernel_ex_task_info.cc | 3 +- .../model_manager/task_info/kernel_ex_task_info.h | 1 + .../model_manager/task_info/kernel_task_info.cc | 7 +- .../node_executor/aicpu/aicpu_node_executor.cc | 15 +- inc/external/OWNERS | 10 ++ tests/depends/runtime/src/runtime_stub.cc | 15 ++ tests/ut/ge/CMakeLists.txt | 1 + .../aicpu/aicpu_node_executor_unittest.cc | 168 +++++++++++++++++++++ 8 files changed, 209 insertions(+), 11 deletions(-) create mode 100644 inc/external/OWNERS create mode 100644 tests/ut/ge/hybrid/node_executor/aicpu/aicpu_node_executor_unittest.cc diff --git a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc index a4b3de75..ee358b5c 100644 --- a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc +++ b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc @@ -106,6 +106,7 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin // 1. Copy context from kernelExDef.private to workspace uint32_t op_index = kernel_ex_def.op_index(); OpDescPtr op_desc = davinci_model_->GetOpByIndex(op_index); + op_desc_ = op_desc; if (op_desc == nullptr) { REPORT_INNER_ERROR("E19999", "Can't get op_desc from davinci_model by index:%u", op_index); GELOGE(INTERNAL_ERROR, "[Get][Op] by index failed, index:%u is out of range!", op_index); @@ -422,7 +423,7 @@ Status KernelExTaskInfo::Distribute() { if (topic_type_flag_ > 0) { dump_flag_ = dump_flag_ | topic_type_flag_; } - rtError_t rt_ret = rtKernelLaunchEx(kernel_buf_, kernel_buf_size_, dump_flag_, stream_); + rtError_t rt_ret = rtKernelLaunchFwk(op_desc_->GetName().c_str(), kernel_buf_, kernel_buf_size_, dump_flag_, stream_); if (rt_ret != RT_ERROR_NONE) { REPORT_CALL_ERROR("E19999", "Call rtKernelLaunchEx failed, ret:0x%X", rt_ret); GELOGE(RT_FAILED, "[Call][RtKernelLaunchEx] failed, ret:0x%X", rt_ret); diff --git a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h index 1b77b715..7d07eb7f 100644 --- a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h +++ b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h @@ -70,6 +70,7 @@ class KernelExTaskInfo : public TaskInfo { uint32_t dump_flag_; uint32_t kernel_buf_size_; DavinciModel *davinci_model_; + OpDescPtr op_desc_; void *kernel_buf_; void *input_output_addr_; void *ext_info_addr_; diff --git a/ge/graph/load/model_manager/task_info/kernel_task_info.cc b/ge/graph/load/model_manager/task_info/kernel_task_info.cc index 07ad63ca..63f4257c 100755 --- a/ge/graph/load/model_manager/task_info/kernel_task_info.cc +++ b/ge/graph/load/model_manager/task_info/kernel_task_info.cc @@ -440,9 +440,10 @@ Status KernelTaskInfo::Distribute() { } GELOGI("distribute task info kernel_type %d, flag %d", kernel_type_, dump_flag_); // blockDim is reserved parameter, set to 1 - rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(so_name_.c_str()), - reinterpret_cast(kernel_name_.c_str()), 1, args_, args_size_, - nullptr, stream_, dump_flag_); + std::string op_name = op_desc_->GetName(); + rtKernelLaunchNames_t launch_name = {so_name_.c_str(), kernel_name_.c_str(), op_name.c_str()}; + rt_ret = rtAicpuKernelLaunchWithFlag(&launch_name, 1, args_, args_size_, + nullptr, stream_, dump_flag_); call_save_dump_ = true; } else { /* default: not skt launch */ diff --git a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc index 820c9b56..cf20303c 100755 --- a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc +++ b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc @@ -477,7 +477,7 @@ Status AicpuTfNodeTask::CopyDataToHbm(TaskContext &context, GE_CHK_STATUS_RET_NOLOG(PrepareCopyInputs(context, out_shape_hbm)); RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[LaunchCopy] Start"); - GE_CHK_RT_RET(rtKernelLaunchEx(copy_task_args_buf_->GetData(), sizeof(STR_FWK_OP_KERNEL), + GE_CHK_RT_RET(rtKernelLaunchFwk(node_name_.c_str(), copy_task_args_buf_->GetData(), sizeof(STR_FWK_OP_KERNEL), RT_KERNEL_DEFAULT, context.GetStream())); RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[LaunchCopy] End"); @@ -638,7 +638,8 @@ Status AicpuTfNodeTask::LaunchTask(TaskContext &context) { GELOGD("Node[%s] launch task start, unknown_type=%d.", node_name_.c_str(), unknown_type_); uint32_t flag = RT_KERNEL_DEFAULT; RECORD_EXECUTION_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[AicpuTfNodertKernelLaunchEx] Start"); - GE_CHK_RT_RET(rtKernelLaunchEx(kernel_buf_->GetData(), kernel_buf_->GetSize(), flag, context.GetStream())); + GE_CHK_RT_RET(rtKernelLaunchFwk(node_name_.c_str(), kernel_buf_->GetData(), + kernel_buf_->GetSize(), flag, context.GetStream())); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[AicpuTfNodertKernelLaunchEx] End"); GELOGD("Node[%s] launch end.", node_name_.c_str()); if (need_sync_) { @@ -819,11 +820,11 @@ Status AicpuNodeTask::LaunchTask(TaskContext &context) { if (kernel_type == ccKernelType::CUST_AI_CPU) { flag |= static_cast(RT_KERNEL_CUSTOM_AICPU); } - auto rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(so_name.c_str()), - reinterpret_cast(kernel_name.c_str()), - 1, // default core dim is 1 - args_.get(), args_size_, - nullptr, context.GetStream(), flag); + rtKernelLaunchNames_t launch_name = {so_name.c_str(), kernel_name.c_str(), node_name_.c_str()}; + auto rt_ret = rtAicpuKernelLaunchWithFlag(&launch_name, + 1, // default core dim is 1 + args_.get(), args_size_, + nullptr, context.GetStream(), flag); GE_CHK_RT_RET(rt_ret); GELOGD("Node[%s] launch task end.", node_name_.c_str()); return SUCCESS; diff --git a/inc/external/OWNERS b/inc/external/OWNERS new file mode 100644 index 00000000..934272a6 --- /dev/null +++ b/inc/external/OWNERS @@ -0,0 +1,10 @@ +approvers: +- gegenhua +reviewers: +- wqtshg +- ji_chen +- xchu42 +- sheng-nan +- wangxiaotian22 +- zhangxiaokun9 +- tangqunzhang diff --git a/tests/depends/runtime/src/runtime_stub.cc b/tests/depends/runtime/src/runtime_stub.cc index 0c9e2c27..510eb1ad 100644 --- a/tests/depends/runtime/src/runtime_stub.cc +++ b/tests/depends/runtime/src/runtime_stub.cc @@ -460,6 +460,21 @@ rtError_t rtDebugUnRegisterForStream(rtStream_t stream) { rtError_t rtFftsTaskLaunch(rtFftsTaskInfo_t *fftsTaskInfo, rtStream_t stream) { return RT_ERROR_NONE; } + +rtError_t rtKernelLaunchFwk(const char *opName, void *args, uint32_t argSize, uint32_t flags, rtStream_t rtStream) { + return RT_ERROR_NONE; +} + +rtError_t rtAicpuKernelLaunchWithFlag(const rtKernelLaunchNames_t *launchNames, uint32_t blockDim, const void *args, + uint32_t argSize, rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags) { + return RT_ERROR_NONE; +} + +rtError_t rtAicpuKernelLaunch(const rtKernelLaunchNames_t *launchNames, uint32_t blockDim, const void *args, + uint32_t argSize, rtSmDesc_t *smDesc, rtStream_t stream) { + return RT_ERROR_NONE; +} + #ifdef __cplusplus } #endif diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 42fa6128..ebaee921 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -735,6 +735,7 @@ set(HYBRID_TEST_FILES "hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc" "hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" "hybrid/node_executor/hccl/hccl_node_executor_unittest.cc" + "hybrid/node_executor/aicpu/aicpu_node_executor_unittest.cc" "hybrid/executor/hybrid_model_async_executor_unittest.cc" "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" diff --git a/tests/ut/ge/hybrid/node_executor/aicpu/aicpu_node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/aicpu/aicpu_node_executor_unittest.cc new file mode 100644 index 00000000..b225949b --- /dev/null +++ b/tests/ut/ge/hybrid/node_executor/aicpu/aicpu_node_executor_unittest.cc @@ -0,0 +1,168 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#define private public +#define protected public +#include "graph/runtime_inference_context.h" +#include "aicpu/common/aicpu_task_struct.h" +#include "hybrid/executor/subgraph_context.h" +#include "hybrid/node_executor/aicpu/aicpu_node_executor.h" +#undef protected +#undef private + +using namespace std; +using namespace testing; + +namespace { +struct AicpuTaskStruct { + aicpu::AicpuParamHead head; + uint64_t io_addrp[6]; +}__attribute__((packed)); +} // namespace + +namespace ge { +using namespace hybrid; + +class UtestAicpuNodeExecutor : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +static NodePtr CreateNode(ComputeGraphPtr graph, const string &name, const string &type, int in_num, int out_num) { + OpDescPtr op_desc = std::make_shared(name, type); + op_desc->SetStreamId(0); + static int32_t index = 0; + op_desc->SetId(index++); + + GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); + TensorUtils::SetSize(tensor, 64); + vector input_offset; + for (int i = 0; i < in_num; i++) { + op_desc->AddInputDesc(tensor); + input_offset.emplace_back(i * 64); + } + op_desc->SetInputOffset(input_offset); + + vector output_offset; + for (int i = 0; i < out_num; i++) { + op_desc->AddOutputDesc(tensor); + output_offset.emplace_back(in_num * 64 + i * 64); + } + op_desc->SetOutputOffset(output_offset); + + return graph->AddNode(op_desc); +} + +TEST_F(UtestAicpuNodeExecutor, aicpu_tf_node_task) { + ComputeGraphPtr graph = std::make_shared("test"); + GeModelPtr ge_sub_model = std::make_shared(); + GeRootModelPtr ge_root_model = std::make_shared(graph); + ge_root_model->SetModelName("test_name"); + ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); + HybridModel hybrid_model(ge_root_model); + + NodePtr node = CreateNode(graph, "frameworkop", FRAMEWORK_OP_TYPE, 4, 2); + + std::unique_ptr new_node; + ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); + NodeItem *node_item = new_node.get(); + hybrid_model.node_items_[node] = std::move(new_node); + node_item->input_start = 0; + node_item->output_start = 0; + node_item->is_dynamic = true; + node_item->shape_inference_type = DEPEND_COMPUTE; + + GraphItem graph_item; + graph_item.node_items_.emplace_back(node_item); + graph_item.total_inputs_ = 4; + graph_item.total_outputs_ = 2; + + GraphExecutionContext graph_context; + SubgraphContext subgraph_context(&graph_item, &graph_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); + graph_context.callback_manager = std::unique_ptr(new CallbackManager()); + + auto node_state = subgraph_context.GetOrCreateNodeState(node_item); + ASSERT_NE(node_state, nullptr); + + for (int i=0; i<4; ++i) { + uint64_t value_0 = 512; + TensorValue in_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetInput(*node_item, 0, in_tensor0); + } + + uint64_t value_0 = 512; + TensorValue out_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetOutput(*node_item, 0, out_tensor0); + + uint64_t value_1 = 512; + TensorValue out_tensor1(&value_1, sizeof(value_1)); + subgraph_context.SetOutput(*node_item, 1, out_tensor1); + + // task + domi::TaskDef task_def; + domi::KernelExDef *kernel_ex_def = task_def.mutable_kernel_ex(); + kernel_ex_def->set_kernel_ext_info_size(12); + + AicpuExtInfo aicpu_ext_info; + aicpu_ext_info.infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_SHAPE_TYPE; + aicpu_ext_info.infoLen = sizeof(int32_t); + int32_t type = node_item->shape_inference_type; + memcpy_s(aicpu_ext_info.infoMsg, sizeof(int32_t), &type, sizeof(int32_t)); + char *ext_mem = (char*)malloc(sizeof(AicpuExtInfo) + sizeof(int32_t)); + memcpy_s(ext_mem, sizeof(AicpuExtInfo) + sizeof(int32_t), &aicpu_ext_info, sizeof(AicpuExtInfo) + sizeof(int32_t)); + std::string ext_info(ext_mem, sizeof(AicpuExtInfo) + sizeof(int32_t)); + + std::string *mutable_ext_info = kernel_ex_def->mutable_kernel_ext_info(); + (*mutable_ext_info) = ext_info; + + hybrid_model.task_defs_[node] = std::vector({task_def, task_def}); + + AicpuTfNodeTask aicpu_tf_node_task(node_item, task_def); + + ASSERT_EQ(aicpu_tf_node_task.Init(hybrid_model), SUCCESS); + ASSERT_EQ(aicpu_tf_node_task.LaunchTask(*node_state->GetTaskContext()), SUCCESS); + + AicpuTaskStruct args; + args.head.length = sizeof(args); + args.head.ioAddrNum = 6; + + domi::TaskDef task_def2; + task_def2.set_type(RT_MODEL_TASK_ALL_KERNEL); + task_def2.mutable_kernel()->set_args(reinterpret_cast(&args), args.head.length); + task_def2.mutable_kernel()->set_args_size(args.head.length); + + hybrid_model.task_defs_[node] = std::vector({task_def2}); + + AicpuNodeTask aicpu_node_task(node_item, task_def); + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), FAILED); + ASSERT_EQ(aicpu_node_task.LaunchTask(*node_state->GetTaskContext()), SUCCESS); + + + //kernel_ex_def->set_allocated_kernel_ext_info(nullptr); + + free(ext_mem); + +} + +} // namespace ge + From 45be54175221c42c1bcb7e46b230fc79c8aa2d02 Mon Sep 17 00:00:00 2001 From: lichun Date: Wed, 14 Jul 2021 19:59:55 +0800 Subject: [PATCH 11/33] assign graph memory max size and variable memory max size adaptively --- ge/graph/manager/graph_var_manager.cc | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/ge/graph/manager/graph_var_manager.cc b/ge/graph/manager/graph_var_manager.cc index 5138a0f5..d0669254 100755 --- a/ge/graph/manager/graph_var_manager.cc +++ b/ge/graph/manager/graph_var_manager.cc @@ -795,18 +795,15 @@ Status VarManager::GetTotalMemorySize(size_t &total_mem_size) { Status VarManager::SetMemoryMallocSize(const map &options) { size_t total_mem_size = 0; - Status ret = VarManager::GetTotalMemorySize(total_mem_size); - if (ret != SUCCESS) { - return ret; - } + GE_CHK_STATUS_RET_NOLOG(VarManager::GetTotalMemorySize(total_mem_size)); GEEVENT("Total memory size is %zu", total_mem_size); graph_mem_max_size_ = floor(total_mem_size * kGraphMemoryManagerMallocRatio); var_mem_max_size_ = floor(total_mem_size * kVarMemoryManagerMallocRatio); - auto it = options.find(GRAPH_MEMORY_MAX_SIZE); - if (it != options.end()) { - string graph_memory_manager_malloc_max_size = it->second; + auto it1 = options.find(GRAPH_MEMORY_MAX_SIZE); + if (it1 != options.end()) { + string graph_memory_manager_malloc_max_size = it1->second; ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_); if (ret != SUCCESS) { GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); @@ -814,9 +811,9 @@ Status VarManager::SetMemoryMallocSize(const map &options) { } } - it = options.find(VARIABLE_MEMORY_MAX_SIZE); - if (it != options.end()) { - string memory_var_manager_malloc_size = it->second; + auto it2 = options.find(VARIABLE_MEMORY_MAX_SIZE); + if (it2 != options.end()) { + string memory_var_manager_malloc_size = it2->second; ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_); if (ret != SUCCESS) { GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); From f6755b5681a5ed5a3618b3aa79d77b1e8c1680c2 Mon Sep 17 00:00:00 2001 From: chenyemeng Date: Thu, 15 Jul 2021 18:51:46 +0800 Subject: [PATCH 12/33] revert --- ge/CMakeLists.txt | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index f83d2607..d1a0da0f 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -109,9 +109,15 @@ endif () ################################################################## set(EXECUTOR_SRC_LIST + "common/dump/dump_manager.cc" "common/dump/dump_op.cc" "common/dump/exception_dumper.cc" "common/dump/opdebug_register.cc" + "common/formats/format_transfers/format_transfer_transpose.cc" + "common/formats/utils/formats_trans_utils.cc" + "common/fp16_t.cc" + "common/ge/op_tiling_manager.cc" + "common/ge/plugin_manager.cc" "common/profiling/ge_profiling.cc" "common/profiling/profiling_manager.cc" "executor/ge_executor.cc" @@ -258,6 +264,29 @@ set(EXECUTOR_SRC_LIST set(COMPILER_SRC_LIST "analyzer/analyzer.cc" "common/dump/dump_op.cc" + "common/dump/dump_properties.cc" + "common/formats/format_transfers/datatype_transfer.cc" + "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" + "common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" + "common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" + "common/formats/format_transfers/format_transfer_fractal_nz.cc" + "common/formats/format_transfers/format_transfer_fractal_z.cc" + "common/formats/format_transfers/format_transfer_fractal_zz.cc" + "common/formats/format_transfers/format_transfer_fracz_hwcn.cc" + "common/formats/format_transfers/format_transfer_fracz_nchw.cc" + "common/formats/format_transfers/format_transfer_fracz_nhwc.cc" + "common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" + "common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" + "common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" + "common/formats/format_transfers/format_transfer_nchw_fz_c04.cc" + "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" + "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" + "common/formats/format_transfers/format_transfer_transpose.cc" + "common/formats/formats.cc" + "common/formats/utils/formats_trans_utils.cc" + "common/fp16_t.cc" + "common/ge/op_tiling_manager.cc" + "common/ge/plugin_manager.cc" "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" "engine_manager/dnnengine_manager.cc" From 4be882056625080d270df9e9eecebb428da3bed6 Mon Sep 17 00:00:00 2001 From: wqtshg Date: Thu, 15 Jul 2021 19:05:14 +0800 Subject: [PATCH 13/33] delete compiling macros --- CMakeLists.txt | 8 ++------ cmake/external_libs/gflags.cmake | 15 ++++++++++----- ge/CMakeLists.txt | 4 +--- ge/offline/CMakeLists.txt | 6 ++---- metadef | 2 +- 5 files changed, 16 insertions(+), 19 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ac0240d9..60509838 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,11 +88,9 @@ else () find_module(hccl libhccl.so ${GE_LIB_PATH}) find_module(adump_server libadump_server.a ${GE_LIB_PATH}) find_module(runtime libruntime.so ${GE_LIB_PATH}) - find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) find_module(resource libresource.so ${GE_LIB_PATH}) find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) - #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) else() find_module(slog libalog.so ${ASCEND_ATC_DIR}) find_module(opt_feature libopt_feature.so ${ASCEND_ATC_DIR}) @@ -108,7 +106,6 @@ else () elseif(PLATFORM STREQUAL "inference") find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) - find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) if(PRODUCT STREQUAL "flr3") elseif(PRODUCT STREQUAL "flr1") @@ -120,10 +117,9 @@ else () endif() elseif(PLATFORM STREQUAL "all") find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) - find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) + find_module(runtime libruntime.so ${ASCEND_ATC_DIR}) find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) - find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) - find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) + find_module(ascend_hal_stub libascend_hal.so ${ASCEND_ATC_DIR}/stub) find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) else() message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") diff --git a/cmake/external_libs/gflags.cmake b/cmake/external_libs/gflags.cmake index 50cfb2bc..b4b57dd7 100755 --- a/cmake/external_libs/gflags.cmake +++ b/cmake/external_libs/gflags.cmake @@ -10,12 +10,17 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") endif() -if (ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz") - set(MD5 "") +if (GE_PB_PKG) + set(REQ_URL "${GE_PB_PKG}/libs/gflags/v2.2.2.tar.gz") + set(MD5 "1a865b93bacfa963201af3f75b7bd64c") else() - set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz") - set(MD5 "") + if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz") + set(MD5 "") + else() + set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz") + set(MD5 "1a865b93bacfa963201af3f75b7bd64c") + endif () endif () set (gflags_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -Dgoogle=ascend_private") diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index f83d2607..cd255c79 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -593,7 +593,6 @@ target_compile_definitions(ge_compiler PRIVATE REUSE_MEMORY=1 FMK_SUPPORT_DUMP FMK_HOST_INFER - COMPILE_OMG_PACKAGE google=ascend_private FUNC_VISIBILITY $<$:ONLY_COMPILE_OPEN_SRC> @@ -655,8 +654,7 @@ target_link_libraries(ge_compiler PRIVATE c_sec error_manager slog - $<$>:$> - $<$:$> + runtime opt_feature -Wl,--as-needed json diff --git a/ge/offline/CMakeLists.txt b/ge/offline/CMakeLists.txt index e11e4a03..935d8a30 100644 --- a/ge/offline/CMakeLists.txt +++ b/ge/offline/CMakeLists.txt @@ -22,7 +22,6 @@ target_compile_options(atc_atc.bin PRIVATE target_compile_definitions(atc_atc.bin PRIVATE PROTOBUF_INLINE_NOT_IN_HEADERS=0 - COMPILE_OMG_PACKAGE google=ascend_private LOG_CPP FUNC_VISIBILITY @@ -48,6 +47,7 @@ target_include_directories(atc_atc.bin PRIVATE target_link_options(atc_atc.bin PRIVATE -Wl,-Bsymbolic + -Wl,-rpath-link,${ASCEND_ATC_DIR}/stub ) target_link_libraries(atc_atc.bin PRIVATE @@ -62,8 +62,7 @@ target_link_libraries(atc_atc.bin PRIVATE parser_common gflags json - $<$>:$> - $<$:$> + runtime slog static_mmpa -lrt @@ -92,7 +91,6 @@ target_compile_options(fwk_atc.bin PRIVATE target_compile_definitions(fwk_atc.bin PRIVATE PROTOBUF_INLINE_NOT_IN_HEADERS=0 - COMPILE_OMG_PACKAGE google=ascend_private LOG_CPP FUNC_VISIBILITY diff --git a/metadef b/metadef index 5a9605f6..a725349b 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 5a9605f6cb1204a729a51fe36bc614cf1d94a496 +Subproject commit a725349b65aef2940555af2ddb7b9461fbe0d5fd From 207bf69c20a5953ae01499434922244161e67206 Mon Sep 17 00:00:00 2001 From: "gengchao4@huawei.com" Date: Thu, 15 Jul 2021 20:01:58 +0800 Subject: [PATCH 14/33] bugfix for taskdef's random variation in offline case --- ge/graph/build/task_generator.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ge/graph/build/task_generator.cc b/ge/graph/build/task_generator.cc index 7bb2e2f6..1adcd0aa 100755 --- a/ge/graph/build/task_generator.cc +++ b/ge/graph/build/task_generator.cc @@ -50,6 +50,7 @@ const char *const kIsInputVar = "INPUT_IS_VAR"; const char *const kIsOutputVar = "OUTPUT_IS_VAR"; const char *const kProfilingMode = "PROFILING_MODE"; const char *const kIteratorV2 = "IteratorV2"; +const char *const kKernelInfoNameHccl = "ops_kernel_info_hccl"; const uint32_t kProfilingArStep = 2; const uint64_t kProfilingFpStartLogid = 1; const uint64_t kProfilingBpEndLogid = 2; @@ -437,14 +438,15 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra } // Reset stream id to ge stream id, as graph load must use ge stream to reassign stream - void *ops_kernel_info_store_ptr = kernel_info_store.get(); for (size_t idx = task_list_size_before; idx < task_list_size_after; ++idx) { task_def_list[idx].set_stream_id(static_cast(stream_id)); op_name_map[idx] = name; - // Set opsKernelInfoStorePtr and op_index, the two fields be use in DistributeTask and InitTaskInfo TaskDef *task_def_ptr = &task_def_list[idx]; GE_CHECK_NOTNULL(task_def_ptr); - task_def_ptr->set_ops_kernel_store_ptr(reinterpret_cast(ops_kernel_info_store_ptr)); + // Set opsKernelInfoStorePtr for hccl which will be use in DistributeTask and InitTaskInfo + if (op_kernel_lib_name == kKernelInfoNameHccl) { + task_def_ptr->set_ops_kernel_store_ptr(reinterpret_cast(kernel_info_store.get())); + } } GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task finished, generate %zu task(s).", op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, From 4132d6dcd22bad7c9c73a5f3e12a62051478a528 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Thu, 15 Jul 2021 20:19:28 +0800 Subject: [PATCH 15/33] Delete common format_transfers files --- ge/CMakeLists.txt | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index d9ef5eef..0236e8bd 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -109,13 +109,9 @@ endif () ################################################################## set(EXECUTOR_SRC_LIST - "common/dump/dump_manager.cc" "common/dump/dump_op.cc" "common/dump/exception_dumper.cc" "common/dump/opdebug_register.cc" - "common/formats/format_transfers/format_transfer_transpose.cc" - "common/formats/utils/formats_trans_utils.cc" - "common/fp16_t.cc" "common/ge/op_tiling_manager.cc" "common/ge/plugin_manager.cc" "common/profiling/ge_profiling.cc" @@ -264,27 +260,6 @@ set(EXECUTOR_SRC_LIST set(COMPILER_SRC_LIST "analyzer/analyzer.cc" "common/dump/dump_op.cc" - "common/dump/dump_properties.cc" - "common/formats/format_transfers/datatype_transfer.cc" - "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" - "common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" - "common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" - "common/formats/format_transfers/format_transfer_fractal_nz.cc" - "common/formats/format_transfers/format_transfer_fractal_z.cc" - "common/formats/format_transfers/format_transfer_fractal_zz.cc" - "common/formats/format_transfers/format_transfer_fracz_hwcn.cc" - "common/formats/format_transfers/format_transfer_fracz_nchw.cc" - "common/formats/format_transfers/format_transfer_fracz_nhwc.cc" - "common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" - "common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" - "common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" - "common/formats/format_transfers/format_transfer_nchw_fz_c04.cc" - "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" - "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" - "common/formats/format_transfers/format_transfer_transpose.cc" - "common/formats/formats.cc" - "common/formats/utils/formats_trans_utils.cc" - "common/fp16_t.cc" "common/ge/op_tiling_manager.cc" "common/ge/plugin_manager.cc" "common/helper/model_cache_helper.cc" From 051d0e9fab55a2530b364ecea1e98c1705e308de Mon Sep 17 00:00:00 2001 From: zhaozhixuan Date: Thu, 15 Jul 2021 20:43:09 +0800 Subject: [PATCH 16/33] Fix bug of single_op. --- ge/single_op/single_op.cc | 4 +++- ge/single_op/task/op_task.cc | 25 ++++++++++++++++++++---- ge/single_op/task/op_task.h | 5 +++-- ge/single_op/task/tbe_task_builder.cc | 2 +- tests/ut/ge/single_op/single_op_task_unittest.cc | 20 +++++++++++++++++++ 5 files changed, 48 insertions(+), 8 deletions(-) diff --git a/ge/single_op/single_op.cc b/ge/single_op/single_op.cc index a82c30ba..23f4cfad 100755 --- a/ge/single_op/single_op.cc +++ b/ge/single_op/single_op.cc @@ -433,11 +433,13 @@ Status DynamicSingleOp::ExecuteAsync(const vector &input_desc, if (!inputs_size.empty()) { StreamResource *stream_resource = SingleOpManager::GetInstance().GetResource(resource_id_, stream_); GE_CHK_STATUS_RET_NOLOG(UpdateInputsBufferAddr(stream_resource, stream_, inputs_size, update_buffers)); - GE_CHK_STATUS_RET_NOLOG(SetHostTensorValue(input_desc, input_buffers)); } if (hybrid_model_executor_ != nullptr) { GELOGD("Execute multi-task dynamic single op by hybrid model executor"); + if (!inputs_size.empty()) { + GE_CHK_STATUS_RET_NOLOG(SetHostTensorValue(input_desc, input_buffers)); + } hybrid::HybridModelExecutor::ExecuteArgs args; GE_CHK_STATUS_RET_NOLOG(InitHybridModelArgs(update_buffers, output_buffers, input_desc, args)); diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc index dbc90ac5..fd6639a5 100755 --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -294,16 +294,15 @@ Status TbeOpTask::UpdateNodeByShape(const vector &input_desc, cons Status TbeOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size) { if (tiling_buffer != nullptr) { - uintptr_t *arg_base = nullptr; - size_t arg_num = 0; - GetIoAddr(arg_base, arg_num); + uintptr_t *arg_base = reinterpret_cast(args_.get()); + size_t arg_num = arg_size_ / sizeof(void *); GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); uint32_t inputs_num = node->GetOpDesc()->GetInputsSize(); uint32_t outputs_num = node->GetOpDesc()->GetOutputsSize(); uint32_t workspace_nums = node->GetOpDesc()->GetWorkspace().size(); uint32_t tiling_index = inputs_num + outputs_num + workspace_nums; - if (arg_num == 0 || arg_num < tiling_index) { + if (arg_num == 0 || arg_num <= tiling_index) { GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Size]Tiling index %u, arg number %zu is invalid.", tiling_index, arg_num); return ACL_ERROR_GE_INTERNAL_ERROR; @@ -481,6 +480,24 @@ void TbeOpTask::GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) { } } +Status AtomicAddrCleanOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size) { + if (tiling_buffer != nullptr) { + uintptr_t *arg_base = reinterpret_cast(args_.get()); + size_t arg_num = arg_size_ / sizeof(void *); + uint32_t tiling_index = atomic_output_indices_.size(); + if (arg_num == 0 || arg_num <= tiling_index) { + GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Size]Tiling index %u, arg number %zu is invalid.", + tiling_index, arg_num); + return ACL_ERROR_GE_INTERNAL_ERROR; + } + arg_base[tiling_index] = reinterpret_cast(tiling_buffer); + } + node_ = node; + tiling_buffer_ = tiling_buffer; + max_tiling_size_ = max_tiling_size; + return SUCCESS; +} + Status AtomicAddrCleanOpTask::UpdateNodeByShape(const vector &input_desc, const vector &output_desc) { return SUCCESS; diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index 132672b0..4a839389 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -97,7 +97,7 @@ class TbeOpTask : public OpTask { const void *GetArgs() const; size_t GetArgSize() const; const std::string &GetStubName() const; - Status EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size); + virtual Status EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size); const std::string &GetTaskType() const override; void SetHandle(void *handle); @@ -149,6 +149,7 @@ class TbeOpTask : public OpTask { class AtomicAddrCleanOpTask : public TbeOpTask { public: Status InitAtomicAddrCleanIndices(); + Status EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size) override; private: Status UpdateNodeByShape(const vector &input_desc, @@ -156,8 +157,8 @@ class AtomicAddrCleanOpTask : public TbeOpTask { Status UpdateIoAddr(const vector &inputs, const vector &outputs) override; Status UpdateTilingArgs(rtStream_t stream) override; Status CalcTilingInfo(optiling::utils::OpRunInfo &run_info) override; - std::vector atomic_output_indices_; + std::vector atomic_output_indices_; }; class AiCpuBaseTask : public OpTask { diff --git a/ge/single_op/task/tbe_task_builder.cc b/ge/single_op/task/tbe_task_builder.cc index 017dac25..f947ca57 100644 --- a/ge/single_op/task/tbe_task_builder.cc +++ b/ge/single_op/task/tbe_task_builder.cc @@ -425,7 +425,7 @@ Status TbeTaskBuilder::InitTilingInfo(TbeOpTask &task) { GELOGD("[%s] Done allocating tiling buffer, size=%ld.", op_desc_->GetName().c_str(), max_size); } - task.EnableDynamicSupport(node_, tiling_buffer, static_cast(max_size)); + GE_CHK_STATUS_RET_NOLOG(task.EnableDynamicSupport(node_, tiling_buffer, static_cast(max_size))); return SUCCESS; } diff --git a/tests/ut/ge/single_op/single_op_task_unittest.cc b/tests/ut/ge/single_op/single_op_task_unittest.cc index 8964df74..5960fbbc 100644 --- a/tests/ut/ge/single_op/single_op_task_unittest.cc +++ b/tests/ut/ge/single_op/single_op_task_unittest.cc @@ -237,3 +237,23 @@ TEST_F(UtestSingleOpTask, test_aicpu_task_update_io_addr) { ASSERT_EQ(ret, PARAM_INVALID); } } + +TEST_F(UtestSingleOpTask, test_dynamic_support) { + auto graph = make_shared("graph"); + auto op_desc = make_shared("Add", "Add"); + auto node = graph->AddNode(op_desc); + AtomicAddrCleanOpTask atomic_task; + TbeOpTask tbe_task; + + ASSERT_EQ(tbe_task.EnableDynamicSupport(node, (void *)0x0001, 1), ACL_ERROR_GE_INTERNAL_ERROR); + ASSERT_EQ(atomic_task.EnableDynamicSupport(node, (void *)0x0001, 1), ACL_ERROR_GE_INTERNAL_ERROR); + + tbe_task.arg_size_ = sizeof(void *); + tbe_task.args_.reset(new (std::nothrow) uint8_t[tbe_task.arg_size_]); + atomic_task.arg_size_ = sizeof(void *); + atomic_task.args_.reset(new (std::nothrow) uint8_t[atomic_task.arg_size_]); + ASSERT_EQ(tbe_task.EnableDynamicSupport(node, (void *)0x0001, 1), SUCCESS); + ASSERT_EQ(atomic_task.EnableDynamicSupport(node, (void *)0x0001, 1), SUCCESS); + tbe_task.tiling_buffer_ = nullptr; + atomic_task.tiling_buffer_ = nullptr; +} From 927439cb92722d36401af139899288d824f333c2 Mon Sep 17 00:00:00 2001 From: lichun Date: Fri, 16 Jul 2021 11:14:50 +0800 Subject: [PATCH 17/33] fix error code and add complex128 support --- ge/generator/ge_generator.cc | 1 - ge/graph/build/memory/graph_mem_assigner.cc | 2 +- ge/graph/build/memory/memory_assigner.cc | 5 +++-- ge/graph/manager/graph_manager.cc | 4 ++-- ge/offline/single_op_parser.cc | 3 ++- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 45eaed59..d35d7d6e 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -1157,7 +1157,6 @@ Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector if (ret != SUCCESS) { REPORT_CALL_ERROR("E19999", "build graph failed, graph id:%u, ret:%d", graph_id, ret); GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "[Build][Graph] fail, graph id: %u", graph_id); - ret = GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; } RtContextUtil::GetInstance().DestroyRtContexts(session_id); diff --git a/ge/graph/build/memory/graph_mem_assigner.cc b/ge/graph/build/memory/graph_mem_assigner.cc index f8878383..542b6215 100755 --- a/ge/graph/build/memory/graph_mem_assigner.cc +++ b/ge/graph/build/memory/graph_mem_assigner.cc @@ -275,7 +275,7 @@ Status GraphMemoryAssigner::ReAssignMemory(bool is_loop_graph, map({"size", "item", "maxsize"}), std::vector({std::to_string(total_mem_offset), "featuremap", std::to_string(VarManager::Instance(session_id)->GetGraphMemoryMaxSize())})); - return ge::FAILED; + return ACL_ERROR_GE_MEMORY_ALLOCATION; } return SUCCESS; } diff --git a/ge/graph/build/memory/memory_assigner.cc b/ge/graph/build/memory/memory_assigner.cc index 6e49827f..5846e922 100755 --- a/ge/graph/build/memory/memory_assigner.cc +++ b/ge/graph/build/memory/memory_assigner.cc @@ -29,9 +29,10 @@ Status MemoryAssigner::AssignMemory(bool is_loop_graph, map &m } // Reassign memory for special nodes - if (graph_mem_assigner.ReAssignMemory(is_loop_graph, mem_offset) != ge::SUCCESS) { + Status ret = graph_mem_assigner.ReAssignMemory(is_loop_graph, mem_offset) + if (ret != ge::SUCCESS) { GELOGE(ge::FAILED, "[ReAssign][Memory] failed, graph:%s", compute_graph_->GetName().c_str()); - return ge::FAILED; + return ret; } // Assign memory (block and offset) for zero copy nodes diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 7d72d85b..9749010a 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -1482,8 +1482,8 @@ Status GraphManager::BuildGraph(const GraphId &graph_id, const std::vectorSetRunFlag(false); if (ret != SUCCESS) { - GELOGE(GE_GRAPH_PRERUN_FAILED, "[Call][StartForRunGraph] failed! graph_id:%u.", graph_id); - return GE_GRAPH_PRERUN_FAILED; + GELOGE(ret, "[Call][StartForRunGraph] failed! graph_id:%u.", graph_id); + return ret; } GELOGI("[BuildGraph] build graph success, graph_id=%u.", graph_id); diff --git a/ge/offline/single_op_parser.cc b/ge/offline/single_op_parser.cc index 6bc5cb3d..aeb73116 100644 --- a/ge/offline/single_op_parser.cc +++ b/ge/offline/single_op_parser.cc @@ -89,7 +89,8 @@ map kDataTypeDict = { {"float", DT_FLOAT}, {"float32", DT_FLOAT}, {"double", DT_DOUBLE}, - {"complex64", DT_COMPLEX64} + {"complex64", DT_COMPLEX64}, + {"complex128", DT_COMPLEX128} }; map kFormatDict = { From a5137fb87f65cb8bd6c940f4d153f430692b767f Mon Sep 17 00:00:00 2001 From: lichun Date: Fri, 16 Jul 2021 11:23:15 +0800 Subject: [PATCH 18/33] fix error code and add complex128 support --- ge/graph/build/memory/memory_assigner.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ge/graph/build/memory/memory_assigner.cc b/ge/graph/build/memory/memory_assigner.cc index 5846e922..41171164 100755 --- a/ge/graph/build/memory/memory_assigner.cc +++ b/ge/graph/build/memory/memory_assigner.cc @@ -29,7 +29,7 @@ Status MemoryAssigner::AssignMemory(bool is_loop_graph, map &m } // Reassign memory for special nodes - Status ret = graph_mem_assigner.ReAssignMemory(is_loop_graph, mem_offset) + Status ret = graph_mem_assigner.ReAssignMemory(is_loop_graph, mem_offset); if (ret != ge::SUCCESS) { GELOGE(ge::FAILED, "[ReAssign][Memory] failed, graph:%s", compute_graph_->GetName().c_str()); return ret; From 21886e608e12983bbe3aecf74e053ed1707ce121 Mon Sep 17 00:00:00 2001 From: zhaozhixuan Date: Fri, 16 Jul 2021 12:00:31 +0800 Subject: [PATCH 19/33] Fix review advice. --- ge/single_op/task/op_task.cc | 26 +++++++++++++----------- tests/ut/ge/single_op/single_op_task_unittest.cc | 8 ++++++-- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc index fd6639a5..ee752022 100755 --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -293,25 +293,26 @@ Status TbeOpTask::UpdateNodeByShape(const vector &input_desc, cons } Status TbeOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size) { + node_ = node; + tiling_buffer_ = tiling_buffer; + max_tiling_size_ = max_tiling_size; if (tiling_buffer != nullptr) { - uintptr_t *arg_base = reinterpret_cast(args_.get()); - size_t arg_num = arg_size_ / sizeof(void *); + uintptr_t *arg_base = nullptr; + size_t arg_num = 0; + GetIoAddr(arg_base, arg_num); GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); uint32_t inputs_num = node->GetOpDesc()->GetInputsSize(); uint32_t outputs_num = node->GetOpDesc()->GetOutputsSize(); uint32_t workspace_nums = node->GetOpDesc()->GetWorkspace().size(); uint32_t tiling_index = inputs_num + outputs_num + workspace_nums; - if (arg_num == 0 || arg_num <= tiling_index) { + if (arg_num == 0 || arg_num < tiling_index) { GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Size]Tiling index %u, arg number %zu is invalid.", tiling_index, arg_num); return ACL_ERROR_GE_INTERNAL_ERROR; } arg_base[tiling_index] = reinterpret_cast(tiling_buffer); } - node_ = node; - tiling_buffer_ = tiling_buffer; - max_tiling_size_ = max_tiling_size; return SUCCESS; } @@ -481,20 +482,21 @@ void TbeOpTask::GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) { } Status AtomicAddrCleanOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size) { + node_ = node; + tiling_buffer_ = tiling_buffer; + max_tiling_size_ = max_tiling_size; if (tiling_buffer != nullptr) { - uintptr_t *arg_base = reinterpret_cast(args_.get()); - size_t arg_num = arg_size_ / sizeof(void *); + uintptr_t *arg_base = nullptr; + size_t arg_num = 0; + GetIoAddr(arg_base, arg_num); uint32_t tiling_index = atomic_output_indices_.size(); - if (arg_num == 0 || arg_num <= tiling_index) { + if (arg_num == 0 || arg_num < tiling_index) { GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Size]Tiling index %u, arg number %zu is invalid.", tiling_index, arg_num); return ACL_ERROR_GE_INTERNAL_ERROR; } arg_base[tiling_index] = reinterpret_cast(tiling_buffer); } - node_ = node; - tiling_buffer_ = tiling_buffer; - max_tiling_size_ = max_tiling_size; return SUCCESS; } diff --git a/tests/ut/ge/single_op/single_op_task_unittest.cc b/tests/ut/ge/single_op/single_op_task_unittest.cc index 5960fbbc..9a0381cd 100644 --- a/tests/ut/ge/single_op/single_op_task_unittest.cc +++ b/tests/ut/ge/single_op/single_op_task_unittest.cc @@ -245,12 +245,16 @@ TEST_F(UtestSingleOpTask, test_dynamic_support) { AtomicAddrCleanOpTask atomic_task; TbeOpTask tbe_task; + tbe_task.arg_size_ = sizeof(void *) * 1; + tbe_task.args_.reset(new (std::nothrow) uint8_t[tbe_task.arg_size_]); + atomic_task.arg_size_ = sizeof(void *) * 1; + atomic_task.args_.reset(new (std::nothrow) uint8_t[atomic_task.arg_size_]); ASSERT_EQ(tbe_task.EnableDynamicSupport(node, (void *)0x0001, 1), ACL_ERROR_GE_INTERNAL_ERROR); ASSERT_EQ(atomic_task.EnableDynamicSupport(node, (void *)0x0001, 1), ACL_ERROR_GE_INTERNAL_ERROR); - tbe_task.arg_size_ = sizeof(void *); + tbe_task.arg_size_ = sizeof(void *) * 2; tbe_task.args_.reset(new (std::nothrow) uint8_t[tbe_task.arg_size_]); - atomic_task.arg_size_ = sizeof(void *); + atomic_task.arg_size_ = sizeof(void *) * 2; atomic_task.args_.reset(new (std::nothrow) uint8_t[atomic_task.arg_size_]); ASSERT_EQ(tbe_task.EnableDynamicSupport(node, (void *)0x0001, 1), SUCCESS); ASSERT_EQ(atomic_task.EnableDynamicSupport(node, (void *)0x0001, 1), SUCCESS); From a029050b65f62ee5f52643babc092ec99efc7e6d Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Tue, 13 Jul 2021 21:10:15 +0800 Subject: [PATCH 20/33] support v1 infershape modified: ge/graph/passes/base_pass.cc modified: ge/graph/passes/base_pass.h modified: ge/graph/passes/infer_base_pass.cc modified: ge/graph/passes/infershape_pass.cc modified: ge/graph/passes/infershape_pass.h modified: ge/graph/preprocess/graph_preprocess.cc modified: tests/ut/ge/graph/passes/addn_pass_unittest.cc modified: tests/ut/ge/graph/passes/base_pass_unittest.cc modified: tests/ut/ge/graph/passes/infershape_pass_unittest.cc modified: ge/graph/passes/base_pass.cc modified: ge/graph/passes/base_pass.h modified: ge/graph/passes/infer_base_pass.cc modified: ge/graph/passes/infershape_pass.cc modified: ge/graph/passes/infershape_pass.h modified: ge/graph/preprocess/graph_preprocess.cc modified: tests/ut/ge/graph/passes/addn_pass_unittest.cc modified: tests/ut/ge/graph/passes/base_pass_unittest.cc modified: tests/ut/ge/graph/passes/infershape_pass_unittest.cc --- ge/graph/passes/base_pass.cc | 849 +++++++----- ge/graph/passes/base_pass.h | 121 +- ge/graph/passes/infer_base_pass.cc | 3 + ge/graph/passes/infershape_pass.cc | 545 +++++--- ge/graph/passes/infershape_pass.h | 94 +- ge/graph/preprocess/graph_preprocess.cc | 16 + tests/ut/ge/graph/passes/addn_pass_unittest.cc | 2 +- tests/ut/ge/graph/passes/base_pass_unittest.cc | 1426 +++++++++++++------- .../ut/ge/graph/passes/infershape_pass_unittest.cc | 423 +++--- 9 files changed, 2187 insertions(+), 1292 deletions(-) diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc index a1551eb2..8b4a8b88 100755 --- a/ge/graph/passes/base_pass.cc +++ b/ge/graph/passes/base_pass.cc @@ -1,374 +1,475 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/base_pass.h" - -#include -#include - -#include "framework/common/debug/log.h" -#include "framework/common/debug/ge_log.h" -#include "graph/compute_graph.h" -#include "graph/utils/graph_utils.h" - -namespace ge { -namespace { -constexpr int kMaxRePassTimes = 10000; -constexpr size_t kMaxOneInNodes = 1000; -// Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later -constexpr int kMaxRecursiveDepth = 20; -struct DuringPassNodeSets { - std::unordered_set nodes_seen; - std::unordered_set nodes_deleted; - std::unordered_set nodes_re_pass; - std::unordered_set nodes_re_pass_immediately; - std::unordered_set nodes_last; - std::unordered_set nodes_suspend; - std::unordered_set nodes_resume; -}; - -void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque &input_edge_nodes, - std::unordered_set &nodes_seen, std::unordered_set &nodes_last) { - nodes_last.clear(); - for (auto &node : graph->GetDirectNode()) { - if (node == nullptr) { - continue; - } - size_t in_nums = node->GetInNodes().size(); - if (in_nums == 0) { - input_edge_nodes.push_back(node); - nodes_seen.insert(node.get()); - } else if (in_nums > kMaxOneInNodes) { - nodes_last.insert(node); - } - } -} - -bool IsAllInNodesAlive(const Node::Vistor &nodes, const std::unordered_set &nodes_suspend) { - return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_suspend.count(n) > 0; }); -} - -void AddNextIterNodes(const Node::Vistor &nodes, std::deque &nodes_to_pass, - DuringPassNodeSets &during_pass_node_set) { - auto &nodes_seen = during_pass_node_set.nodes_seen; - const auto &nodes_last = during_pass_node_set.nodes_last; - const auto &nodes_suspend = during_pass_node_set.nodes_suspend; - for (auto &node : nodes) { - if (node == nullptr) { - continue; - } - if (nodes_last.count(node) != 0) { - continue; - } - if (nodes_suspend.count(node) > 0) { - GELOGD("The node %s has suspend by pass, skip it.", node->GetName().c_str()); - continue; - } - - bool all_in_nodes_alive = IsAllInNodesAlive(node->GetInAllNodes(), nodes_suspend); - bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); - if (all_in_nodes_seen && all_in_nodes_alive && nodes_seen.insert(node.get()).second) { - nodes_to_pass.push_back(node); - } - } -} - -void AddRepassNodes(DuringPassNodeSets &during_pass_node_set, std::deque &nodes) { - for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { - GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); - nodes.push_front(node); - } - during_pass_node_set.nodes_re_pass_immediately.clear(); -} - -void AddResumeNodes(DuringPassNodeSets &during_pass_node_set, std::deque &nodes) { - for (auto &node : during_pass_node_set.nodes_resume) { - const auto &it = during_pass_node_set.nodes_suspend.find(node); - if (it != during_pass_node_set.nodes_suspend.end()) { - during_pass_node_set.nodes_suspend.erase(node); - GELOGD("The node %s resumed by pass.", node->GetName().c_str()); - nodes.push_back(node); - } else { - GELOGW("The node %s not suspend, drop from resumed", node->GetName().c_str()); - } - } - during_pass_node_set.nodes_resume.clear(); -} - -void PushToSuspendNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, - const std::unordered_set &nodes_suspend, - const std::unordered_set &nodes_resume) { - for (const auto &node : nodes_suspend) { - GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str()); - during_pass_node_set.nodes_suspend.emplace(node); - } - - for (const auto &node : nodes_resume) { - GELOGD("The iteration suspend of node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str()); - during_pass_node_set.nodes_resume.emplace(node); - } -} - -void PushToRePassIfSeen(NodePtr &node, const std::pair &name_to_pass, - std::unordered_set &nodes_seen, const std::unordered_set &nodes_to_re_pass, - std::unordered_set &nodes_re_pass) { - for (const auto &node_to_re_pass : nodes_to_re_pass) { - if (node_to_re_pass == nullptr) { - GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(), - node->GetName().c_str(), node->GetType().c_str()); - continue; - } - if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { - GELOGD("The node %s will be re-pass.", node_to_re_pass->GetName().c_str()); - nodes_re_pass.insert(node_to_re_pass); - } else { - GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str()); - } - } -} - -Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNodeSets &during_pass_node_set) { - if (node == nullptr) { - REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); - GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); - return FAILED; - } - GELOGD("Begin to run pass for node %s", node->GetName().c_str()); - for (const auto &name_to_pass : names_to_passes) { - if (name_to_pass.second == nullptr) { - GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s), skip it", name_to_pass.first.c_str()); - continue; - } - - GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str()); - name_to_pass.second->init(); - auto result = name_to_pass.second->Run(node); - if (result != SUCCESS) { - REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", - name_to_pass.first.c_str(), node->GetName().c_str(), result); - GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result " - "%u, the passes will be terminated immediately.", - name_to_pass.first.c_str(), node->GetName().c_str(), result); - return result; - } - - const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); - PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, - during_pass_node_set.nodes_re_pass); - - const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); - PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, - during_pass_node_set.nodes_re_pass_immediately); - - PushToSuspendNodes(during_pass_node_set, name_to_pass.first, - name_to_pass.second->GetNodesSuspend(), name_to_pass.second->GetNodesResume()); - - const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); - during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); - if (nodes_deleted_by_pass.count(node) > 0) { - GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), - name_to_pass.first.c_str()); - break; - } - } - - return SUCCESS; -} - -void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) { - for (auto &name_to_pass : names_to_pass) { - name_to_pass.second->SetOption(option, ""); - } -} - -void ClearOption(NamesToPass names_to_pass) { - for (auto &name_to_pass : names_to_pass) { - name_to_pass.second->ClearOptions(); - } -} - -bool CheckNode(const NodePtr &node, const DuringPassNodeSets &during_pass_node_set) { - if (node == nullptr) { - GELOGW("node is null"); - return false; - } - if (during_pass_node_set.nodes_deleted.count(node) > 0) { - GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); - return false; - } - if (during_pass_node_set.nodes_suspend.count(node) > 0) { - GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", - node->GetName().c_str()); - return false; - } - - return true; -} -} // namespace - -Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map) { - if (node == nullptr) { - REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); - GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); - return FAILED; - } - GELOGI("Prepare to isolate and delete node, name:%s, type:%s.", node->GetName().c_str(), - node->GetType().c_str()); - ComputeGraphPtr graph = node->GetOwnerComputeGraph(); - if (graph == nullptr) { - REPORT_INNER_ERROR("E19999", "The owner graph of node:%s must not be null.", node->GetName().c_str()); - GELOGE(FAILED, "[Get][OwnerComputeGraph] failed, The owner graph of node:%s must not be null.", - node->GetName().c_str()); - return FAILED; - } - - AddRePassNodesWithInOut(node); - - if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str()); - GELOGE(FAILED, "[Isolate][Node] %s failed.", node->GetName().c_str()); - return FAILED; - } - - if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != SUCCESS) { - REPORT_CALL_ERROR("E19999", "call RemoveNodeWithoutRelink for node:%s failed.", node->GetName().c_str()); - GELOGE(FAILED, "[Call][RemoveNodeWithoutRelink] for node:%s failed.", node->GetName().c_str()); - return FAILED; - } - - AddNodeDeleted(node); - return SUCCESS; -} - -Status GEPass::Run(const NamesToPass &names_to_passes) { - if (graph_ == nullptr) { - REPORT_INNER_ERROR("E19999", "graph_ is nullptr, check invalid."); - GELOGE(INTERNAL_ERROR, "[Check][Param] The graph is nullptr"); - return INTERNAL_ERROR; - } - if (names_to_passes.empty()) { - GELOGW("No passes input, the GEPass will do nothing"); - return INTERNAL_ERROR; - } - - if (depth_ > kMaxRecursiveDepth) { - GELOGE(PARAM_INVALID, - "[Check][Param] The pass for root graph %s will be terminated because too many nesting" - " levels(%d) of subgraphs, last subgraph is %s", - root_graph_->GetName().c_str(), depth_, graph_->GetName().c_str()); - return PARAM_INVALID; - } - - return RunPassesOneGraph(names_to_passes); -} - -Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { - GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); - std::deque nodes; - DuringPassNodeSets during_pass_node_set; - GetAllNodesNoInputEdge(graph_, nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); - GELOGD("Start points count %zu", nodes.size()); - int re_pass_times = 0; - - do { - for (auto &node : during_pass_node_set.nodes_re_pass) { - nodes.push_back(node); - during_pass_node_set.nodes_seen.insert(node.get()); - } - during_pass_node_set.nodes_re_pass.clear(); - - while (!nodes.empty()) { - NodePtr node = nodes.front(); - nodes.pop_front(); - - (void)during_pass_node_set.nodes_re_pass.erase(node); - if (!CheckNode(node, during_pass_node_set)) { - continue; - } - AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); - - auto ret = RunPasses(node, names_to_passes, during_pass_node_set); - if (ret != SUCCESS) { - GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", - node->GetName().c_str(), node->GetType().c_str(), ret); - return ret; - } - - bool has_sub_graph = false; - ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph); - if (ret != SUCCESS) { - GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str()); - return ret; - } - - if (has_sub_graph) { - GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); - SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); - ret = RunPasses(node, names_to_passes, during_pass_node_set); - if (ret != SUCCESS) { - GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", - node->GetName().c_str(), node->GetType().c_str(), ret); - return ret; - } - - // There is only one option scene, so set and clear options around the `RunPasses` func. - // if there are more than one scene to set options, the `ClearOption` function - // should be called each time at the begin of the iteration - ClearOption(names_to_passes); - } - - AddRepassNodes(during_pass_node_set, nodes); - AddResumeNodes(during_pass_node_set, nodes); - } - - for (auto &node : during_pass_node_set.nodes_last) { - bool all_in_nodes_seen = node->IsAllInNodesSeen(during_pass_node_set.nodes_seen); - if (all_in_nodes_seen && during_pass_node_set.nodes_seen.insert(node.get()).second) { - nodes.push_back(node); - } - } - during_pass_node_set.nodes_last.clear(); - } while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); - - if (re_pass_times == kMaxRePassTimes) { - GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); - } - GELOGD("All passes runs end"); - - return SUCCESS; -} -Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { - auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); - has_sub_graph = false; - for (const auto &name : sub_graph_names) { - auto graph = root_graph_->GetSubgraph(name); - if (graph == nullptr) { - GELOGW("Can not find the sub graph %s from node %s, the pass-process will skip it", - name.c_str(), node->GetName().c_str()); - continue; - } - has_sub_graph = true; - GELOGI("Begin to run passes on the sub graph %s of node %s", name.c_str(), node->GetName().c_str()); - GEPass pass(graph, root_graph_, depth_ + 1); - auto ret = pass.Run(names_to_passes); - if (ret != SUCCESS) { - GELOGE(ret, "[Run][Passes] for sub graph:%s from node:%s failed", name.c_str(), node->GetName().c_str()); - return ret; - } - } - return SUCCESS; -} -} // namespace ge +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/base_pass.h" + +#include +#include + +#include "common/debug/log.h" +#include "graph/utils/graph_utils.h" + +namespace ge { +namespace { +constexpr int kMaxRePassTimes = 10000; +constexpr size_t kMaxOneInNodes = 1000; +// Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later +constexpr int kMaxRecursiveDepth = 20; + +void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, + GEPass::GraphLevelState &g_state) { + for (auto &node : graph->GetDirectNode()) { + if (node == nullptr) { + continue; + } + size_t in_nums = node->GetInNodes().size(); + if (in_nums == 0) { + g_state.AddNodeToQueueIfNotSeen(node); + } else if (in_nums > kMaxOneInNodes) { + g_state.nodes_last.insert(node); + } + } +} + +bool AnyNodesIn(const Node::Vistor &nodes, const std::unordered_set &nodes_set) { + return std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { + return nodes_set.count(n) > 0; + }); +} + + +bool IsNodeReadyToQueue(const NodePtr &node, GEPass::GraphLevelState &g_state) { + if (node == nullptr) { + GELOGW("node is null"); + return false; + } + if (g_state.nodes_deleted.count(node) > 0) { + GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); + return false; + } + + if (g_state.nodes_last.count(node) != 0) { + return false; + } + + // all in_node seen && all in_node not suspend + if (!node->IsAllInNodesSeen(g_state.nodes_seen)) { + return false; + } + + if (g_state.nodes_suspend.count(node) > 0) { + GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", + node->GetName().c_str()); + return false; + } + + if (AnyNodesIn(node->GetInAllNodes(), g_state.nodes_suspend)) { + GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", + node->GetName().c_str()); + return false; + } + return true; +} + +void AddNextIterNodes(const NodePtr &cur_node, + std::unordered_set &out_nodes_before_pass, + GEPass::GraphLevelState &g_state) { + for (auto &node : cur_node->GetOutNodes()) { + if (node == nullptr) { + continue; + } + if(out_nodes_before_pass.erase(node) == 0) { + // after pass node, new output node come up + GELOGI("New output node %s come up after pass %s.", + node->GetName().c_str(), cur_node->GetName().c_str()); + } + + // all in_node seen && all in_node not suspend + if (IsNodeReadyToQueue(node, g_state)) { + g_state.AddNodeToQueueIfNotSeen(node); + } + } + + // + for (const auto &node : out_nodes_before_pass) { + // A-->B-->C if B was + // unlink edge may happend, add these node to queue if needed + if (node->GetInAllNodes().empty() && IsNodeReadyToQueue(node, g_state)) { + GELOGI("Node %s may lost from cur node, add to queue if not seen.", + node->GetName().c_str(), cur_node->GetName().c_str()); + g_state.AddNodeToQueueIfNotSeen(node); + } + } +} + +void AddImmediateRepassNodesToQueue(NodePtr &cur_node, + std::unordered_map re_pass_imm_nodes_to_pass_names, + GEPass::GraphLevelState &g_state) { + for (const auto &node_2_pass_names : re_pass_imm_nodes_to_pass_names) { + auto imme_repass_node = node_2_pass_names.first; + if (imme_repass_node == nullptr) { + GELOGW("Found null immediately re-pass node when executing pass %s on node %s type %s", + node_2_pass_names.second.c_str(), + cur_node->GetName().c_str(), cur_node->GetType().c_str()); + continue; + } + if (g_state.nodes_passed.count(imme_repass_node) > 0) { + GELOGD("The node %s specified by pass %s has been passed, it will repass immediately", + imme_repass_node->GetName().c_str(), node_2_pass_names.second.c_str()); + g_state.AddNodeToQueueFront(imme_repass_node); + continue; + } + GELOGW("The node %s specified by pass %s has un-passed, it will not repass immediately", + node_2_pass_names.first->GetName().c_str(), node_2_pass_names.second.c_str()); + } +} + +void AddLastNodesToQueue(GEPass::GraphLevelState &g_state) { + for (auto &node : g_state.nodes_last) { + if (node->IsAllInNodesSeen(g_state.nodes_seen)) { + g_state.AddNodeToQueueIfNotSeen(node); + } + } + g_state.nodes_last.clear(); +} + +void AddResumeNodesToQueue(const std::unordered_map resume_node_2_pass_names, + GEPass::GraphLevelState &g_state) { + // Now base pass doesnt record the order of suspend & resume, so we dont know which one come first in a node pass. + // Here if one node pass suspend and resume a node ,consider it resume that node. + // Better way to record the order, and here suspend or resume in order. + for (const auto &node_2_pass_names : resume_node_2_pass_names) { + auto node = node_2_pass_names.first; + if (g_state.nodes_suspend.erase(node) > 0) { + if (g_state.nodes_seen.count(node.get()) > 0 || node->IsAllInNodesSeen(g_state.nodes_seen)) { + g_state.nodes.push_back(node); + GELOGD("Node %s has been resumed by pass %s, and add to pass queue", + node->GetName().c_str(), node_2_pass_names.second.c_str()); + } + } + } +} + +void PushToRePassIfSeen(NodePtr &node, const std::pair &name_to_pass, + std::unordered_set &nodes_seen, const std::vector &nodes_to_re_pass, + GEPass::RepassLevelState &rp_state) { + for (const auto &node_to_re_pass : nodes_to_re_pass) { + if (node_to_re_pass == nullptr) { + GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(), + node->GetName().c_str(), node->GetType().c_str()); + continue; + } + if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { + if (rp_state.AddNodeToRepass(node_to_re_pass)) { + GELOGD("The node %s will be re-pass.", node_to_re_pass->GetName().c_str()); + continue; + } + GELOGD("Node %s has been added to repass queue, no need to add again.", node_to_re_pass->GetName().c_str()); + } else { + GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str()); + } + } +} + +void SetFlagOption(NodePassOption option, NamesToPass names_to_pass) { + for (auto &name_to_pass : names_to_pass) { + name_to_pass.second->SetOption(option, ""); + } +} + +void ClearOption(NamesToPass names_to_pass) { + for (auto &name_to_pass : names_to_pass) { + name_to_pass.second->ClearOptions(); + } +} +} // namespace + +Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map, + bool is_repass_io_immediately) { + if (node == nullptr) { + REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); + GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); + return FAILED; + } + GELOGI("Prepare to isolate and delete node, name:%s, type:%s.", node->GetName().c_str(), + node->GetType().c_str()); + ComputeGraphPtr graph = node->GetOwnerComputeGraph(); + if (graph == nullptr) { + REPORT_INNER_ERROR("E19999", "The owner graph of node:%s must not be null.", node->GetName().c_str()); + GELOGE(FAILED, "[Get][OwnerComputeGraph] failed, The owner graph of node:%s must not be null.", + node->GetName().c_str()); + return FAILED; + } + + is_repass_io_immediately ? AddImmediateRePassNodesWithInOut(node) : AddRePassNodesWithInOut(node); + + if (GraphUtils::IsolateNode(node, io_map) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Isolate Node:%s failed", node->GetName().c_str()); + GELOGE(FAILED, "[Isolate][Node] %s failed.", node->GetName().c_str()); + return FAILED; + } + + if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != SUCCESS) { + REPORT_CALL_ERROR("E19999", "call RemoveNodeWithoutRelink for node:%s failed.", node->GetName().c_str()); + GELOGE(FAILED, "[Call][RemoveNodeWithoutRelink] for node:%s failed.", node->GetName().c_str()); + return FAILED; + } + + AddNodeDeleted(node); + return SUCCESS; +} + +Status GEPass::Run(const NamesToPass &names_to_passes) { + if (graph_ == nullptr) { + REPORT_INNER_ERROR("E19999", "graph_ is nullptr, check invalid."); + GELOGE(INTERNAL_ERROR, "[Check][Param] The graph is nullptr"); + return INTERNAL_ERROR; + } + if (names_to_passes.empty()) { + GELOGW("No passes input, the GEPass will do nothing"); + return INTERNAL_ERROR; + } + for (const auto &name_to_pass : names_to_passes) { + if (name_to_pass.second == nullptr) { + GELOGE(INTERNAL_ERROR, "[Check][Param] There is null pointer in passes(%s)", name_to_pass.first.c_str()); + return INTERNAL_ERROR; + } + } + + if (depth_ > kMaxRecursiveDepth) { + GELOGE(PARAM_INVALID, + "[Check][Param] The pass for root graph %s will be terminated because too many nesting" + " levels(%d) of subgraphs, last subgraph is %s", + root_graph_->GetName().c_str(), depth_, graph_->GetName().c_str()); + return PARAM_INVALID; + } + + return RunPassesOneGraph(names_to_passes); + // todo debug mode is on, find first node in topo order which is not passed. and give a warning +} + +void NotifyPassGraphStart(const ComputeGraphPtr &graph, const NamesToPass &names_to_pass) { + for (auto &name_to_pass : names_to_pass) { + name_to_pass.second->OnStartPassGraph(graph); + } +} + +Status GEPass::HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state) { + std::unordered_map resume_nodes_to_pass_names; + for (auto &name_to_pass : names_to_passes) { + name_to_pass.second->init(); + auto ret = name_to_pass.second->OnSuspendNodesLeaked(); + if (ret != SUCCESS) { + GELOGE(ret, "Internal error with OnSuspendNodesLeaked on pass %s.", name_to_pass.first.c_str()); + return ret; + } + for (const auto &resume_node : name_to_pass.second->GetNodesResume()){ + resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ","); + } + } + AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state); + return SUCCESS; +} + +Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { + GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); + NotifyPassGraphStart(graph_, names_to_passes); + GraphLevelState g_state; + g_state.re_pass_times = 0; + GetAllNodesNoInputEdge(graph_, g_state); + GELOGD("Start points count %zu", g_state.nodes.size()); + + do { + if (!g_state.nodes_suspend.empty()) { + auto ret = HandleLeakedSuspendNodes(names_to_passes, g_state); + if (ret != SUCCESS) { + // log inside upper function + return ret; + } + if (g_state.nodes.empty()) { + GELOGE(INTERNAL_ERROR, "There are some suspended nodes leaked and no pass resume them."); + return INTERNAL_ERROR; + } + } + auto ret = RunPassesGraphRepass(names_to_passes, g_state); + if (ret != SUCCESS) { + return ret; + } + } while (!g_state.nodes_suspend.empty()); + + return SUCCESS; +} + + +Status GEPass::RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state) { + RepassLevelState rp_state; + do { + for (auto &node : rp_state.nodes_re_pass) { + if (rp_state.nodes_re_pass_set.count(node) > 0) { + GELOGD("Add node %s to queue for re-pass", node->GetName().c_str()); + g_state.AddNodeToQueue(node); + } + } + rp_state.ClearRepass(); + + while (!g_state.nodes.empty()) { + auto node = g_state.PopFront(); + if (g_state.nodes_deleted.count(node) > 0) { + GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); + continue; + } + rp_state.EraseNodeFromRepass(node); + g_state.nodes_seen.insert(node.get()); + + // collect out nodes before pass + std::unordered_set out_nodes_before_pass; + for (const auto &out_node : node->GetOutNodes()) { + out_nodes_before_pass.insert(out_node); + } + auto ret = RunPassesNodeOnce(node, names_to_passes, g_state, rp_state); + if (ret != SUCCESS) { + GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(), + node->GetType().c_str(), ret); + return ret; + } + AddNextIterNodes(node, out_nodes_before_pass, g_state); + + } + AddLastNodesToQueue(g_state); + } while ((!rp_state.nodes_re_pass.empty() || !g_state.nodes.empty()) && ++g_state.re_pass_times < kMaxRePassTimes); + + if (g_state.re_pass_times == kMaxRePassTimes) { + GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); + } + GELOGD("All passes runs end"); + return SUCCESS; +} + +Status GEPass::RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph) { + auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); + has_sub_graph = false; + for (const auto &name : sub_graph_names) { + auto graph = root_graph_->GetSubgraph(name); + if (graph == nullptr) { + GELOGW("Can not find the sub graph %s from node %s, the pass-process will skip it", + name.c_str(), node->GetName().c_str()); + continue; + } + has_sub_graph = true; + GELOGI("Begin to run passes on the sub graph %s of node %s", name.c_str(), node->GetName().c_str()); + GEPass pass(graph, root_graph_, depth_ + 1); + auto ret = pass.Run(names_to_passes); + if (ret != SUCCESS) { + GELOGE(ret, "[Run][Passes] for sub graph:%s from node:%s failed", name.c_str(), node->GetName().c_str()); + return ret; + } + } + return SUCCESS; +} + +Status GEPass::RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes, + GraphLevelState &g_state, RepassLevelState &rp_state) { + auto ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state); + if (ret != SUCCESS) { + GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code:%u", node->GetName().c_str(), + node->GetType().c_str(), ret); + return ret; + } + + bool has_sub_graph = false; + ret = RunPassesOnSubGraph(node, names_to_passes, has_sub_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[Run][Passes] on the sub graph of node %s failed", node->GetName().c_str()); + return ret; + } + + if (has_sub_graph) { + GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); + SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); + ret = RunPassesOnNode(node, names_to_passes, g_state, rp_state); + if (ret != SUCCESS) { + GELOGE(ret, "[Process][Passes] on node %s type %s failed, error code: %u", node->GetName().c_str(), + node->GetType().c_str(), ret); + return ret; + } + + // There is only one option scene, so set and clear options around the `RunPasses` func. + // if there are more than one scene to set options, the `ClearOption` function + // should be called each time at the begin of the iteration + ClearOption(names_to_passes); + } + return SUCCESS; +} + +Status GEPass::RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state, + RepassLevelState &rp_state) { + if (node == nullptr) { + REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid."); + GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); + return FAILED; + } + GELOGD("Begin to run pass for node %s", node->GetName().c_str()); + for (const auto &name_to_pass : names_to_passes) { + GELOGD("Begin to run pass %s for node %s", name_to_pass.first.c_str(), node->GetName().c_str()); + name_to_pass.second->init(); + auto result = name_to_pass.second->Run(node); + if (result != SUCCESS) { + REPORT_CALL_ERROR("E19999", "process pass %s on node:%s failed, ret:%u", + name_to_pass.first.c_str(), node->GetName().c_str(), result); + GELOGE(INTERNAL_ERROR, "[Process][Pass] %s on node %s failed, result " + "%u, the passes will be terminated immediately.", + name_to_pass.first.c_str(), node->GetName().c_str(), result); + return result; + } + if (name_to_pass.second->GetNodesDeleted().count(node) > 0) { + GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), + name_to_pass.first.c_str()); + break; + } + } + + g_state.nodes_passed.insert(node); + + std::unordered_map re_pass_imm_nodes_to_pass_names; + std::unordered_map resume_nodes_to_pass_names; + // if muti psss repass one same node, it will add to queue many times, so collect and duplicate + for (const auto &name_to_pass : names_to_passes) { + PushToRePassIfSeen(node, name_to_pass, g_state.nodes_seen, + name_to_pass.second->GetNodesNeedRePass(), + rp_state); + // collect imm_node && resume_node among these passes + for (const auto &imm_node : name_to_pass.second->GetNodesNeedRePassImmediately()){ + re_pass_imm_nodes_to_pass_names[imm_node].append(name_to_pass.first + ","); + } + for (const auto &resume_node : name_to_pass.second->GetNodesResume()){ + resume_nodes_to_pass_names[resume_node].append(name_to_pass.first + ","); + } + + for (const auto &suspend_node : name_to_pass.second->GetNodesSuspend()) { + GELOGD("The iteration suspend of node %s has been set by pass %s", suspend_node->GetName().c_str(), + name_to_pass.first.c_str()); + g_state.nodes_suspend.insert(suspend_node); + } + const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); + g_state.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); + } + + AddImmediateRepassNodesToQueue(node, re_pass_imm_nodes_to_pass_names, g_state); + AddResumeNodesToQueue(resume_nodes_to_pass_names, g_state); + + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/passes/base_pass.h b/ge/graph/passes/base_pass.h index d0f125b2..093e2dce 100644 --- a/ge/graph/passes/base_pass.h +++ b/ge/graph/passes/base_pass.h @@ -22,7 +22,6 @@ #include #include #include - #include "framework/common/ge_inner_error_codes.h" #include "framework/common/types.h" #include "graph/compute_graph.h" @@ -40,6 +39,7 @@ enum NodePassOption { }; class BaseNodePass { + // todo comments public: /// /// Optimize on one node. the function can add nodes to the graph, change @@ -51,7 +51,7 @@ class BaseNodePass { virtual ~BaseNodePass() = default; - const std::unordered_set &GetNodesNeedRePass() { return nodes_need_re_pass_; } + const std::vector &GetNodesNeedRePass() { return nodes_need_re_pass_; } const std::unordered_set &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } @@ -61,23 +61,32 @@ class BaseNodePass { const std::unordered_set &GetNodesResume() { return nodes_resume_; } + virtual Status OnSuspendNodesLeaked() { return SUCCESS; } + void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } void ClearOptions() { options_.clear(); } void init() { nodes_need_re_pass_.clear(); - nodes_deleted_.clear(); nodes_need_re_pass_immediately_.clear(); + nodes_deleted_.clear(); nodes_suspend_.clear(); nodes_resume_.clear(); } + virtual void OnStartPassGraph(const ComputeGraphPtr &graph) { + current_graph_name_ = graph->GetName(); + } + protected: - Status IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map); + const string &GetCurrentGraphName() const { + return current_graph_name_; + } + Status IsolateAndDeleteNode(NodePtr &node, const std::vector &io_map, bool is_repass_io_immediately = false); - Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list &io_map) { - return IsolateAndDeleteNode(node, std::vector(io_map)); + Status IsolateAndDeleteNode(NodePtr &node, const std::initializer_list &io_map, bool is_repass_io_immediately = false) { + return IsolateAndDeleteNode(node, std::vector(io_map), is_repass_io_immediately); } /// @@ -86,7 +95,7 @@ class BaseNodePass { /// optimized by other passes, call this function. /// @param node /// - void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.insert(node); } + void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.emplace_back(node); } /// /// Add a node to be optimized immediately again. If you add a new node to the graph, or @@ -101,14 +110,30 @@ class BaseNodePass { /// @param node /// void AddRePassNodesWithInOut(const NodePtr &node) { + auto in_nodes = node->GetInNodes(); + for (auto &in_node : in_nodes) { + AddRePassNode(in_node); + } AddRePassNode(node); auto out_nodes = node->GetOutNodes(); for (auto &out_node : out_nodes) { AddRePassNode(out_node); } + } + + /// + /// Add a node and it's input/output data nodes to be optimized immediately again. + /// @param node + /// + void AddImmediateRePassNodesWithInOut(const NodePtr &node) { auto in_nodes = node->GetInNodes(); for (auto &in_node : in_nodes) { - AddRePassNode(in_node); + AddImmediateRePassNode(in_node); + } + AddImmediateRePassNode(node); + auto out_nodes = node->GetOutNodes(); + for (auto &out_node : out_nodes) { + AddImmediateRePassNode(out_node); } } @@ -123,34 +148,27 @@ class BaseNodePass { void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } /// - /// If you suspend a node from the graph, especially following node. The remain - /// iterate passes will stop process on the suspend node(if it can be + /// If you postpone a node from the graph, especially following node. The remain + /// iterate passes will stop process on the postpone node(if it can be /// reached by edge connections) till the last one. Obviously it is a waste of - /// time. You can add the suspend nodes by calling this function, to stop the + /// time. You can add the postpone nodes by calling this function, to stop the /// next iterations. /// @param node /// void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); } - /// - /// If you resume a node from the graph, especially following node. The remain - /// iterate passes will continue process on the resume node(if it can be - /// reached by edge connections) till the last one. - /// You can add the resume nodes by calling this function, to resume the - /// next iterations. - /// @param node - /// void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); } bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } private: - std::unordered_set nodes_need_re_pass_; + std::vector nodes_need_re_pass_; std::unordered_set nodes_need_re_pass_immediately_; std::unordered_set nodes_deleted_; std::unordered_set nodes_suspend_; std::unordered_set nodes_resume_; std::map options_; + std::string current_graph_name_; }; using NamesToPass = std::vector>; @@ -160,12 +178,75 @@ class GEPass { explicit GEPass(ComputeGraphPtr &graph) : graph_(graph), root_graph_(graph), depth_(1) {} virtual ~GEPass() = default; Status Run(const NamesToPass &names_to_passes); + /* +* todo +* OneGraph: nodes_deleted, nodes_seen, nodes_passed, nodes_suspended +* RePass: nodes_re_pass +* GraphOneTime: nodes_last +* NodeOneTime: nodes_re_pass_immediately, nodes_resume +*/ + struct GraphLevelState { + std::unordered_set nodes_deleted; + std::unordered_set nodes_seen; + std::unordered_set nodes_passed; + std::unordered_set nodes_suspend; + std::unordered_set nodes_last; + std::deque nodes; + int re_pass_times; + + void AddNodeToQueueFront(NodePtr node) { + nodes_seen.insert(node.get()); + nodes.emplace_front(std::move(node)); + } + + void AddNodeToQueue(NodePtr node) { + nodes_seen.insert(node.get()); + nodes.emplace_back(std::move(node)); + } + void AddNodeToQueueIfNotSeen(NodePtr node) { + if (nodes_seen.insert(node.get()).second) { + nodes.emplace_back(std::move(node)); + } + } + NodePtr PopFront() { + NodePtr node = nodes.front(); + nodes.pop_front(); + return node; + } + }; + struct RepassLevelState { + std::vector nodes_re_pass; + std::unordered_set nodes_re_pass_set; + bool AddNodeToRepass(NodePtr node) { + if (!nodes_re_pass_set.insert(node).second) { + return false; + } + nodes_re_pass.emplace_back(node); + return true; + } + void EraseNodeFromRepass(NodePtr node) { + nodes_re_pass_set.erase(node); + } + void ClearRepass() { + nodes_re_pass_set.clear(); + nodes_re_pass.clear(); + } + }; + struct GraphOneTimeLevelState { + std::unordered_set nodes_last; + }; private: GEPass(ComputeGraphPtr &graph, ComputeGraphPtr &root_graph, int depth) : graph_(graph), root_graph_(root_graph), depth_(depth) {} + Status RunPassesNodeOnce(NodePtr &node, const NamesToPass &names_to_passes, + GraphLevelState &g_state, RepassLevelState &rp_state); + Status RunPassesGraphRepass(const NamesToPass &names_to_passes, GraphLevelState &g_state); Status RunPassesOneGraph(const NamesToPass &names_to_passes); Status RunPassesOnSubGraph(const NodePtr &node, const NamesToPass &names_to_passes, bool &has_sub_graph); + Status RunPassesOnNode(NodePtr &node, const NamesToPass &names_to_passes, GraphLevelState &g_state, + RepassLevelState &rp_state); + Status HandleLeakedSuspendNodes(const NamesToPass &names_to_passes, GraphLevelState &g_state); ComputeGraphPtr graph_; ComputeGraphPtr root_graph_; int depth_; diff --git a/ge/graph/passes/infer_base_pass.cc b/ge/graph/passes/infer_base_pass.cc index 25c45677..636cf2ab 100644 --- a/ge/graph/passes/infer_base_pass.cc +++ b/ge/graph/passes/infer_base_pass.cc @@ -86,6 +86,9 @@ bool InferBasePass::NeedInfer(const NodePtr &node) const { return true; } void InferBasePass::AddChangedNodesImmediateRepass(const std::set &changed_nodes) { // need passed_nodes set to solve the problem that multi-input operators do repass in advance. // when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes. + for (const auto &node_ele : changed_nodes) { + AddImmediateRePassNode(node_ele); + } } graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set &changed_nodes) { diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index a5e64519..deaebf4f 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -1,175 +1,370 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/infershape_pass.h" -#include "common/util/error_manager/error_manager.h" -#include "framework/common/debug/ge_log.h" -#include "analyzer/analyzer.h" -#include "framework/common/util.h" -#include "graph/shape_refiner.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/node_utils.h" -#include "common/omg_util.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/type_utils.h" - -namespace ge { - -void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { - desc_str += "["; - std::vector> shape_range; - (void)desc->GetShapeRange(shape_range); - for (const auto &pair : shape_range) { - desc_str += "{"; - desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); - desc_str += "},"; - } - desc_str += "]"; - shape_range.clear(); - (void)desc->GetOriginShapeRange(shape_range); - for (const auto &pair : shape_range) { - desc_str += ",{"; - desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); - desc_str += "},"; - } -} - -std::string GetInTensorInfoWithString(const ge::NodePtr &node) { - ge::OpDescPtr op_desc = node->GetOpDesc(); - std::stringstream ss; - ss << "{"; - int32_t in_idx = 0; - for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { - if (input_desc == nullptr) { - in_idx++; - continue; - } - if (in_idx > 0) { - ss << " "; - } - ss << "input_" << in_idx << " " << "tensor: ["; - ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),"; - ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),"; - ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),"; - ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),"; - ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),"; - ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),"; - string range_str; - SerialShapeRange(input_desc, range_str); - ss << "(shape_range:" << range_str << ")]"; - in_idx++; - } - return ss.str(); -} - -Status InferShapePass::Run(NodePtr &node) { - // kOptimizeAfterSubGraph exist means after subgraph - auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); - if (ret != GRAPH_SUCCESS) { - // select INFERSHAPE failed info - auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - auto root_graph = ge::GraphUtils::FindRootGraph(graph); - GE_CHECK_NOTNULL(root_graph); - analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), - analyzer::INFER_SHAPE, node, "InferShapeFailed!"}; - (void)Analyzer::GetInstance()->DoAnalyze(analyze_info); - (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), - root_graph->GetGraphID()); - - REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed, input_tensor:%s", - node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); - GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed, input_tensor:%s", - node->GetName().c_str(), node->GetType().c_str(), GetInTensorInfoWithString(node).c_str()); - return GE_GRAPH_INFERSHAPE_FAILED; - } - - GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node)); - bool need_repass = false; - auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass); - if (has_attr) { - if (!OptionExists(kOptimizeAfterSubGraph)) { - return SUCCESS; - } - if (need_repass) { - AddImmediateRePassNode(node); - GELOGD("Node %s need repass immediately.", node->GetName().c_str()); - } else { - // clear attr on while - node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); - } - } - return SUCCESS; -} - -Status InferShapePass::RePassLoopNode(const NodePtr &node) { - const auto RePassNode = [&](const std::set &re_pass_types) { - for (auto &n : node->GetOutDataNodes()) { - GE_CHECK_NOTNULL(n); - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str()); - if (re_pass_types.count(node_type) > 0) { - AddImmediateRePassNode(n); - (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false); - GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str()); - } - } - return SUCCESS; - }; - - const auto ExProcNode = [&](const std::set &proc_types, - const std::function &proc_func, - const std::string &info) { - for (auto &n : node->GetOutDataNodes()) { - GE_CHECK_NOTNULL(n); - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "[Get][OriginalType] of node:%s failed.", n->GetName().c_str()); - if (proc_types.count(node_type) > 0) { - proc_func(this, n); - GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str()); - } - } - return SUCCESS; - }; - - std::string node_type; - GE_CHK_STATUS_RET(GetOriginalType(node, node_type), - "[Get][OriginalType] of node:%s failed.", node->GetName().c_str()); - if (kNextIterationOpTypes.count(node_type) > 0) { - return RePassNode(kMergeOpTypes); // Re-Pass Merge - } - - if (kMergeOpTypes.count(node_type) > 0) { - if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { - node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); - return RePassNode(kSwitchOpTypes); // Re-Pass Switch - } - return SUCCESS; - } - - if (kSwitchOpTypes.count(node_type) > 0) { - if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { - node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); - return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit - } else { - return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit - } - } - - return SUCCESS; -} -} // namespace ge +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + *+ + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/infershape_pass.h" +#include "common/util/error_manager/error_manager.h" +#include "framework/common/debug/ge_log.h" +#include "analyzer/analyzer.h" +#include "framework/common/util.h" +#include "graph/shape_refiner.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +#include "common/omg_util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/type_utils.h" + +#include "external/graph/operator_factory.h" + +namespace ge { +namespace { +constexpr int kSwitchExitAnchorIndex = 0; +constexpr int kSwitchPredAnchorIndex = 1; +void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { + desc_str += "["; + std::vector> shape_range; + (void)desc->GetShapeRange(shape_range); + for (const auto &pair : shape_range) { + desc_str += "{"; + desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); + desc_str += "},"; + } + desc_str += "]"; + shape_range.clear(); + (void)desc->GetOriginShapeRange(shape_range); + for (const auto &pair : shape_range) { + desc_str += ",{"; + desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); + desc_str += "},"; + } +} +void UpdateShapeAndDType(const GeTensorDescPtr &src, GeTensorDescPtr &dst) { + dst->SetOriginShape(src->GetOriginShape()); + dst->SetShape(src->GetShape()); + dst->SetDataType(src->GetDataType()); + dst->SetOriginDataType(src->GetOriginDataType()); + vector> src_shape_range; + src->GetShapeRange(src_shape_range); + dst->SetShapeRange(src_shape_range); + dst->SetOriginShapeRange(src_shape_range); + ge::TensorUtils::SetRealDimCnt(*dst, static_cast(src->GetShape().GetDims().size())); +} +} // namespace + +std::string InferShapePass::SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const { + std::stringstream ss; + ss << "(shape:[" << tensor_desc->MutableShape().ToString() << "]),"; + ss << "(format:" << TypeUtils::FormatToSerialString(tensor_desc->GetFormat()) << "),"; + ss << "(dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()) << "),"; + ss << "(origin_shape:" << tensor_desc->GetOriginShape().ToString() << "),"; + ss << "(origin_format:" << TypeUtils::FormatToSerialString(tensor_desc->GetOriginFormat()) << "),"; + ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(tensor_desc->GetOriginDataType()) << "),"; + string range_str; + SerialShapeRange(tensor_desc, range_str); + ss << "(shape_range:" << range_str << ")"; + return ss.str(); +} +Status InferShapePass::SuspendV1LoopExitNodes(const NodePtr &node) { + if (node->GetType() != SWITCH) { + return SUCCESS; + } + auto pred_node = NodeUtils::GetInDataNodeByIndex(*node, kSwitchPredAnchorIndex); + GE_CHECK_NOTNULL(pred_node); + if (pred_node->GetType() != LOOPCOND) { + return SUCCESS; + } + + for (const auto &anchor_2_node : NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, kSwitchExitAnchorIndex)) { + GELOGI("Found v1 loop when infershape, suspend Exit node %s, type %s.", anchor_2_node.second->GetName().c_str(), + anchor_2_node.second->GetType().c_str()); + auto &suspend_nodes = graphs_2_suspend_nodes_[GetCurrentGraphName()]; + if (suspend_nodes.nodes_set.insert(anchor_2_node.second).second) { + suspend_nodes.nodes.push(anchor_2_node.second); + AddNodeSuspend(anchor_2_node.second); + } + } + return SUCCESS; +} + +Status InferShapePass::Infer(NodePtr &node) { + auto ret = InferShapeAndType(node); + if (ret != GRAPH_SUCCESS) { + auto graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + auto root_graph = ge::GraphUtils::FindRootGraph(graph); + GE_CHECK_NOTNULL(root_graph); + analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), + analyzer::INFER_SHAPE, node, "InferShapeFailed!"}; + (void)Analyzer::GetInstance()->DoAnalyze(analyze_info); + (void)Analyzer::GetInstance()->SaveAnalyzerDataToFile(root_graph->GetSessionID(), + root_graph->GetGraphID()); + REPORT_CALL_ERROR("E19999", "Call InferShapeAndType for node:%s(%s) failed", node->GetName().c_str(), + node->GetType().c_str()); + GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "[Call][InferShapeAndType] for node:%s(%s) failed", node->GetName().c_str(), + node->GetType().c_str()); + return GE_GRAPH_INFERSHAPE_FAILED; + } + return SUCCESS; +} + +graphStatus InferShapePass::InferShapeAndType(NodePtr &node) { + auto ret = SuspendV1LoopExitNodes(node); + if (ret != SUCCESS) { + GELOGE(ret, "Suspend V1 loop exit nodes failed."); + return ret; + } + bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); + auto opdesc = node->GetOpDesc(); + if (node->Verify() != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Verifying %s failed.", node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + Operator op = OpDescUtils::CreateOperatorFromNode(node); + + if (!is_unknown_graph) { + auto inference_context = ShapeRefiner::CreateInferenceContext(node); + GE_CHECK_NOTNULL(inference_context); + GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); + op.SetInferenceContext(inference_context); + } + + graphStatus status = CallInferShapeFunc(node, op); + if (status != GRAPH_NODE_NEED_REPASS && status != GRAPH_PARAM_INVALID && status != GRAPH_SUCCESS) { + // node like netoutput return param_invalid, but valid ? + return GE_GRAPH_INFERSHAPE_FAILED; + } + UpdateCurNodeOutputDesc(node); + if (!is_unknown_graph) { + auto ctx_after_infer = op.GetInferenceContext(); + if (ctx_after_infer != nullptr) { + GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); + if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { + GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), + ctx_after_infer->GetMarks().size()); + ShapeRefiner::PushToContextMap(node, ctx_after_infer); + } + } + } + + return (status == GRAPH_NODE_NEED_REPASS) ? GRAPH_NODE_NEED_REPASS : GRAPH_SUCCESS; +} + +void InferShapePass::UpdateCurNodeOutputDesc(NodePtr &node) { + auto op_desc = node->GetOpDesc(); + for (const auto &out_anchor : node->GetAllOutDataAnchors()) { + auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + GE_IF_BOOL_EXEC(output_tensor == nullptr, continue); + GE_IF_BOOL_EXEC(output_tensor->MutableShape().GetDims().empty(), + output_tensor->SetOriginShape(output_tensor->GetShape())); + + ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetOriginShape().GetDims() + .size())); + output_tensor->SetOriginDataType(output_tensor->GetDataType()); + // set output origin shape range + std::vector> range; + (void)output_tensor->GetShapeRange(range); + output_tensor->SetOriginShapeRange(range); + GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", + node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), + TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), + TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); + } +} + +bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { + // check shape range + vector> src_shape_range; + vector> dst_shape_range; + src->GetShapeRange(src_shape_range); + dst->GetShapeRange(dst_shape_range); + if (src_shape_range.size() != dst_shape_range.size()) { + GELOGI("Src shape range size is %zu, dst shape range size is %zu, not same.", src_shape_range.size(), + dst_shape_range.size()); + return false; + } + for (size_t i = 0; i < src_shape_range.size(); ++i) { + if (src_shape_range[i].first != dst_shape_range[i].first || + src_shape_range[i].second != dst_shape_range[i].second) { + GELOGI("Current dim %zu. Src shape range is [%lu-%lu], dst shape range is [%lu-%lu], not same.", + i, src_shape_range[i].first, src_shape_range[i].second, dst_shape_range[i].first, dst_shape_range[i].second); + return false; + } + } + + // check shape + auto src_shape = src->GetShape(); + auto dst_shape = dst->GetShape(); + if (src_shape.GetDims() != dst_shape.GetDims() || src->GetOriginShape().GetDims() != dst->GetOriginShape().GetDims() || + src->GetDataType() != dst->GetDataType() || src->GetOriginDataType() != dst->GetOriginDataType()) { + GELOGD( + "Src shape is %s, origin_shape is %s, data_type is %s, origin data_type is %s; " + "Dst shape is %s, origin_shape is %s, data_type is %s, original data_type is %s, not same.", + src_shape.ToString().c_str(), src->GetOriginShape().ToString().c_str(), + TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst_shape.ToString().c_str(), + dst->GetOriginShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str()); + return false; + } + return true; +} + +graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { + changed = !SameTensorDesc(src, dst); + // refresh src itself + src->SetOriginShape(src->GetShape()); + src->SetOriginDataType(src->GetDataType()); + TensorUtils::SetRealDimCnt(*src, static_cast(src->GetOriginShape().GetDims().size())); + vector> src_shape_range; + src->GetShapeRange(src_shape_range); + src->SetOriginShapeRange(src_shape_range); + + if (!changed) { + GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update."); + return SUCCESS; + } + UpdateShapeAndDType(src, dst); + GELOGD( + "UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s." + "To dst Node: shape: [%s], datatype: %s, original datatype is %s.", + src->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(src->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(src->GetOriginDataType()).c_str(), dst->GetShape().ToString().c_str(), + TypeUtils::DataTypeToSerialString(dst->GetDataType()).c_str(), + TypeUtils::DataTypeToSerialString(dst->GetOriginDataType()).c_str()); + return SUCCESS; +} + +graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) { + auto op_desc = node->GetOpDesc(); + const auto &op_type = op_desc->GetType(); + auto ret = op_desc->CallInferFunc(op); + if (ret == GRAPH_PARAM_INVALID) { + // Op ir no infer func, try to get infer func from operator factory + auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); + if (node_op.IsEmpty()) { + GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); + return ret; + } + + GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); + auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); + node_op.BreakConnect(); + if (temp_op_desc == nullptr) { + REPORT_CALL_ERROR("E19999", "GetOpDescFromOperator failed, return nullptr."); + GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null"); + return GRAPH_FAILED; + } + if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { + GELOGW("InferShapeAndType UpdateInputName failed"); + for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { + if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { + break; + } + return GRAPH_SUCCESS; + } + } + if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { + GELOGW("InferShapeAndType UpdateOutputName failed"); + } + op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); + ret = op_desc->CallInferFunc(op); + GELOGI("op CallInferFunc second. ret: %u", ret); + } + return ret; +} + +graphStatus InferShapePass::UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) { + GELOGD("Enter update parent node shape for class branch op process"); + // check sub_graph shape.If not same ,do unknown shape process + auto ref_out_tensor = src.at(0); + ge::GeShape &ref_out_tensor_shape = ref_out_tensor->MutableShape(); + for (auto &tensor : src) { + if (ref_out_tensor->GetDataType() != tensor->GetDataType()) { + REPORT_INNER_ERROR("E19999", "Does not support diff dtype among all ref output, shape:%s", + ref_out_tensor_shape.ToString().c_str()); + GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype output"); + return GRAPH_FAILED; + } + auto shape = tensor->MutableShape(); + if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { + GELOGD("Shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", shape.GetShapeSize(), + ref_out_tensor_shape.GetShapeSize()); + ref_out_tensor_shape = GeShape(UNKNOWN_RANK); + break; + } + for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { + if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { + continue; + } + GELOGD("j: %zu ,shape from subgraph size: %lu, ref_out_tensor_shape size: %lu", j, shape.GetShapeSize(), + ref_out_tensor_shape.GetShapeSize()); + (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); + } + } + UpdateShapeAndDType(ref_out_tensor, dst); + return GRAPH_SUCCESS; +} +graphStatus InferShapePass::UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, + GeTensorDescPtr &dst) { + // check sub_graph shape. Get max for update. + if (src.empty()) { + GELOGI("Src subgraph shape is empty."); + return SUCCESS; + } + + int64_t max_size = 0; + size_t max_shape_index = 0; + auto &ref_out_tensor = src.at(0); + for (size_t j = 0; j < src.size(); ++j) { + auto &tensor = src.at(j); + if (ref_out_tensor->GetDataType() != tensor->GetDataType()) { + REPORT_INNER_ERROR("E19999", "node does not support diff dtype among all ref output"); + GELOGE(GRAPH_FAILED, "[Check][Param] node does not support diff dtype among all ref output"); + return GRAPH_FAILED; + } + + auto shape = tensor->MutableShape(); + int64_t size = 1; + for (auto dim : shape.GetDims()) { + if (dim != 0 && INT64_MAX / dim < size) { + REPORT_INNER_ERROR("E19999", "The shape:%s size overflow", shape.ToString().c_str()); + GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow"); + return PARAM_INVALID; + } + size *= dim; + } + + if (size > max_size) { + max_size = size; + max_shape_index = j; + } + } + UpdateShapeAndDType(src.at(max_shape_index), dst); + return GRAPH_SUCCESS; +} +Status InferShapePass::OnSuspendNodesLeaked() { + auto iter = graphs_2_suspend_nodes_.find(GetCurrentGraphName()); + if (iter == graphs_2_suspend_nodes_.end()) { + GELOGI("Current graph %s no suspend node.", GetCurrentGraphName().c_str()); + return SUCCESS; + } + if (!iter->second.nodes.empty()) { + AddNodeResume(iter->second.PopSuspendedNode()); + } + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/passes/infershape_pass.h b/ge/graph/passes/infershape_pass.h index 9c5d432d..00d90775 100644 --- a/ge/graph/passes/infershape_pass.h +++ b/ge/graph/passes/infershape_pass.h @@ -1,38 +1,56 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ -#define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ - -#include "graph/passes/base_pass.h" - -namespace ge { -class InferShapePass : public BaseNodePass { - public: - /// - /// Entry of the InferShapePass optimizer - /// @param [in] graph: Input ComputeGraph - /// @return SUCCESS: Execution succeed - /// @return OTHERS: Execution failed - /// @author - /// - Status Run(ge::NodePtr &node) override; - - private: - Status RePassLoopNode(const NodePtr &node); -}; -} // namespace ge -#endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ +#define GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ + +#include "graph/passes/infer_base_pass.h" +#include + +namespace ge { +class InferShapePass : public InferBasePass { + public: + std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override; + graphStatus Infer(NodePtr &node) override; + + graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override; + graphStatus UpdateOutputFromSubgraphs(const std::vector &src, GeTensorDescPtr &dst) override; + graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector &src, + GeTensorDescPtr &dst) override; + + Status OnSuspendNodesLeaked() override; + + private: + graphStatus InferShapeAndType(NodePtr &node); + graphStatus CallInferShapeFunc(NodePtr &node, Operator &op); + bool SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDescPtr &dst); + void UpdateCurNodeOutputDesc(NodePtr &node); + Status SuspendV1LoopExitNodes(const NodePtr &node); + struct SuspendNodes { + std::stack nodes; + std::unordered_set nodes_set; + + NodePtr PopSuspendedNode() { + auto top_node = nodes.top(); + nodes.pop(); + nodes_set.erase(top_node); + return top_node; + } + }; + std::map graphs_2_suspend_nodes_; +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index 2efe623e..446af9bf 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -1999,6 +1999,22 @@ Status GraphPrepare::CheckUserInput(const std::vector &user_input) { Status GraphPrepare::InferShapeForPreprocess() { GELOGI("Start infershape for preprocess."); + // Prepare dummy_shape for v1 control_flow op before infershape + for (const auto &node : compute_graph_->GetAllNodes()) { + string type; + GetOriginalType(node, type); + if (type == MERGE || type == REFMERGE) { + for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { + GELOGD("Prepare for infershape: update %s input_shape as dummy.", node->GetName().c_str()); + NodeUtils::UpdateInputShape(*node, i, GeShape(DUMMY_SHAPE)); + } + } else if (type == WHILE) { + for (size_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { + GELOGD("Prepare for infershape: update %s output_shape as dummy.", node->GetName().c_str()); + NodeUtils::UpdateOutputShape(*node, i, GeShape(DUMMY_SHAPE)); + } + } + } GEPass ge_passes(compute_graph_); NamesToPass names_to_passes; AssertPass assert_pass; diff --git a/tests/ut/ge/graph/passes/addn_pass_unittest.cc b/tests/ut/ge/graph/passes/addn_pass_unittest.cc index 6107a7d8..39029e8c 100644 --- a/tests/ut/ge/graph/passes/addn_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/addn_pass_unittest.cc @@ -72,7 +72,7 @@ TEST(UtestGraphPassesAddnPass, null_pass) { AddNPass *addn_pass = nullptr; NamesToPass names_to_pass; names_to_pass.emplace_back("Test", addn_pass); - EXPECT_EQ(pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR); } TEST(UtestGraphPassesAddnPass, null_graph) { diff --git a/tests/ut/ge/graph/passes/base_pass_unittest.cc b/tests/ut/ge/graph/passes/base_pass_unittest.cc index c687e07f..3b0235f5 100644 --- a/tests/ut/ge/graph/passes/base_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/base_pass_unittest.cc @@ -1,523 +1,903 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include - -#include "gtest/gtest.h" - -#define protected public -#include "graph/passes/base_pass.h" -#undef protected - -#include "external/graph/ge_error_codes.h" -#include "framework/common/ge_inner_error_codes.h" -#include "framework/common/types.h" -#include "graph/node.h" -#include "graph/utils/graph_utils.h" -#include "graph_builder_utils.h" - -template class std::unordered_set; - -namespace ge { -class UtestTestPass : public BaseNodePass { - public: - UtestTestPass() = default; - UtestTestPass(bool dead_loop) : dead_loop_(dead_loop), run_times_(0) {} - - Status Run(NodePtr &node) override { - ++run_times_; - iter_nodes_.push_back(node); - auto iter = names_to_add_del_.find(node->GetName()); - if (iter != names_to_add_del_.end()) { - for (const auto &node_name : iter->second) { - auto del_node = node->GetOwnerComputeGraph()->FindNode(node_name); - GraphUtils::IsolateNode(del_node, {0}); - AddNodeDeleted(del_node); - } - } - iter = names_to_add_repass_.find(node->GetName()); - if (iter != names_to_add_repass_.end()) { - auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); - for (const auto &node_name : iter->second) { - for (auto &node_re_pass : all_nodes) { - if (node_re_pass->GetName() == node_name) { - AddRePassNode(node_re_pass); - break; - } - } - } - if (!dead_loop_) { - names_to_add_repass_.erase(iter); - } - } - // simulate infershape pass - if(node->GetType() == WHILE){ - bool need_repass = false; - AttrUtils::GetBool(node->GetOpDesc(),"_need_infer_again", need_repass); - if(!OptionExists(kOptimizeAfterSubGraph)){ - return SUCCESS; - } - if(need_repass){ - AttrUtils::SetBool(node->GetOpDesc(),"_need_infer_again", false); - AddImmediateRePassNode(node); - } - else{ - // clear attr on while - node->GetOpDesc()->DelAttr("_need_infer_again"); - } - } - return SUCCESS; - } - void clear() { iter_nodes_.clear(); } - std::vector GetIterNodes() { return iter_nodes_; } - - void AddRePassNodeName(const std::string &iter_node, const std::string &re_pass_node) { - names_to_add_repass_[iter_node].insert(re_pass_node); - } - void AddDelNodeName(const std::string &iter_node, const std::string &del_node) { - names_to_add_del_[iter_node].insert(del_node); - } - unsigned int GetRunTimes() { return run_times_; } - - private: - std::vector iter_nodes_; - std::map> names_to_add_del_; - std::map> names_to_add_repass_; - bool dead_loop_; - unsigned int run_times_; -}; - -class TestDelPass : public BaseNodePass { - public: - Status Run(NodePtr &node) override { return SUCCESS; } -}; - -class UTESTGraphPassesBasePass : public testing::Test { - protected: - UTESTGraphPassesBasePass() { - auto p1 = new UtestTestPass; - names_to_pass_.push_back(std::make_pair("test1", p1)); - } - void SetUp() override { - for (auto &name_to_pass : names_to_pass_) { - dynamic_cast(name_to_pass.second)->clear(); - } - } - ~UTESTGraphPassesBasePass() override { - for (auto &name_to_pass : names_to_pass_) { - delete name_to_pass.second; - } - } - NamesToPass names_to_pass_; -}; -/// reshape1 -/// | -/// add1 -/// / \. -/// | | -/// data1 const1 -ComputeGraphPtr BuildGraph1() { - auto builder = ut::GraphBuilder("g1"); - auto data = builder.AddNode("data1", DATA, 0, 1); - auto a1 = builder.AddNode("add1", ADD, 2, 1); - auto c1 = builder.AddNode("const1", CONSTANT, 0, 1); - auto r1 = builder.AddNode("reshape1", RESHAPE, 1, 1); - - builder.AddDataEdge(data, 0, a1, 0); - builder.AddDataEdge(c1, 0, a1, 1); - builder.AddDataEdge(a1, 0, r1, 0); - - return builder.GetGraph(); -} - -/// sum1 -/// / \. -/// / \. -/// / \. -/// reshape1 addn1 -/// | c | -/// add1 <--- shape1 -/// / \ | -/// | | | -/// data1 const1 const2 -ComputeGraphPtr BuildGraph2() { - auto builder = ut::GraphBuilder("g1"); - auto data1 = builder.AddNode("data1", DATA, 0, 1); - auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); - auto const2 = builder.AddNode("const2", CONSTANT, 0, 1); - auto add1 = builder.AddNode("add1", ADD, 2, 1); - auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1); - auto reshape1 = builder.AddNode("reshape1", RESHAPE, 1, 1); - auto addn1 = builder.AddNode("addn1", ADDN, 1, 1); - auto sum1 = builder.AddNode("sum1", SUM, 2, 1); - - builder.AddDataEdge(data1, 0, add1, 0); - builder.AddDataEdge(const1, 0, add1, 1); - builder.AddDataEdge(const2, 0, shape1, 0); - builder.AddControlEdge(shape1, add1); - builder.AddDataEdge(add1, 0, reshape1, 0); - builder.AddDataEdge(shape1, 0, addn1, 0); - builder.AddDataEdge(reshape1, 0, sum1, 0); - builder.AddDataEdge(addn1, 0, sum1, 1); - - return builder.GetGraph(); -} - -/// rnextiteration -/// | | -/// merge -/// | -/// data1 -ComputeGraphPtr BuildGraph3() { - auto builder = ut::GraphBuilder("g1"); - auto data1 = builder.AddNode("data1", DATA, 0, 1); - auto merge1 = builder.AddNode("merge1", MERGE, 2, 1); - auto next1 = builder.AddNode("next1", NEXTITERATION, 1, 1); - - builder.AddDataEdge(data1, 0, merge1, 0); - builder.AddDataEdge(merge1, 0, next1, 0); - builder.AddDataEdge(next1, 0, merge1, 1); - builder.AddControlEdge(merge1, next1); - builder.AddControlEdge(next1, merge1); - - return builder.GetGraph(); -} - -void CheckIterOrder(UtestTestPass *pass, std::vector> &nodes_layers) { - std::unordered_set layer_nodes; - size_t layer_index = 0; - for (const auto &node : pass->GetIterNodes()) { - layer_nodes.insert(node->GetName()); - EXPECT_LT(layer_index, nodes_layers.size()); - if (layer_nodes == nodes_layers[layer_index]) { - layer_index++; - layer_nodes.clear(); - } - } - EXPECT_EQ(layer_index, nodes_layers.size()); -} - -/// Op1 -/// | -/// Merge -/// / \. -/// Op2 Op3 -TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { - auto builder = ut::GraphBuilder("g1"); - auto merge_node = builder.AddNode("Merge", MERGE, 1, 1); - auto node1 = builder.AddNode("Op1", RELU, 1, 1); - auto node2 = builder.AddNode("Op2", CONVOLUTION, 1, 1); - auto node3 = builder.AddNode("Op3", CONVOLUTION, 1, 1); - - GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); - GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); - - EXPECT_EQ(node1->GetOutDataNodes().size(), 1); - - TestDelPass del_pass; - auto ret = del_pass.IsolateAndDeleteNode(merge_node, {0, -1}); - EXPECT_EQ(ret, FAILED); - - OpDescPtr op_desc = std::make_shared("merge", MERGE); - NodePtr node = shared_ptr(new (std::nothrow) Node(op_desc, nullptr)); - ret = del_pass.IsolateAndDeleteNode(node, {0, -1}); - EXPECT_EQ(ret, FAILED); -} - -/// Op1 -/// | -/// Merge -/// / \. -/// Op2 Op3 -TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { - auto builder = ut::GraphBuilder("g1"); - auto merge_node = builder.AddNode("Merge", MERGE, 1, 2); - auto node1 = builder.AddNode("Op1", RELU, 1, 1); - auto node2 = builder.AddNode("Op2", CONVOLUTION, 1, 1); - auto node3 = builder.AddNode("Op3", CONVOLUTION, 1, 1); - - GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); - GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); - - EXPECT_EQ(node1->GetOutDataNodes().size(), 1); - - TestDelPass del_pass; - auto ret = del_pass.IsolateAndDeleteNode(merge_node, {0, -1}); - EXPECT_EQ(ret, SUCCESS); -} - -TEST_F(UTESTGraphPassesBasePass, data_graph) { - auto graph = BuildGraph1(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); - auto *pass = dynamic_cast(names_to_pass_[0].second); - - EXPECT_EQ(pass->GetIterNodes().size(), 4); - std::vector> layers; - layers.push_back({"data1", "const1"}); - layers.push_back({"add1"}); - layers.push_back({"reshape1"}); - CheckIterOrder(pass, layers); -} - -TEST_F(UTESTGraphPassesBasePass, graph_with_control_link) { - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); - auto *pass = dynamic_cast(names_to_pass_[0].second); - - EXPECT_EQ(pass->GetIterNodes().size(), 8); - EXPECT_EQ(pass->GetIterNodes().at(3)->GetName(), "shape1"); - - std::vector> layers; - layers.push_back({"data1", "const1", "const2"}); - layers.push_back({"shape1"}); - layers.push_back({"add1", "addn1", "reshape1"}); - layers.push_back({"sum1"}); - CheckIterOrder(pass, layers); -} - -TEST_F(UTESTGraphPassesBasePass, re_pass_after) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddRePassNodeName("add1", "sum1"); - test_pass.AddRePassNodeName("shape1", "sum1"); - test_pass.AddRePassNodeName("shape1", "add1"); - test_pass.AddRePassNodeName("data1", "add1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 8); -} - -TEST_F(UTESTGraphPassesBasePass, re_pass_before) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddRePassNodeName("add1", "data1"); - - auto graph = BuildGraph1(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 5); - EXPECT_EQ(test_pass.GetIterNodes().at(2)->GetName(), "add1"); - EXPECT_EQ(test_pass.GetIterNodes().at(3)->GetName(), "reshape1"); - EXPECT_EQ(test_pass.GetIterNodes().at(4)->GetName(), "data1"); -} - -TEST_F(UTESTGraphPassesBasePass, re_pass_before_multi_times) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddRePassNodeName("add1", "data1"); - test_pass.AddRePassNodeName("add1", "const1"); - test_pass.AddRePassNodeName("reshape1", "data1"); - - auto graph = BuildGraph1(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 6); - EXPECT_EQ(test_pass.GetIterNodes().at(2)->GetName(), "add1"); - EXPECT_EQ(test_pass.GetIterNodes().at(3)->GetName(), "reshape1"); -} - -TEST_F(UTESTGraphPassesBasePass, del_after) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddDelNodeName("add1", "sum1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 7); -} - -TEST_F(UTESTGraphPassesBasePass, del_after_multiple) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddDelNodeName("add1", "sum1"); - test_pass.AddDelNodeName("add1", "reshape1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 6); -} - -TEST_F(UTESTGraphPassesBasePass, del_after_break_link) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddDelNodeName("shape1", "add1"); - test_pass.AddDelNodeName("shape1", "addn1"); - test_pass.AddRePassNodeName("shape1", "shape1"); - test_pass.AddRePassNodeName("shape1", "reshape1"); - test_pass.AddRePassNodeName("shape1", "sum1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 7); -} - -TEST_F(UTESTGraphPassesBasePass, del_self_and_after) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddDelNodeName("shape1", "add1"); - test_pass.AddDelNodeName("shape1", "addn1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 4); -} - -TEST_F(UTESTGraphPassesBasePass, del_before) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddDelNodeName("reshape1", "add1"); - test_pass.AddDelNodeName("sum1", "addn1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 8); -} - -TEST_F(UTESTGraphPassesBasePass, re_pass_and_del) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddRePassNodeName("add1", "sum1"); - test_pass.AddDelNodeName("reshape1", "sum1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetIterNodes().size(), 7); -} -/* -TEST_F(UTESTGraphPassesBasePass, dead_loop) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(true); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - test_pass.AddRePassNodeName("add1", "sum1"); - test_pass.AddRePassNodeName("sum1", "add1"); - - auto graph = BuildGraph2(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); - EXPECT_EQ(test_pass.GetRunTimes(), 1007); -} -*/ - -TEST_F(UTESTGraphPassesBasePass, while_loop) { - NamesToPass names_to_pass; - auto test_pass = UtestTestPass(true); - names_to_pass.push_back(std::make_pair("test", &test_pass)); - - auto graph = BuildGraph3(); - auto ge_pass = GEPass(graph); - EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); -} - -/// data1 const -/// \ / -/// while -/// / \. -/// | | -/// cast1 cast2 -ComputeGraphPtr BuildWhileGraph1() { - // build sub graph - auto builder_sub = ut::GraphBuilder("sub"); - auto data_1 = builder_sub.AddNode("data_1", DATA, 0, 1); - auto data_2 = builder_sub.AddNode("data_2", DATA, 0, 1); - auto add = builder_sub.AddNode("add", ADD, 2, 1); - - builder_sub.AddDataEdge(data_1, 0, add, 0); - builder_sub.AddDataEdge(data_2, 0, add, 1); - auto sub_graph = builder_sub.GetGraph(); - sub_graph->SetName("while_sub"); - // build root graph - auto builder = ut::GraphBuilder("g1"); - auto data = builder.AddNode("data1", DATA, 0, 1); - auto const_op = builder.AddNode("const_op", CONSTANT, 0, 1); - auto c1 = builder.AddNode("cast1", CAST, 1, 1); - auto c2 = builder.AddNode("cast2", CAST, 1, 1); - // add while op - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1,1,1,1})); - tensor_desc->SetFormat(FORMAT_ND); - tensor_desc->SetDataType(DT_INT32); - - auto op_desc = std::make_shared("while", WHILE); - for (int i = 0; i < 2; ++i) { - op_desc->AddInputDesc(tensor_desc->Clone()); - } - for (int i = 0; i < 2; ++i) { - op_desc->AddOutputDesc(tensor_desc->Clone()); - } - AttrUtils::SetBool(op_desc,"_need_infer_again", true); - op_desc->AddSubgraphName(sub_graph->GetName()); - op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); - auto root_graph = builder.GetGraph(); - auto while_op = root_graph->AddNode(op_desc); - - builder.AddDataEdge(data, 0, while_op, 0); - builder.AddDataEdge(const_op, 0, while_op, 1); - builder.AddDataEdge(while_op, 0, c1, 0); - builder.AddDataEdge(while_op, 1, c2, 0); - sub_graph->SetParentGraph(root_graph); - sub_graph->SetParentNode(while_op); - root_graph->AddSubgraph(sub_graph); - return root_graph; -} - -TEST_F(UTESTGraphPassesBasePass, while_infershape) { -NamesToPass names_to_pass; -auto test_pass = UtestTestPass(); -names_to_pass.push_back(std::make_pair("test", &test_pass)); - -auto graph = BuildWhileGraph1(); -auto ge_pass = GEPass(graph); -auto while_node = graph->FindNode("while"); -EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); -EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); -} - -} // namespace ge +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#define protected public +#include "graph/passes/base_pass.h" +#undef protected + +#include "framework/common/types.h" +#include "graph/node.h" +#include "graph/utils/graph_utils.h" +#include "graph_builder_utils.h" + +template class std::unordered_set; + +namespace ge { +class UtestTestPass : public BaseNodePass { + public: + UtestTestPass() = default; + UtestTestPass(bool dead_loop) : dead_loop_(dead_loop), run_times_(0) {} + + Status Run(NodePtr &node) override { + ++run_times_; + iter_nodes_.push_back(node); + auto iter = names_to_add_del_.find(node->GetName()); + if (iter != names_to_add_del_.end()) { + for (const auto &node_name : iter->second) { + auto del_node = node->GetOwnerComputeGraph()->FindNode(node_name); + GraphUtils::IsolateNode(del_node, {0}); + AddNodeDeleted(del_node); + } + } + iter = names_to_add_repass_.find(node->GetName()); + if (iter != names_to_add_repass_.end()) { + auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); + for (const auto &node_name : iter->second) { + for (auto &node_re_pass : all_nodes) { + if (node_re_pass->GetName() == node_name) { + AddRePassNode(node_re_pass); + break; + } + } + } + if (!dead_loop_) { + names_to_add_repass_.erase(iter); + } + } + + iter = names_to_add_repass_immediate_.find(node->GetName()); + if (iter != names_to_add_repass_immediate_.end()) { + auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); + for (const auto &node_name : iter->second) { + for (auto &node_re_pass : all_nodes) { + if (node_re_pass->GetName() == node_name) { + AddImmediateRePassNode(node_re_pass); + break; + } + } + } + if (!dead_loop_) { + names_to_add_repass_immediate_.erase(iter); + } + } + + iter = names_to_add_suspend_.find(node->GetName()); + if (iter != names_to_add_suspend_.end()) { + auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); + for (const auto &node_name : iter->second) { + for (auto &node_re_pass : all_nodes) { + if (node_re_pass->GetName() == node_name) { + AddNodeSuspend(node_re_pass); + break; + } + } + } + if (!dead_loop_) { + names_to_add_suspend_.erase(iter); + } + } + + iter = names_to_add_resume_.find(node->GetName()); + if (iter != names_to_add_resume_.end()) { + auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes(); + for (const auto &node_name : iter->second) { + for (auto &node_re_pass : all_nodes) { + if (node_re_pass->GetName() == node_name) { + AddNodeResume(node_re_pass); + break; + } + } + } + if (!dead_loop_) { + names_to_add_resume_.erase(iter); + } + } + // simulate infershape pass + if(node->GetType() == WHILE){ + bool need_repass = false; + AttrUtils::GetBool(node->GetOpDesc(),"_need_infer_again", need_repass); + if(!OptionExists(kOptimizeAfterSubGraph)){ + return SUCCESS; + } + if(need_repass){ + AttrUtils::SetBool(node->GetOpDesc(),"_need_infer_again", false); + AddImmediateRePassNode(node); + } + else{ + // clear attr on while + node->GetOpDesc()->DelAttr("_need_infer_again"); + } + } + return SUCCESS; + } + + Status OnSuspendNodesLeaked() override { + // resume all node remain in suspend_nodes when leaked + auto compute_graph = (iter_nodes_.size() > 0) ? iter_nodes_[0]->GetOwnerComputeGraph() : nullptr; + if (compute_graph == nullptr) { + return SUCCESS; + } + + for (const auto &node_name : names_to_add_resume_onleaked_) { + auto node_to_resume = compute_graph->FindNode(node_name); + AddNodeResume(node_to_resume); + } + return SUCCESS; + } + void clear() { iter_nodes_.clear(); } + std::vector GetIterNodes() { return iter_nodes_; } + + void AddRePassNodeName(const std::string &iter_node, const std::string &re_pass_node) { + names_to_add_repass_[iter_node].insert(re_pass_node); + } + void AddDelNodeName(const std::string &iter_node, const std::string &del_node) { + names_to_add_del_[iter_node].insert(del_node); + } + void AddRePassImmediateNodeName(const std::string &iter_node, const std::string &re_pass_node) { + names_to_add_repass_immediate_[iter_node].insert(re_pass_node); + } + + void AddSuspendNodeName(const std::string &iter_node, const std::string &suspend_node) { + names_to_add_suspend_[iter_node].insert(suspend_node); + } + void AddResumeNodeName(const std::string &iter_node, const std::string &resume_node) { + names_to_add_resume_[iter_node].insert(resume_node); + } + void AddResumeNodeNameOnLeaked(const std::string &resume_node) { + names_to_add_resume_onleaked_.insert(resume_node); + } + + unsigned int GetRunTimes() { return run_times_; } + + private: + std::vector iter_nodes_; + std::map> names_to_add_del_; + std::map> names_to_add_repass_; + std::map> names_to_add_repass_immediate_; + std::map> names_to_add_suspend_; + std::map> names_to_add_resume_; + std::unordered_set names_to_add_resume_onleaked_; + + bool dead_loop_; + unsigned int run_times_; +}; + +class TestDelPass : public BaseNodePass { + public: + Status Run(NodePtr &node) override { return SUCCESS; } +}; + +class UTESTGraphPassesBasePass : public testing::Test { + protected: + UTESTGraphPassesBasePass() { + auto p1 = new UtestTestPass; + names_to_pass_.push_back(std::make_pair("test1", p1)); + } + void SetUp() override { + for (auto &name_to_pass : names_to_pass_) { + dynamic_cast(name_to_pass.second)->clear(); + } + } + ~UTESTGraphPassesBasePass() override { + for (auto &name_to_pass : names_to_pass_) { + delete name_to_pass.second; + } + } + NamesToPass names_to_pass_; +}; +/// reshape1 +/// | +/// add1 +/// / \. +/// | | +/// data1 const1 +ComputeGraphPtr BuildGraph1() { + auto builder = ut::GraphBuilder("g1"); + auto data = builder.AddNode("data1", DATA, 0, 1); + auto a1 = builder.AddNode("add1", ADD, 2, 1); + auto c1 = builder.AddNode("const1", CONSTANT, 0, 1); + auto r1 = builder.AddNode("reshape1", RESHAPE, 1, 1); + + builder.AddDataEdge(data, 0, a1, 0); + builder.AddDataEdge(c1, 0, a1, 1); + builder.AddDataEdge(a1, 0, r1, 0); + + return builder.GetGraph(); +} + +/// sum1 +/// / \. +/// / \. +/// / \. +/// reshape1 addn1 +/// | c | +/// add1 <--- shape1 +/// / \ | +/// | | | +/// data1 const1 const2 +ComputeGraphPtr BuildGraph2() { + auto builder = ut::GraphBuilder("g1"); + auto data1 = builder.AddNode("data1", DATA, 0, 1); + auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); + auto const2 = builder.AddNode("const2", CONSTANT, 0, 1); + auto add1 = builder.AddNode("add1", ADD, 2, 1); + auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1); + auto reshape1 = builder.AddNode("reshape1", RESHAPE, 1, 1); + auto addn1 = builder.AddNode("addn1", ADDN, 1, 1); + auto sum1 = builder.AddNode("sum1", SUM, 2, 1); + + builder.AddDataEdge(data1, 0, add1, 0); + builder.AddDataEdge(const1, 0, add1, 1); + builder.AddDataEdge(const2, 0, shape1, 0); + builder.AddControlEdge(shape1, add1); + builder.AddDataEdge(add1, 0, reshape1, 0); + builder.AddDataEdge(shape1, 0, addn1, 0); + builder.AddDataEdge(reshape1, 0, sum1, 0); + builder.AddDataEdge(addn1, 0, sum1, 1); + + return builder.GetGraph(); +} + +/// rnextiteration +/// | | +/// merge +/// | +/// data1 +ComputeGraphPtr BuildGraph3() { + auto builder = ut::GraphBuilder("g1"); + auto data1 = builder.AddNode("data1", DATA, 0, 1); + auto merge1 = builder.AddNode("merge1", MERGE, 2, 1); + auto next1 = builder.AddNode("next1", NEXTITERATION, 1, 1); + + builder.AddDataEdge(data1, 0, merge1, 0); + builder.AddDataEdge(merge1, 0, next1, 0); + builder.AddDataEdge(next1, 0, merge1, 1); + builder.AddControlEdge(merge1, next1); + builder.AddControlEdge(next1, merge1); + + return builder.GetGraph(); +} + +/// cast1--shape1 +/// / +/// data1 +/// \ +/// transdata1--shape2 +ComputeGraphPtr BuildGraph4() { + auto builder = ut::GraphBuilder("g1"); + auto data1 = builder.AddNode("data1", DATA, 0, 1); + auto cast1 = builder.AddNode("cast1", CAST, 1, 1); + auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1); + auto transdata1 = builder.AddNode("transdata1", TRANSDATA, 1, 1); + auto shape2 = builder.AddNode("shape2", SHAPE, 1, 1); + + builder.AddDataEdge(data1, 0, cast1, 0); + builder.AddDataEdge(data1, 0, transdata1, 0); + builder.AddDataEdge(cast1, 0, shape1, 0); + builder.AddDataEdge(transdata1, 0, shape2, 0); + return builder.GetGraph(); +} + +void CheckIterOrder(UtestTestPass *pass, std::vector> &nodes_layers) { + std::unordered_set layer_nodes; + size_t layer_index = 0; + for (const auto &node : pass->GetIterNodes()) { + layer_nodes.insert(node->GetName()); + EXPECT_LT(layer_index, nodes_layers.size()); + if (layer_nodes == nodes_layers[layer_index]) { + layer_index++; + layer_nodes.clear(); + } + } + EXPECT_EQ(layer_index, nodes_layers.size()); +} + +/// Op1 +/// | +/// Merge +/// / \. +/// Op2 Op3 +TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { + auto builder = ut::GraphBuilder("g1"); + auto merge_node = builder.AddNode("Merge", MERGE, 1, 1); + auto node1 = builder.AddNode("Op1", RELU, 1, 1); + auto node2 = builder.AddNode("Op2", CONVOLUTION, 1, 1); + auto node3 = builder.AddNode("Op3", CONVOLUTION, 1, 1); + + GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); + + EXPECT_EQ(node1->GetOutDataNodes().size(), 1); + + TestDelPass del_pass; + auto ret = del_pass.IsolateAndDeleteNode(merge_node, {0, -1}); + EXPECT_EQ(ret, FAILED); + + OpDescPtr op_desc = std::make_shared("merge", MERGE); + NodePtr node = shared_ptr(new (std::nothrow) Node(op_desc, nullptr)); + ret = del_pass.IsolateAndDeleteNode(node, {0, -1}); + EXPECT_EQ(ret, FAILED); +} + +/// Op1 +/// | +/// Merge +/// / \. +/// Op2 Op3 +TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { + auto builder = ut::GraphBuilder("g1"); + auto merge_node = builder.AddNode("Merge", MERGE, 1, 2); + auto node1 = builder.AddNode("Op1", RELU, 1, 1); + auto node2 = builder.AddNode("Op2", CONVOLUTION, 1, 1); + auto node3 = builder.AddNode("Op3", CONVOLUTION, 1, 1); + + GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); + + EXPECT_EQ(node1->GetOutDataNodes().size(), 1); + + TestDelPass del_pass; + auto ret = del_pass.IsolateAndDeleteNode(merge_node, {0, -1}); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UTESTGraphPassesBasePass, data_graph) { + auto graph = BuildGraph1(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + auto *pass = dynamic_cast(names_to_pass_[0].second); + + EXPECT_EQ(pass->GetIterNodes().size(), 4); + std::vector> layers; + layers.push_back({"data1", "const1"}); + layers.push_back({"add1"}); + layers.push_back({"reshape1"}); + CheckIterOrder(pass, layers); +} + +TEST_F(UTESTGraphPassesBasePass, graph_with_control_link) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + auto *pass = dynamic_cast(names_to_pass_[0].second); + + EXPECT_EQ(pass->GetIterNodes().size(), 8); + EXPECT_EQ(pass->GetIterNodes().at(3)->GetName(), "shape1"); + + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1", "reshape1"}); + layers.push_back({"sum1"}); + CheckIterOrder(pass, layers); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_after) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddRePassNodeName("add1", "sum1"); + test_pass.AddRePassNodeName("shape1", "sum1"); + test_pass.AddRePassNodeName("shape1", "add1"); + test_pass.AddRePassNodeName("data1", "add1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 8); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_before) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddRePassNodeName("add1", "data1"); + + auto graph = BuildGraph1(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 5); + EXPECT_EQ(test_pass.GetIterNodes().at(2)->GetName(), "add1"); + EXPECT_EQ(test_pass.GetIterNodes().at(3)->GetName(), "reshape1"); + EXPECT_EQ(test_pass.GetIterNodes().at(4)->GetName(), "data1"); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_before_multi_times) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddRePassNodeName("add1", "data1"); + test_pass.AddRePassNodeName("add1", "const1"); + test_pass.AddRePassNodeName("reshape1", "data1"); + + auto graph = BuildGraph1(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 6); + EXPECT_EQ(test_pass.GetIterNodes().at(2)->GetName(), "add1"); + EXPECT_EQ(test_pass.GetIterNodes().at(3)->GetName(), "reshape1"); +} + +TEST_F(UTESTGraphPassesBasePass, del_after) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddDelNodeName("add1", "sum1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 7); +} + +TEST_F(UTESTGraphPassesBasePass, del_after_multiple) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddDelNodeName("add1", "sum1"); + test_pass.AddDelNodeName("add1", "reshape1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 6); +} + +TEST_F(UTESTGraphPassesBasePass, del_after_break_link) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddDelNodeName("shape1", "add1"); + test_pass.AddDelNodeName("shape1", "addn1"); + test_pass.AddRePassNodeName("shape1", "shape1"); + test_pass.AddRePassNodeName("shape1", "reshape1"); + test_pass.AddRePassNodeName("shape1", "sum1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 7); +} + +TEST_F(UTESTGraphPassesBasePass, del_self_and_after) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddDelNodeName("shape1", "add1"); + test_pass.AddDelNodeName("shape1", "addn1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 6); +} + +TEST_F(UTESTGraphPassesBasePass, del_before) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddDelNodeName("reshape1", "add1"); + test_pass.AddDelNodeName("sum1", "addn1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 8); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_and_del) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddRePassNodeName("add1", "sum1"); + test_pass.AddDelNodeName("reshape1", "sum1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 7); +} +/* +TEST_F(UTESTGraphPassesBasePass, dead_loop) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(true); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + test_pass.AddRePassNodeName("add1", "sum1"); + test_pass.AddRePassNodeName("sum1", "add1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetRunTimes(), 1007); +} +*/ + +TEST_F(UTESTGraphPassesBasePass, while_loop) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(true); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + auto graph = BuildGraph3(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); +} + +/// data1 const +/// \ / +/// while +/// / \. +/// | | +/// cast1 cast2 +ComputeGraphPtr BuildWhileGraph1() { + // build sub graph + auto builder_sub = ut::GraphBuilder("sub"); + auto data_1 = builder_sub.AddNode("data_1", DATA, 0, 1); + auto data_2 = builder_sub.AddNode("data_2", DATA, 0, 1); + auto add = builder_sub.AddNode("add", ADD, 2, 1); + + builder_sub.AddDataEdge(data_1, 0, add, 0); + builder_sub.AddDataEdge(data_2, 0, add, 1); + auto sub_graph = builder_sub.GetGraph(); + sub_graph->SetName("while_sub"); + // build root graph + auto builder = ut::GraphBuilder("g1"); + auto data = builder.AddNode("data1", DATA, 0, 1); + auto const_op = builder.AddNode("const_op", CONSTANT, 0, 1); + auto c1 = builder.AddNode("cast1", CAST, 1, 1); + auto c2 = builder.AddNode("cast2", CAST, 1, 1); + // add while op + auto tensor_desc = std::make_shared(); + tensor_desc->SetShape(GeShape({1,1,1,1})); + tensor_desc->SetFormat(FORMAT_ND); + tensor_desc->SetDataType(DT_INT32); + + auto op_desc = std::make_shared("while", WHILE); + for (int i = 0; i < 2; ++i) { + op_desc->AddInputDesc(tensor_desc->Clone()); + } + for (int i = 0; i < 2; ++i) { + op_desc->AddOutputDesc(tensor_desc->Clone()); + } + AttrUtils::SetBool(op_desc,"_need_infer_again", true); + op_desc->AddSubgraphName(sub_graph->GetName()); + op_desc->SetSubgraphInstanceName(0,sub_graph->GetName()); + auto root_graph = builder.GetGraph(); + auto while_op = root_graph->AddNode(op_desc); + + builder.AddDataEdge(data, 0, while_op, 0); + builder.AddDataEdge(const_op, 0, while_op, 1); + builder.AddDataEdge(while_op, 0, c1, 0); + builder.AddDataEdge(while_op, 1, c2, 0); + sub_graph->SetParentGraph(root_graph); + sub_graph->SetParentNode(while_op); + root_graph->AddSubgraph(sub_graph); + return root_graph; +} + +TEST_F(UTESTGraphPassesBasePass, while_infershape) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + auto graph = BuildWhileGraph1(); + auto ge_pass = GEPass(graph); + auto while_node = graph->FindNode("while"); + EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // repass pre_node immediately + test_pass->AddRePassImmediateNodeName("reshape1", "add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + + EXPECT_EQ(test_pass->GetIterNodes().size(), 9);// todo + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "add1", "sum1"}); + CheckIterOrder(test_pass, layers); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // repass cur_node immediately + test_pass->AddRePassImmediateNodeName("reshape1", "reshape1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + + EXPECT_EQ(test_pass->GetIterNodes().size(), 9); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(test_pass, layers); +} + +TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // repass next_node immediately + test_pass->AddRePassImmediateNodeName("reshape1", "sum1"); + // repass node after next_node immediately + test_pass->AddRePassImmediateNodeName("add1", "sum1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + + EXPECT_EQ(test_pass->GetIterNodes().size(), 8); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(test_pass, layers); +} +/** + * A->B->C + * if node B suspend its pre_node A, and C resume A, it is a useless operation, so iter_order should follow normal order + * when C resuem A, A will pass again. + */ +TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_C_resume_A) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // add1->reshape1->sum1 + test_pass->AddSuspendNodeName("reshape1", "add1"); + test_pass->AddResumeNodeName("sum1", "add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 9); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + layers.push_back({"add1"}); + CheckIterOrder(test_pass, layers); +} + +/** + * A->B->C + * if node B suspend its pre_node A, and B resume A, it is a useless operation, so iter_order should follow normal order + * when B resuem A, A will pass again. + */ +TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_B_resume_A) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // add1->reshape1->sum1 + test_pass->AddSuspendNodeName("reshape1", "add1"); + test_pass->AddResumeNodeName("reshape1", "add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 9); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1", "add1"}); + CheckIterOrder(test_pass, layers); +} + +/** + * A->B->C + * if node B resume C(which is not suspended), it is a useless operation, C will not pass. + */ +TEST_F(UTESTGraphPassesBasePass, B_resume_node_not_suspended) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // add1->reshape1->sum1 + test_pass->AddResumeNodeName("reshape1", "sum1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 8); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(test_pass, layers); +} + +/** + * A->B->C + * if node B suspend its pre_node A, it is a useless operation, so iter_order should follow normal order + * because nobody resume it ,which means A is a leaked node, so return fail + */ +TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_nobody_resume_it_return_failed) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + // suspend pre_node immediately + test_pass.AddSuspendNodeName("reshape1", "add1"); + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), INTERNAL_ERROR); +} + +/** + * A->B->C + * if node B suspend its pre_node A, it is a useless operation, + * so iter_order should follow normal order + * resume A on leaked, which means A will pass again + */ +TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_resume_it_onleaked) { + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("reshape1", "add1"); + test_pass->AddResumeNodeNameOnLeaked("add1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + layers.push_back({"add1"}); + CheckIterOrder(test_pass, layers); +} + + +/// cast1--shape1 +/// / +/// data1 +/// \ +/// transdata1--shape2 +/** + * suspend cur node + * cast1 suspend itself, shape2 resume cast1 + * iter order follows : data1; cast1,transdata1; shape2; cast1 ; shape1 + */ +TEST_F(UTESTGraphPassesBasePass, cast1_suspend_cur_node_shape2_resume_cast1) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("cast1", "cast1"); + test_pass->AddResumeNodeName("shape2", "cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 6); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"cast1","transdata1"}); + layers.push_back({"shape2"}); + layers.push_back({"cast1", "shape1"}); + CheckIterOrder(test_pass, layers); +} +/** + * suspend cur node + * cast1 suspend itself, then resume cast1 + * iter order follows : data1; cast1,cast1,transdata1; shape2; shape1. + */ +TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_itself) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("cast1", "cast1"); + test_pass->AddResumeNodeName("cast1", "cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 6); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"cast1","transdata1","cast1","shape1", "shape2"}); + CheckIterOrder(test_pass, layers); +} +/** + * suspend cur node + * cast1 suspend itself, then resume cast1 on leaked + * iter order follows : data1; cast1,cast1,transdata1; shape2; shape1. + */ +TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_onleaked) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("cast1", "cast1"); + test_pass->AddResumeNodeNameOnLeaked("cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 6); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"cast1","transdata1", "shape2"}); + layers.push_back({"cast1","shape1"}); + CheckIterOrder(test_pass, layers); +} +/** + * suspend next node + * data1 suspend cast1, then resume cast1 on leaked + * iter order follows : data1; transdata1, shape2; cast1, shape1. + */ +TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_resume_cast1_onleaked) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("data1", "cast1"); + test_pass->AddResumeNodeNameOnLeaked("cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS); + EXPECT_EQ(test_pass->GetIterNodes().size(), 5); + std::vector> layers; + layers.push_back({"data1"}); + layers.push_back({"transdata1", "shape2"}); + layers.push_back({"cast1","shape1"}); + CheckIterOrder(test_pass, layers); +} + +/** + * suspend next node + * data1 suspend cast1, nobody resume it + * iter order follows : data1; transdata1, shape2; + * run ret is failed ,because node leaked + */ +TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_nobody_resume) { + auto graph = BuildGraph4(); + auto ge_pass = GEPass(graph); + auto *test_pass = dynamic_cast(names_to_pass_[0].second); + // suspend pre_node immediately + test_pass->AddSuspendNodeName("data1", "cast1"); + EXPECT_EQ(ge_pass.Run(names_to_pass_), INTERNAL_ERROR); + EXPECT_EQ(test_pass->GetIterNodes().size(), 3); +} + +/* +TEST_F(UTESTGraphPassesBasePass, suspend_pre_node) { + NamesToPass names_to_pass; + auto test_pass = UtestTestPass(); + names_to_pass.push_back(std::make_pair("test", &test_pass)); + + // repass next_node immediately + test_pass.AddRePassNodeName("reshape1", "sum1"); + // repass node after next_node immediately + test_pass.AddRePassNodeName("add1", "sum1"); + + auto graph = BuildGraph2(); + auto ge_pass = GEPass(graph); + EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS); + EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo + std::vector> layers; + layers.push_back({"data1", "const1", "const2"}); + layers.push_back({"shape1"}); + layers.push_back({"add1", "addn1"}); + layers.push_back({"reshape1", "sum1"}); + CheckIterOrder(&test_pass, layers); +}*/ +} // namespace ge diff --git a/tests/ut/ge/graph/passes/infershape_pass_unittest.cc b/tests/ut/ge/graph/passes/infershape_pass_unittest.cc index 13e66c50..d84aff50 100644 --- a/tests/ut/ge/graph/passes/infershape_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/infershape_pass_unittest.cc @@ -1,161 +1,262 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#define protected public -#define private public -#include "graph/passes/infershape_pass.h" - -#include "graph/utils/tensor_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/operator_factory.h" -#include "graph/operator_reg.h" -#include "graph_builder_utils.h" - -using namespace std; -using namespace testing; -namespace ge { -class UtestGraphInfershapePass : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { - OpDescPtr op_desc = std::make_shared(name, type); - op_desc->SetStreamId(0); - static int32_t index = 0; - op_desc->SetId(index++); - - GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); - TensorUtils::SetSize(tensor, 512); - vector input_offset; - for (int i = 0; i < in_num; i++) { - op_desc->AddInputDesc(tensor); - input_offset.emplace_back(1024); - } - op_desc->SetInputOffset(input_offset); - - vector output_offset; - for (int i = 0; i < out_num; i++) { - op_desc->AddOutputDesc(tensor); - output_offset.emplace_back(1024); - } - op_desc->SetOutputOffset(output_offset); - - op_desc->SetWorkspace({}); - op_desc->SetWorkspaceBytes({}); - op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); - - const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; - op_desc->AddInferFunc(stub_func); - op_desc->AddInferFormatFunc(stub_func); - op_desc->AddVerifierFunc(stub_func); - - return graph.AddNode(op_desc); -} - -TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { - GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); - string type = "AddN"; - auto addn_op_desc = std::make_shared("AddN", type); - addn_op_desc->AddInputDesc(ge_tensor_desc); - addn_op_desc->AddOutputDesc(ge_tensor_desc); - auto graph = std::make_shared("test"); - auto addn_node = std::make_shared(addn_op_desc, graph); - addn_node->Init(); - - InferShapePass infershape_pass; - EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED); -} - -TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { - auto graph = std::make_shared("test"); - - auto no_op_desc = std::make_shared("No", "NoOp"); - auto no_op_node = graph->AddNode(no_op_desc); - AttrUtils::SetBool(no_op_desc, "_need_infer_again", false); - - InferShapePass infershape_pass; - infershape_pass.options_[kOptimizeAfterSubGraph] = "yes"; - EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); -} - -TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) { -/******************************************************************************* - * Exit Identify - * \ / \. - * \ / \. - * Switch Add - * / | | - * / | | - * / | | - * LoopCond | | - * \ | | - * \ | | - * \ | | - * Less | | - * \ | NextIteration - * \ | | - * \ | | - * Merge <---------| - * | - * | - * Enter - ******************************************************************************/ - auto graph = std::make_shared("test_infer_shape"); - auto data1 = CreateNode(*graph, "data", DATA, 1, 1); - auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); - auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); - auto less1 = CreateNode(*graph, "less", LESS, 2, 1); - auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); - auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); - auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); - auto add1 = CreateNode(*graph, "add", ADD, 2, 1); - auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); - auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); - auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); - auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); - auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); - - GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); - GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); - GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); - GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); - GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); - - GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); - GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); - - GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); - GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); - - GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); - GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); - GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); - - GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); - GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); - - GEPass ge_passes(graph); - NamesToPass names_to_passes; - InferShapePass infer_shape_pass; - names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); - - EXPECT_EQ(ge_passes.Run(names_to_passes), SUCCESS); -} -} // namespace ge +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define protected public +#define private public +#include "graph/passes/infershape_pass.h" + +#include "graph/utils/tensor_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/operator_factory.h" +#include "graph/operator_reg.h" +#include "graph_builder_utils.h" + +using namespace std; +using namespace testing; +namespace ge { +class UtestGraphInfershapePass : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { + OpDescPtr op_desc = std::make_shared(name, type); + op_desc->SetStreamId(0); + static int32_t index = 0; + op_desc->SetId(index++); + + GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); + TensorUtils::SetSize(tensor, 512); + vector input_offset; + for (int i = 0; i < in_num; i++) { + op_desc->AddInputDesc(tensor); + input_offset.emplace_back(1024); + } + op_desc->SetInputOffset(input_offset); + + vector output_offset; + for (int i = 0; i < out_num; i++) { + op_desc->AddOutputDesc(tensor); + output_offset.emplace_back(1024); + } + op_desc->SetOutputOffset(output_offset); + + op_desc->SetWorkspace({}); + op_desc->SetWorkspaceBytes({}); + op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); + + const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; + op_desc->AddInferFunc(stub_func); + op_desc->AddInferFormatFunc(stub_func); + op_desc->AddVerifierFunc(stub_func); + + return graph.AddNode(op_desc); +} + +TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { + GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); + string type = "AddN"; + auto addn_op_desc = std::make_shared("AddN", type); + addn_op_desc->AddInputDesc(ge_tensor_desc); + addn_op_desc->AddOutputDesc(ge_tensor_desc); + auto graph = std::make_shared("test"); + auto addn_node = std::make_shared(addn_op_desc, graph); + addn_node->Init(); + + InferShapePass infershape_pass; + EXPECT_EQ(infershape_pass.Run(addn_node), GRAPH_FAILED); +} + +TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) { +/******************************************************************************* + * Exit Identify + * \ / \. + * \ / \. + * Switch Add + * / | | + * / | | + * / | | + * LoopCond | | + * \ | | + * \ | | + * \ | | + * Less | | + * \ | NextIteration + * \ | | + * \ | | + * Merge <---------| + * | + * | + * Enter + ******************************************************************************/ + auto graph = std::make_shared("test_infer_shape"); + auto data1 = CreateNode(*graph, "data", DATA, 1, 1); + auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); + auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); + auto less1 = CreateNode(*graph, "less", LESS, 2, 1); + auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); + auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); + auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); + auto add1 = CreateNode(*graph, "add", ADD, 2, 1); + auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); + auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); + auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); + auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); + auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); + GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); + GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); + + GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); + GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); + + GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); + GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); + + GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); + GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); + GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); + + GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); + GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); + + GEPass ge_passes(graph); + NamesToPass names_to_passes; + InferShapePass infer_shape_pass; + names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); + + EXPECT_EQ(infer_shape_pass.Run(switch1), SUCCESS); + auto suspend_nodes = infer_shape_pass.GetNodesSuspend(); + auto exit_node = graph->FindNode("exit"); + EXPECT_EQ(suspend_nodes.count(exit_node), 1); + infer_shape_pass.OnSuspendNodesLeaked(); + auto resume_nodes = infer_shape_pass.GetNodesResume(); + EXPECT_EQ(resume_nodes.count(exit_node), 1); +} +TEST_F(UtestGraphInfershapePass, update_tensordesc_when_changed) { + GeTensorDesc src_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); + GeTensorDesc dst_ge_tensor_desc(GeShape({2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); + GeTensorDescPtr src_tensor_desc_ptr = std::make_shared(src_ge_tensor_desc); + GeTensorDescPtr dst_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + bool changed = false; + infershape_pass.UpdateTensorDesc(src_tensor_desc_ptr, dst_tensor_desc_ptr, changed); + EXPECT_EQ(changed, true); + EXPECT_EQ(dst_tensor_desc_ptr->GetShape().GetDims(), std::vector({1, 2, 3, 4})); +} + +TEST_F(UtestGraphInfershapePass, update_tensordesc_when_not_changed) { + GeTensorDesc src_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); + GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); + GeTensorDescPtr src_tensor_desc_ptr = std::make_shared(src_ge_tensor_desc); + GeTensorDescPtr dst_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + bool changed = false; + infershape_pass.UpdateTensorDesc(src_tensor_desc_ptr, dst_tensor_desc_ptr, changed); + EXPECT_EQ(changed, false); +} + +TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_failed) { + // ref output has different dtype + GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); + GeTensorDesc ge_tensor_desc2(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared(ge_tensor_desc1); + GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared(ge_tensor_desc2); + GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + auto ret = infershape_pass.UpdateOutputFromSubgraphs({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, dst_ge_tensor_desc_ptr); + EXPECT_EQ(ret, GRAPH_FAILED); +} + +TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_get_unknown_rank) { + // ref output has different dtype + GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc ge_tensor_desc2(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared(ge_tensor_desc1); + GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared(ge_tensor_desc2); + GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + auto ret = infershape_pass.UpdateOutputFromSubgraphs({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, dst_ge_tensor_desc_ptr); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(dst_ge_tensor_desc_ptr->GetShape().GetDims(), UNKNOWN_RANK); +} + +TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_get_unknown_shape) { + // ref output has different dtype + GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc ge_tensor_desc2(GeShape({2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared(ge_tensor_desc1); + GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared(ge_tensor_desc2); + GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + auto ret = infershape_pass.UpdateOutputFromSubgraphs({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, dst_ge_tensor_desc_ptr); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(dst_ge_tensor_desc_ptr->GetShape().GetDims(), std::vector({-1,2,3,4})); + // todo shape range? +} + +TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_for_multiDims_failed) { + // ref output has different dtype + GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); + GeTensorDesc ge_tensor_desc2(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared(ge_tensor_desc1); + GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared(ge_tensor_desc2); + GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + auto ret = infershape_pass.UpdateOutputFromSubgraphsForMultiDims({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, + dst_ge_tensor_desc_ptr); + EXPECT_EQ(ret, GRAPH_FAILED); +} + +TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_for_multiDims_failed_shape_size_overflow) { + // ref output has different dtype + GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc ge_tensor_desc2(GeShape({INT64_MAX, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared(ge_tensor_desc1); + GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared(ge_tensor_desc2); + GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + auto ret = infershape_pass.UpdateOutputFromSubgraphsForMultiDims({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, + dst_ge_tensor_desc_ptr); + EXPECT_EQ(ret, PARAM_INVALID); +} + +TEST_F(UtestGraphInfershapePass, update_output_from_subgraphs_for_multiDims_success) { + // ref output has different dtype + GeTensorDesc ge_tensor_desc1(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc ge_tensor_desc2(GeShape({2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDesc dst_ge_tensor_desc(GeShape({1, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT); + GeTensorDescPtr ge_tensor_desc1_ptr = std::make_shared(ge_tensor_desc1); + GeTensorDescPtr ge_tensor_desc2_ptr = std::make_shared(ge_tensor_desc2); + GeTensorDescPtr dst_ge_tensor_desc_ptr = std::make_shared(dst_ge_tensor_desc); + InferShapePass infershape_pass; + auto ret = infershape_pass.UpdateOutputFromSubgraphsForMultiDims({ge_tensor_desc1_ptr, ge_tensor_desc2_ptr}, + dst_ge_tensor_desc_ptr); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(dst_ge_tensor_desc_ptr->GetShape().GetDims(), std::vector({2,2,3,4})); +} +} // namespace ge From eee6bc92d1b2bcd150b18621274222ba511ee5bd Mon Sep 17 00:00:00 2001 From: lichun Date: Fri, 16 Jul 2021 16:13:45 +0800 Subject: [PATCH 21/33] fix error code and add complex128 support --- tests/ut/ge/graph_ir/ge_ir_build_unittest.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc index 60f33ed3..500dbc2a 100644 --- a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc +++ b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc @@ -367,7 +367,7 @@ TEST(UtestIrBuild, check_data_op_attr_index_valid) { }; ModelBufferData model; graphStatus ret = aclgrphBuildModel(graph, build_options, model); - EXPECT_EQ(ret, GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); + EXPECT_EQ(ret, ge::FAILED); } // set attr index invalid, when not set input shape range @@ -377,7 +377,7 @@ TEST(UtestIrBuild, check_data_attr_index_succ_no_input_range) { const map build_options; ModelBufferData model; graphStatus ret = aclgrphBuildModel(graph, build_options, model); - EXPECT_EQ(ret, GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); + EXPECT_EQ(ret, ge::FAILED); } TEST(UtestIrBuild, check_modify_mixlist_param) { From 4cb9aff399e1a8564cc12827bffda9a3a540749d Mon Sep 17 00:00:00 2001 From: lichun Date: Fri, 16 Jul 2021 17:21:22 +0800 Subject: [PATCH 22/33] fix error code and add complex128 support --- .../ge/graph/build/graph_mem_assigner_unittest.cc | 85 ++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc diff --git a/tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc b/tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc new file mode 100644 index 00000000..983f1763 --- /dev/null +++ b/tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "graph/anchor.h" +#include "graph/attr_value.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "omg/omg_inner_types.h" +#include "../passes/graph_builder_utils.h" + +#define protected public +#define private public +#include "graph/build/memory/binary_block_mem_assigner.h" +#include "graph/build/memory/graph_mem_assigner_unittest.h" +#include "graph/build/memory/hybrid_mem_assigner.h" +#include "graph/build/memory/max_block_mem_assigner.h" +#include "graph/manager/graph_var_manager.h" +#undef protected +#undef private + +using namespace std; +using namespace testing; +using namespace ge; +using domi::GetContext; + +class UtestTaskGeneratorTest : public testing::Test { + public: + ge::ComputeGraphPtr BuildGraphWithVar(int64_t session_id) { + // init + MemManager::Instance().Initialize(std::vector({RT_MEMORY_HBM})); + VarManager::Instance(session_id)->Init(0, 0, 0, 0); + ge::ut::GraphBuilder builder("graph"); + auto var_input = builder.AddNode("var", "Variable", 1, 1); + auto const_input = builder.AddNode("const", "Const", 1, 1); + auto assign = builder.AddNode("assgin", "Assign", 2, 1); + // add link + builder.AddDataEdge(var_input, 0, assign, 0); + builder.AddDataEdge(const_input, 0, assign, 1); + // set offset + var_input->GetOpDesc()->SetOutputOffset({10000}); + const_input->GetOpDesc()->SetOutputOffset({1000}); + assign->GetOpDesc()->SetInputOffset({10100, 1000}); + assign->GetOpDesc()->SetOutputOffset({10100}); + // set inner offset + int64_t inner_offset = 100; + ge::AttrUtils::SetInt(assign->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_INNER_OFFSET, inner_offset); + ge::AttrUtils::SetInt(assign->GetOpDesc()->MutableOutputDesc(0), ATTR_NAME_INNER_OFFSET, inner_offset); + // add var addr + VarManager::Instance(session_id)->var_resource_->var_offset_map_.emplace(10000, RT_MEMORY_HBM); + + return builder.GetGraph(); + } + +protected: + void SetUp() {} + void TearDown() {} +}; + +TEST_F(UtestMemoryAssignerTest, graph_memory_assign_continuous_input) { + ge::ComputeGraphPtr compute_graph = make_shared(""); + GraphMemoryAssigner graph_mem_assigner(compute_graph); + map mem_type_to_offset = {}; + Status ret = ReAssignMemory(false, mem_type_to_offset); + EXPECT_EQ(ret, ge::FAILED); +} + From 7e6461f7f17fcc824f1b91c98ceb54a01cc2df15 Mon Sep 17 00:00:00 2001 From: lichun Date: Fri, 16 Jul 2021 17:32:17 +0800 Subject: [PATCH 23/33] fix error code and add complex128 support --- tests/ut/ge/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 773a2686..a5b3942d 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -677,6 +677,7 @@ set(MULTI_PARTS_TEST_FILES "graph/build/stream_allocator_unittest.cc" "graph/build/model_builder_unittest.cc" "graph/build/mem_assigner_unittest.cc" + "graph/build/graph_mem_assigner_unittest.cc" "graph/build/task_generator_unittest.cc" "graph/build/buffer_pool_mem_assigner_unittest.cc" "graph/execute/graph_execute_unittest.cc" From 8b7ae630863eac1d2aef3b20cf0496139554c420 Mon Sep 17 00:00:00 2001 From: lichun Date: Fri, 16 Jul 2021 17:35:08 +0800 Subject: [PATCH 24/33] fix error code and add complex128 support --- tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc b/tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc index 983f1763..4bff9b38 100644 --- a/tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc +++ b/tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc @@ -30,7 +30,7 @@ #define protected public #define private public #include "graph/build/memory/binary_block_mem_assigner.h" -#include "graph/build/memory/graph_mem_assigner_unittest.h" +#include "graph/build/memory/graph_mem_assigner.h" #include "graph/build/memory/hybrid_mem_assigner.h" #include "graph/build/memory/max_block_mem_assigner.h" #include "graph/manager/graph_var_manager.h" From a928df8eaaaaf0c8bea944d674e5eb58700854bd Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Fri, 16 Jul 2021 18:17:25 +0800 Subject: [PATCH 25/33] fix sc problem --- ge/graph/passes/infershape_pass.cc | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index deaebf4f..05b1b5fc 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -138,7 +138,9 @@ graphStatus InferShapePass::InferShapeAndType(NodePtr &node) { if (!is_unknown_graph) { auto inference_context = ShapeRefiner::CreateInferenceContext(node); GE_CHECK_NOTNULL(inference_context); - GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); + std::vector marks; + inference_context->GetMarks(marks); + GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), marks.size()); op.SetInferenceContext(inference_context); } @@ -151,10 +153,12 @@ graphStatus InferShapePass::InferShapeAndType(NodePtr &node) { if (!is_unknown_graph) { auto ctx_after_infer = op.GetInferenceContext(); if (ctx_after_infer != nullptr) { - GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); - if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { + std::vector marks; + ctx_after_infer->GetMarks(marks); + GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), marks.size()); + if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !marks.empty()) { GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), - ctx_after_infer->GetMarks().size()); + marks.size()); ShapeRefiner::PushToContextMap(node, ctx_after_infer); } } @@ -254,7 +258,8 @@ graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) { auto ret = op_desc->CallInferFunc(op); if (ret == GRAPH_PARAM_INVALID) { // Op ir no infer func, try to get infer func from operator factory - auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); + + auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType().c_str()); if (node_op.IsEmpty()) { GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); return ret; From f98a41081af2d7f23434cc1c63be84d0bef9d49c Mon Sep 17 00:00:00 2001 From: lichun Date: Fri, 16 Jul 2021 18:25:05 +0800 Subject: [PATCH 26/33] fix error code and add complex128 support --- tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc b/tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc index 4bff9b38..703ac3b4 100644 --- a/tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc +++ b/tests/ut/ge/graph/build/graph_mem_assigner_unittest.cc @@ -34,6 +34,7 @@ #include "graph/build/memory/hybrid_mem_assigner.h" #include "graph/build/memory/max_block_mem_assigner.h" #include "graph/manager/graph_var_manager.h" +#include "graph/manager/graph_mem_manager.h" #undef protected #undef private @@ -42,7 +43,7 @@ using namespace testing; using namespace ge; using domi::GetContext; -class UtestTaskGeneratorTest : public testing::Test { +class UtestGraphMemAssigner : public testing::Test { public: ge::ComputeGraphPtr BuildGraphWithVar(int64_t session_id) { // init @@ -75,11 +76,15 @@ protected: void TearDown() {} }; -TEST_F(UtestMemoryAssignerTest, graph_memory_assign_continuous_input) { +TEST_F(UtestGraphMemAssigner, graph_memory_assign_fail_case) { ge::ComputeGraphPtr compute_graph = make_shared(""); GraphMemoryAssigner graph_mem_assigner(compute_graph); + MemoryOffset mem_offset(2, 10000); + graph_mem_assigner.memory_offset_.insert({2, mem_offset}); + VarManager::Instance(0)->graph_mem_max_size_ = 0; + map mem_type_to_offset = {}; - Status ret = ReAssignMemory(false, mem_type_to_offset); - EXPECT_EQ(ret, ge::FAILED); + Status ret = graph_mem_assigner.ReAssignMemory(false, mem_type_to_offset); + EXPECT_EQ(ret, ACL_ERROR_GE_MEMORY_ALLOCATION); } From 02c3500ceff3bcc491d536b944a0821635de0770 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Fri, 16 Jul 2021 19:28:02 +0800 Subject: [PATCH 27/33] delete model_cache_helper.cc --- ge/CMakeLists.txt | 2 - ge/common/helper/model_cache_helper.cc | 1721 -------------------- ge/common/helper/model_cache_helper.h | 123 -- ge/executor/ge_executor.cc | 1 + ge/generator/ge_generator.cc | 1 + ge/graph/manager/graph_manager.cc | 209 +-- ge/graph/manager/graph_manager.h | 10 - ge/graph/manager/graph_manager_utils.cc | 38 +- ge/graph/manager/graph_manager_utils.h | 2 - ge/graph/manager/graph_var_manager.cc | 44 - ge/graph/manager/graph_var_manager.h | 6 - ge/graph/manager/model_manager/event_manager.cc | 83 - ge/graph/manager/model_manager/event_manager.h | 98 -- ge/graph/manager/trans_var_data_utils.h | 1 - ge/graph/passes/global_step_insert_pass.cc | 1 - ge/init/gelib.h | 11 +- tests/ut/ge/CMakeLists.txt | 7 - .../ut/ge/graph/execute/model_executor_unittest.cc | 1 + tests/ut/ge/graph/graph_load_unittest.cc | 93 -- .../new_model_manager_data_inputer_unittest.cc | 64 - .../new_model_manager_davinci_model_unittest.cc | 1433 ---------------- .../new_model_manager_event_manager_unittest.cc | 117 -- .../load/new_model_manager_task_build_unittest.cc | 115 -- .../ut/ge/graph/load/output_net_output_unittest.cc | 300 ---- .../ut/ge/graph/manager/graph_manager_unittest.cc | 3 - 25 files changed, 23 insertions(+), 4461 deletions(-) delete mode 100755 ge/common/helper/model_cache_helper.cc delete mode 100755 ge/common/helper/model_cache_helper.h delete mode 100644 ge/graph/manager/model_manager/event_manager.cc delete mode 100644 ge/graph/manager/model_manager/event_manager.h delete mode 100644 tests/ut/ge/graph/graph_load_unittest.cc delete mode 100644 tests/ut/ge/graph/load/new_model_manager_data_inputer_unittest.cc delete mode 100644 tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc delete mode 100644 tests/ut/ge/graph/load/new_model_manager_event_manager_unittest.cc delete mode 100644 tests/ut/ge/graph/load/new_model_manager_task_build_unittest.cc delete mode 100644 tests/ut/ge/graph/load/output_net_output_unittest.cc diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 0236e8bd..f98297d8 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -262,7 +262,6 @@ set(COMPILER_SRC_LIST "common/dump/dump_op.cc" "common/ge/op_tiling_manager.cc" "common/ge/plugin_manager.cc" - "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" "engine_manager/dnnengine_manager.cc" "ge_local_engine/engine/host_cpu_engine.cc" @@ -300,7 +299,6 @@ set(COMPILER_SRC_LIST "graph/manager/graph_var_manager.cc" "graph/manager/host_mem_allocator.cc" "graph/manager/host_mem_manager.cc" - "graph/manager/model_manager/event_manager.cc" "graph/manager/rdma_pool_allocator.cc" "graph/manager/session_scope_mem_allocator.cc" "graph/manager/trans_var_data_utils.cc" diff --git a/ge/common/helper/model_cache_helper.cc b/ge/common/helper/model_cache_helper.cc deleted file mode 100755 index 0e6c6329..00000000 --- a/ge/common/helper/model_cache_helper.cc +++ /dev/null @@ -1,1721 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/helper/model_cache_helper.h" - -#include -#include -#include - -#include "common/model_parser/model_parser.h" -#include "framework/common/helper/model_helper.h" -#include "graph/detail/model_serialize_imp.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/tensor_utils.h" -#include "init/gelib.h" -#include "proto/ge_ir.pb.h" - -using namespace std; - -namespace { -const char *const kTbeKernelInfoStoreName = "AIcoreEngine"; -const char *const kGraphName = "temp_name"; -// Keys of json -const char *const kNodeNum = "nodeNum"; -const char *const kEdgeNum = "edgeNum"; -const char *const kGraphHash = "graphHash"; -const char *const kNodeHash = "nodeHash"; -const char *const kHash = "hash"; -const char *const kSessionId = "sessionId"; -const char *const kDeviceId = "deviceId"; -const char *const kJobId = "jobId"; -const char *const kGraphMemMaxSize = "graphMemMaxSize"; -const char *const kVarMemMaxSize = "varMemMaxSize"; -const char *const kVarMemLogicBase = "varMemLogicBase"; -const char *const kUseMaxMemSize = "useMaxMemSize"; -const char *const kMemResourceMap = "memResourceMap"; -const char *const kMemType = "memType"; -const char *const kTotalSize = "totalSize"; -const char *const kVarMemSize = "varMemSize"; -const char *const kVarResource = "varResource"; -const char *const kVarAddrMgrMap = "varAddrMgrMap"; -const char *const kName = "name"; -const char *const kAddress = "address"; -const char *const kOffset = "offset"; -const char *const kMemoryType = "memoryType"; -const char *const kTensorDesc = "tensorDesc"; -const char *const kDataType = "dataType"; -const char *const kShape = "shape"; -const char *const kLayout = "layout"; -const char *const kOriginDataType = "originDataType"; -const char *const kOriginShape = "originShape"; -const char *const kOriginLayout = "originLayout"; -const char *const kRealDimCnt = "realDimCnt"; -const char *const kCurVarTensorDescMap = "curVarTensorDescMap"; -const char *const kTransRoads = "transRoads"; -const char *const kTransRoad = "transRoad"; -const char *const kNodeType = "nodeType"; -const char *const kInputTensorDesc = "inputTensorDesc"; -const char *const kOutputTensorDesc = "outputTensorDesc"; -const char *const kChangedGraphId = "changedGraphId"; -const char *const kAllocatedGraphId = "allocatedGraphId"; -const char *const kGraphId = "graphId"; -const char *const kVarBroadcastInfo = "varBroadcastInfo"; -const char *const kBroadcastName = "broadcastName"; -const char *const kIdx = "idx"; -const char *const kInputOffset = "inputOffset"; -const char *const kInputSize = "inputSize"; -const char *const kOutputOffset = "outputOffset"; -const char *const kOutputSize = "outputSize"; -// Suffix of cache files -const char *const kBeforeVarManagerSuffix = "_before_build_var_manager.json"; -const char *const kAfterVarManagerSuffix = "_after_build_var_manager.json"; -const char *const kManifestSuffix = ".manifest"; -const char *const kOmSuffix = ".om"; -} // namespace - -namespace ge { -map ModelCacheHelper::graph_id_run_times_; -ModelCacheHelper::ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph) - : session_id_(session_id), - graph_id_(graph_id), - compute_graph_(compute_graph), - is_cache_path_valid_for_output(false) { - if (graph_id_run_times_.count(graph_id) == 0) { - graph_id_run_times_[graph_id] = 1; - } else { - graph_id_run_times_[graph_id] = graph_id_run_times_[graph_id] + 1; - } - for (const auto &node : compute_graph_->GetDirectNode()) { - bool is_variable = (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || - (node->GetType() == VARHANDLEOP) || (node->GetType() == CONSTANTOP); - if (!is_variable) { - continue; - } - var_names_.insert(node->GetName()); - } - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr != nullptr && instance_ptr->IsIncreBuild()) { - std::string cache_path = instance_ptr->GetIncreBuildCachePath(); - GELOGD("Incre build path conf: %s", cache_path.c_str()); - string fake_file_path = cache_path + to_string(graph_id_) + kManifestSuffix; - if (CheckOutputPathValid(fake_file_path)) { - is_cache_path_valid_for_output = true; - } else { - GELOGW("Invalid cache path for output."); - } - std::string real_cache_path = RealPath(cache_path.c_str()); - if (real_cache_path.empty()) { - GELOGW("Invalid incre build cache path conf: %s", cache_path.c_str()); - return; - } - cache_path_ = real_cache_path + '/'; - GELOGD("Try to use incre build cache path: %s", cache_path_.c_str()); - } -} - -ModelCacheHelper::~ModelCacheHelper() { var_names_.clear(); } - -bool ModelCacheHelper::IsModelCacheHit() const { - CacheInfo cache_info; - if (GetCacheInfo(cache_info) != SUCCESS) { - GELOGI("Get cache info of graph id[%u] failed.", graph_id_); - return false; - } - // Check number of nodes and edges first. - if (cache_info.node_num != compute_graph_->GetDirectNodesSize()) { - GELOGI("Graph id[%u] cache miss: the node number of the graph does not match the cache info.", graph_id_); - return false; - } - size_t edge_num = 0; - for (const auto &node : compute_graph_->GetDirectNode()) { - for (const auto &anchor : node->GetAllInAnchors()) { - edge_num += anchor->GetPeerAnchors().size(); - } - } - if (cache_info.edge_num != edge_num) { - GELOGI("Graph id[%u] cache miss: the edge number of the graph does not match the cache info.", graph_id_); - return false; - } - size_t compute_graph_hash; - auto ret = GetComputeGraphHash(compute_graph_hash); - if (ret != SUCCESS || cache_info.graph_hash != compute_graph_hash) { - GELOGI("Graph id[%u] cache miss: the hash code of the graph does not match the cache info.", graph_id_); - return false; - } - if (!IsNodeHashSameAsCache(cache_info.nodes_hash)) { - GELOGI("Graph id[%u] cache miss: the hash code of node does not match the cache info.", graph_id_); - return false; - } - - string var_manager_cache = - to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kBeforeVarManagerSuffix; - Json var_manager_json; - if (LoadJsonFromFile(var_manager_cache, var_manager_json) != SUCCESS) { - GELOGW("Fail to load json from cache file: %s", var_manager_cache.c_str()); - return false; - } - if (!IsVarManagerSameAsCache(var_manager_json)) { - GELOGI("Graph id[%u] cache miss: the VarManager does not match the cache info.", graph_id_); - return false; - } - GELOGI("Graph id[%u] cache hit.", graph_id_); - return true; -} - -Status ModelCacheHelper::RefreshComputeGraph(const ComputeGraphPtr &compute_graph) { - if (compute_graph->IsValid()) { - compute_graph_ = compute_graph; - var_names_.clear(); - for (const auto &node : compute_graph_->GetDirectNode()) { - bool is_variable = (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || - (node->GetType() == VARHANDLEOP) || (node->GetType() == CONSTANTOP); - if (!is_variable) { - continue; - } - var_names_.insert(node->GetName()); - } - return SUCCESS; - } else { - GELOGW("Invalid compute graph."); - return FAILED; - } -} - -Status ModelCacheHelper::ClearCache(uint32_t graph_id) const { - if (!is_cache_path_valid_for_output) { - GELOGW("Invalid cache path."); - return SUCCESS; - } - string manifest_file = cache_path_ + to_string(graph_id) + kManifestSuffix; - string manifest_file_path = RealPath(manifest_file.c_str()); - int ret; - if (!manifest_file_path.empty()) { - ret = remove(manifest_file_path.c_str()); - // If remove file failed, print the warning log - if (ret != 0) { - GELOGW("Clear cache [%s] failed.", manifest_file_path.c_str()); - } - } - string before_var_manager_file = cache_path_ + to_string(graph_id) + kManifestSuffix; - string before_var_manager_file_path = RealPath(before_var_manager_file.c_str()); - if (!before_var_manager_file_path.empty()) { - ret = remove(before_var_manager_file_path.c_str()); - if (ret != 0) { - GELOGW("Clear cache [%s] failed.", before_var_manager_file_path.c_str()); - } - } - string after_var_manager_file = cache_path_ + to_string(graph_id) + kManifestSuffix; - string after_var_manager_file_path = RealPath(after_var_manager_file.c_str()); - if (!after_var_manager_file_path.empty()) { - ret = remove(after_var_manager_file_path.c_str()); - if (ret != 0) { - GELOGW("Clear cache [%s] failed.", after_var_manager_file_path.c_str()); - } - } - string om_file = cache_path_ + to_string(graph_id) + kManifestSuffix; - string om_file_path = RealPath(om_file.c_str()); - if (!om_file_path.empty()) { - ret = remove(om_file_path.c_str()); - if (ret != 0) { - GELOGW("Clear cache [%s] failed.", om_file_path.c_str()); - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverVarManagerFromCache() const { - string var_manager_cache = - to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kAfterVarManagerSuffix; - Json var_manager_json; - if (LoadJsonFromFile(var_manager_cache, var_manager_json) != SUCCESS) { - GELOGW("Fail to load json from cache file: %s", var_manager_cache.c_str()); - return FAILED; - } - - Json mem_resource_json = move(var_manager_json[kMemResourceMap]); - auto ret = RecoverMemResource(mem_resource_json); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[MemResource]"); - return FAILED; - } - Json var_resource_json = move(var_manager_json[kVarResource]); - ret = RecoverAllocatedGraphId(var_resource_json[kAllocatedGraphId]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[AllocatedGraphId]"); - return FAILED; - } - ret = RecoverChangedGraphId(var_resource_json[kChangedGraphId]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[ChangedGraphId]"); - return FAILED; - } - ret = RecoverBroadcastInfo(var_resource_json[kVarBroadcastInfo]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[VarBroadcastInfo]"); - return FAILED; - } - ret = RecoverVarAddrAndTensorDesc(var_resource_json[kVarAddrMgrMap]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[VarAddrMgrMap & CurVarTensorDesc]"); - return FAILED; - } - ret = RecoverTransRoads(var_resource_json[kTransRoads]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[TransRoads]"); - return FAILED; - } - GELOGI("Recover VarManager from cache[%s] success.", cache_path_.c_str()); - return SUCCESS; -} - -Status ModelCacheHelper::GetNodesNeedRecompile(ComputeGraphPtr &graph, vector &nodes) { - std::shared_ptr instance = ge::GELib::GetInstance(); - if (instance == nullptr || !instance->InitFlag()) { - GELOGW("RecompileNodes failed."); - return ge::GE_CLI_GE_NOT_INITIALIZED; - } - // Collect aicore ops for recompile - for (auto &node : graph->GetDirectNode()) { - if (node == nullptr) { - continue; - } - auto op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - continue; - } - // Get op kernel lib name - string kernel_lib_name = op_desc->GetOpKernelLibName(); - if (kernel_lib_name.empty()) { - // reset op kernel lib - (void)instance->DNNEngineManagerObj().GetDNNEngineName(node); - kernel_lib_name = op_desc->GetOpKernelLibName(); - if (kernel_lib_name.empty()) { - GELOGW("Get node:%s, type:%s kernel lib failed.", node->GetName().c_str(), op_desc->GetType().c_str()); - continue; - } - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecompileNodes(GeModelPtr &ge_model) { - std::shared_ptr instance = ge::GELib::GetInstance(); - if (instance == nullptr || !instance->InitFlag()) { - GELOGW("RecompileNodes failed."); - return ge::GE_CLI_GE_NOT_INITIALIZED; - } - // Get aicore ops kernel info store. - OpsKernelInfoStorePtr kernel_info = instance->OpsKernelManagerObj().GetOpsKernelInfoStore(kTbeKernelInfoStoreName); - if (kernel_info == nullptr) { - GELOGW("Get %s ops kernel info store failed", kTbeKernelInfoStoreName); - return INTERNAL_ERROR; - } - - auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); - vector node_vec; - auto ret = GetNodesNeedRecompile(compute_graph, node_vec); - GE_CHK_BOOL_EXEC_WARN(ret == ge::SUCCESS, return ret, "Get nodes need recompiling failed"); - // Recompile aicore ops - ret = kernel_info->CompileOp(node_vec); - GE_CHK_BOOL_EXEC_WARN(ret == ge::SUCCESS, return ret, "Recompile op failed"); - const TBEKernelStore &tbekernel_store = ge_model->GetTBEKernelStore(); - TBEKernelStore tbe_kernel_store; - for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { - auto node_op_desc = n->GetOpDesc(); - GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); - TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); - if (tbe_kernel == nullptr) { - // Load tbe kernel from tbe_kernel_store to op if op was not recompiled - auto op_desc = n->GetOpDesc(); - tbekernel_store.LoadTBEKernelBinToOpDesc(op_desc); - GELOGD("LoadOmModelFromCache: Load tbe kernel bin to op desc[%s].", op_desc->GetName().c_str()); - } - tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); - GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); - // Refresh tbe kernel in tbe_kernel_store - tbe_kernel_store.AddTBEKernel(tbe_kernel); - GELOGD("Add tbe kernel bin %s", tbe_kernel->GetName().c_str()); - } - GE_CHK_BOOL_EXEC_WARN(tbe_kernel_store.Build(), return FAILED, "TBE Kernels store build failed!"); - ge_model->SetTBEKernelStore(tbe_kernel_store); - return SUCCESS; -} - -Status ModelCacheHelper::GetNodesHash(map &hash_map) const { - vector nodes; - GraphUtils::TopologicalSortingByName(compute_graph_, nodes); - ModelSerializeImp model_serialize_imp; - std::hash node_hash; - for (const auto &node : nodes) { - if (node == nullptr) { - continue; - } - proto::OpDef op_def; - bool is_framework_op = (node->GetType() == FRAMEWORKOP); - int32_t framework_type = 0; - if (is_framework_op) { - AttrUtils::GetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, framework_type); - AttrUtils::SetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, 0); - } - bool ret = model_serialize_imp.SerializeNode(node, &op_def, is_framework_op); - op_def.set_id(0); // Id of op is not stable because of parallel parsing - // Clear weights attr in constant. - auto attr = op_def.mutable_attr(); - if (op_def.type() == CONSTANT || op_def.type() == CONSTANTOP) { - attr->erase(ATTR_NAME_WEIGHTS); - } - if (is_framework_op) { - AttrUtils::SetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, framework_type); - } - if (!ret) { - GELOGW("Fail to serialize node[%s].", node->GetName().c_str()); - return INTERNAL_ERROR; - } - string prototxt; - ret = google::protobuf::TextFormat::PrintToString(op_def, &prototxt); - if (!ret) { - GELOGW("Print OpDef to string failed."); - hash_map.clear(); - return INTERNAL_ERROR; - } - size_t hash_code = node_hash(prototxt); - hash_map[node->GetName()] = hash_code; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetComputeGraphHash(size_t &hash) const { - proto::GraphDef graph_proto; - ModelSerializeImp model_serialize_imp; - // The name of compute graph may be generated randomly, so replace it temporarily. - const string origin_name = compute_graph_->GetName(); - compute_graph_->SetName(kGraphName); - bool serialize_ret = model_serialize_imp.SerializeGraph(compute_graph_, &graph_proto); - graph_proto.clear_op(); - if (!serialize_ret) { - GELOGW("Serialize graph failed."); - hash = 0; - return INTERNAL_ERROR; - } - compute_graph_->SetName(origin_name); - // Generate proto text of GraphDef - string prototxt; - bool print_ret = google::protobuf::TextFormat::PrintToString(graph_proto, &prototxt); - if (!print_ret) { - GELOGW("Print GraphDef to string failed."); - hash = 0; - return INTERNAL_ERROR; - } - // Get the hash code of proto text - std::hash graph_hash; - hash = graph_hash(prototxt); - return SUCCESS; -} - -Status ModelCacheHelper::SaveJsonToFile(const string &file_name, const Json &json) const { - if (!is_cache_path_valid_for_output) { - GELOGW("Invalid cache path."); - return PARAM_INVALID; - } - // Check whether the manifest exists, if not, create it. - string real_path = RealPath(cache_path_.c_str()); - if (real_path.empty()) { - GELOGW("File path is invalid. please check cache path: %s", cache_path_.c_str()); - return FAILED; - } - const string path = cache_path_ + file_name; - const int FILE_AUTHORITY = 0600; - int fd = mmOpen2(path.c_str(), M_WRONLY | M_CREAT | O_TRUNC, FILE_AUTHORITY); - if (fd < 0) { - GELOGW("Fail to open the file:%s. errmsg:%s", path.c_str(), strerror(errno)); - return INTERNAL_ERROR; - } - if (mmClose(fd) != 0) { - GELOGW("Fail to close the file:%s. errmsg:%s", path.c_str(), strerror(errno)); - return INTERNAL_ERROR; - } - - // Write json into cache file - ofstream ofs; - ofs.open(path); - if (!ofs.is_open()) { - GELOGW("Fail to open the file: %s.", path.c_str()); - return INTERNAL_ERROR; - } - ofs << json << std::endl; - ofs.close(); - return SUCCESS; -} - -Status ModelCacheHelper::LoadJsonFromFile(const string &file_name, Json &json) const { - if (!json.is_null()) { - GELOGW("Input param json type should be null."); - return PARAM_INVALID; - } - string real_path = RealPath(cache_path_.c_str()); - if (real_path.empty()) { - GELOGW("File path is invalid. please check cache path: %s", cache_path_.c_str()); - return FAILED; - } - const string path = cache_path_ + file_name; - if (!CheckInputPathValid(path)) { - GELOGW("Invalid cache path for input:%s.", path.c_str()); - return FAILED; - } - string cache_real_path = RealPath(path.c_str()); - if (cache_real_path.empty()) { - GELOGI("File[%s] is not found.", path.c_str()); - return FAILED; - } - // Read json from cache file - ifstream ifs; - ifs.open(path); - if (!ifs.is_open()) { - GELOGW("Fail to open the file: %s.", path.c_str()); - return INTERNAL_ERROR; - } - try { - ifs >> json; - } catch (nlohmann::detail::parse_error e) { - GELOGW("Fail to load json from file, json throw an error:%s.", e.what()); - return INTERNAL_ERROR; - } catch (nlohmann::detail::invalid_iterator e) { - GELOGW("Fail to load json from file, json throw an error:%s.", e.what()); - return INTERNAL_ERROR; - } catch (nlohmann::detail::type_error e) { - GELOGW("Fail to load json from file, json throw an error:%s.", e.what()); - return INTERNAL_ERROR; - } catch (nlohmann::detail::out_of_range e) { - GELOGW("Fail to load json from file, json throw an error:%s.", e.what()); - return INTERNAL_ERROR; - } catch (nlohmann::detail::other_error e) { - GELOGW("Fail to load json from file, json throw an error:%s.", e.what()); - return INTERNAL_ERROR; - } - - if (!json.is_object()) { - GELOGW("Fail to load the json file: %s.", path.c_str()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::SaveCacheInfoToCache() const { - // Generate cache json - // example: {"edgeNum":6,"nodeNum":7,"graphCache":134714827475991356} - Json cache_json; - try { - cache_json[kNodeNum] = compute_graph_->GetDirectNodesSize(); - size_t edge_num = 0; - for (const auto &node : compute_graph_->GetDirectNode()) { - for (const auto &anchor : node->GetAllInAnchors()) { - edge_num += anchor->GetPeerAnchors().size(); - } - } - cache_json[kEdgeNum] = edge_num; - size_t hash = 0; - auto ret = GetComputeGraphHash(hash); - if (ret != SUCCESS) { - GELOGW("Error occur when generate graph hash code."); - return ret; - } - cache_json[kGraphHash] = hash; - Json nodes_hash_json; - ret = GetNodesHashMapJson(nodes_hash_json); - if (ret != SUCCESS) { - GELOGW("Error occur when generate nodes hash code."); - return ret; - } - cache_json[kNodeHash] = nodes_hash_json; - } catch (const std::exception &e) { - GELOGW("Fail to generate cache info json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - string cache_manifest = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kManifestSuffix; - - auto ret = SaveJsonToFile(cache_manifest, cache_json); - if (ret != SUCCESS) { - GELOGW("Fail to save cache info to json file, path: %s.", cache_path_.c_str()); - return ret; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetCacheInfo(CacheInfo &cache_info) const { - string cache_manifest = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kManifestSuffix; - Json cache_json; - if (LoadJsonFromFile(cache_manifest, cache_json) != SUCCESS) { - GELOGW("Fail to load json from cache file: %s", cache_manifest.c_str()); - return INTERNAL_ERROR; - } - if (!cache_json.is_object()) { - GELOGW("Manifest should be a json object"); - return INTERNAL_ERROR; - } - try { - cache_info.node_num = cache_json[kNodeNum]; - cache_info.edge_num = cache_json[kEdgeNum]; - cache_info.graph_hash = cache_json[kGraphHash]; - Json nodes_hash_json = cache_json[kNodeHash]; - if (!(nodes_hash_json.is_null() || nodes_hash_json.is_array())) { - GELOGW("Nodes hash in cache should be null or array."); - return FAILED; - } - for (const auto &iter : nodes_hash_json) { - cache_info.nodes_hash[iter[kName].get()] = iter[kHash].get(); - } - } catch (const std::exception &e) { - GELOGW("Fail to get info from json file. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -bool ModelCacheHelper::IsAllocatedGraphIdSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare allocated graph id info between json and VarManager - std::map allocated_graph_id; - auto ret = ParseAllocatedGraphIdFromJson(json, allocated_graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to parse AllocatedGraphId from Json."); - return false; - } - for (const auto &iter : allocated_graph_id) { - uint32_t graph_id = 0; - ret = VarManager::Instance(session_id_)->GetAllocatedGraphId(iter.first, graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to find allocated graph id of var[%s].", iter.first.c_str()); - return false; - } - if (graph_id != iter.second) { - GELOGW("The allocated graph id of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsNodeHashSameAsCache(const map &hash_map) const { - map cur_hash_map; - GetNodesHash(cur_hash_map); - if (hash_map.size() != cur_hash_map.size()) { - GELOGI("The number of hash code is different from cache info."); - return false; - } - for (const auto &iter : cur_hash_map) { - if (hash_map.count(iter.first) == 0) { - GELOGI("Node[%s] is not found in cache info.", iter.first.c_str()); - return false; - } - if (hash_map.at(iter.first) != iter.second) { - GELOGI("The hash code of node[%s] is different from cache info.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsMemResourceSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare var mem size info between json and VarManager - std::map var_mem_size; - auto ret = ParseMemResourceFromJson(json, var_mem_size); - if (ret != SUCCESS) { - GELOGW("Fail to parse MemResource from Json."); - return false; - } - for (const auto &iter : var_mem_size) { - int64_t mem_size = VarManager::Instance(session_id_)->GetVarMemSize(iter.first); - if (mem_size != iter.second) { - GELOGW("The var mem size of memory_type[%u] in cache is different from VarManager.", iter.first); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsChangedGraphIdSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare variable changed graph id info between json and VarManager - std::map changed_graph_id; - auto ret = ParseChangedGraphIdFromJson(json, changed_graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to parse ChangedGraphId from Json."); - return false; - } - for (const auto &iter : changed_graph_id) { - uint32_t graph_id = 0; - ret = VarManager::Instance(session_id_)->GetChangedGraphId(iter.first, graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to find changed graph id of var[%s].", iter.first.c_str()); - return false; - } - if (graph_id != iter.second) { - GELOGW("The changed graph id of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsCurVarTensorDescSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare variable tensor desc info between json and VarManager - std::unordered_map cur_var_tensor_desc; - auto ret = ParseCurVarTensorDescMapFromJson(json, cur_var_tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to parse CurVarTensorDesc from Json."); - return false; - } - for (const auto &iter : cur_var_tensor_desc) { - GeTensorDesc tensor_desc; - ret = VarManager::Instance(session_id_)->GetCurVarDesc(iter.first, tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to find tensor desc of var[%s].", iter.first.c_str()); - return false; - } - uint32_t l_real_dim_cnt = 0; - uint32_t r_real_dim_cnt = 0; - TensorUtils::GetRealDimCnt(tensor_desc, l_real_dim_cnt); - TensorUtils::GetRealDimCnt(iter.second, r_real_dim_cnt); - if ((tensor_desc.GetDataType() != iter.second.GetDataType()) || - (tensor_desc.GetOriginDataType() != iter.second.GetOriginDataType()) || - (tensor_desc.GetFormat() != iter.second.GetFormat()) || - (tensor_desc.GetOriginFormat() != iter.second.GetOriginFormat()) || - (tensor_desc.GetShape().ToString() != iter.second.GetShape().ToString()) || - (tensor_desc.GetOriginShape().ToString() != iter.second.GetOriginShape().ToString()) || - (l_real_dim_cnt != r_real_dim_cnt)) { - GELOGW("The var tensor desc of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsVarAddrMgrMapSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare variable address info between json and VarManager - std::vector> var_addr_mgr_vector; - std::set var_offset_set; - auto ret = ParseVarAddrMgrMapFromJson(json, var_addr_mgr_vector, var_offset_set); - if (ret != SUCCESS) { - GELOGW("Fail to parse VarAddrMgrMap from Json."); - return false; - } - for (const auto &iter : var_addr_mgr_vector) { - uint8_t *dev_ptr = nullptr; - rtMemType_t memory_type; - ret = VarManager::Instance(session_id_)->GetVarAddr(iter.first, iter.second.tensor_desc, &dev_ptr, memory_type); - if (ret != SUCCESS) { - GELOGW("Fail to find tensor desc of var[%s].", iter.first.c_str()); - return false; - } - // Compare memory type and logic address - if (iter.second.memory_type != memory_type || iter.second.address != dev_ptr) { - GELOGW("The VarAddrMgr of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsBroadcastInfoSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare broadcast info between json and VarManager - std::unordered_map var_broadcast_info; - auto ret = ParseBroadcastInfoFromJson(json, var_broadcast_info); - if (ret != SUCCESS) { - GELOGW("Fail to parse BroadcastInfo from Json."); - return false; - } - for (const auto &iter : var_broadcast_info) { - VarBroadCastInfo broadcast_info; - if (VarManager::Instance(session_id_)->GetBroadCastInfo(graph_id_, iter.first, broadcast_info) != SUCCESS) { - GELOGW("Fail to find broadcast info of var[%s].", iter.first.c_str()); - return false; - } - if (iter.second.var_name != broadcast_info.var_name || iter.second.idx != broadcast_info.idx || - iter.second.input_size != broadcast_info.input_size || - iter.second.input_offset != broadcast_info.input_offset || - iter.second.output_size != broadcast_info.output_size || - iter.second.output_offset != broadcast_info.output_offset) { - GELOGW("The BroadcastInfo of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsTransRoadsSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare trans road between json and VarManager - std::unordered_map> trans_roads; - auto ret = ParseTransRoadsFromJson(json, trans_roads); - if (ret != SUCCESS) { - GELOGW("Fail to parse TransRoads from Json."); - return false; - } - for (const auto &iter : trans_roads) { - VarTransRoad *trans_road; - trans_road = VarManager::Instance(session_id_)->GetTransRoad(iter.first); - if (trans_road == nullptr) { - GELOGW("Fail to find trans road of var[%s].", iter.first.c_str()); - return false; - } - if (trans_road->size() != iter.second.size()) { - GELOGW("The TransRoad of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - // Compare every trans node in trans road. - for (size_t idx = 0; idx < trans_road->size(); idx += 1) { - if (!(trans_road->at(idx).node_type == iter.second.at(idx).node_type && - trans_road->at(idx).input == iter.second.at(idx).input && - trans_road->at(idx).output == iter.second.at(idx).output)) { - GELOGW("The TransRoad of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - } - return true; -} - -bool ModelCacheHelper::IsVarManagerParamSameAsCache(Json &json) const { - if (!json.is_object()) { - GELOGW("Input param json type should be object."); - return false; - } - try { - if (json[kSessionId].get() != session_id_) { - GELOGW("Check VarManager cache failed.[sessionId]"); - return false; - } - if (json[kDeviceId].get() != VarManager::Instance(session_id_)->DeviceId()) { - GELOGW("Check VarManager cache failed.[deviceId]"); - return false; - } - if (json[kJobId].get() != VarManager::Instance(session_id_)->JobId()) { - GELOGW("Check VarManager cache failed.[jobId]"); - return false; - } - if (json[kGraphMemMaxSize].get() != VarManager::Instance(session_id_)->GetGraphMemoryMaxSize()) { - GELOGW("Check VarManager cache failed.[graphMemMaxSize]"); - return false; - } - if (json[kVarMemMaxSize].get() != VarManager::Instance(session_id_)->GetVarMemMaxSize()) { - GELOGW("Check VarManager cache failed.[varMemMaxSize]"); - return false; - } - if (json[kVarMemLogicBase].get() != VarManager::Instance(session_id_)->GetVarMemLogicBase()) { - GELOGW("Check VarManager cache failed.[varMemLogicBase]"); - return false; - } - if (json[kUseMaxMemSize].get() != VarManager::Instance(session_id_)->GetUseMaxMemorySize()) { - GELOGW("Check VarManager cache failed.[useMaxMemSize]"); - return false; - } - } catch (const std::exception &e) { - GELOGW("Fail to check VarManager json. Error message: %s", e.what()); - return false; - } - return true; -} - -bool ModelCacheHelper::IsVarManagerSameAsCache(Json &json) const { - if (!json.is_object()) { - GELOGW("Input param json type should be object."); - return false; - } - try { - if (!IsVarManagerParamSameAsCache(json)) { - GELOGW("Check VarManager cache failed.[Param]"); - return false; - } - Json mem_resource_json = move(json[kMemResourceMap]); - auto ret = IsMemResourceSameAsCache(mem_resource_json); - if (!ret) { - GELOGW("Check VarManager cache failed.[MemResource]"); - return false; - } - Json var_resource_json = move(json[kVarResource]); - ret = IsAllocatedGraphIdSameAsCache(var_resource_json[kAllocatedGraphId]); - if (!ret) { - GELOGW("Check VarManager cache failed.[AllocatedGraphId]"); - return false; - } - ret = IsChangedGraphIdSameAsCache(var_resource_json[kChangedGraphId]); - if (!ret) { - GELOGW("Check VarManager cache failed.[ChangedGraphId]"); - return false; - } - ret = IsBroadcastInfoSameAsCache(var_resource_json[kVarBroadcastInfo]); - if (!ret) { - GELOGW("Check VarManager cache failed.[VarBroadcastInfo]"); - return false; - } - ret = IsCurVarTensorDescSameAsCache(var_resource_json[kCurVarTensorDescMap]); - if (!ret) { - GELOGW("Check VarManager cache failed.[CurVarTensorDesc]"); - return false; - } - ret = IsVarAddrMgrMapSameAsCache(var_resource_json[kVarAddrMgrMap]); - if (!ret) { - GELOGW("Check VarManager cache failed.[VarAddrMgrMap]"); - return false; - } - ret = IsTransRoadsSameAsCache(var_resource_json[kTransRoads]); - if (!ret) { - GELOGW("Check VarManager cache failed.[TransRoads]"); - return false; - } - } catch (const std::exception &e) { - GELOGW("Fail to check VarManager json. Error message: %s", e.what()); - return false; - } - return true; -} - -Status ModelCacheHelper::RecoverMemResource(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::map var_mem_size; - auto ret = ParseMemResourceFromJson(json, var_mem_size); - if (ret != SUCCESS) { - GELOGW("Fail to parse MemResource from Json."); - return ret; - } - for (const auto &iter : var_mem_size) { - ret = VarManager::Instance(session_id_)->UpdateVarMemSize(iter.first, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to recover var mem size."); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverAllocatedGraphId(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::map allocated_graph_id; - auto ret = ParseAllocatedGraphIdFromJson(json, allocated_graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to parse AllocatedGraphId from Json."); - return ret; - } - for (const auto &iter : allocated_graph_id) { - ret = VarManager::Instance(session_id_)->SetAllocatedGraphId(iter.first, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to recover allocated graph id."); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverChangedGraphId(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::map changed_graph_id; - auto ret = ParseChangedGraphIdFromJson(json, changed_graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to parse AllocatedGraphId from Json."); - return ret; - } - for (const auto &iter : changed_graph_id) { - ret = VarManager::Instance(session_id_)->SetChangedGraphId(iter.first, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to recover changed graph id."); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverVarAddrAndTensorDesc(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::vector> var_addr_mgr_vector; - std::set var_offset_set; - auto ret = ParseVarAddrMgrMapFromJson(json, var_addr_mgr_vector, var_offset_set); - if (ret != SUCCESS) { - GELOGW("Fail to parse VarAddrMgrMap from Json."); - return ret; - } - for (const auto &iter : var_addr_mgr_vector) { - const VarAddrMgr &tensor_addr_mgr = iter.second; - const bool var_exist = VarManager::Instance(session_id_)->IsVarExist(iter.first, tensor_addr_mgr.tensor_desc); - // SaveVarVddr if var does not exist, the logic address will be recorded by VarManager - if (!var_exist) { - auto logic_address = reinterpret_cast(reinterpret_cast(tensor_addr_mgr.address)); - auto offset = (tensor_addr_mgr.offset); - // Check logic address and offset - if (logic_address - offset != VarManager::Instance(session_id_)->GetVarMemLogicBase()) { - GELOGW("Check logic_address[%lu] and offset [%lu] of %s failed, var mem logic base is %lu, abandon", - logic_address, offset, iter.first.c_str(), VarManager::Instance(session_id_)->GetVarMemLogicBase()); - return PARAM_INVALID; - } - // Offset is needed by SaveVarVddr instead of logic address - ret = VarManager::Instance(session_id_)->SaveVarAddr(iter.first, tensor_addr_mgr.tensor_desc, - reinterpret_cast(reinterpret_cast(offset)), - tensor_addr_mgr.memory_type); - if (ret != SUCCESS) { - GELOGW("Fail to recover VarAddr or TensorDesc of var[%s].", iter.first.c_str()); - return ret; - } - } - // SetVarAddr to update cur_var_tensor_desc_map_ - ret = VarManager::Instance(session_id_) - ->SetVarAddr(iter.first, tensor_addr_mgr.tensor_desc, tensor_addr_mgr.address, tensor_addr_mgr.memory_type); - if (ret != SUCCESS) { - GELOGW("Fail to recover VarAddr or TensorDesc desc of var[%s].", iter.first.c_str()); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverBroadcastInfo(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::unordered_map var_broadcast_info; - auto ret = ParseBroadcastInfoFromJson(json, var_broadcast_info); - if (ret != SUCCESS) { - GELOGW("Fail to parse BroadcastInfo from Json."); - return ret; - } - for (const auto &iter : var_broadcast_info) { - VarBroadCastInfo broadcast_info; - ret = VarManager::Instance(session_id_)->SaveBroadCastInfo(graph_id_, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to recover broadcast info of var[%s].", iter.first.c_str()); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverTransRoads(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::unordered_map> trans_roads; - auto ret = ParseTransRoadsFromJson(json, trans_roads); - if (ret != SUCCESS) { - GELOGW("Fail to parse TransRoads from Json."); - return ret; - } - for (const auto &iter : trans_roads) { - ret = VarManager::Instance(session_id_)->SetTransRoad(iter.first, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to find trans road of var[%s].", iter.first.c_str()); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json) { - if (!(json.is_null() || json.is_object())) { - GELOGW("Input param json type should be null or object."); - return PARAM_INVALID; - } - try { - json[kDataType] = static_cast(ge_tensor_desc.GetDataType()); - json[kOriginDataType] = static_cast(ge_tensor_desc.GetOriginDataType()); - json[kLayout] = static_cast(ge_tensor_desc.GetFormat()); - json[kOriginLayout] = static_cast(ge_tensor_desc.GetOriginFormat()); - json[kShape] = ge_tensor_desc.GetShape().GetDims(); - json[kOriginShape] = ge_tensor_desc.GetOriginShape().GetDims(); - uint32_t real_dim_cnt = 0; - (void)TensorUtils::GetRealDimCnt(ge_tensor_desc, real_dim_cnt); // [No need to check value] - json[kRealDimCnt] = real_dim_cnt; - } catch (const std::exception &e) { - GELOGW("Fail to trans GeTensorDesc to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::JsonToTensorDesc(const Json &json, ge::GeTensorDesc &ge_tensor_desc) { - if (!json.is_object()) { - GELOGW("Input param json type should be object."); - return PARAM_INVALID; - } - try { - ge_tensor_desc.SetDataType(static_cast(json[kDataType].get())); - ge_tensor_desc.SetOriginDataType(static_cast(json[kOriginDataType].get())); - ge_tensor_desc.SetFormat(static_cast(json[kLayout].get())); - ge_tensor_desc.SetOriginFormat(static_cast(json[kOriginLayout].get())); - GeShape shape(json[kShape].get>()); - ge_tensor_desc.SetShape(shape); - GeShape origin_shape(json[kOriginShape].get>()); - ge_tensor_desc.SetOriginShape(origin_shape); - auto real_dim_cnt = json[kRealDimCnt].get(); - (void)TensorUtils::SetRealDimCnt(ge_tensor_desc, real_dim_cnt); // [No need to check value] - } catch (const std::exception &e) { - GELOGW("Fail to trans Json to GeTensorDesc. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetNodesHashMapJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - map hash_map; - GetNodesHash(hash_map); - for (const auto &iter : hash_map) { - Json node_hash_json; - try { - node_hash_json[kName] = iter.first; - node_hash_json[kHash] = iter.second; - json.emplace_back(move(node_hash_json)); - } catch (const std::exception &e) { - GELOGW("Fail to trans node cache to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::GetMemResourceMap(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - const auto total_size = VarManager::Instance(session_id_)->GetVarMemMaxSize(); - const auto var_mem_size = VarManager::Instance(session_id_)->GetVarMemSize(RT_MEMORY_HBM); - Json mem_resource_json; - try { - mem_resource_json[kMemType] = RT_MEMORY_HBM; - mem_resource_json[kTotalSize] = total_size; - mem_resource_json[kVarMemSize] = var_mem_size; - json.emplace_back(move(mem_resource_json)); - } catch (const std::exception &e) { - GELOGW("Fail to trans MemResourceMap to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetVarAddrMgrMapJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::unordered_map var_addr_mgr_map; - VarManager::Instance(session_id_)->GetAllVarAddrMgr(var_addr_mgr_map); - try { - for (const auto &iter : var_addr_mgr_map) { - Json var_addr_json; - string name; - GetVarNameFromVarKey(iter.first, iter.second.tensor_desc, name); - var_addr_json[kName] = name; - var_addr_json[kAddress] = static_cast(reinterpret_cast(iter.second.address)); - var_addr_json[kMemoryType] = iter.second.memory_type; - var_addr_json[kOffset] = iter.second.offset; - - // Copy tensor desc to json. - Json tensor_desc_json; - auto ret = TensorDescToJson(iter.second.tensor_desc, tensor_desc_json); - if (ret != SUCCESS) { - GELOGW("Fail to trans tensor desc to json."); - return INTERNAL_ERROR; - } - var_addr_json[kTensorDesc] = move(tensor_desc_json); - - json.emplace_back(move(var_addr_json)); - } - } catch (const std::exception &e) { - GELOGW("Fail to trans VarAddrMgrMap to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetCurVarTensorDescMapJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - try { - for (const auto &name : var_names_) { - Json cur_tensor_desc_json; - GeTensorDesc tensor_desc; - auto ret = VarManager::Instance(session_id_)->GetCurVarDesc(name, tensor_desc); - if (ret != SUCCESS) { - GELOGI("Get variable[%s] current tensor desc failed. It will be skipped.", name.c_str()); - continue; - } - cur_tensor_desc_json[kName] = name; - - Json tensor_desc_json; - ret = TensorDescToJson(tensor_desc, tensor_desc_json); - if (ret != SUCCESS) { - GELOGW("Fail to trans tensor desc to json."); - return INTERNAL_ERROR; - } - cur_tensor_desc_json[kTensorDesc] = move(tensor_desc_json); - json.emplace_back(move(cur_tensor_desc_json)); - } - } catch (const std::exception &e) { - GELOGW("Fail to trans CurVarTensorDescMap to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetTransRoadsJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - try { - for (const auto &name : var_names_) { - auto trans_road = VarManager::Instance(session_id_)->GetTransRoad(name); - if (trans_road == nullptr) { - continue; - } - // Json object, variable name and trans road - Json trans_road_map_json; - trans_road_map_json[kName] = name; - - Json trans_road_json; - Status ret; - // Add nodes' info to json - for (const auto &trans_node_info : *trans_road) { - Json trans_node_info_json; - trans_node_info_json[kNodeType] = trans_node_info.node_type; - Json input_tensor_desc_json; - ret = TensorDescToJson(trans_node_info.input, input_tensor_desc_json); - if (ret != SUCCESS) { - GELOGW("Fail to trans tensor desc to json."); - return INTERNAL_ERROR; - } - trans_node_info_json[kInputTensorDesc] = move(input_tensor_desc_json); - Json output_tensor_desc_json; - ret = TensorDescToJson(trans_node_info.output, output_tensor_desc_json); - if (ret != SUCCESS) { - GELOGW("Fail to trans tensor desc to json."); - return INTERNAL_ERROR; - } - trans_node_info_json[kOutputTensorDesc] = move(output_tensor_desc_json); - trans_road_json.emplace_back(move(trans_node_info_json)); - } - trans_road_map_json[kTransRoad] = move(trans_road_json); - json.emplace_back(move(trans_road_map_json)); - } - } catch (const std::exception &e) { - GELOGW("Fail to trans VarToTransRoad to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetChangedGraphIdJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - for (const auto &name : var_names_) { - uint32_t changed_graph_id = 0; - Status ret = VarManager::Instance(session_id_)->GetChangedGraphId(name, changed_graph_id); - if (ret != SUCCESS) { - continue; - } - Json name_and_changed_graph_id; - try { - name_and_changed_graph_id[kName] = name; - name_and_changed_graph_id[kGraphId] = changed_graph_id; - json.emplace_back(move(name_and_changed_graph_id)); - } catch (const std::exception &e) { - GELOGW("Fail to trans ChangedGraphId to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::GetAllocatedGraphIdJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - for (const auto &name : var_names_) { - uint32_t allocated_graph_id = 0; - Status ret = VarManager::Instance(session_id_)->GetAllocatedGraphId(name, allocated_graph_id); - if (ret != SUCCESS) { - continue; - } - Json name_and_allocated_graph_id; - try { - name_and_allocated_graph_id[kName] = name; - name_and_allocated_graph_id[kGraphId] = allocated_graph_id; - json.emplace_back(move(name_and_allocated_graph_id)); - } catch (const std::exception &e) { - GELOGW("Fail to trans AllocatedGraphId to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::GetBroadcastInfoJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - for (const auto &name : var_names_) { - VarBroadCastInfo var_broadcast_info; - Status ret = VarManager::Instance(session_id_)->GetBroadCastInfo(graph_id_, name, var_broadcast_info); - if (ret != SUCCESS) { - continue; - } - Json var_broadcast_info_json; - try { - var_broadcast_info_json[kName] = name; - var_broadcast_info_json[kBroadcastName] = var_broadcast_info.broadcast_name; - var_broadcast_info_json[kIdx] = var_broadcast_info.idx; - var_broadcast_info_json[kInputOffset] = var_broadcast_info.input_offset; - var_broadcast_info_json[kInputSize] = var_broadcast_info.input_size; - var_broadcast_info_json[kOutputOffset] = var_broadcast_info.output_offset; - var_broadcast_info_json[kOutputSize] = var_broadcast_info.output_size; - json.emplace_back(move(var_broadcast_info_json)); - } catch (const std::exception &e) { - GELOGW("Fail to trans VarBroadcastInfo to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::GetVarResourceJson(Json &json) const { - if (!(json.is_null() || json.is_object())) { - GELOGW("Input param json type should be null or object."); - return PARAM_INVALID; - } - Json var_addr_mgr_map_json; - Status ret = GetVarAddrMgrMapJson(var_addr_mgr_map_json); - if (ret != SUCCESS) { - GELOGW("GetVarAddrMgrMapJson failed."); - return INTERNAL_ERROR; - } - - Json cur_var_tensor_desc_map_json; - ret = GetCurVarTensorDescMapJson(cur_var_tensor_desc_map_json); - if (ret != SUCCESS) { - GELOGW("GetCurVarTensorDescMapJson failed."); - return INTERNAL_ERROR; - } - - Json trans_roads_json; - ret = GetTransRoadsJson(trans_roads_json); - if (ret != SUCCESS) { - GELOGW("GetTransRoadsJson failed."); - return INTERNAL_ERROR; - } - - Json changed_graph_id_json; - ret = GetChangedGraphIdJson(changed_graph_id_json); - if (ret != SUCCESS) { - GELOGW("GetChangedGraphIdJson failed."); - return INTERNAL_ERROR; - } - - Json allocated_graph_id_json; - ret = GetAllocatedGraphIdJson(allocated_graph_id_json); - if (ret != SUCCESS) { - GELOGW("GetAllocatedGraphIdJson failed."); - return INTERNAL_ERROR; - } - - Json var_broadcast_info_json; - ret = GetBroadcastInfoJson(var_broadcast_info_json); - if (ret != SUCCESS) { - GELOGW("GetBroadcastInfoJson failed."); - return INTERNAL_ERROR; - } - - try { - json[kVarAddrMgrMap] = move(var_addr_mgr_map_json); - json[kCurVarTensorDescMap] = move(cur_var_tensor_desc_map_json); - json[kTransRoads] = move(trans_roads_json); - json[kChangedGraphId] = move(changed_graph_id_json); - json[kAllocatedGraphId] = move(allocated_graph_id_json); - json[kVarBroadcastInfo] = move(var_broadcast_info_json); - } catch (const exception &e) { - GELOGW("Fail to generate VarResource json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetVarManagerJson(Json &json) const { - if (!(json.is_null() || json.is_object())) { - GELOGW("Input param json type should be null or object."); - return PARAM_INVALID; - } - - Json mem_resource_map_json; - auto ret = GetMemResourceMap(mem_resource_map_json); - if (ret != SUCCESS) { - GELOGW("GetMemResourceMap failed."); - return INTERNAL_ERROR; - } - - Json var_resource_json; - ret = GetVarResourceJson(var_resource_json); - if (ret != SUCCESS) { - GELOGW("GetVarResourceJson failed."); - return INTERNAL_ERROR; - } - - try { - json[kSessionId] = session_id_; - json[kDeviceId] = VarManager::Instance(session_id_)->DeviceId(); - json[kJobId] = VarManager::Instance(session_id_)->JobId(); - json[kGraphMemMaxSize] = VarManager::Instance(session_id_)->GetGraphMemoryMaxSize(); - json[kVarMemMaxSize] = VarManager::Instance(session_id_)->GetVarMemMaxSize(); - json[kVarMemLogicBase] = VarManager::Instance(session_id_)->GetVarMemLogicBase(); - json[kUseMaxMemSize] = VarManager::Instance(session_id_)->GetUseMaxMemorySize(); - json[kMemResourceMap] = move(mem_resource_map_json); - json[kVarResource] = move(var_resource_json); - } catch (const exception &e) { - GELOGW("Fail to generate VarManager json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::SaveVarManagerToCache(bool before_build) const { - if (!is_cache_path_valid_for_output) { - GELOGW("Invalid cache path."); - return FAILED; - } - Json var_manager_json; - auto ret = GetVarManagerJson(var_manager_json); - if (ret != SUCCESS) { - GELOGW("Fail to generate VarManager json."); - return FAILED; - } - string var_manager_path = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + - (before_build ? kBeforeVarManagerSuffix : kAfterVarManagerSuffix); - ret = SaveJsonToFile(var_manager_path, var_manager_json); - if (ret != SUCCESS) { - GELOGW("Fail to save VarManager info to json file, path: %s.", cache_path_.c_str()); - return ret; - } - return SUCCESS; -} - -Status ModelCacheHelper::SaveOmModelToCache(const GeModelPtr &ge_model) const { - if (!is_cache_path_valid_for_output) { - GELOGW("Invalid cache path."); - return FAILED; - } - string om_path = RealPath(cache_path_.c_str()); - if (om_path.empty()) { - GELOGW("file path is invalid. please check path om: %s", cache_path_.c_str()); - return FAILED; - } - string cache_om_path = cache_path_; - cache_om_path += (to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kOmSuffix); - GELOGI("SaveOmModelToCache: start to save om model : %s", cache_om_path.c_str()); - ModelHelper model_helper; - SaveParam save_param; - ModelBufferData model; - Status ret = model_helper.SaveToOmModel(ge_model, save_param, cache_om_path, model); - if (ret != SUCCESS) { - GELOGW("SaveOmModelToCache: save mode failed. ret = %u", ret); - return ret; - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseMemResourceFromJson(const Json &json, map &mem_resource) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - mem_resource.clear(); - for (const Json &mem_resource_json : json) { - try { - rtMemType_t mem_type = mem_resource_json[kMemType].get(); - uint64_t var_mem_size = mem_resource_json[kVarMemSize].get(); - mem_resource[mem_type] = var_mem_size; - } catch (const exception &e) { - GELOGW("Fail to trans Json to MemResource. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseVarAddrMgrMapFromJson( - const Json &json, std::vector> &var_addr_mgr_vector, - std::set &var_offset_set) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - var_addr_mgr_vector.clear(); - var_offset_set.clear(); - for (const Json &var_addr_json : json) { - VarAddrMgr var_addr_mgr; - try { - auto logic_address = var_addr_json[kAddress].get(); - auto address = reinterpret_cast(reinterpret_cast(logic_address)); - var_addr_mgr.address = address; - var_addr_mgr.offset = var_addr_json[kOffset].get(); - var_addr_mgr.memory_type = var_addr_json[kMemoryType].get(); - auto ret = JsonToTensorDesc(var_addr_json[kTensorDesc], var_addr_mgr.tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to trans json to tensor desc."); - return ret; - } - var_addr_mgr_vector.emplace_back(var_addr_json[kName].get(), move(var_addr_mgr)); - var_offset_set.insert(logic_address); - } catch (const exception &e) { - GELOGW("Fail to trans Json to VarAddrMgr. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseCurVarTensorDescMapFromJson( - const Json &json, std::unordered_map &cur_var_tensor_desc_map) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - cur_var_tensor_desc_map.clear(); - for (const Json &tensor_desc_json : json) { - GeTensorDesc tensor_desc; - try { - auto ret = JsonToTensorDesc(tensor_desc_json[kTensorDesc], tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to trans json to tensor desc."); - return ret; - } - cur_var_tensor_desc_map[tensor_desc_json[kName].get()] = move(tensor_desc); - } catch (const exception &e) { - GELOGW("Fail to trans Json to VarAddrMgr. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseTransRoadsFromJson( - const Json &json, std::unordered_map> &trans_roads) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - trans_roads.clear(); - try { - for (const Json &name_trans_road_json : json) { - const Json &trans_road_json = name_trans_road_json[kTransRoad]; - if (!(trans_road_json.is_array() || trans_road_json.is_null())) { - GELOGW("%s json type should be null or object.", kTransRoad); - return PARAM_INVALID; - } - vector trans_road; - for (const Json &trans_node_json : trans_road_json) { - TransNodeInfo trans_node_info; - trans_node_info.node_type = trans_node_json[kNodeType]; - GeTensorDesc input_tensor_desc; - auto ret = JsonToTensorDesc(trans_node_json[kInputTensorDesc], input_tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to trans json to tensor desc."); - return ret; - } - trans_node_info.input = move(input_tensor_desc); - GeTensorDesc output_tensor_desc; - ret = JsonToTensorDesc(trans_node_json[kOutputTensorDesc], output_tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to trans json to tensor desc."); - return ret; - } - trans_node_info.output = move(output_tensor_desc); - trans_road.emplace_back(move(trans_node_info)); - } - trans_roads[name_trans_road_json[kName].get()] = move(trans_road); - } - } catch (const exception &e) { - GELOGW("Fail to trans Json to TransRoads. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseChangedGraphIdFromJson(const Json &json, - std::map &changed_graph_id) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - changed_graph_id.clear(); - for (const Json &name_graph_id_json : json) { - try { - changed_graph_id[name_graph_id_json[kName].get()] = name_graph_id_json[kGraphId].get(); - } catch (const exception &e) { - GELOGW("Fail to trans Json to changed graph id. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseAllocatedGraphIdFromJson(const Json &json, - std::map &allocated_graph_id) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - allocated_graph_id.clear(); - for (const Json &name_graph_id_json : json) { - try { - allocated_graph_id[name_graph_id_json[kName].get()] = name_graph_id_json[kGraphId].get(); - } catch (const exception &e) { - GELOGW("Fail to trans Json to allocated graph id. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseBroadcastInfoFromJson( - const Json &json, std::unordered_map &var_broadcast_info) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - for (const Json &broadcast_info_json : json) { - VarBroadCastInfo broadcast_info; - try { - broadcast_info.var_name = broadcast_info_json[kName].get(); - broadcast_info.broadcast_name = broadcast_info_json[kBroadcastName].get(); - broadcast_info.idx = broadcast_info_json[kIdx].get(); - broadcast_info.input_offset = broadcast_info_json[kInputOffset].get(); - broadcast_info.input_size = broadcast_info_json[kInputSize].get(); - broadcast_info.output_offset = broadcast_info_json[kOutputOffset].get(); - broadcast_info.output_size = broadcast_info_json[kOutputSize].get(); - } catch (const exception &e) { - GELOGW("Fail to trans Json to VarBroadCastInfo. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - var_broadcast_info[broadcast_info.var_name] = broadcast_info; - } - return SUCCESS; -} - -Status ModelCacheHelper::LoadOmModelFromCache(GeModelPtr &ge_model) const { - string cache_om = cache_path_ + to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kOmSuffix; - if (!CheckInputPathValid(cache_om)) { - GELOGW("Invalid cache path for input:%s.", cache_om.c_str()); - return FAILED; - } - string om_path = RealPath(cache_om.c_str()); - if (om_path.empty()) { - GELOGW("file path is invalid. please check file om: %s", om_path.c_str()); - return FAILED; - } - GELOGI("load model data from file: %s", om_path.c_str()); - Status ret; - int32_t priority = 0; - ModelData model_data; - ret = ModelParserBase::LoadFromFile(om_path.c_str(), priority, model_data); - if (ret != SUCCESS) { - GELOGW("LoadOmModelFromCache: Load model from file failed. ret = %u", ret); - return ret; - } - std::function callback = [&]() { - if (model_data.model_data != nullptr) { - delete[] reinterpret_cast(model_data.model_data); - model_data.model_data = nullptr; - } - }; - GE_MAKE_GUARD(release, callback); - - ModelHelper model_helper; - ret = model_helper.LoadModel(model_data); - if (ret != SUCCESS) { - GELOGW("LoadOmModelFromCache: Load model from data failed. ret = %u", ret); - return ret; - } - ge_model = model_helper.GetGeModel(); - ret = RecompileNodes(ge_model); - if (ret != SUCCESS) { - GELOGW("LoadOmModelFromCache: recompile nodes failed. ret = %u", ret); - return ret; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc, - string &var_name) { - std::string::size_type underline_idx = var_key.rfind('_'); - if (underline_idx == std::string::npos) { - GELOGW("Invalid var key: underline not found"); - return FAILED; - } - std::string::size_type format_idx = - var_key.rfind(std::to_string(static_cast(tensor_desc.GetFormat())), underline_idx); - if (format_idx == std::string::npos) { - GELOGW("Invalid var key: format not found"); - return FAILED; - } - var_name = var_key.substr(0, format_idx); - return SUCCESS; -} -} // namespace ge diff --git a/ge/common/helper/model_cache_helper.h b/ge/common/helper/model_cache_helper.h deleted file mode 100755 index f0831075..00000000 --- a/ge/common/helper/model_cache_helper.h +++ /dev/null @@ -1,123 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ -#define GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ - -#include -#include -#include - -#include "external/ge/ge_api_error_codes.h" -#include "graph/compute_graph.h" -#include "graph/manager/graph_var_manager.h" -#include "common/model/ge_model.h" - -namespace ge { -using Json = nlohmann::json; - -struct CacheInfo { - size_t node_num; - size_t edge_num; - size_t graph_hash; - map nodes_hash; - CacheInfo() : node_num(0), edge_num(0), graph_hash(0) {} -}; - -class ModelCacheHelper { - public: - ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph); - ~ModelCacheHelper(); - - Status SaveCacheInfoToCache () const; - Status SaveVarManagerToCache(bool before_build) const; - Status SaveOmModelToCache(const GeModelPtr &ge_model) const; - bool IsModelCacheHit() const; - Status RecoverVarManagerFromCache() const; - Status LoadOmModelFromCache(GeModelPtr &ge_model) const; - Status RefreshComputeGraph(const ComputeGraphPtr &compute_graph); - Status ClearCache(uint32_t graph_id) const; - - private: - Status GetComputeGraphHash(size_t &hash) const; - Status GetNodesHash(map &hash_map) const; - Status GetCacheInfo(CacheInfo &cache_info) const; - - Status RecoverMemResource(const Json &json) const; - Status RecoverAllocatedGraphId(const Json &json) const; - Status RecoverChangedGraphId(const Json &json) const; - Status RecoverVarAddrAndTensorDesc(const Json &json) const; - Status RecoverBroadcastInfo(const Json &json) const; - Status RecoverTransRoads(const Json &json) const; - static Status GetNodesNeedRecompile(ComputeGraphPtr &graph, vector &nodes); - static Status RecompileNodes(GeModelPtr &ge_model); - - bool IsNodeHashSameAsCache(const map &hash_map) const; - bool IsMemResourceSameAsCache(Json &json) const; - bool IsChangedGraphIdSameAsCache(Json &json) const; - bool IsAllocatedGraphIdSameAsCache(Json &json) const; - bool IsCurVarTensorDescSameAsCache(Json &json) const; - bool IsVarAddrMgrMapSameAsCache(Json &json) const; - bool IsBroadcastInfoSameAsCache(Json &json) const; - bool IsTransRoadsSameAsCache(Json &json) const; - bool IsVarManagerSameAsCache(Json &json) const; - bool IsVarManagerParamSameAsCache(Json &json) const; - - Status SaveJsonToFile(const string &file_name, const Json &json) const; - Status LoadJsonFromFile(const string &file_name, Json &json) const; - - Status GetNodesHashMapJson(Json &json) const; - Status GetMemResourceMap(Json &json) const; - Status GetVarAddrMgrMapJson(Json &json) const; - Status GetCurVarTensorDescMapJson(Json &json) const; - Status GetTransRoadsJson(Json &json) const; - Status GetChangedGraphIdJson(Json &json) const; - Status GetAllocatedGraphIdJson(Json &json) const; - Status GetBroadcastInfoJson(Json &json) const; - Status GetVarResourceJson(Json &json) const; - Status GetVarManagerJson(Json &json) const; - - static Status TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json); - static Status JsonToTensorDesc(const Json &json, GeTensorDesc &ge_tensor_desc); - static Status ParseMemResourceFromJson(const Json &json, map &mem_resource); - static Status ParseVarAddrMgrMapFromJson(const Json &json, - std::vector> &var_addr_mgr_vector, - std::set &var_offset_set); - static Status ParseCurVarTensorDescMapFromJson( - const Json &json, std::unordered_map &cur_var_tensor_desc_map); - static Status ParseTransRoadsFromJson(const Json &json, - std::unordered_map> &trans_roads); - static Status ParseChangedGraphIdFromJson(const Json &json, - std::map &changed_graph_id); - static Status ParseAllocatedGraphIdFromJson(const Json &json, - std::map &allocated_graph_id); - static Status ParseBroadcastInfoFromJson(const Json &json, - std::unordered_map &var_broadcast_info); - static Status GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc, string &var_name); - - uint64_t session_id_; - uint32_t graph_id_; - string cache_path_; - ComputeGraphPtr compute_graph_; - std::set var_names_; - bool is_cache_path_valid_for_output; - static map graph_id_run_times_; -}; - -using ModelCacheHelperPtr = std::shared_ptr; -} // namespace ge - -#endif // GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ diff --git a/ge/executor/ge_executor.cc b/ge/executor/ge_executor.cc index 73cd7bb5..76cde2b9 100755 --- a/ge/executor/ge_executor.cc +++ b/ge/executor/ge_executor.cc @@ -27,6 +27,7 @@ #include "graph/load/graph_loader.h" #include "graph/load/model_manager/model_manager.h" #include "graph/manager/graph_mem_manager.h" +#include "graph/manager/graph_var_manager.h" #include "single_op/single_op_manager.h" #include "graph/load/model_manager/davinci_model.h" #include "opskernel_manager/ops_kernel_builder_manager.h" diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 45eaed59..1a80a3e0 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -30,6 +30,7 @@ #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" #include "graph/manager/graph_manager.h" +#include "graph/manager/graph_var_manager.h" #include "graph/manager/util/rt_context_util.h" #include "graph/operator_factory_impl.h" #include "graph/opsproto_manager.h" diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 7d72d85b..d1237f4e 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -248,7 +248,6 @@ Status GraphManager::Finalize() { Analyzer::GetInstance()->DestroyGraphJsonObject(session_id, graph_id); } graph_map_.clear(); - cache_helper_map_.clear(); graph_count_.clear(); // graph context @@ -1005,13 +1004,6 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vectorGetGraphId(), compute_graph, ge_model); - if (save_ret != SUCCESS) { - GELOGW("Fail to save cache."); - } GEEVENT("[GEPERFTRACE] GE PreRun End"); return SUCCESS; } @@ -1063,18 +1055,15 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: graph_node->GetGraphId()); return PARAM_INVALID; } - GeModelPtr ge_model = nullptr; - // check need incre build. - ret = IncreBuild(graph_node, ge_model); + + ret = PreRun(graph_node, inputs, ge_root_model, session_id); + // release rts generate context + RtContextUtil::GetInstance().DestroyRtContexts(session_id, graph_node->GetGraphId()); if (ret != SUCCESS) { - ret = PreRun(graph_node, inputs, ge_root_model, session_id); - // release rts generate context - RtContextUtil::GetInstance().DestroyRtContexts(session_id, graph_node->GetGraphId()); - if (ret != SUCCESS) { - GELOGE(ret, "[Call][PreRun] Failed, graph_id:%u, session_id:%lu.", graph_node->GetGraphId(), session_id); - return ret; - } + GELOGE(ret, "[Call][PreRun] Failed, graph_id:%u, session_id:%lu.", graph_node->GetGraphId(), session_id); + return ret; } + ret = LoadGraph(ge_root_model, graph_node); if (ret != SUCCESS) { GELOGE(ret, "[Load][Graph] Failed, graph_id:%u.", graph_node->GetGraphId()); @@ -1104,91 +1093,6 @@ Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphN return executor_->LoadGraph(ge_root_model, graph_node); } -Status GraphManager::LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, - GeModelPtr &ge_model) { - auto graph_id = graph_node->GetGraphId(); - auto ret = cache_helper->LoadOmModelFromCache(ge_model); - if (ret != SUCCESS) { - GELOGW("Fail to load om model from cache."); - if (cache_helper->ClearCache(graph_id) != SUCCESS) { - GELOGW("Fail to clear cache of graph %u.", graph_id); - } - return FAILED; - } - ret = cache_helper->RecoverVarManagerFromCache(); - if (ret != SUCCESS) { - GELOGW("Fail to recover VarManager from cache."); - if (cache_helper->ClearCache(graph_id) != SUCCESS) { - GELOGW("Fail to clear cache of graph %u.", graph_id); - } - return FAILED; - } - ComputeGraphPtr compute_graph_in_model = GraphUtils::GetComputeGraph(ge_model->GetGraph()); - if (compute_graph_in_model == nullptr) { - GELOGW("Error occurred when get compute graph from om, abandon."); - return FAILED; - } else { - graph_node->SetComputeGraph(compute_graph_in_model); - graph_node->SetGeModel(ge_model); - GELOGI("Load model and graph form cache om file."); - } - return SUCCESS; -} - -Status GraphManager::SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper) { - auto ret = cache_helper->SaveCacheInfoToCache(); - if (ret != SUCCESS) { - GELOGW("Fail to save cache info of graph[%d] to cache.", graph_id); - return FAILED; - } - ret = cache_helper->SaveVarManagerToCache(true); - if (ret != SUCCESS) { - GELOGW("Fail to save var manager to cache."); - cache_helper->ClearCache(graph_id); - return FAILED; - } - GELOGI("Cache files have been saved."); - return SUCCESS; -} - -Status GraphManager::SaveCacheAfterBuild(uint32_t graph_id, ge::ComputeGraphPtr graph, GeModelPtr &ge_model) { - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if ((instance_ptr == nullptr) || !instance_ptr->InitFlag()) { - GELOGW("GELib not initialized."); - return FAILED; - } - - if (instance_ptr->IsIncreBuild()) { - std::lock_guard lock(member_mutex_); - auto iter = cache_helper_map_.find(graph_id); - if (iter == cache_helper_map_.end()) { - GELOGW("Can not find ModelCacheHelper of graph[%u]", graph_id); - return FAILED; - } else { - ModelCacheHelperPtr cache_helper = iter->second; - auto ret = cache_helper->RefreshComputeGraph(graph); - if (ret != SUCCESS) { - cache_helper->ClearCache(graph_id); - GELOGW("Fail to refresh cache helper's compute graph"); - return FAILED; - } - ret = cache_helper->SaveVarManagerToCache(false); - if (ret != SUCCESS) { - cache_helper->ClearCache(graph_id); - GELOGW("Fail to save VarManager to cache"); - return FAILED; - } - ret = cache_helper->SaveOmModelToCache(ge_model); - if (ret != SUCCESS) { - cache_helper->ClearCache(graph_id); - GELOGW("Fail to save om model to cache"); - return FAILED; - } - } - } - return SUCCESS; -} - Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, const std::vector &inputs, std::vector &outputs) { GE_CHECK_NOTNULL(executor_); @@ -1239,8 +1143,6 @@ Status GraphManager::RunGraphWithStreamAsync(const GraphId &graph_id, rtStream_t graph_node->SetIsSpecificStream(true); ComputeGraphPtr compute_graph_tmp = GraphUtils::GetComputeGraph(*(graph_node->GetGraph())); - // when set incre build, add cache helper map - AddModelCacheHelperToMap(graph_id, session_id, compute_graph_tmp); if (options_.local_fmk_op_flag) { GetCompilerStages(graph_id).optimizer.TranFrameOp(compute_graph_tmp); } @@ -1299,9 +1201,6 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vector lock(member_mutex_); - auto iter = cache_helper_map_.find(graph_id); - if (iter != cache_helper_map_.end()) { - cache_helper_map_.erase(iter); - } else { - GELOGW("[GraphManager] cache helper does not exist, graph_id = %u", graph_id); - } -} - bool GraphManager::CheckModelLoad(const GeRootModelPtr &ge_root_model, bool load_flag) { return ((ge_root_model != nullptr) && (ge_root_model->GetModelId() != INVALID_MODEL_ID) && load_flag); } @@ -1555,7 +1444,6 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { var_acc_ctrl_.RemoveGraph(graph_id); RemoveGraphNode(graph_id); - RemoveModelCacheHelper(graph_id); auto ge_root_model = graph_node->GetGeRootModel(); if (CheckModelLoad(ge_root_model, graph_node->GetLoadFlag())) { @@ -2727,61 +2615,6 @@ Status GraphManager::RunGraphAsync(const GraphId &graph_id, const std::vector instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr != nullptr && instance_ptr->IsIncreBuild()) { - std::lock_guard lock(member_mutex_); - auto iter = cache_helper_map_.find(graph_id); - if (iter == cache_helper_map_.end()) { - ModelCacheHelperPtr cache_helper = MakeShared(session_id, graph_id, compute_graph); - if (cache_helper != nullptr) { - cache_helper_map_.emplace(std::make_pair(graph_id, cache_helper)); - } else { - GELOGW("Cache helper make shared failed, graph_id = %u.", graph_id); - } - } - } -} - -ModelCacheHelperPtr GraphManager::FindModelCacheHelper(GraphId graph_id) { - std::lock_guard lock(member_mutex_); - auto iter = cache_helper_map_.find(graph_id); - if (iter != cache_helper_map_.end()) { - return iter->second; - } - - return nullptr; -} - -Status GraphManager::IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model) { - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->IsIncreBuild()) { - return FAILED; - } - const uint32_t graph_id = graph_node->GetGraphId(); - ModelCacheHelperPtr cache_helper = FindModelCacheHelper(graph_id); - if (cache_helper == nullptr) { - GELOGW("Can not find ModelCacheHelper of graph[%u]", graph_id); - return FAILED; - } - if (cache_helper->IsModelCacheHit()) { - GEEVENT("Model cache hit."); - Status ret = LoadFromCache(graph_node, cache_helper, ge_model); - if (ret == SUCCESS) { - return SUCCESS; - } else { - GELOGW("Error occurred when load from cache, abandon."); - } - } else { - GEEVENT("Model cache miss."); - } - if (SaveCacheBeforeBuild(graph_node->GetGraphId(), cache_helper) != SUCCESS) { - GELOGW("Error occurred when save cache."); - } - return FAILED; -} - Status GraphManager::CheckIncreBuildAndPreRun(const PreRunArgs &args, GraphNodePtr &graph_node, GeRootModelPtr &ge_root_model) { if (!IsGraphNeedBuild(graph_node)) { @@ -2796,20 +2629,18 @@ Status GraphManager::CheckIncreBuildAndPreRun(const PreRunArgs &args, return PARAM_INVALID; } // check need incre build. - GeModelPtr ge_model = nullptr; - if (IncreBuild(graph_node, ge_model) != SUCCESS) { - std::vector ge_inputs; - for (const auto &item: args.input_tensor) { - ge_inputs.emplace_back(TensorAdapter::AsGeTensor(item)); - } - Status ret = PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); - // release rts generate context - RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); - if (ret != SUCCESS) { - ReturnError(args.callback, ret, "PreRun Failed."); - return ret; - } + std::vector ge_inputs; + for (const auto &item: args.input_tensor) { + ge_inputs.emplace_back(TensorAdapter::AsGeTensor(item)); } + Status ret = PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); + // release rts generate context + RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); + if (ret != SUCCESS) { + ReturnError(args.callback, ret, "PreRun Failed."); + return ret; + } + graph_node->SetBuildFlag(true); var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); return SUCCESS; @@ -2878,10 +2709,6 @@ void GraphManager::PreRunThread() { graph_node->Unlock(); return; } - // when set incre build, save cache helper. - AddModelCacheHelperToMap(args.graph_id, args.session_id, compute_graph_tmp); - - std::vector ge_models; if (options_.local_fmk_op_flag) { GetCompilerStages(graph_node->GetGraphId()).optimizer.TranFrameOp(compute_graph_tmp); diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index 84d2b11e..e7cd88a9 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -27,7 +27,6 @@ #include "common/blocking_queue.h" #include "framework/common/ge_inner_error_codes.h" -#include "common/helper/model_cache_helper.h" #include "external/graph/types.h" #include "external/ge/ge_api_types.h" #include "graph/build/graph_builder.h" @@ -339,14 +338,6 @@ class GraphManager { bool IsGraphNeedBuild(const GraphNodePtr &graph_node); - Status LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, GeModelPtr &ge_model); - Status SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper); - Status SaveCacheAfterBuild(uint32_t graph_id, ComputeGraphPtr graph, GeModelPtr &ge_model); - void AddModelCacheHelperToMap(const GraphId &graph_id, uint64_t session_id, ComputeGraphPtr &compute_graph); - Status IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model); - void RemoveModelCacheHelper(const GraphId &graph_id); - ModelCacheHelperPtr FindModelCacheHelper(GraphId graph_id); - void SetRunContext(const GraphNodePtr &graph_node); void PushGraph(const RunArgs &args); @@ -411,7 +402,6 @@ class GraphManager { std::thread prerun_thread_; ComputeGraphPtr compute_graph_; std::map graph_map_; - std::map cache_helper_map_; // summary and checkpoint callback function list for ME, key is summary or checkpoint std::map &)>> me_callback_map_; diff --git a/ge/graph/manager/graph_manager_utils.cc b/ge/graph/manager/graph_manager_utils.cc index 42251b10..225a748a 100644 --- a/ge/graph/manager/graph_manager_utils.cc +++ b/ge/graph/manager/graph_manager_utils.cc @@ -70,45 +70,9 @@ void GraphNode::IncreaseLoadCount() { ++load_count_; } -SubGraphInfo::SubGraphInfo() : subgraph_ptr_(nullptr), ge_model_ptr_(nullptr), malloc_flag_(false) {} +SubGraphInfo::SubGraphInfo() : subgraph_ptr_(nullptr), ge_model_ptr_(nullptr) {} SubGraphInfo::~SubGraphInfo() { - if (malloc_flag_) { - for (auto &buffer_addr : buffer_addr_) { - if (buffer_addr == nullptr) { - continue; - } - rtError_t rt_ret; - rt_ret = rtFreeHost(buffer_addr); - buffer_addr = nullptr; - if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "[Call][RtFreeHost] subgraph free buffer failed, modelId = %u", - model_id_info_.model_id); - } - } - } -} - -Status SubGraphInfo::FreeInOutBuffer() { - if (malloc_flag_) { - for (auto iter = buffer_addr_.begin(); iter != buffer_addr_.end(); ++iter) { - rtError_t rt_ret; - rt_ret = rtFreeHost(*iter); - if (rt_ret != RT_ERROR_NONE) { - REPORT_CALL_ERROR("E19999", "Call rtFreeHost fail, ret:%d", rt_ret); - GELOGE(rt_ret, "[Call][RtFreeHost] subgraph free buffer failed, modelId = %u", model_id_info_.model_id); - buffer_addr_.erase(buffer_addr_.begin(), iter); - return GE_GRAPH_FREE_FAILED; - } - } - buffer_addr_.clear(); - - malloc_flag_ = false; - return SUCCESS; - } else { - GELOGI("[GraphManager] not malloc buffer, modelId = %u", model_id_info_.model_id); - return SUCCESS; - } } GraphModelListener::GraphModelListener(std::mutex &mutex, std::condition_variable &cond) diff --git a/ge/graph/manager/graph_manager_utils.h b/ge/graph/manager/graph_manager_utils.h index 14eb67f2..efdbecf8 100644 --- a/ge/graph/manager/graph_manager_utils.h +++ b/ge/graph/manager/graph_manager_utils.h @@ -86,8 +86,6 @@ class SubGraphInfo { void SetGeModelPtr(const GeModelPtr &ge_model_ptr) { ge_model_ptr_ = ge_model_ptr; } bool GeModelIsValid() const { return ge_model_ptr_ != nullptr; } - Status FreeInOutBuffer(); - void SetOutputContext(const std::string &output) { output_names_ = output; } std::string GetOutputContext() const { return output_names_; } diff --git a/ge/graph/manager/graph_var_manager.cc b/ge/graph/manager/graph_var_manager.cc index d0669254..ce5b335e 100755 --- a/ge/graph/manager/graph_var_manager.cc +++ b/ge/graph/manager/graph_var_manager.cc @@ -429,10 +429,6 @@ ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTenso return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type); } -void VarManager::GetAllVarAddrMgr(std::unordered_map &var_addr_mgr_map) { - var_resource_->GetAllVarAddrMgr(var_addr_mgr_map); -} - int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { std::lock_guard lock(mutex_); MemResource *mem_resource = nullptr; @@ -453,36 +449,6 @@ int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { return mem_resource->GetVarMemSize(); } -Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) { - std::lock_guard lock(mutex_); - MemResource *mem_resource = nullptr; - auto iter = mem_resource_map_.find(memory_type); - if (iter == mem_resource_map_.end()) { - mem_resource = MemResource::BuildMemResourceFromType(memory_type); - if (mem_resource == nullptr) { - REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu", - memory_type, session_id_); - GELOGE(ge::INTERNAL_ERROR, "[Alloc][MemResource] failed, memory_type:%u, session_id:%lu", - memory_type, session_id_); - return ge::INTERNAL_ERROR; - } else { - mem_resource_map_[memory_type] = mem_resource; - } - } else { - mem_resource = iter->second; - } - - if (mem_resource == nullptr) { - REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu", - memory_type, session_id_); - GELOGE(ge::INTERNAL_ERROR, "[Check][Param] MemResource is invalid, memory_type:%u, session_id:%lu", - memory_type, session_id_); - return FAILED; - } - mem_resource->UpdateVarMemSize(mem_size); - return SUCCESS; -} - ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, rtMemType_t memory_type) { std::lock_guard lock(mutex_); @@ -638,16 +604,6 @@ ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastIn return SUCCESS; } -ge::Status VarManager::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) { - std::lock_guard lock(mutex_); - - if (var_resource_ == nullptr) { - GELOGW("VarManager has not been init."); - return ge::INTERNAL_ERROR; - } - return var_resource_->GetBroadCastInfo(graph_id, var_name, broad_cast_info); -} - ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) { std::lock_guard lock(mutex_); GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str()); diff --git a/ge/graph/manager/graph_var_manager.h b/ge/graph/manager/graph_var_manager.h index a1b45959..f0e3b89b 100755 --- a/ge/graph/manager/graph_var_manager.h +++ b/ge/graph/manager/graph_var_manager.h @@ -223,14 +223,10 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, rtMemType_t &memory_type); - void GetAllVarAddrMgr(std::unordered_map &var_addr_mgr_map); - ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); - ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); - ge::Status GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc); ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); @@ -273,8 +269,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { int64_t GetVarMemSize(rtMemType_t memory_type); - Status UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size); - bool IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc); bool IsVarExist(const std::string &var_name); diff --git a/ge/graph/manager/model_manager/event_manager.cc b/ge/graph/manager/model_manager/event_manager.cc deleted file mode 100644 index 339e9894..00000000 --- a/ge/graph/manager/model_manager/event_manager.cc +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/manager/model_manager/event_manager.h" - -#define RETURN_IF_COND_NOT_MET(condition, ...) \ - do { \ - if (!(condition)) { \ - GELOGE(FAILED, __VA_ARGS__); \ - return; \ - } \ - } while (0); - -namespace ge { -Status EventManager::Init(size_t event_num) { - if (this->inited_) { - return SUCCESS; - } - - rtEvent_t event = nullptr; - current_idx_ = 0; - for (size_t i = 0; i < event_num; ++i) { - GE_CHK_RT_RET(rtEventCreate(&event)); - this->event_list_.push_back(event); - } - - this->inited_ = true; - - return SUCCESS; -} - -void EventManager::Release() noexcept { - for (size_t i = 0; i < this->event_list_.size(); ++i) { - rtError_t rt_ret = rtEventDestroy(this->event_list_[i]); - RETURN_IF_COND_NOT_MET(rt_ret == RT_ERROR_NONE, "[Destroy][Event] failed, idx is %zu, ret is 0x%x.", i, rt_ret); - } - this->event_list_.clear(); - - this->inited_ = false; -} - -Status EventManager::EventRecord(size_t event_idx, rtStream_t stream) { - GE_CHK_BOOL_RET_STATUS_NOLOG(this->inited_, INTERNAL_ERROR); - - GE_CHK_BOOL_RET_STATUS_NOLOG(event_idx < this->event_list_.size(), PARAM_INVALID); - - GE_CHK_RT_RET(rtEventRecord(this->event_list_[event_idx], stream)); - - current_idx_ = static_cast(event_idx); - return SUCCESS; -} - -Status EventManager::EventElapsedTime(size_t start_event_idx, size_t stop_event_idx, float &time) { - GE_CHK_BOOL_RET_STATUS_NOLOG(this->inited_, INTERNAL_ERROR); - - GE_CHK_BOOL_RET_STATUS_NOLOG(start_event_idx < this->event_list_.size() && - stop_event_idx < this->event_list_.size() && start_event_idx <= stop_event_idx, - PARAM_INVALID); - - GE_CHK_RT_RET(rtEventElapsedTime(&time, this->event_list_[start_event_idx], this->event_list_[stop_event_idx])); - - return SUCCESS; -} - -Status EventManager::GetEvent(uint32_t index, rtEvent_t &event) { - GE_CHK_BOOL_RET_STATUS_NOLOG(index < this->event_list_.size(), PARAM_INVALID); - event = this->event_list_[index]; - return SUCCESS; -} -} // namespace ge diff --git a/ge/graph/manager/model_manager/event_manager.h b/ge/graph/manager/model_manager/event_manager.h deleted file mode 100644 index 2cb1c3f6..00000000 --- a/ge/graph/manager/model_manager/event_manager.h +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_MANAGER_MODEL_MANAGER_EVENT_MANAGER_H_ -#define GE_GRAPH_MANAGER_MODEL_MANAGER_EVENT_MANAGER_H_ - - -#include - -#include "framework/common/fmk_error_codes.h" -#include "framework/common/fmk_types.h" -#include "framework/common/util.h" -#include "runtime/event.h" - -namespace ge { -class EventManager { - public: - /// - /// @ingroup domi_ome - /// @brief constructor - /// - EventManager() : inited_(false), current_idx_(0) {} - /// - /// @ingroup domi_ome - /// @brief destructor - /// - ~EventManager() { this->Release(); } - - /// - /// @ingroup domi_ome - /// @brief init and create event list - /// @param [in] event_num event number created - /// @return exec result - /// - Status Init(size_t event_num); - - /// - /// @ingroup domi_ome - /// @brief event record - /// @param [in] event_idx event index - /// @param [in] stream related stream - /// @return exec result - /// - Status EventRecord(size_t event_idx, rtStream_t stream); - - /// - /// @ingroup domi_ome - /// @brief time between start and end in ms - /// @param [in] start_event_idx start event index - /// @param [in] stop_event_idx stop event index - /// @param [out] time - /// @return exec result - /// - Status EventElapsedTime(size_t start_event_idx, size_t stop_event_idx, float &time); - - /// - /// @ingroup domi_ome - /// @brief current event index - /// @return - /// - uint32_t CurrentIdx() const { return current_idx_; } - - /// - /// @ingroup domi_ome - /// @brief get event at specific loc - /// @param [in] index event index - /// @return - /// - Status GetEvent(uint32_t index, rtEvent_t &event); - - /// - /// @ingroup domi_ome - /// @brief release event list - /// @param [in] - /// @return - /// - void Release() noexcept; - - private: - std::vector event_list_; - bool inited_; - uint32_t current_idx_; -}; // EventManager -} // namespace ge -#endif // GE_GRAPH_MANAGER_MODEL_MANAGER_EVENT_MANAGER_H_ diff --git a/ge/graph/manager/trans_var_data_utils.h b/ge/graph/manager/trans_var_data_utils.h index 174efbb3..f5a89a50 100755 --- a/ge/graph/manager/trans_var_data_utils.h +++ b/ge/graph/manager/trans_var_data_utils.h @@ -24,7 +24,6 @@ #include "graph/utils/tensor_utils.h" #include "graph/node.h" #include "runtime/context.h" -#include "graph/manager/graph_var_manager.h" namespace ge { class TransVarDataUtils { diff --git a/ge/graph/passes/global_step_insert_pass.cc b/ge/graph/passes/global_step_insert_pass.cc index 297e4ee2..ada4e12a 100755 --- a/ge/graph/passes/global_step_insert_pass.cc +++ b/ge/graph/passes/global_step_insert_pass.cc @@ -24,7 +24,6 @@ #include "framework/common/util.h" #include "graph/debug/ge_attr_define.h" #include "common/ge/ge_util.h" -#include "graph/manager/graph_var_manager.h" #include "graph/passes/pass_utils.h" #include "graph/ge_context.h" diff --git a/ge/init/gelib.h b/ge/init/gelib.h index 5e66be51..226dd4c8 100644 --- a/ge/init/gelib.h +++ b/ge/init/gelib.h @@ -28,7 +28,6 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/utils/anchor_utils.h" -#include "graph/manager/graph_var_manager.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_types.h" @@ -63,13 +62,7 @@ class GE_FUNC_VISIBILITY GELib { bool InitFlag() const { return init_flag_; } // get TrainMode flag - bool isTrainMode() { return is_train_mode_; } - - // get incre build flag - bool IsIncreBuild() const { return is_incre_build_; } - - // get incre build cache path - const std::string &GetIncreBuildCachePath() const { return incre_build_cache_path_; } + bool IsTrainMode() { return is_train_mode_; } void InitProfiling(Options &options); void ShutDownProfiling(); @@ -100,8 +93,6 @@ class GE_FUNC_VISIBILITY GELib { bool is_system_inited = false; bool is_shutdown = false; bool is_use_hcom = false; - bool is_incre_build_ = false; - std::string incre_build_cache_path_; }; } // namespace ge diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 49c9161d..a0790cf2 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -140,7 +140,6 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" "${GE_CODE_DIR}/ge/graph/build/graph_builder.cc" "${GE_CODE_DIR}/ge/graph/partition/graph_partition.cc" - "${GE_CODE_DIR}/ge/common/helper/model_cache_helper.cc" "${GE_CODE_DIR}/ge/ir_build/ge_ir_build.cc" "${GE_CODE_DIR}/ge/ir_build/attr_options/utils.cc" "${GE_CODE_DIR}/ge/ir_build/attr_options/keep_dtype_option.cc" @@ -248,7 +247,6 @@ set(GRAPH_DAVINCI_MODEL_SRC_FILES "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/super_kernel/super_kernel.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/super_kernel/super_kernel_factory.cc" "${GE_CODE_DIR}/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc" - "${GE_CODE_DIR}/ge/graph/manager/model_manager/event_manager.cc" ) set(GRAPH_EXECUTE_COMMON_SRC_FILES @@ -520,13 +518,9 @@ set(COMMON_TEST_FILES set(DISTINCT_GRAPH_LOAD_TEST_FILES "graph/load/data_dumper_unittest.cc" - #"graph/load/new_model_manager_data_inputer_unittest.cc" - #"graph/load/new_model_manager_davinci_model_unittest.cc" "graph/load/model_manager_unittest.cc" "graph/load/new_model_manager_model_manager_aicpu_unittest.cc" "graph/load/end_graph_task_unittest.cc" - "graph/load/new_model_manager_event_manager_unittest.cc" - #"graph/load/output_net_output_unittest.cc" "graph/load/davinci_model_unittest.cc" "graph/load/tbe_handle_store_unittest.cc" "graph/load/hccl_task_info_unittest.cc" @@ -536,7 +530,6 @@ set(DISTINCT_GRAPH_LOAD_TEST_FILES "graph/load/memcpy_addr_async_task_info_unittest.cc" "graph/load/memcpy_async_task_info_unittest.cc" "graph/load/cpu_queue_schedule_unittest.cc" - #"graph/graph_load_unittest.cc" "graph/ge_executor_unittest.cc" "graph/load/model_helper_unittest.cc" "graph/load/model_utils_unittest.cc" diff --git a/tests/ut/ge/graph/execute/model_executor_unittest.cc b/tests/ut/ge/graph/execute/model_executor_unittest.cc index d4e0e3a4..cd907e99 100644 --- a/tests/ut/ge/graph/execute/model_executor_unittest.cc +++ b/tests/ut/ge/graph/execute/model_executor_unittest.cc @@ -20,6 +20,7 @@ #define private public #include "graph/execute/model_executor.h" #include "graph/manager/graph_manager.h" +#include "graph/manager/graph_var_manager.h" #include "graph/load/model_manager/model_manager.h" #include "graph/load/model_manager/davinci_model.h" diff --git a/tests/ut/ge/graph/graph_load_unittest.cc b/tests/ut/ge/graph/graph_load_unittest.cc deleted file mode 100644 index 93282a5e..00000000 --- a/tests/ut/ge/graph/graph_load_unittest.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include - -#include "common/debug/log.h" -#include "common/helper/model_helper.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "graph/op_desc.h" -#include "graph/types.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/op_desc_utils.h" - -#define protected public -#define private public -#include "graph/load/graph_loader.h" - -#include "framework/common/ge_inner_error_codes.h" -#include "graph/load/model_manager/model_manager.h" -#include "graph/manager/graph_manager_utils.h" -#include "common/model/ge_model.h" -#undef private -#undef protected - -using namespace testing; -namespace ge { - -class UtestGraphGraphLoad : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestGraphGraphLoad, load_graph_param_invalid1) { - std::shared_ptr graph_run_listener = nullptr; - SubGraphInfo sub_graph1; - ge::SubGraphInfoPtr sub_graph_ptr1 = std::make_shared(sub_graph1); - ModelIdInfo model_Id_info; - model_Id_info.model_id = 1; - - GeModelPtr ge_model_ptr = std::make_shared(); - sub_graph_ptr1->SetGeModelPtr(ge_model_ptr); - - std::vector input_flag; - input_flag.push_back(false); - sub_graph_ptr1->SetInputFlag(input_flag); - - ge::GraphLoader graph_load; - EXPECT_EQ(GE_GRAPH_PARAM_NULLPTR, graph_load.LoadGraph(sub_graph_ptr1->ge_model_ptr_, graph_run_listener, model_Id_info)); - sub_graph_ptr1->SetModelIdInfo(model_Id_info); -} - -TEST_F(UtestGraphGraphLoad, load_graph_param_invalid2) { - std::mutex sync_run_mutex; - std::condition_variable condition; - std::shared_ptr listener = std::make_shared(sync_run_mutex, condition); - - SubGraphInfo sub_graph1; - ge::SubGraphInfoPtr sub_graph_ptr1 = std::make_shared(sub_graph1); - ModelIdInfo model_Id_info; - model_Id_info.model_id = 1; - - GeModelPtr ge_model_ptr = std::make_shared(); - sub_graph_ptr1->SetGeModelPtr(ge_model_ptr); - - std::vector input_flag; - input_flag.push_back(false); - sub_graph_ptr1->SetInputFlag(input_flag); - - ge::GraphLoader graph_load; - EXPECT_EQ(GE_GRAPH_PARAM_NULLPTR, graph_load.LoadGraph(sub_graph_ptr1->ge_model_ptr_, listener, model_Id_info)); - sub_graph_ptr1->SetModelIdInfo(model_Id_info); -} -} // namespace ge diff --git a/tests/ut/ge/graph/load/new_model_manager_data_inputer_unittest.cc b/tests/ut/ge/graph/load/new_model_manager_data_inputer_unittest.cc deleted file mode 100644 index 43c2ad15..00000000 --- a/tests/ut/ge/graph/load/new_model_manager_data_inputer_unittest.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -#include - -#include "graph/load/model_manager/data_inputer.h" - -#include "common/debug/log.h" -#include "common/debug/memory_dumper.h" -#include "common/types.h" -#include "new_op_test_utils.h" - -using namespace std; -using namespace testing; - -namespace ge { - -class UtestModelManagerDataInputer : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -/// InputDataWrapper -/// constructor -/// GetInput -TEST_F(UtestModelManagerDataInputer, inputdatawrapper_construct) { - InputDataWrapper *input_data_wrapper = new InputDataWrapper(); - - input_data_wrapper->GetInput(); - - delete input_data_wrapper; -} - -/// InputDataWrapper -/// Init func with correct input -TEST_F(UtestModelManagerDataInputer, success_inputdatawrapper_init) { - InputDataWrapper *input_data_wrapper = new InputDataWrapper(); - ge::InputData input_data; - ge::OutputData output_data; - Status ret = input_data_wrapper->Init(input_data, output_data); - - EXPECT_EQ(ret, SUCCESS); - - delete input_data_wrapper; - input_data_wrapper = NULL; -} - -} // namespace ge diff --git a/tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc b/tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc deleted file mode 100644 index 38a250ad..00000000 --- a/tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc +++ /dev/null @@ -1,1433 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "common/debug/log.h" -#include "common/debug/memory_dumper.h" -#include "common/types.h" - -#define private public -#define protected public -#include "graph/compute_graph.h" -#include "graph/utils/graph_utils.h" -#include "graph/model_serialize.h" -#include "graph/load/model_manager/davinci_model.h" -#include "common/properties_manager.h" -#include "common/op/ge_op_utils.h" -#include -#include "runtime/dev.h" -#include "runtime/kernel.h" -#include "cce/fwk_adpt_struct.h" -#include "graph/load/model_manager/task_info/task_info_factory.h" -#include "graph/load/model_manager/task_info/task_info.h" -#include "graph/load/model_manager/task_info/stream_active_task_info.h" -#include "graph/load/model_manager/task_info/stream_switch_task_info.h" -#include "graph/load/model_manager/task_info/profiler_trace_task_info.h" -#include "graph/load/model_manager/task_info/memcpy_async_task_info.h" -#include "graph/load/model_manager/task_info/label_set_task_info.h" -#include "graph/load/model_manager/task_info/kernel_ex_task_info.h" -#include "graph/load/model_manager/task_info/kernel_task_info.h" -#include "graph/load/model_manager/task_info/hccl_task_info.h" -#include "graph/load/model_manager/task_info/fusion_start_task_info.h" -#include "graph/load/model_manager/task_info/fusion_stop_task_info.h" -#include "graph/load/model_manager/task_info/event_record_task_info.h" -#include "graph/load/model_manager/task_info/event_wait_task_info.h" -#include "graph/manager/graph_var_manager.h" -#include "graph/load/model_manager/model_manager.h" -#undef private -#undef protected - -#include "new_op_test_utils.h" -#include "graph/debug/ge_attr_define.h" -using namespace std; -using namespace testing; -using domi::EventExDef; -using domi::KernelContext; -using domi::KernelDef; -using domi::LogTimeStampDef; -using domi::ModelTaskDef; -using domi::StreamActiveDef; -using domi::TaskDef; - -namespace ge { -class UtestModelManagerDavinciModel : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -class DModelListener : public ge::ModelListener { - public: - DModelListener(){}; - uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t resultCode) { - GELOGI("In Call back. OnComputeDone"); - return 0; - } -}; - -shared_ptr g_label_call_back(new DModelListener()); - -static ge::OpDescPtr CreateOpDesc(string name = "", string type = "") { - auto op_desc = std::make_shared(name, type); - op_desc->SetStreamId(0); - op_desc->SetId(0); - - ge::AttrUtils::SetFloat(op_desc, ge::ATTR_NAME_ALPHA, 0); - ge::AttrUtils::SetFloat(op_desc, ge::ATTR_NAME_BETA, 0); - - op_desc->SetWorkspace({}); - ; - op_desc->SetWorkspaceBytes({}); - op_desc->SetInputOffset({}); - op_desc->SetOutputOffset({}); - - ge::AttrUtils::SetListStr(op_desc, ge::ATTR_NAME_WEIGHT_NAME, {}); - ge::AttrUtils::SetInt(op_desc, ge::POOLING_ATTR_MODE, 0); - ge::AttrUtils::SetInt(op_desc, ge::POOLING_ATTR_PAD_MODE, 0); - ge::AttrUtils::SetInt(op_desc, ge::POOLING_ATTR_DATA_MODE, 0); - ge::AttrUtils::SetInt(op_desc, ge::POOLING_ATTR_CEIL_MODE, 0); - ge::AttrUtils::SetInt(op_desc, ge::POOLING_ATTR_NAN_OPT, 0); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_WINDOW, {}); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_PAD, {}); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_STRIDE, {}); - ge::AttrUtils::SetListInt(op_desc, ge::ATTR_NAME_ACTIVE_STREAM_LIST, {1, 1}); - ge::AttrUtils::SetInt(op_desc, ge::ATTR_NAME_STREAM_SWITCH_COND, 0); - ge::AttrUtils::SetInt(op_desc, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, FMK_TYPE_T); - return op_desc; -} - -// tset failed_rt_free_host -TEST_F(UtestModelManagerDavinciModel, failed_rt_free_host) { - DavinciModel model(0, g_label_call_back); - - OutputData output_data; - - auto op_desc = CreateOpDesc("Pooling", "Pooling"); - op_desc->SetOutputOffset({1}); - op_desc->SetInputOffset({1}); - - { - ge::GeTensorDesc in_desc(ge::GeShape({1, 1, 1, 1})); - ge::TensorUtils::SetSize(in_desc, 16); - ge::TensorUtils::SetOutputTensor(in_desc, false); - ge::TensorUtils::SetInputTensor(in_desc, true); - op_desc->AddInputDesc(in_desc); - } - - { - ge::GeTensorDesc out_desc(ge::GeShape({1, 1, 1, 1})); - ge::TensorUtils::SetSize(out_desc, 16); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetInputTensor(out_desc, false); - op_desc->AddOutputDesc(out_desc); - } - ge::AttrUtils::SetInt(op_desc, ge::POOLING_ATTR_PAD_MODE, cce::CC_PADDING_DIRECTASSIGN); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_PAD, vector({1, 1, 1, 1})); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_WINDOW, vector({1, 1})); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_STRIDE, vector({1, 1})); - - auto compute_graph = make_shared("g"); - auto node = compute_graph->AddNode(op_desc); - - OmeTestOpUtils::InitModel(model); - - model.data_op_list_.push_back(op_desc); - - EXPECT_EQ(ge::INTERNAL_ERROR, model.ReturnResult(1, false, false, &output_data)); -} - -// test modeldef_fail -TEST_F(UtestModelManagerDavinciModel, contruct_modeldef_createfail) { - DavinciModel model(0, g_label_call_back); - - OmeTestOpUtils::InitModel(model); - - auto op_desc = CreateOpDesc("Pooling", "Pooling"); - op_desc->SetOutputOffset({1}); - op_desc->SetInputOffset({1}); - - { - ge::GeTensorDesc in_desc(ge::GeShape({1, 1, 1, 1})); - ge::TensorUtils::SetSize(in_desc, 16); - ge::TensorUtils::SetOutputTensor(in_desc, false); - ge::TensorUtils::SetInputTensor(in_desc, true); - op_desc->AddInputDesc(in_desc); - } - - { - ge::GeTensorDesc out_desc(ge::GeShape({1, 1, 1, 1})); - ge::TensorUtils::SetSize(out_desc, 16); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetInputTensor(out_desc, false); - op_desc->AddOutputDesc(out_desc); - } - ge::AttrUtils::SetInt(op_desc, ge::POOLING_ATTR_PAD_MODE, cce::CC_PADDING_DIRECTASSIGN); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_PAD, vector({1, 1, 1, 1})); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_WINDOW, vector({1, 1})); - ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_STRIDE, vector({1, 1})); - - model.GetEventList(); -} - -// test CopyInputDataToModel -TEST_F(UtestModelManagerDavinciModel, copy_input_data_to_model_fail) { - DavinciModel model(0, g_label_call_back); - - ge::InputData input_data; - ge::DataBuffer data_buffer; - data_buffer.data = new char[16]; - data_buffer.length = 16; - input_data.index = 0; - input_data.model_id = 1; - input_data.blobs.push_back(data_buffer); - - model.op_list_.clear(); - - delete[](char *) data_buffer.data; -} - -// test StreamNum -TEST_F(UtestModelManagerDavinciModel, streamnum_success) { - DavinciModel *model = new DavinciModel(0, g_label_call_back); - - OmeTestOpUtils::InitModel(*model); - - EXPECT_EQ(0, model->StreamNum()); - EXPECT_EQ(ge::INTERNAL_ERROR, model->ModelRunStart()); - - EXPECT_EQ(ge::SUCCESS, model->ModelRunStop()); - - delete model; -} - -// test EventNum -TEST_F(UtestModelManagerDavinciModel, eventnum_success) { - DavinciModel *model = new DavinciModel(0, g_label_call_back); - - OmeTestOpUtils::InitModel(*model); - - EXPECT_EQ(0, model->EventNum()); - EXPECT_EQ(ge::INTERNAL_ERROR, model->ModelRunStart()); - - EXPECT_EQ(ge::SUCCESS, model->ModelRunStop()); - - delete model; -} - -TEST_F(UtestModelManagerDavinciModel, handlelist_success) { - DavinciModel *model = new DavinciModel(0, g_label_call_back); - - OmeTestOpUtils::InitModel(*model); - - EXPECT_EQ(ge::INTERNAL_ERROR, model->ModelRunStart()); - - EXPECT_EQ(ge::SUCCESS, model->ModelRunStop()); - - delete model; -} - -// test GetEventList -TEST_F(UtestModelManagerDavinciModel, eventlist_success) { - DavinciModel *model = new DavinciModel(0, g_label_call_back); - - OmeTestOpUtils::InitModel(*model); - - EXPECT_EQ(true, model->GetEventList().empty()); - EXPECT_EQ(ge::INTERNAL_ERROR, model->ModelRunStart()); - - EXPECT_EQ(ge::SUCCESS, model->ModelRunStop()); - - delete model; -} - -// test Shrink -TEST_F(UtestModelManagerDavinciModel, shrink_success) { - DavinciModel model(0, g_label_call_back); - OpDescPtr op_desc_ptr = make_shared("Cast", "Cast"); - void *addr = nullptr; - rtMalloc(&addr, 128, RT_MEMORY_HBM); - model.saved_task_addrs_.emplace(op_desc_ptr, addr); - model.Shrink(); - EXPECT_EQ(model.saved_task_addrs_.isEmpty(), true); -} - -// test rtMalloc -TEST_F(UtestModelManagerDavinciModel, failed_reset_device) { - DavinciModel model(0, g_label_call_back); - ge::OutputData output_data; - ge::DataBuffer buf_data; - rtMalloc(&buf_data.data, 128, RT_MEMORY_HBM); - buf_data.length = 128; - output_data.blobs.push_back(buf_data); - EXPECT_EQ(ge::INTERNAL_ERROR, model.ReturnResult(1, true, false, &output_data)); - rtFree(buf_data.data); -} - -// test priority -TEST_F(UtestModelManagerDavinciModel, init_not_support_priority) { - int32_t priority = 8; - DavinciModel model(priority, g_label_call_back); -} - -// test GetInputOutputDescInfo -TEST_F(UtestModelManagerDavinciModel, success_GetInputOutputDescInfo_without_netoutput) { - DavinciModel model(0, g_label_call_back); - - auto op_desc = CreateOpDesc("Data", "Data"); - op_desc->SetOutputOffset({1}); - op_desc->SetInputOffset({1}); - op_desc->SetStreamId(0); - - { - ge::GeTensorDesc in_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(in_desc, false); - ge::TensorUtils::SetInputTensor(in_desc, true); - op_desc->AddInputDesc(in_desc); - } - - { - ge::GeTensorDesc out_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetInputTensor(out_desc, false); - op_desc->AddOutputDesc(out_desc); - } - - op_desc->SetSrcName({"Pooling1", "Pooling0"}); - op_desc->SetSrcIndex({0, 1}); - - auto compute_graph = make_shared("g"); - auto node = compute_graph->AddNode(op_desc); - - model.data_op_list_.push_back(op_desc); - model.output_size_list_.push_back(32); - - model.op_list_[0] = op_desc; - - model.output_op_list_.push_back(op_desc); - - vector input_shapes; - vector output_shapes; - EXPECT_EQ(ge::SUCCESS, model.GetInputOutputDescInfo(input_shapes, output_shapes)); -} - -TEST_F(UtestModelManagerDavinciModel, CopyTensorFromSrcVarNode_input_is_nullptr) { - NodePtr src_node = nullptr; - NodePtr dst_node = nullptr; - DavinciModel model(0, g_label_call_back); - Status ret = model.CopyTensorFromSrcVarNode(src_node, dst_node); - EXPECT_EQ(FAILED, ret); -} - -TEST_F(UtestModelManagerDavinciModel, CopyTensorFromSrcVarNode_success) { - ge::ComputeGraphPtr graph = std::make_shared("default"); - OpDescPtr op_desc_ptr = make_shared("Cast", "Cast"); - GeTensorDesc dims_tensor_desc(GeShape({1, 1, 1, 1}), FORMAT_NCHW, DT_FLOAT16); - GeTensorDesc dims_tensor_desc_in(GeShape({1, 1, 1, 1}), FORMAT_NCHW, DT_FLOAT); - op_desc_ptr->AddInputDesc(dims_tensor_desc_in); - op_desc_ptr->AddOutputDesc(dims_tensor_desc); - - NodePtr src_node = graph->AddNode(op_desc_ptr); - NodePtr dst_node = graph->AddNode(op_desc_ptr); - DavinciModel model(0, g_label_call_back); - Status ret = model.CopyTensorFromSrcVarNode(src_node, dst_node); -} - -TEST_F(UtestModelManagerDavinciModel, CopyVarData_graph_is_nullptr) { - ge::ComputeGraphPtr graph = nullptr; - DavinciModel model(0, g_label_call_back); - Status ret = model.CopyVarData(graph); - EXPECT_EQ(FAILED, ret); -} - -TEST_F(UtestModelManagerDavinciModel, copy_var_data_success) { - ge::ComputeGraphPtr graph = std::make_shared("default"); - OpDescPtr op_desc_ptr = make_shared("Variable", "Variable"); - GeTensorDesc dims_tensor_desc(GeShape({1, 1, 1, 1}), FORMAT_NCHW, DT_FLOAT16); - GeTensorDesc dims_tensor_desc_in(GeShape({1, 1, 1, 1}), FORMAT_NCHW, DT_FLOAT16); - op_desc_ptr->AddInputDesc(dims_tensor_desc_in); - op_desc_ptr->AddOutputDesc(dims_tensor_desc); - - NodePtr src_node = graph->AddNode(op_desc_ptr); - (void)ge::AttrUtils::SetStr(src_node->GetOpDesc(), "_copy_from_var_node", "abc"); - (void)ge::AttrUtils::SetBool(src_node->GetOpDesc(), "_copy_value", false); - - DavinciModel model(0, g_label_call_back); - Status ret = model.CopyVarData(graph); -} - -TEST_F(UtestModelManagerDavinciModel, get_input_output_desc_info_without_data_op_list) { - DavinciModel model(0, g_label_call_back); - vector input_list; - vector output_list; - Status ret = model.GetInputOutputDescInfo(input_list, output_list); - EXPECT_EQ(SUCCESS, ret); -} - -// test GetInputOutputDescInfo -TEST_F(UtestModelManagerDavinciModel, success_get_input_output_descInfo_with_net_output) { - DavinciModel model(0, g_label_call_back); - - auto op_desc = CreateOpDesc("Data", "Data"); - op_desc->SetOutputOffset({1}); - op_desc->SetInputOffset({1}); - op_desc->SetStreamId(0); - - { - ge::GeTensorDesc in_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(in_desc, false); - ge::TensorUtils::SetInputTensor(in_desc, true); - op_desc->AddInputDesc(in_desc); - } - - { - ge::GeTensorDesc out_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetInputTensor(out_desc, false); - op_desc->AddOutputDesc(out_desc); - } - op_desc->SetSrcName({"Pooling1", "Pooling0"}); - op_desc->SetSrcIndex({0, 1}); - - auto compute_graph = make_shared("g"); - auto data_node = compute_graph->AddNode(op_desc); - - model.data_op_list_.push_back(op_desc); - - op_desc->SetType("NetOutput"); - - auto no_node = compute_graph->AddNode(op_desc); - - model.op_list_[0] = op_desc; - - model.output_op_list_.push_back(op_desc); - model.output_size_list_.push_back(32); - - vector input_shapes; - vector output_shapes; - EXPECT_EQ(ge::SUCCESS, model.GetInputOutputDescInfo(input_shapes, output_shapes)); -} - -TEST_F(UtestModelManagerDavinciModel, success_get_input_output_desc_info_for_zero_copy_with_net_output) { - DavinciModel model(0, g_label_call_back); - - auto op_desc = CreateOpDesc("Data", "Data"); - op_desc->SetOutputOffset({1}); - op_desc->SetInputOffset({1}); - op_desc->SetStreamId(0); - - { - ge::GeTensorDesc in_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(in_desc, false); - ge::TensorUtils::SetInputTensor(in_desc, true); - op_desc->AddInputDesc(in_desc); - } - - { - ge::GeTensorDesc out_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetInputTensor(out_desc, false); - op_desc->AddOutputDesc(out_desc); - } - - op_desc->SetSrcName({"Pooling1", "Pooling0"}); - op_desc->SetSrcIndex({0, 1}); - - auto compute_graph = make_shared("g"); - auto data_node = compute_graph->AddNode(op_desc); - - model.data_op_list_.push_back(op_desc); - - op_desc->SetType("NetOutput"); - - auto net_out_node = compute_graph->AddNode(op_desc); - model.op_list_[0] = op_desc; - - model.output_op_list_.push_back(op_desc); - model.output_size_list_.push_back(32); - model.output_memory_size_list_.push_back(64); - - vector input_shapes; - vector output_shapes; - EXPECT_EQ(ge::SUCCESS, model.GetInputOutputDescInfoForZeroCopy(input_shapes, output_shapes)); -} - -TEST_F(UtestModelManagerDavinciModel, success_get_input_output_desc_info_dim_size_not4) { - DavinciModel model(0, g_label_call_back); - - auto op_desc = CreateOpDesc("Data", "Data"); - op_desc->SetOutputOffset({1}); - op_desc->SetInputOffset({1}); - op_desc->SetStreamId(0); - - { - ge::GeTensorDesc in_desc(ge::GeShape({1, 1, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(in_desc, false); - ge::TensorUtils::SetInputTensor(in_desc, true); - op_desc->AddInputDesc(in_desc); - } - - { - ge::GeTensorDesc out_desc(ge::GeShape({1, 1, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetInputTensor(out_desc, false); - op_desc->AddOutputDesc(out_desc); - } - - op_desc->SetSrcName({"Pooling1", "Pooling0"}); - op_desc->SetSrcIndex({0, 1}); - - auto compute_graph = make_shared("g"); - auto data_node = compute_graph->AddNode(op_desc); - - model.data_op_list_.push_back(op_desc); - - op_desc->SetType("NetOutput"); - - auto net_out_node = compute_graph->AddNode(op_desc); - model.op_list_[0] = op_desc; - - model.output_op_list_.push_back(op_desc); - model.output_size_list_.push_back(32); - - vector input_shapes; - vector output_shapes; - EXPECT_EQ(ge::SUCCESS, model.GetInputOutputDescInfo(input_shapes, output_shapes)); -} - -// test GetLabelList -TEST_F(UtestModelManagerDavinciModel, get_label_list_success) { - DavinciModel model(0, g_label_call_back); - OmeTestOpUtils::InitModel(model); - vector label_list; - model.label_list_ = label_list; - EXPECT_EQ(label_list, model.GetLabelList()); -} - -// test GetInputListSize -TEST_F(UtestModelManagerDavinciModel, get_label_list_size_success) { - DavinciModel model(0, g_label_call_back); - OmeTestOpUtils::InitModel(model); - vector data_op_list; - data_op_list.push_back(std::make_shared()); - model.data_op_list_ = data_op_list; -} - -// test GetFlowctrlOpList -TEST_F(UtestModelManagerDavinciModel, get_flow_ctrl_op_list_success) { - DavinciModel model(0, g_label_call_back); - OmeTestOpUtils::InitModel(model); - std::map flowctrl_op_index_internal_map; - flowctrl_op_index_internal_map.insert(pair(1, 1)); - model.flowctrl_op_index_internal_map_ = flowctrl_op_index_internal_map; -} - -// test SetFlowctrlOpList -TEST_F(UtestModelManagerDavinciModel, get_flow_ctrl_index_success) { - DavinciModel model(0, g_label_call_back); - OmeTestOpUtils::InitModel(model); - EXPECT_EQ(0, model.GetFlowctrlIndex(0)); - EXPECT_EQ(1, model.GetFlowctrlIndex(0)); - EXPECT_EQ(0, model.GetFlowctrlIndex(1)); - EXPECT_EQ(1, model.GetFlowctrlIndex(1)); - EXPECT_EQ(2, model.GetFlowctrlIndex(0)); -} - -// test GetRegisterStub -TEST_F(UtestModelManagerDavinciModel, success_get_register_stub) { - DavinciModel model(0, g_label_call_back); - OmeTestOpUtils::InitModel(model); - std::string binfile = "tvmbin"; - string ret = model.GetRegisterStub(binfile); - EXPECT_EQ("tvmbin", ret); - model.tvm_bin_kernel_.insert("tvmbin"); - ret = model.GetRegisterStub(binfile); - EXPECT_EQ("tvmbin", ret); -} - -// test InitTbeHandle -TEST_F(UtestModelManagerDavinciModel, success_init_tbe_handle) { - DavinciModel model(0, g_label_call_back); - OmeTestOpUtils::InitModel(model); - std::shared_ptr op_desc = std::make_shared(); - Status ret = model.InitTbeHandle(op_desc); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); -} - -// test InitTVMTask failed -TEST_F(UtestModelManagerDavinciModel, init_tvm_task_failed1) { - DavinciModel model(0, g_label_call_back); - uint16_t offset = 0; - TaskDef *task_def = new TaskDef(); - KernelDef *kernel_def = task_def->mutable_kernel(); - map op_list; - model.op_list_ = op_list; - - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - Status ret = kernel_task_info->InitTVMTask(&model, offset, kernel_def[0]); - EXPECT_EQ(INTERNAL_ERROR, ret); - task_def->clear_kernel(); - delete kernel_task_info; - delete task_def; -} - -TEST_F(UtestModelManagerDavinciModel, kernel_taskInfo_init_cce_task_failed1) { - DavinciModel model(0, g_label_call_back); - - TaskDef *task_def = new TaskDef(); - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - KernelDef *kernel_def = task_def->mutable_kernel(); - Status ret = kernel_task_info->InitCceTask(&model, kernel_def[0]); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); - task_def->clear_kernel(); - delete kernel_task_info; - delete task_def; -} - -// test SetContext success -TEST_F(UtestModelManagerDavinciModel, success_kernel_taskInfo_init_set_context) { - DavinciModel model(0, g_label_call_back); - - TaskDef *task_def = new TaskDef(); - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - KernelDef *kernel_def = task_def->mutable_kernel(); - KernelContext *context = kernel_def->mutable_context(); - context->set_op_id(1); - context->set_kernel_func_id(1); - context->set_is_flowtable(true); - context->set_args_count(1); - context->set_args_offset("args111111", 10); - - Status ret = kernel_task_info->SetContext(kernel_def[0]); - EXPECT_EQ(ge::SUCCESS, ret); - - ret = kernel_task_info->Release(); - EXPECT_EQ(ge::SUCCESS, ret); - kernel_def->clear_context(); - task_def->clear_kernel(); - delete kernel_task_info; - delete task_def; -} - -// test SetContext failed -TEST_F(UtestModelManagerDavinciModel, kernel_taskInfo_init_set_context_failed1) { - DavinciModel model(0, g_label_call_back); - - TaskDef *task_def = new TaskDef(); - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - KernelDef *kernel_def = task_def->mutable_kernel(); - KernelContext *context = kernel_def->mutable_context(); - context->set_op_id(1); - context->set_kernel_func_id(1); - context->set_is_flowtable(true); - context->set_args_count(0); - Status ret = kernel_task_info->SetContext(kernel_def[0]); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); - - kernel_def->clear_context(); - task_def->clear_kernel(); - delete kernel_task_info; - delete task_def; -} - -TEST_F(UtestModelManagerDavinciModel, kernel_taskInfo_init_set_context_failed2) { - DavinciModel model(0, g_label_call_back); - - TaskDef *task_def = new TaskDef(); - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - KernelDef *kernel_def = task_def->mutable_kernel(); - KernelContext *context = kernel_def->mutable_context(); - context->set_op_id(1); - context->set_kernel_func_id(1); - context->set_is_flowtable(true); - context->set_args_count(5); - context->set_args_offset("\0\0"); // args_offset = 0 - - Status ret = kernel_task_info->SetContext(kernel_def[0]); - EXPECT_EQ(ge::PARAM_INVALID, ret); - - kernel_def->clear_context(); - task_def->clear_kernel(); - delete kernel_task_info; - delete task_def; -} - -// test success DistributeDumpTask -TEST_F(UtestModelManagerDavinciModel, success_distribute_dump_task) { - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - KernelDef *kernel_def = task_def->mutable_kernel(); - - kernel_def->set_stub_func("kerneltaskinfo"); - kernel_def->set_block_dim(10); - kernel_def->set_args("args111111", 10); - kernel_def->set_args_size(10); - rtSmDesc_t l2CtrlInfo; - l2CtrlInfo.data[0].L2_mirror_addr = 1024; - kernel_def->set_sm_desc((void *)&l2CtrlInfo, sizeof(rtSmDesc_t)); - - // for SetStream - rtStream_t stream = nullptr; - rtStreamCreate(&stream, 0); - std::vector stream_list; - stream_list.push_back(stream); - Status ret = kernel_task_info->SetStream(0, stream_list); - EXPECT_EQ(SUCCESS, ret); - - ret = kernel_task_info->Release(); - EXPECT_EQ(SUCCESS, ret); - rtStreamDestroy(stream); - task_def->clear_kernel(); - delete kernel_task_info; - delete task_def; -} - -// test success GetTaskID -TEST_F(UtestModelManagerDavinciModel, success_get_task_id) { - ModelTaskDef *model_task_def = new ModelTaskDef(); - TaskDef *task = model_task_def->add_task(); - task->set_type(RT_MODEL_TASK_KERNEL); - TaskInfoPtr task_info = TaskInfoFactory::Instance().Create(static_cast(task->type())); - - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - uint32_t ret = task_info->GetTaskID(); - EXPECT_EQ(0, ret); - ret = kernel_task_info->GetTaskID(); - EXPECT_EQ(0, ret); - HcclTaskInfo *hccl_task_info = new HcclTaskInfo(); - ret = hccl_task_info->GetTaskID(); - EXPECT_EQ(0, ret); - - delete hccl_task_info; - delete kernel_task_info; - delete model_task_def; -} - -// test StoreInputOutputTensor success -TEST_F(UtestModelManagerDavinciModel, success_store_input_output_tensor) { - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - KernelTaskInfo *kernel_task_info = new KernelTaskInfo(); - - std::vector input_data_addrs; - std::vector output_data_addrs; - std::vector<::tagCcAICPUTensor> input_descs; - std::vector<::tagCcAICPUTensor> output_descs; - - int test = 1; - int *addr = &test; - void *input; - void *output; - input = addr; - output = addr; - input_data_addrs.push_back(&input); - output_data_addrs.push_back(output); - - tagCcAICPUTensor input_desc; - tagCcAICPUTensor output_desc; - input_descs.push_back(input_desc); - output_descs.push_back(output_desc); - - Status ret = kernel_task_info->StoreInputOutputTensor(input_data_addrs, output_data_addrs, input_descs, output_descs); - EXPECT_EQ(SUCCESS, ret); - ret = kernel_task_info->Release(); - EXPECT_EQ(SUCCESS, ret); - delete kernel_task_info; - delete task_def; -} - -// test init EventRecordTaskInfo -TEST_F(UtestModelManagerDavinciModel, success_event_record_task_init) { - DavinciModel *model1 = nullptr; - TaskDef *task_def1 = new TaskDef(); - EventRecordTaskInfo *eventRecordTaskInfo1 = new EventRecordTaskInfo(); - Status ret1 = eventRecordTaskInfo1->Init(task_def1[0], model1); - EXPECT_EQ(PARAM_INVALID, ret1); - - delete eventRecordTaskInfo1; - delete task_def1; - delete model1; - DavinciModel model(0, g_label_call_back); - - ModelTaskDef *model_task_info = new ModelTaskDef(); - TaskDef *task = model_task_info->add_task(); - task->set_type(RT_MODEL_TASK_EVENT_RECORD); - TaskInfoPtr task_info = TaskInfoFactory::Instance().Create(static_cast(task->type())); - - task->stream_id_ = 0; - rtStream_t rt_stream; - rtStreamCreate(&rt_stream, 1); - vector stream_list; - stream_list.push_back(rt_stream); - model.stream_list_ = stream_list; - - task->set_event_id(1); - model.runtime_param_.event_num = 1; - Status ret = task_info->Init(task[0], &model); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); - - model.runtime_param_.event_num = 2; - rtEvent_t event1; - rtEvent_t event2; - rtEventCreate(&event1); - rtEventCreate(&event2); - model.event_list_.push_back(event1); - model.event_list_.push_back(event2); - - EventExDef *event_ex_def = task->mutable_event_ex(); - event_ex_def->set_event_type(1); - - ret = task_info->Init(task[0], &model); - EXPECT_EQ(SUCCESS, ret); - - task->clear_event_ex(); - task_info->Release(); - delete model_task_info; -} - -// test init EventWaitTaskInfo -TEST_F(UtestModelManagerDavinciModel, success_event_wait_task_init) { - DavinciModel *model1 = nullptr; - TaskDef *task_def1 = new TaskDef(); - EventWaitTaskInfo *event_wait_task_info1 = new EventWaitTaskInfo(); - Status ret1 = event_wait_task_info1->Init(task_def1[0], model1); - EXPECT_EQ(PARAM_INVALID, ret1); - - delete event_wait_task_info1; - delete task_def1; - delete model1; - DavinciModel model(0, g_label_call_back); - - ModelTaskDef *model_task_info = new ModelTaskDef(); - TaskDef *task = model_task_info->add_task(); - task->set_type(RT_MODEL_TASK_EVENT_WAIT); - TaskInfoPtr task_info = TaskInfoFactory::Instance().Create(static_cast(task->type())); - - task->stream_id_ = 0; - rtStream_t rt_stream; - rtStreamCreate(&rt_stream, 1); - vector stream_list; - stream_list.push_back(rt_stream); - model.stream_list_ = stream_list; - - task->set_event_id(1); - model.runtime_param_.event_num = 1; - Status ret = task_info->Init(task[0], &model); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); - - model.runtime_param_.event_num = 2; - rtEvent_t event1; - rtEvent_t event2; - rtEventCreate(&event1); - rtEventCreate(&event2); - model.event_list_.push_back(event1); - model.event_list_.push_back(event2); - - EventExDef *event_ex_def = task->mutable_event_ex(); - event_ex_def->set_event_type(1); - - ret = task_info->Init(task[0], &model); - EXPECT_EQ(SUCCESS, ret); - - task->clear_event_ex(); - task_info->Release(); - delete model_task_info; -} - -// test fusion_start_task Init -TEST_F(UtestModelManagerDavinciModel, success_fusion_start_task_init) { - DavinciModel *model1 = nullptr; - TaskDef *task_def1 = new TaskDef(); - FusionStartTaskInfo *fusion_start_task_info1 = new FusionStartTaskInfo(); - Status ret1 = fusion_start_task_info1->Init(task_def1[0], model1); - EXPECT_EQ(PARAM_INVALID, ret1); - - delete fusion_start_task_info1; - delete task_def1; - delete model1; - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - FusionStartTaskInfo *fusion_start_task_info = new FusionStartTaskInfo(); - task_def->set_stream_id(0); - rtStream_t stream; - rtStreamCreate(&stream, 0); - model.stream_list_.push_back(stream); - - Status ret = fusion_start_task_info->Init(task_def[0], &model); - EXPECT_EQ(SUCCESS, ret); - delete fusion_start_task_info; - delete task_def; -} - -// test fusion_end_task Init -TEST_F(UtestModelManagerDavinciModel, success_fusion_end_task_rinit) { - DavinciModel *model1 = nullptr; - TaskDef *task_def1 = new TaskDef(); - FusionStopTaskInfo *fusion_stop_task_info1 = new FusionStopTaskInfo(); - Status ret1 = fusion_stop_task_info1->Init(task_def1[0], model1); - EXPECT_EQ(PARAM_INVALID, ret1); - - delete fusion_stop_task_info1; - delete task_def1; - delete model1; - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - FusionStopTaskInfo *fusion_stop_task_info = new FusionStopTaskInfo(); - task_def->set_stream_id(0); - rtStream_t stream; - rtStreamCreate(&stream, 0); - model.stream_list_.push_back(stream); - - Status ret = fusion_stop_task_info->Init(task_def[0], &model); - EXPECT_EQ(SUCCESS, ret); - delete fusion_stop_task_info; - delete task_def; -} - -// test kernel_ex_task_Release -TEST_F(UtestModelManagerDavinciModel, success_kernel_ex_task_release) { - KernelExTaskInfo *kernel_ex_task_info = new KernelExTaskInfo(); - Status ret = kernel_ex_task_info->Release(); - EXPECT_EQ(SUCCESS, ret); - - delete kernel_ex_task_info; -} - -// test hccl_Distribute -TEST_F(UtestModelManagerDavinciModel, success_Distribute7) { - DavinciModel model(0, g_label_call_back); - - ModelTaskDef *model_task_def = new ModelTaskDef(); - TaskDef *task7 = model_task_def->add_task(); - task7->set_type(RT_MODEL_TASK_HCCL); - TaskInfoPtr task_info7 = TaskInfoFactory::Instance().Create(static_cast(task7->type())); - Status ret = task_info7->Init(task7[0], &model); - EXPECT_EQ(FAILED, ret); - - std::vector task_list; - task_list.push_back(task_info7); - model.task_list_ = task_list; - - task_info7->Release(); - delete model_task_def; -} - -// test hccl_GetPrivateDefByTaskDef -TEST_F(UtestModelManagerDavinciModel, success_hccl_get_private_def_by_task_def) { - DavinciModel model(0, g_label_call_back); - - ModelTaskDef *model_task_def = new ModelTaskDef(); - TaskDef *task7 = model_task_def->add_task(); - task7->set_type(RT_MODEL_TASK_HCCL); - // for SetStream - rtStream_t stream = nullptr; - rtStreamCreate(&stream, 0); - model.stream_list_.push_back(stream); - // for GetPrivateDefByTaskDef - task7->set_ops_kernel_store_ptr(10); - std::string value = "hccl_task"; - task7->set_private_def(value); - - TaskInfoPtr task_info7 = TaskInfoFactory::Instance().Create(static_cast(task7->type())); - // for Distribute - Status ret = task_info7->Init(task7[0], &model); - EXPECT_EQ(ge::PARAM_INVALID, ret); - - task_info7->Release(); - delete model_task_def; -} - -// test hccl_task_TransToGETaskInfo -TEST_F(UtestModelManagerDavinciModel, success_hccl_trans_to_ge_task_info) { - DavinciModel model(0, g_label_call_back); - - ModelTaskDef *model_task_def = new ModelTaskDef(); - TaskDef *task7 = model_task_def->add_task(); - // for type - task7->set_type(RT_MODEL_TASK_HCCL); - TaskInfoPtr task_info7 = TaskInfoFactory::Instance().Create(static_cast(task7->type())); - - GETaskInfo ge_task; - HcclTaskInfo *hccl_task_info = new HcclTaskInfo(); - hccl_task_info->TransToGETaskInfo(ge_task); - - delete hccl_task_info; - delete model_task_def; -} - -// test stream_active_task Init -TEST_F(UtestModelManagerDavinciModel, success_stream_active_task_init) { - DavinciModel *model1 = nullptr; - TaskDef *task_def1 = new TaskDef(); - StreamActiveTaskInfo *stream_active_task_info1 = new StreamActiveTaskInfo(); - Status ret1 = stream_active_task_info1->Init(task_def1[0], model1); - EXPECT_EQ(PARAM_INVALID, ret1); - delete stream_active_task_info1; - delete task_def1; - delete model1; - - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - task_def->set_stream_id(0); - rtStream_t stream1, stream2; - rtStreamCreate(&stream1, 0); - rtStreamCreate(&stream2, 0); - model.stream_list_.push_back(stream1); - - StreamActiveTaskInfo *stream_active_task_info = new StreamActiveTaskInfo(); - - StreamActiveDef *stream_active_def = task_def->mutable_stream_active(); - stream_active_def->set_op_index(0); - stream_active_def->set_active_stream_id(0); - - std::map flowctrl; - flowctrl.insert(pair(1, 1)); - model.flowctrl_op_index_internal_map_ = flowctrl; - - auto opDef = CreateOpDesc("", ""); - model.op_list_[0] = opDef; - - Status ret = stream_active_task_info->Init(task_def[0], &model); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); // line 51 - - model.stream_list_.push_back(stream2); - ret = stream_active_task_info->Init(task_def[0], &model); - EXPECT_EQ(SUCCESS, ret); - - task_def->clear_stream_active(); - delete stream_active_task_info; - delete task_def; -} - -// test label_set_task Init -TEST_F(UtestModelManagerDavinciModel, success_label_set_task_init) { - DavinciModel *model1 = nullptr; - TaskDef *task_def1 = new TaskDef(); - LabelSetTaskInfo *label_set_task_info1 = new LabelSetTaskInfo(); - Status ret1 = label_set_task_info1->Init(task_def1[0], model1); - EXPECT_EQ(PARAM_INVALID, ret1); - delete label_set_task_info1; - delete task_def1; - delete model1; - - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - LabelSetTaskInfo *label_set_task_info = new LabelSetTaskInfo(); - task_def->set_stream_id(0); - rtStream_t stream; - rtStreamCreate(&stream, 0); - model.stream_list_.push_back(stream); - - task_def->set_label_id(1); - model.runtime_param_.batch_num = 0; - Status ret = label_set_task_info->Init(task_def[0], &model); - EXPECT_EQ(PARAM_INVALID, ret); - - task_def->clear_label_id(); - task_def->set_label_id(0); - model.runtime_param_.batch_num = 1; - rtLabel_t label; - rtLabelCreate(&label); - model.label_list_.push_back(label); - - ret = label_set_task_info->Init(task_def[0], &model); - EXPECT_EQ(SUCCESS, ret); - delete label_set_task_info; - delete task_def; -} - -// test label_goto_task init -TEST_F(UtestModelManagerDavinciModel, success_label_goto_task_init) { - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - LabelGotoTaskInfo *label_goto_task_info = new LabelGotoTaskInfo(); - task_def->set_stream_id(0); - - rtStream_t stream; - rtStreamCreate(&stream, 0); - model.stream_list_.push_back(stream); - - rtLabel_t label; - rtLabelCreate(&label); - model.label_list_.push_back(label); - - Status ret = label_goto_task_info->Init(task_def[0], &model); - EXPECT_EQ(SUCCESS, ret); - - delete label_goto_task_info; - delete task_def; -} - -// test profiler_trace_task init -TEST_F(UtestModelManagerDavinciModel, success_profiler_trace_task_init) { - DavinciModel *model1 = nullptr; - TaskDef *task_def1 = new TaskDef(); - ProfilerTraceTaskInfo *profiler_trace_task_info1 = new ProfilerTraceTaskInfo(); - Status ret1 = profiler_trace_task_info1->Init(task_def1[0], model1); - EXPECT_EQ(PARAM_INVALID, ret1); - - delete profiler_trace_task_info1; - delete task_def1; - delete model1; - DavinciModel model(0, g_label_call_back); - TaskDef *task_def = new TaskDef(); - task_def->set_stream_id(0); - rtStream_t stream; - rtStreamCreate(&stream, 0); - model.stream_list_.push_back(stream); - LogTimeStampDef *logTimeStampDef = task_def->mutable_log_timestamp(); - logTimeStampDef->set_logid(1); - logTimeStampDef->set_notify(1); - logTimeStampDef->set_flat(1); - ProfilerTraceTaskInfo *profiler_trace_task_info = new ProfilerTraceTaskInfo(); - Status ret = profiler_trace_task_info->Init(task_def[0], &model); - EXPECT_EQ(SUCCESS, ret); - - task_def->clear_log_timestamp(); - delete profiler_trace_task_info; - delete task_def; -} - -TEST_F(UtestModelManagerDavinciModel, profiling_model_success) { - rtStream_t stream = nullptr; - rtStreamCreate(&stream, 0); - - DavinciModel model(0, g_label_call_back); - model.model_id_ = 1; - model.name_ = "test"; - model.version_ = 0x01; - - model.stream_list_.push_back(stream); - - ge::ModelData data; - rtMallocHost(&data.model_data, 128); - data.model_len = 128; - - ModelDef *model_def = new ModelDef(); - auto op_def = CreateOpDesc("", "Data"); - op_def->SetInputOffset({1}); - op_def->SetOutputOffset({100}); - - ge::GeTensorDesc descin(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT); - ge::TensorUtils::SetSize(descin, 4); - op_def->AddInputDesc(descin); - ge::GeTensorDesc desc_out(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetSize(desc_out, 32); - op_def->AddInputDesc(desc_out); - op_def->SetId(0); - - model.data_op_list_.push_back(op_def); - model.op_list_[0] = op_def; - - auto opdef1 = CreateOpDesc("", "Relu"); - opdef1->SetInputOffset({1}); - opdef1->SetOutputOffset({100}); - - ge::GeTensorDesc desc_in1(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT); - ge::TensorUtils::SetSize(desc_in1, 4); - opdef1->AddInputDesc(desc_in1); - ge::GeTensorDesc desc_out1(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetSize(desc_out1, 32); - opdef1->AddInputDesc(desc_out1); - op_def->SetId(1); - - model.op_list_[1] = opdef1; - - auto opdef2 = CreateOpDesc("", "Relu"); - opdef2->SetInputOffset({1}); - opdef2->SetOutputOffset({100}); - - ge::GeTensorDesc desc_in2(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT); - ge::TensorUtils::SetSize(desc_in2, 4); - opdef2->AddInputDesc(desc_in2); - ge::GeTensorDesc desc_out2(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetSize(desc_out2, 32); - opdef2->AddInputDesc(desc_out2); - op_def->SetId(2); - - model.op_list_[2] = opdef2; - - auto opdef3 = CreateOpDesc("", "Relu"); - opdef3->SetInputOffset({1}); - opdef3->SetOutputOffset({100}); - - ge::GeTensorDesc desc_in3(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT); - ge::TensorUtils::SetSize(desc_in3, 4); - opdef3->AddInputDesc(desc_in3); - ge::GeTensorDesc desc_out3(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetSize(desc_out3, 32); - opdef3->AddInputDesc(desc_out3); - op_def->SetId(3); - - model.op_list_[3] = opdef3; - - auto opdef4 = CreateOpDesc("", "Relu"); - opdef4->SetInputOffset({1}); - opdef4->SetOutputOffset({100}); - - ge::GeTensorDesc desc_in4(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT); - ge::TensorUtils::SetSize(desc_in4, 4); - opdef4->AddInputDesc(desc_in4); - ge::GeTensorDesc desc_out4(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT16); - ge::TensorUtils::SetSize(desc_out4, 32); - opdef4->AddInputDesc(desc_out4); - op_def->SetId(4); - - model.op_list_[4] = opdef4; - - ge::InputData input_data; - ge::DataBuffer data_buffer; - data_buffer.data = new char[4]; - data_buffer.length = 4; - input_data.index = 0; - input_data.model_id = 1; - input_data.blobs.push_back(data_buffer); - - rtFreeHost(data.model_data); - delete[](char *) data_buffer.data; - delete model_def; -} - -TEST_F(UtestModelManagerDavinciModel, success_output_list_0) { - DavinciModel model(0, g_label_call_back); - - uint32_t version = 0; - uint64_t session_id = 0; - uint32_t device_id = 0; - uint64_t job_id = 0; - Status ret = VarManager::Instance(session_id)->Init(version, session_id, device_id, job_id); - EXPECT_EQ(ret, ge::SUCCESS); - - ret = model.ReturnNoOutput(1); - EXPECT_EQ(ret, ge::SUCCESS); - - VarManagerPool::Instance().Destroy(); -} - -// test dyncbatch_distributeTask_SUCCESS -TEST_F(UtestModelManagerDavinciModel, dyncbatch_distribute_task_success) { - DavinciModel model(0, g_label_call_back); - - rtStream_t stream = nullptr; - rtStreamCreate(&stream, 0); - - rtLabel_t label = nullptr; - rtLabelCreate(&label); - model.label_list_.push_back(label); - rtLabelCreate(&label); - model.label_list_.push_back(label); - rtLabelCreate(&label); - model.label_list_.push_back(label); - - rtLabelDestroy(label); - rtStreamDestroy(stream); -} - -// test GetOutputDescInfo -TEST_F(UtestModelManagerDavinciModel, success_get_output_desc_info_with_netoutput) { - setenv("GE_TRAIN", "1", true); - DavinciModel model(0, g_label_call_back); - - auto op_desc = CreateOpDesc("Data", "Data"); - op_desc->SetOutputOffset({1}); - op_desc->SetInputOffset({1}); - op_desc->SetStreamId(0); - - { - ge::GeTensorDesc in_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_FRACTAL_Z, ge::DT_FLOAT16); - ge::TensorUtils::SetOutputTensor(in_desc, false); - ge::TensorUtils::SetInputTensor(in_desc, true); - op_desc->AddInputDesc(in_desc); - } - - { - ge::GeTensorDesc out_desc(ge::GeShape({1, 1, 10, 10}), ge::FORMAT_NCHW, ge::DT_FLOAT); - ge::TensorUtils::SetOutputTensor(out_desc, true); - ge::TensorUtils::SetInputTensor(out_desc, false); - op_desc->AddOutputDesc(out_desc); - } - - op_desc->SetSrcName({"Pooling1", "Pooling0"}); - op_desc->SetSrcIndex({0, 1}); - - auto compute_graph = make_shared("g"); - - op_desc->SetType("NetOutput"); - - auto net_out_node = compute_graph->AddNode(op_desc); - model.op_list_[0] = op_desc; - - model.output_op_list_.push_back(op_desc); - model.output_size_list_.push_back(32); - model.output_memory_size_list_.push_back(64); - - vector output_shapes; - vector formats; - EXPECT_EQ(ge::SUCCESS, model.GetOutputDescInfo(output_shapes, formats)); - - setenv("GE_TRAIN", "0", true); -} - -TEST_F(UtestModelManagerDavinciModel, device_runtime_success_Run) { - rtStream_t stream = nullptr; - rtStreamCreate(&stream, 0); - - DavinciModel model(0, g_label_call_back); - - model.stream_list_.push_back(stream); - auto model_def = make_shared(); - - auto op_def = CreateOpDesc("", "Data"); - - auto compute_graph = make_shared("g"); - compute_graph->AddNode(op_def); - - model_def->SetGraph(ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph)); - - model.data_op_list_.push_back(op_def); - - model.data_inputer_ = new DataInputer(); - - model.ModelRunStart(); - - OutputData output_data; - ge::InputData input_data; - - ge::DataBuffer data_buffer; - data_buffer.data = new char[16]; - data_buffer.length = 16; - - input_data.index = 0; - input_data.model_id = 1; - input_data.blobs.push_back(data_buffer); - - model.ModelRunStop(); - - delete[](char *) data_buffer.data; -} - -TEST_F(UtestModelManagerDavinciModel, run_failed) { - rtStream_t stream = nullptr; - rtStreamCreate(&stream, 0); - - DavinciModel model(0, g_label_call_back); - - model.stream_list_.push_back(stream); - auto model_def = make_shared(); - - auto op_def = CreateOpDesc("", "Data"); - - auto compute_graph = make_shared("g"); - compute_graph->AddNode(op_def); - - model_def->SetGraph(ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph)); - - model.data_op_list_.push_back(op_def); - - model.data_inputer_ = new DataInputer(); - - model.ModelRunStart(); - - OutputData output_data; - ge::InputData input_data; - - ge::DataBuffer data_buffer; - data_buffer.data = new char[16]; - data_buffer.length = 16; - - input_data.index = 0; - input_data.model_id = 1; - input_data.blobs.push_back(data_buffer); - - model.ModelRunStop(); - delete[](char *) data_buffer.data; -} - -TEST_F(UtestModelManagerDavinciModel, run_failed01) { - rtStream_t stream = nullptr; - rtStreamCreate(&stream, 0); - - DavinciModel model(0, g_label_call_back); - - model.stream_list_.push_back(stream); - auto model_def = make_shared(); - - auto op_def = CreateOpDesc("", "Data"); - - auto compute_graph = make_shared("g"); - compute_graph->AddNode(op_def); - - model_def->SetGraph(ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph)); - - model.data_op_list_.push_back(op_def); - - model.data_inputer_ = nullptr; - model.ModelRunStart(); - - model.ModelRunStop(); -} - -TEST_F(UtestModelManagerDavinciModel, init_tbe_handle_fe_registered) { - DavinciModel::tvm_bin_kernel_.clear(); - DavinciModel model(0, g_label_call_back); - OpDescPtr op_desc = CreateOpDesc("MatMul", "MatMul"); - - std::vector kernelBin; - TBEKernelPtr tbe_kernel = std::make_shared("name/MatMul", std::move(kernelBin)); - op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); - - std::string kernel_name("kernel/MatMul"); - AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); - - EXPECT_EQ(model.InitTbeHandle(op_desc), SUCCESS); - EXPECT_EQ(model.InitTbeHandle(op_desc), SUCCESS); - - EXPECT_EQ(model.used_tbe_handle_map_.size(), 0); - DavinciModel::tvm_bin_kernel_.clear(); -} - -TEST_F(UtestModelManagerDavinciModel, init_tbe_handle_ge_registered) { - DavinciModel::tvm_bin_kernel_.clear(); - DavinciModel model(0, g_label_call_back); - OpDescPtr op_desc = CreateOpDesc("MatMul", "MatMul"); - - std::vector kernelBin; - TBEKernelPtr tbe_kernel = std::make_shared("name/MatMul", std::move(kernelBin)); - op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel); - - std::string kernel_name("kernel/MatMul"); - AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name); - - string session_graph_id; - AttrUtils::GetStr(op_desc, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); - const char *bin_file_key = DavinciModel::GetRegisterStub(op_desc->GetName(), session_graph_id); - model.used_tbe_handle_map_[bin_file_key] = 1; // test first register. - - EXPECT_EQ(model.InitTbeHandle(op_desc), SUCCESS); - EXPECT_EQ(model.InitTbeHandle(op_desc), SUCCESS); - - EXPECT_EQ(model.used_tbe_handle_map_.size(), 1); - - auto it = model.used_tbe_handle_map_.find(bin_file_key); - EXPECT_NE(it, model.used_tbe_handle_map_.end()); - EXPECT_EQ(it->second, 3); - DavinciModel::tvm_bin_kernel_.clear(); -} -} // namespace ge diff --git a/tests/ut/ge/graph/load/new_model_manager_event_manager_unittest.cc b/tests/ut/ge/graph/load/new_model_manager_event_manager_unittest.cc deleted file mode 100644 index ee708501..00000000 --- a/tests/ut/ge/graph/load/new_model_manager_event_manager_unittest.cc +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "common/debug/log.h" -#include "common/debug/memory_dumper.h" -#include "common/types.h" - -#define private public -#include "graph/manager/model_manager/event_manager.h" -#undef private - -using namespace ge; -using namespace std; -using namespace testing; - -class UtestModelManagerEventManager : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -// test repeat initialize -TEST_F(UtestModelManagerEventManager, repeat_initialization) { - ge::EventManager event_manager; - size_t event_num = 1; - event_manager.Init(event_num); - Status ret = event_manager.Init(event_num); - EXPECT_EQ(ret, SUCCESS); -} - -TEST_F(UtestModelManagerEventManager, call_event_record_normal) { - ge::EventManager event_manager; - size_t event_num = 1; - Status ret = event_manager.Init(event_num); - EXPECT_EQ(SUCCESS, ret); - EXPECT_NE(event_manager.event_list_.size(), 0); - - ret = event_manager.EventRecord(0, NULL); - EXPECT_EQ(SUCCESS, ret); -} - -// test load EventRecore when uninited -TEST_F(UtestModelManagerEventManager, call_event_record_while_uninited) { - ge::EventManager event_manager; - Status ret = event_manager.EventRecord(1, NULL); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); -} - -// test with invalid param when load EventRecord -TEST_F(UtestModelManagerEventManager, call_event_record_with_invalid_param) { - ge::EventManager event_manager; - Status ret = event_manager.Init(1); - EXPECT_EQ(SUCCESS, ret); - ret = event_manager.EventRecord(1, NULL); - EXPECT_EQ(ge::PARAM_INVALID, ret); -} - -// test load EventElapsedTime when uninited -TEST_F(UtestModelManagerEventManager, call_event_elapsed_time_while_uninited) { - ge::EventManager event_manager; - float time = .0f; - Status ret = event_manager.EventElapsedTime(1, 2, time); - EXPECT_EQ(ge::INTERNAL_ERROR, ret); -} - -// test with invalid param when load EventElapsedTime -TEST_F(UtestModelManagerEventManager, call_event_elapsed_time_with_invalid_param) { - ge::EventManager *event_manager = new ge::EventManager; - size_t event_num = 2; - Status ret = event_manager->Init(event_num); - EXPECT_EQ(SUCCESS, ret); - float time = .0f; - - // normal load - ret = event_manager->EventElapsedTime(0, 1, time); - EXPECT_EQ(SUCCESS, ret); - - // startevent_idx overstep boundary - ret = event_manager->EventElapsedTime(2, 1, time); - EXPECT_EQ(ge::PARAM_INVALID, ret); - - // stopevent_idx overstep boundary - ret = event_manager->EventElapsedTime(1, 2, time); - EXPECT_EQ(ge::PARAM_INVALID, ret); - - // startevent_idx > stopevent_idx - ret = event_manager->EventElapsedTime(1, 0, time); - EXPECT_EQ(ge::PARAM_INVALID, ret); - - delete event_manager; -} -TEST_F(UtestModelManagerEventManager, call_get_event) { - ge::EventManager event_manager; - size_t event_num = 1; - event_manager.Init(event_num); - rtEvent_t event = nullptr; - Status ret = event_manager.GetEvent(2, event); - EXPECT_EQ(ge::PARAM_INVALID, ret); - ret = event_manager.GetEvent(0, event); - EXPECT_EQ(SUCCESS, ret); -} diff --git a/tests/ut/ge/graph/load/new_model_manager_task_build_unittest.cc b/tests/ut/ge/graph/load/new_model_manager_task_build_unittest.cc deleted file mode 100644 index f10ccd7f..00000000 --- a/tests/ut/ge/graph/load/new_model_manager_task_build_unittest.cc +++ /dev/null @@ -1,115 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "common/debug/log.h" -#include "common/debug/memory_dumper.h" -#include "common/types.h" -#include "new_op_test_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/attr_utils.h" -#include "graph/detail/model_serialize_imp.h" -#include "proto/ge_ir.pb.h" - -#define private public -#define protected public -#include "graph/compute_graph.h" -#include "graph/utils/graph_utils.h" -#include "graph/model_serialize.h" -#include "graph/load/model_manager/davinci_model.h" -#include "common/properties_manager.h" -#include "common/op/ge_op_utils.h" -#include -#include "runtime/dev.h" -#include "runtime/kernel.h" -#include "cce/fwk_adpt_struct.h" -#undef private -#undef protected - -using namespace std; -using namespace testing; - -namespace ge { -class UtestModelManagerTaskBuilder : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} - - /// data weight - /// | | | | - /// |-conv-| | | - /// | | | - /// conv2d | - /// | | - /// |-resApply - - void BuildGraph(ComputeGraphPtr graph) { - OpDescPtr data = std::make_shared("DATA1", "data"); - OpDescPtr weight = std::make_shared("WEIGHT", "weight"); - OpDescPtr conv_op = std::make_shared("conv", "conv"); - OpDescPtr conv_2D = std::make_shared("conv_2D", "conv2d"); - OpDescPtr res_apply_op = std::make_shared("res_apply_op", "resapply"); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - int32_t blockSize = 4096; - - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 1); - data->AddOutputDesc(out_desc); - - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 2); - weight->AddOutputDesc(out_desc); - - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 1); - conv_op->AddInputDesc(out_desc); - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 2); - conv_op->AddInputDesc(out_desc); - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 3); - conv_op->AddOutputDesc(out_desc); - - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 3); - conv_2D->AddInputDesc(out_desc); - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 2); - conv_2D->AddInputDesc(out_desc); - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 4); - conv_2D->AddOutputDesc(out_desc); - - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 4); - res_apply_op->AddInputDesc(out_desc); - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 1); - res_apply_op->AddInputDesc(out_desc); - ge::TensorUtils::SetDataOffset(out_desc, blockSize * 5); - res_apply_op->AddOutputDesc(out_desc); - - NodePtr data_node = graph->AddNode(data); - NodePtr weigth_node = graph->AddNode(weight); - NodePtr conv_node = graph->AddNode(conv_op); - NodePtr conv_2D_node = graph->AddNode(conv_2D); - NodePtr res_node = graph->AddNode(res_apply_op); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(weigth_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv_2D_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(weigth_node->GetOutDataAnchor(0), conv_2D_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv_2D_node->GetOutDataAnchor(0), res_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(weigth_node->GetOutDataAnchor(0), res_node->GetInDataAnchor(1)); - return; - } -}; -} // namespace ge diff --git a/tests/ut/ge/graph/load/output_net_output_unittest.cc b/tests/ut/ge/graph/load/output_net_output_unittest.cc deleted file mode 100644 index 97246dad..00000000 --- a/tests/ut/ge/graph/load/output_net_output_unittest.cc +++ /dev/null @@ -1,300 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "securec.h" - -#define protected public -#define private public -#include "common/debug/memory_dumper.h" -#include "common/op/ge_op_utils.h" -#include "graph/load/model_manager/davinci_model.h" -#include "graph/load/model_manager/model_utils.h" -#include "graph/manager/graph_var_manager.h" -#include "new_op_test_utils.h" -#include "proto/om.pb.h" - -using namespace std; - -namespace ge { -class UtestNetOutput : public testing::Test { - protected: - void TearDown() {} - shared_ptr GenOpdef(OpDescPtr &op_desc, int flag) { - shared_ptr builder = make_shared(op_desc); - builder->SetStreamId(0); - builder->AddInput(1); - builder->SetType("NetOutput"); - - if (flag == 1) { - auto input_desc_1 = builder->AddInputDesc({1, 1, 10, 10}, FORMAT_NCHW, DT_FLOAT16); - } - auto input_desc_1 = builder->AddInputDesc({1, 1, 10, 10}, FORMAT_NCHW, DT_FLOAT16); - - if (flag == 2) { - auto input_desc_2 = builder->AddInputDesc({1, 1, 10, 10}, FORMAT_NCHW, DT_FLOAT16); - } - if (flag == 3) { - builder->AddInput(10); - } - - return builder; - } - shared_ptr GenOpdef2(OpDescPtr &op_desc) { - shared_ptr builder = make_shared(op_desc); - builder->SetStreamId(0); - builder->SetType("NetOutput"); - builder->AddInput(10); - - auto input_desc_1 = builder->AddInputDesc({64, 32, 5, 5}, FORMAT_FRACTAL_Z, DT_FLOAT); - - builder->AddInput(1000000); - auto input_desc_2 = builder->AddInputDesc({1, 10, 10, 1}, FORMAT_NHWC, DT_FLOAT); - - builder->AddOutput(2000000); - auto output_desc_1 = builder->AddOutputDesc({64, 32, 5, 5}, FORMAT_NCHW, DT_FLOAT); - - builder->AddOutput(2100000); - output_desc_1 = builder->AddOutputDesc({1, 10, 10, 1}, FORMAT_NHWC, DT_FLOAT); - - return builder; - } - - public: - shared_ptr dav_model_; -}; - -TEST_F(UtestNetOutput, test_get_input_size) { - shared_ptr custom_op_desc = make_shared(); - OmeTestOpDescBuilder builder(custom_op_desc); - builder.SetName("netoutput"); - builder.SetStreamId(0); - builder.SetType("NetOutput"); - - auto input_desc_1 = builder.AddInputDesc({1, 1, 1, 1}, FORMAT_FRACTAL_Z, DT_FLOAT); - builder.AddInput(1); - auto output_desc = builder.AddOutputDesc({1, 1, 1, 1}, FORMAT_NCHW, DT_FLOAT); - builder.AddOutput(1); - builder.Finish(); - - vector v_output_size = ModelUtils::GetInputSize(custom_op_desc); - EXPECT_EQ(v_output_size.size(), 1); -} - -// test ModelUtils::IsOutput -TEST_F(UtestNetOutput, success_is_output) { - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - OmeTestOpDescBuilder builder(op_desc); - builder.SetType("NetOutput"); - vector outputs_desc; - std::shared_ptr desc = std::make_shared(); - outputs_desc.push_back(desc); - op_desc->outputs_desc_ = outputs_desc; - bool ret = model_utils->IsOutput(op_desc); - EXPECT_EQ(false, ret); - - delete model_utils; -} - -// test ModelUtils::IsOutput -TEST_F(UtestNetOutput, true_is_output) { - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - OmeTestOpDescBuilder builder(op_desc); - builder.SetType("NetOutput"); - vector outputs_desc; - std::shared_ptr desc = std::make_shared(); - outputs_desc.push_back(desc); - op_desc->outputs_desc_ = outputs_desc; - ge::TensorUtils::SetOutputTensor(*(outputs_desc[0].get()), true); - bool ret = model_utils->IsOutput(op_desc); - EXPECT_EQ(true, ret); - - delete model_utils; -} - -// test ModelUtils::IsInputTensorNeedTrans -TEST_F(UtestNetOutput, success_is_output_tensor_need_trans) { - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - OmeTestOpDescBuilder builder(op_desc); - builder.SetType("NetOutput"); - size_t tensor_index = 1; - vector outputs_desc; - std::shared_ptr desc = std::make_shared(); - outputs_desc.push_back(desc); - op_desc->outputs_desc_ = outputs_desc; - op_desc->inputs_desc_ = outputs_desc; - - bool ret = model_utils->IsInputTensorNeedTrans(op_desc, tensor_index); - EXPECT_EQ(false, ret); - - delete model_utils; -} - -// test ModelUtils::GetOutputSize -TEST_F(UtestNetOutput, success_get_output_size) { - vector v_output_size; - - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - vector outputs_desc; - std::shared_ptr desc = std::make_shared(); - outputs_desc.push_back(desc); - op_desc->outputs_desc_ = outputs_desc; - EXPECT_EQ(v_output_size, model_utils->GetOutputSize(op_desc)); - - vector output = {1}; - op_desc->SetOutputOffset(output); - uint32_t tensor_size = 0; - v_output_size.push_back(tensor_size); - EXPECT_EQ(v_output_size, model_utils->GetOutputSize(op_desc)); - delete model_utils; -} - -// test ModelUtils::GetWorkspaceSize -TEST_F(UtestNetOutput, success_get_workspace_size) { - vector v_workspace_size; - - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - vector workspace = {1}; - op_desc->SetWorkspace(workspace); - EXPECT_EQ(v_workspace_size, model_utils->GetWorkspaceSize(op_desc)); - - op_desc->SetWorkspaceBytes(workspace); - v_workspace_size.push_back(1); - EXPECT_EQ(v_workspace_size, model_utils->GetWorkspaceSize(op_desc)); - delete model_utils; -} - -// test ModelUtils::GetWeightSize -TEST_F(UtestNetOutput, success_get_weight_size) { - vector v_weight_size; - - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - op_desc->SetType("Const"); - EXPECT_EQ(v_weight_size, model_utils->GetWeightSize(op_desc)); - - op_desc->SetType("NetOutput"); - vector inputs_desc; - std::shared_ptr desc = std::make_shared(); - inputs_desc.push_back(desc); - op_desc->inputs_desc_ = inputs_desc; - - vector is_input_const = {true}; - op_desc->SetIsInputConst(is_input_const); - v_weight_size.push_back(0); - EXPECT_EQ(v_weight_size, model_utils->GetWeightSize(op_desc)); - - delete model_utils; -} - -// test ModelUtils::GetWeights -TEST_F(UtestNetOutput, success_get_weights) { - vector v_weights; - - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - op_desc->SetType("Const"); - EXPECT_EQ(v_weights, model_utils->GetWeights(op_desc)); - - op_desc->SetType("NetOutput"); - vector inputs_desc; - std::shared_ptr desc = std::make_shared(); - inputs_desc.push_back(desc); - op_desc->inputs_desc_ = inputs_desc; - - vector is_input_const = {true}; - op_desc->SetIsInputConst(is_input_const); - GeTensorDesc tensor_desc; - EXPECT_EQ(v_weights, model_utils->GetWeights(op_desc)); - - delete model_utils; -} - -// test ModelUtils::GetInputDescs -TEST_F(UtestNetOutput, success_get_input_descs) { - vector<::opTensor_t> v_input_descs; - vector<::tagCcAICPUTensor> ret; - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - ret = model_utils->GetInputDescs(op_desc); - EXPECT_EQ(v_input_descs.size(), ret.size()); - - vector inputs_desc; - std::shared_ptr desc = std::make_shared(); - inputs_desc.push_back(desc); - op_desc->inputs_desc_ = inputs_desc; - vector is_input_const = {false}; - op_desc->SetIsInputConst(is_input_const); - - opTensor_t tmp; - tmp.format = OP_TENSOR_FORMAT_NC1HWC0; - tmp.dim_cnt = 0; - tmp.data_type = OP_DATA_FLOAT; - v_input_descs.push_back(tmp); - ret = model_utils->GetInputDescs(op_desc); - EXPECT_EQ(v_input_descs.size(), ret.size()); - - delete model_utils; -} - -// test ModelUtils::GetOutputDescs -TEST_F(UtestNetOutput, success_get_output_descs) { - vector<::opTensor_t> v_output_descs; - vector<::tagCcAICPUTensor> ret; - ModelUtils *model_utils = new ModelUtils(); - std::shared_ptr op_desc = std::make_shared(); - ret = model_utils->GetOutputDescs(op_desc); - EXPECT_EQ(v_output_descs.size(), ret.size()); - - vector outputs_desc; - std::shared_ptr desc = std::make_shared(); - outputs_desc.push_back(desc); - op_desc->outputs_desc_ = outputs_desc; - - opTensor_t tmp; - tmp.format = OP_TENSOR_FORMAT_NC1HWC0; - tmp.dim_cnt = 0; - tmp.data_type = OP_DATA_FLOAT; - v_output_descs.push_back(tmp); - ret = model_utils->GetOutputDescs(op_desc); - EXPECT_EQ(v_output_descs.size(), ret.size()); - - delete model_utils; -} - -// test Output::GetOutputData -TEST_F(UtestNetOutput, success_get_output_data) { - Output *output = new Output(nullptr, nullptr); - output->v_input_data_addr_.push_back((void *)1); - output->v_input_size_.push_back(1); - output->input_num_ = 1; - - vector v_data_addr; - vector v_data_size; - output->GetOutputData(v_data_addr, v_data_size); - - EXPECT_EQ(output->v_input_data_addr_, v_data_addr); - EXPECT_EQ(output->v_input_size_, v_data_size); - delete output; -} -} // namespace ge diff --git a/tests/ut/ge/graph/manager/graph_manager_unittest.cc b/tests/ut/ge/graph/manager/graph_manager_unittest.cc index 518cfdcd..b40690e2 100644 --- a/tests/ut/ge/graph/manager/graph_manager_unittest.cc +++ b/tests/ut/ge/graph/manager/graph_manager_unittest.cc @@ -30,9 +30,6 @@ #define protected public #define private public #include "graph/manager/graph_manager.h" -#define const -#include "common/helper/model_cache_helper.h" -#undef const #include "init/gelib.h" #include "common/math/math_util.h" From 67974b31362c13d8fa986abc7ecdee3f9e50b2f4 Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Sat, 17 Jul 2021 14:41:15 +0800 Subject: [PATCH 28/33] fix pytorch infershape origin shape --- ge/graph/passes/infershape_pass.cc | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index 05b1b5fc..0555929d 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -228,19 +228,13 @@ bool InferShapePass::SameTensorDesc(const GeTensorDescPtr &src, const GeTensorDe } graphStatus InferShapePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) { - changed = !SameTensorDesc(src, dst); - // refresh src itself - src->SetOriginShape(src->GetShape()); - src->SetOriginDataType(src->GetDataType()); - TensorUtils::SetRealDimCnt(*src, static_cast(src->GetOriginShape().GetDims().size())); - vector> src_shape_range; - src->GetShapeRange(src_shape_range); - src->SetOriginShapeRange(src_shape_range); - - if (!changed) { + changed = false; + if (SameTensorDesc(src, dst)) { GELOGD("Peer dst tensor_desc is same as src tensor_desc. No need update."); return SUCCESS; } + + changed = true; UpdateShapeAndDType(src, dst); GELOGD( "UpdatePeerInputDesc from src Node: shape: [%s], datatype: %s, original datatype is %s." From 3dc9881cd65bef96a7583b9db4be0c9e138ddee9 Mon Sep 17 00:00:00 2001 From: "gengchao4@huawei.com" Date: Sat, 17 Jul 2021 19:22:59 +0800 Subject: [PATCH 29/33] bugfix for taskdef's random variation in offline case --- tests/ut/ge/graph/build/task_generator_unittest.cc | 77 +++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/tests/ut/ge/graph/build/task_generator_unittest.cc b/tests/ut/ge/graph/build/task_generator_unittest.cc index 1e865050..7be20fa1 100644 --- a/tests/ut/ge/graph/build/task_generator_unittest.cc +++ b/tests/ut/ge/graph/build/task_generator_unittest.cc @@ -29,6 +29,8 @@ #define protected public #define private public +#include "init/gelib.h" +#include "ge/opskernel_manager/ops_kernel_builder_manager.h" #include "graph/build/task_generator.h" #include "graph/manager/graph_mem_manager.h" #include "graph/manager/graph_var_manager.h" @@ -41,9 +43,46 @@ using namespace ge; namespace { const char *const kIsInputVar = "INPUT_IS_VAR"; const char *const kIsOutputVar = "OUTPUT_IS_VAR"; -} +const char *const kKernelInfoNameHccl = "ops_kernel_info_hccl"; +} // namespace class UtestTaskGeneratorTest : public testing::Test { public: + struct FakeOpsKernelBuilder : OpsKernelBuilder { + FakeOpsKernelBuilder(){}; + + private: + Status Initialize(const map &options) override { + return SUCCESS; + }; + Status Finalize() override { + return SUCCESS; + }; + Status CalcOpRunningParam(Node &node) override { + return SUCCESS; + }; + Status GenerateTask(const Node &node, RunContext &context, std::vector &tasks) override { + domi::TaskDef task_def; + tasks.push_back(task_def); + return SUCCESS; + }; + }; + + struct FakeOpsKernelInfoStore : OpsKernelInfoStore { + FakeOpsKernelInfoStore() = default; + + private: + Status Initialize(const std::map &options) override { + return SUCCESS; + }; + Status Finalize() override { + return SUCCESS; + }; + bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override { + return true; + }; + void GetAllOpsKernelInfo(std::map &infos) const override{}; + }; + ge::ComputeGraphPtr BuildGraphFpProfiling() { ge::ut::GraphBuilder builder("graph"); auto data = builder.AddNode("data", "phony", 1, 1); @@ -95,6 +134,14 @@ class UtestTaskGeneratorTest : public testing::Test { return builder.GetGraph(); } + ge::ComputeGraphPtr BuildHcclGraph() { + ge::ut::GraphBuilder builder("graph"); + auto hccl_node = builder.AddNode("hccl_phony_node", "HCCL_PHONY", 0, 0); + auto op_desc = hccl_node->GetOpDesc(); + op_desc->SetOpKernelLibName(kKernelInfoNameHccl); + op_desc->SetStreamId(0); + return builder.GetGraph(); + } protected: void SetUp() {} @@ -156,3 +203,31 @@ TEST_F(UtestTaskGeneratorTest, AutoFindBpOpIndex) { output_desc->SetName("hcom"); EXPECT_EQ(task_generator.AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes), SUCCESS); } + +TEST_F(UtestTaskGeneratorTest, GenerateTask) { + map options; + Status ret = ge::GELib::Initialize(options); + EXPECT_EQ(ret, SUCCESS); + + shared_ptr instance_ptr = ge::GELib::GetInstance(); + EXPECT_NE(instance_ptr, nullptr); + + OpsKernelInfoStorePtr ops_kernel_info_store_ptr = MakeShared(); + instance_ptr->opsManager_.ops_kernel_store_.insert(make_pair(kKernelInfoNameHccl, ops_kernel_info_store_ptr)); + + OpsKernelBuilderManager &builder_manager_instance_ptr = ge::OpsKernelBuilderManager::Instance(); + OpsKernelBuilderPtr fake_builder = MakeShared(); + builder_manager_instance_ptr.ops_kernel_builders_[kKernelInfoNameHccl] = fake_builder; + + auto graph = BuildHcclGraph(); + TaskGenerator task_generator(nullptr, 0); + RunContext run_context; + run_context.graphStreamList.push_back(static_cast(ops_kernel_info_store_ptr.get())); + vector all_reduce_nodes; + vector task_def_list; + map op_name_map; + + EXPECT_EQ(task_generator.GenerateTask(run_context, graph, task_def_list, op_name_map), SUCCESS); + EXPECT_EQ(task_def_list.size(), 1); + EXPECT_EQ(task_def_list[0].ops_kernel_store_ptr(), reinterpret_cast(ops_kernel_info_store_ptr.get())); +} \ No newline at end of file From d715f462b1723dc56cd61d31a565da2c8abc8d35 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Mon, 19 Jul 2021 09:06:53 +0800 Subject: [PATCH 30/33] Delete unused from SubGraphInfo --- ge/graph/manager/graph_manager_utils.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/ge/graph/manager/graph_manager_utils.h b/ge/graph/manager/graph_manager_utils.h index efdbecf8..e17d9046 100644 --- a/ge/graph/manager/graph_manager_utils.h +++ b/ge/graph/manager/graph_manager_utils.h @@ -105,10 +105,7 @@ class SubGraphInfo { std::vector output_flag_; ModelIdInfo model_id_info_; GeModelPtr ge_model_ptr_; - bool malloc_flag_; - std::vector buffer_addr_; std::string output_names_; - std::vector buffer_size_; std::string stream_label_; std::unordered_map end_to_pld_; std::unordered_map pld_to_end_; From 84928b083e3b160bc01fc3bc201e4167fb6873c4 Mon Sep 17 00:00:00 2001 From: "gengchao4@huawei.com" Date: Mon, 19 Jul 2021 19:19:28 +0800 Subject: [PATCH 31/33] fos code check --- ge/graph/build/task_generator.cc | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/ge/graph/build/task_generator.cc b/ge/graph/build/task_generator.cc index 1adcd0aa..abb409c4 100755 --- a/ge/graph/build/task_generator.cc +++ b/ge/graph/build/task_generator.cc @@ -387,13 +387,7 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra GE_CHK_BOOL_EXEC_INFO(!op_kernel_lib_name.empty(), continue, "Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); auto kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); - if (kernel_info_store == nullptr) { - REPORT_INNER_ERROR("E19999", "Get ops kernel info store failed for op:%s(%s), op_kernel_name:%s", - node->GetName().c_str(), node->GetType().c_str(), op_kernel_lib_name.c_str()); - GELOGE(INTERNAL_ERROR, "[Call][GetOpsKernelInfoStore] No ops kernel store or ops kernel builder found. " - "node:%s(%s), op_kernel_lib_name=%s.", name.c_str(), type.c_str(), op_kernel_lib_name.c_str()); - return INTERNAL_ERROR; - } + GE_CHECK_NOTNULL(kernel_info_store); GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", name.c_str(), type.c_str()); if (node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) { From 4f5a7fcefd3cebfc788cdbce74a7583b22a76a1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=B6=9B?= Date: Mon, 19 Jul 2021 19:58:54 +0800 Subject: [PATCH 32/33] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!2?= =?UTF-8?q?028=20:=20Fix=20bug=20of=20single=5Fop.'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ge/single_op/single_op.cc | 4 +--- ge/single_op/task/op_task.cc | 25 +++--------------------- ge/single_op/task/op_task.h | 5 ++--- ge/single_op/task/tbe_task_builder.cc | 2 +- tests/ut/ge/single_op/single_op_task_unittest.cc | 24 ----------------------- 5 files changed, 7 insertions(+), 53 deletions(-) diff --git a/ge/single_op/single_op.cc b/ge/single_op/single_op.cc index 23f4cfad..a82c30ba 100755 --- a/ge/single_op/single_op.cc +++ b/ge/single_op/single_op.cc @@ -433,13 +433,11 @@ Status DynamicSingleOp::ExecuteAsync(const vector &input_desc, if (!inputs_size.empty()) { StreamResource *stream_resource = SingleOpManager::GetInstance().GetResource(resource_id_, stream_); GE_CHK_STATUS_RET_NOLOG(UpdateInputsBufferAddr(stream_resource, stream_, inputs_size, update_buffers)); + GE_CHK_STATUS_RET_NOLOG(SetHostTensorValue(input_desc, input_buffers)); } if (hybrid_model_executor_ != nullptr) { GELOGD("Execute multi-task dynamic single op by hybrid model executor"); - if (!inputs_size.empty()) { - GE_CHK_STATUS_RET_NOLOG(SetHostTensorValue(input_desc, input_buffers)); - } hybrid::HybridModelExecutor::ExecuteArgs args; GE_CHK_STATUS_RET_NOLOG(InitHybridModelArgs(update_buffers, output_buffers, input_desc, args)); diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc index ee752022..dbc90ac5 100755 --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -293,9 +293,6 @@ Status TbeOpTask::UpdateNodeByShape(const vector &input_desc, cons } Status TbeOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size) { - node_ = node; - tiling_buffer_ = tiling_buffer; - max_tiling_size_ = max_tiling_size; if (tiling_buffer != nullptr) { uintptr_t *arg_base = nullptr; size_t arg_num = 0; @@ -313,6 +310,9 @@ Status TbeOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, } arg_base[tiling_index] = reinterpret_cast(tiling_buffer); } + node_ = node; + tiling_buffer_ = tiling_buffer; + max_tiling_size_ = max_tiling_size; return SUCCESS; } @@ -481,25 +481,6 @@ void TbeOpTask::GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) { } } -Status AtomicAddrCleanOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size) { - node_ = node; - tiling_buffer_ = tiling_buffer; - max_tiling_size_ = max_tiling_size; - if (tiling_buffer != nullptr) { - uintptr_t *arg_base = nullptr; - size_t arg_num = 0; - GetIoAddr(arg_base, arg_num); - uint32_t tiling_index = atomic_output_indices_.size(); - if (arg_num == 0 || arg_num < tiling_index) { - GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[Check][Size]Tiling index %u, arg number %zu is invalid.", - tiling_index, arg_num); - return ACL_ERROR_GE_INTERNAL_ERROR; - } - arg_base[tiling_index] = reinterpret_cast(tiling_buffer); - } - return SUCCESS; -} - Status AtomicAddrCleanOpTask::UpdateNodeByShape(const vector &input_desc, const vector &output_desc) { return SUCCESS; diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index 4a839389..132672b0 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -97,7 +97,7 @@ class TbeOpTask : public OpTask { const void *GetArgs() const; size_t GetArgSize() const; const std::string &GetStubName() const; - virtual Status EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size); + Status EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size); const std::string &GetTaskType() const override; void SetHandle(void *handle); @@ -149,7 +149,6 @@ class TbeOpTask : public OpTask { class AtomicAddrCleanOpTask : public TbeOpTask { public: Status InitAtomicAddrCleanIndices(); - Status EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, uint32_t max_tiling_size) override; private: Status UpdateNodeByShape(const vector &input_desc, @@ -157,8 +156,8 @@ class AtomicAddrCleanOpTask : public TbeOpTask { Status UpdateIoAddr(const vector &inputs, const vector &outputs) override; Status UpdateTilingArgs(rtStream_t stream) override; Status CalcTilingInfo(optiling::utils::OpRunInfo &run_info) override; - std::vector atomic_output_indices_; + }; class AiCpuBaseTask : public OpTask { diff --git a/ge/single_op/task/tbe_task_builder.cc b/ge/single_op/task/tbe_task_builder.cc index f947ca57..017dac25 100644 --- a/ge/single_op/task/tbe_task_builder.cc +++ b/ge/single_op/task/tbe_task_builder.cc @@ -425,7 +425,7 @@ Status TbeTaskBuilder::InitTilingInfo(TbeOpTask &task) { GELOGD("[%s] Done allocating tiling buffer, size=%ld.", op_desc_->GetName().c_str(), max_size); } - GE_CHK_STATUS_RET_NOLOG(task.EnableDynamicSupport(node_, tiling_buffer, static_cast(max_size))); + task.EnableDynamicSupport(node_, tiling_buffer, static_cast(max_size)); return SUCCESS; } diff --git a/tests/ut/ge/single_op/single_op_task_unittest.cc b/tests/ut/ge/single_op/single_op_task_unittest.cc index 9a0381cd..8964df74 100644 --- a/tests/ut/ge/single_op/single_op_task_unittest.cc +++ b/tests/ut/ge/single_op/single_op_task_unittest.cc @@ -237,27 +237,3 @@ TEST_F(UtestSingleOpTask, test_aicpu_task_update_io_addr) { ASSERT_EQ(ret, PARAM_INVALID); } } - -TEST_F(UtestSingleOpTask, test_dynamic_support) { - auto graph = make_shared("graph"); - auto op_desc = make_shared("Add", "Add"); - auto node = graph->AddNode(op_desc); - AtomicAddrCleanOpTask atomic_task; - TbeOpTask tbe_task; - - tbe_task.arg_size_ = sizeof(void *) * 1; - tbe_task.args_.reset(new (std::nothrow) uint8_t[tbe_task.arg_size_]); - atomic_task.arg_size_ = sizeof(void *) * 1; - atomic_task.args_.reset(new (std::nothrow) uint8_t[atomic_task.arg_size_]); - ASSERT_EQ(tbe_task.EnableDynamicSupport(node, (void *)0x0001, 1), ACL_ERROR_GE_INTERNAL_ERROR); - ASSERT_EQ(atomic_task.EnableDynamicSupport(node, (void *)0x0001, 1), ACL_ERROR_GE_INTERNAL_ERROR); - - tbe_task.arg_size_ = sizeof(void *) * 2; - tbe_task.args_.reset(new (std::nothrow) uint8_t[tbe_task.arg_size_]); - atomic_task.arg_size_ = sizeof(void *) * 2; - atomic_task.args_.reset(new (std::nothrow) uint8_t[atomic_task.arg_size_]); - ASSERT_EQ(tbe_task.EnableDynamicSupport(node, (void *)0x0001, 1), SUCCESS); - ASSERT_EQ(atomic_task.EnableDynamicSupport(node, (void *)0x0001, 1), SUCCESS); - tbe_task.tiling_buffer_ = nullptr; - atomic_task.tiling_buffer_ = nullptr; -} From a1dd84cc5354dd71fc494da33cc21c90505936ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=8D=8E?= Date: Mon, 19 Jul 2021 22:32:54 +0800 Subject: [PATCH 33/33] aicpu op --- ge/engine_manager/dnnengine_manager.cc | 4 + ge/graph/load/model_manager/davinci_model.cc | 52 +++++ ge/graph/load/model_manager/davinci_model.h | 6 + .../model_manager/task_info/kernel_ex_task_info.cc | 109 +++++++++- .../model_manager/task_info/kernel_ex_task_info.h | 8 + .../model_manager/task_info/kernel_task_info.cc | 106 +++++++++- .../model_manager/task_info/kernel_task_info.h | 8 + ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc | 30 +++ ge/hybrid/node_executor/aicpu/aicpu_ext_info.h | 5 + .../node_executor/aicpu/aicpu_node_executor.cc | 121 +++++++++++ .../node_executor/aicpu/aicpu_node_executor.h | 10 +- ge/single_op/task/op_task.cc | 117 +++++++++++ ge/single_op/task/op_task.h | 7 + tests/depends/runtime/src/runtime_stub.cc | 93 ++++++++- tests/depends/runtime/src/runtime_stub.h | 70 +++++++ tests/ut/ge/CMakeLists.txt | 1 + .../ge/graph/load/kernel_ex_task_info_unittest.cc | 141 ++++++++++++- .../ut/ge/graph/load/kernel_task_info_unittest.cc | 140 ++++++++++++- .../aicpu/aicpu_node_executor_unittest.cc | 227 ++++++++++++++++++++- tests/ut/ge/single_op/single_op_task_unittest.cc | 131 +++++++++++- third_party/fwkacllib/inc/cce/fwk_adpt_struct.h | 16 ++ third_party/fwkacllib/inc/runtime/config.h | 8 + third_party/fwkacllib/inc/runtime/dev.h | 12 ++ 23 files changed, 1396 insertions(+), 26 deletions(-) create mode 100644 tests/depends/runtime/src/runtime_stub.h diff --git a/ge/engine_manager/dnnengine_manager.cc b/ge/engine_manager/dnnengine_manager.cc index 0fadd993..36f11828 100644 --- a/ge/engine_manager/dnnengine_manager.cc +++ b/ge/engine_manager/dnnengine_manager.cc @@ -239,6 +239,10 @@ std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) { op_desc->SetOpEngineName(it.engine); op_desc->SetOpKernelLibName(kernel_name); // set attrs for taking information when load txt to graph object + if (it.flagAsync) { + GELOGD("Set aicpu blocking op:%s attribute(is_blocking_op):true", op_desc->GetName().c_str()); + (void)AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + } (void) AttrUtils::SetStr(op_desc, ATTR_NAME_ENGINE_NAME_FOR_LX, it.engine); (void) AttrUtils::SetStr(op_desc, ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX, kernel_name); GELOGD("DNNEngineManager:Set OpKernelLibName %s and engine name %s to op_desc %s", kernel_name.c_str(), diff --git a/ge/graph/load/model_manager/davinci_model.cc b/ge/graph/load/model_manager/davinci_model.cc index aba06173..495ec28e 100755 --- a/ge/graph/load/model_manager/davinci_model.cc +++ b/ge/graph/load/model_manager/davinci_model.cc @@ -238,6 +238,12 @@ DavinciModel::~DavinciModel() { GE_LOGW_IF(rtEventDestroy(event_list_[i]) != RT_ERROR_NONE, "Destroy event failed, index: %zu", i); } + for (const auto &it : stream_2_event_) { + if (rtEventDestroy(it.second) != RT_ERROR_NONE) { + GELOGW("Destroy event failed"); + } + } + FreeWeightsMem(); FreeFeatureMapMem(); @@ -4648,4 +4654,50 @@ Status DavinciModel::GetTotalMemSizeExcludeZeroCopy(int64_t &total_useful_size) total_useful_size = runtime_param_.mem_size - runtime_param_.zero_copy_size; return SUCCESS; } + +Status DavinciModel::GetEventIdForBlockingAicpuOp(const OpDescPtr &op_desc, rtStream_t stream, uint32_t &event_id) { + GELOGI("Get event id for aicpu blocking op:%s", op_desc->GetName().c_str()); + auto it = stream_2_event_.find(stream); + if (it != stream_2_event_.end()) { + auto rt_ret = rtGetEventID(it->second, &event_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetEventID failed for op:%s(%s), ret:0x%X", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetEventID] failed for op:%s(%s), ret:0x%X", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + } else { + rtEvent_t rt_event = nullptr; + auto rt_ret = rtEventCreateWithFlag(&rt_event, RT_EVENT_WITH_FLAG); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtEventCreateWithFlag failed for op:%s(%s), ret:0x%X", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), rt_ret); + GELOGE(RT_FAILED, "[Call][rtEventCreateWithFlag] failed for op:%s(%s), ret:0x%X", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + rt_ret = rtGetEventID(rt_event, &event_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetEventID failed for op:%s(%s), ret:0x%X", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetEventID] failed for op:%s(%s), ret:0x%X", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + stream_2_event_.emplace(stream, rt_event); + } + return SUCCESS; +} + +Status DavinciModel::GetEventByStream(const rtStream_t &stream, rtEvent_t &rt_event) { + auto it = stream_2_event_.find(stream); + if (it == stream_2_event_.end()) { + REPORT_INNER_ERROR("E19999", "Get event failed"); + GELOGE(FAILED, "[Get][Event] Get event failed"); + return FAILED; + } + rt_event = it->second; + return SUCCESS; +} } // namespace ge diff --git a/ge/graph/load/model_manager/davinci_model.h b/ge/graph/load/model_manager/davinci_model.h index fe89f66f..76b0beef 100755 --- a/ge/graph/load/model_manager/davinci_model.h +++ b/ge/graph/load/model_manager/davinci_model.h @@ -582,6 +582,10 @@ class DavinciModel { void SetRunningFlag(bool flag) { running_flg_ = flag; } Status SetRunAsyncListenerCallback(const RunAsyncCallback &callback); + // for blocking aicpu op + Status GetEventByStream(const rtStream_t &stream, rtEvent_t &rt_event); + Status GetEventIdForBlockingAicpuOp(const OpDescPtr &op_desc, rtStream_t stream, uint32_t &event_id); + private: // memory address of weights uint8_t *weights_mem_base_; @@ -1107,6 +1111,8 @@ class DavinciModel { // op name to attrs mapping std::map>> op_name_to_attrs_; + + std::map stream_2_event_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_ diff --git a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc index 1a6ab542..fe9cd0cc 100644 --- a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc +++ b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc @@ -26,8 +26,8 @@ #include "external/graph/attr_value.h" #include "graph/load/model_manager/davinci_model.h" #include "graph/load/model_manager/model_manager.h" -#include "hybrid/node_executor/aicpu/aicpu_ext_info.h" #include "framework/common/debug/log.h" +#include "runtime/rt.h" namespace { const char *const kAicpuAllshape = "_AllShape"; @@ -43,7 +43,7 @@ Status KernelExTaskInfo::InitTaskExtInfo(const std::string &ext_info, const OpDe UnknowShapeOpType unknown_type = static_cast(unknown_shape_type_val); uint32_t num_inputs = op_desc->GetInputsSize(); uint32_t num_outputs = op_desc->GetOutputsSize(); - std::unique_ptr ext_handle( + std::shared_ptr ext_handle( new(std::nothrow) ::ge::hybrid::AicpuExtInfoHandler(op_desc->GetName(), num_inputs, num_outputs, @@ -76,6 +76,16 @@ Status KernelExTaskInfo::InitTaskExtInfo(const std::string &ext_info, const OpDe } } } + + AttrUtils::GetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, is_blocking_aicpu_op_); + GELOGD("Get op:%s attribute(is_blocking_op), value:%d", op_desc->GetName().c_str(), is_blocking_aicpu_op_); + + if (UpdateEventIdForAicpuBlockingOp(op_desc, ext_handle) != SUCCESS) { + GELOGE(FAILED, "[Call][UpdateEventIdForAicpuBlockingOp] failed for op:%s(%s)", + op_desc->GetName().c_str(), op_desc->GetType().c_str()); + return FAILED; + } + auto rt_ret = rtMalloc(&ext_info_addr_, ext_handle->GetExtInfoLen(), RT_MEMORY_HBM); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, REPORT_CALL_ERROR("E19999", "Call rtMalloc failed, size:%zu, ret:0x%X", ext_info.size(), rt_ret); @@ -448,6 +458,101 @@ Status KernelExTaskInfo::Distribute() { stream_id_ = stream_id; GELOGI("KernelExTaskInfo Distribute Success. task id: %u, stream id: %u", task_id_, stream_id_); + if (is_blocking_aicpu_op_) { + if (DistributeWaitTaskForAicpuBlockingOp() != SUCCESS) { + GELOGE(FAILED, "[Call][DistributeWaitTaskForAicpuBlockingOp] Call DistributeWaitTaskForAicpuBlockingOp failed"); + return FAILED; + } + } + return SUCCESS; +} + +Status KernelExTaskInfo::CheckDeviceSupportBlockingAicpuOpProcess(bool &is_support) { + int32_t device_id = 0; + auto rt_ret = rtGetDevice(&device_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetDevice failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetDevice] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + int32_t value = 0; + rt_ret = rtGetDeviceCapability(device_id, FEATURE_TYPE_BLOCKING_OPERATOR, RT_MODULE_TYPE_AICPU, &value); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetDeviceCapability failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetDeviceCapability] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + if (value != RT_AICPU_BLOCKING_OP_NOT_SUPPORT && value != RT_AICPU_BLOCKING_OP_SUPPORT) { + REPORT_INNER_ERROR("E19999", "Value should be %d or %d but %d", + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, RT_AICPU_BLOCKING_OP_SUPPORT, value); + GELOGE(FAILED, "[Check][Value] Value should be %d or %d but %d", + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, RT_AICPU_BLOCKING_OP_SUPPORT, value); + return FAILED; + } + is_support = (value == RT_AICPU_BLOCKING_OP_SUPPORT ? true : false); + return SUCCESS; +} + +Status KernelExTaskInfo::UpdateEventIdForAicpuBlockingOp(const OpDescPtr &op_desc, + std::shared_ptr &ext_handle) { + if (is_blocking_aicpu_op_) { + bool is_support = false; + if (CheckDeviceSupportBlockingAicpuOpProcess(is_support) != SUCCESS) { + GELOGE(FAILED, "[Call][CheckDeviceSupportBlockingAicpuOpProcess] Call CheckDeviceSupportBlockingAicpuOpProcess failed"); + return FAILED; + } + if (!is_support) { + GELOGD("Device not support blocking aicpu op process"); + return SUCCESS; + } + uint32_t event_id = 0; + if (davinci_model_->GetEventIdForBlockingAicpuOp(op_desc, stream_, event_id) != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Get event id failed for op:%s(%s).", op_desc->GetName().c_str(), + op_desc->GetType().c_str()); + GELOGE(FAILED, "[Get][EventId] Get event id failed for op:%s(%s)", op_desc->GetName().c_str(), + op_desc->GetType().c_str()); + return FAILED; + } + if (ext_handle->UpdateEventId(event_id) != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Update event id failed for op:%s(%s).", op_desc->GetName().c_str(), + op_desc->GetType().c_str()); + GELOGE(FAILED, "[Update][EventId] Update event id failed for op:%s(%s)", op_desc->GetName().c_str(), + op_desc->GetType().c_str()); + return FAILED; + } + GELOGI("Update event_id=%u success", event_id); + } + return SUCCESS; +} + +Status KernelExTaskInfo::DistributeWaitTaskForAicpuBlockingOp() { + bool is_support = false; + if (CheckDeviceSupportBlockingAicpuOpProcess(is_support) != SUCCESS) { + GELOGE(FAILED, "[Call][CheckDeviceSupportBlockingAicpuOpProcess] Call CheckDeviceSupportBlockingAicpuOpProcess failed"); + return FAILED; + } + if (!is_support) { + GELOGD("Device not support blocking aicpu op process."); + return SUCCESS; + } + GELOGD("Distribute wait task begin"); + rtEvent_t rt_event = nullptr; + if (davinci_model_->GetEventByStream(stream_, rt_event) != SUCCESS) { + GELOGE(FAILED, "[Call][GetEventByStream] Call GetEventByStream failed"); + return FAILED; + } + auto rt_ret = rtStreamWaitEvent(stream_, rt_event); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtStreamWaitEvent failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtApi] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + rt_ret = rtEventReset(rt_event, stream_); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtEventReset failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtApi] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } return SUCCESS; } diff --git a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h index 7d07eb7f..eb411576 100644 --- a/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h +++ b/ge/graph/load/model_manager/task_info/kernel_ex_task_info.h @@ -19,6 +19,7 @@ #include "graph/load/model_manager/task_info/task_info.h" #include "graph/op_desc.h" +#include "hybrid/node_executor/aicpu/aicpu_ext_info.h" namespace ge { class KernelExTaskInfo : public TaskInfo { @@ -65,6 +66,12 @@ class KernelExTaskInfo : public TaskInfo { void InitDumpArgs(void *addr, const OpDescPtr &op_desc); Status InitTaskExtInfo(const std::string &ext_info, const OpDescPtr &op_desc); + // for blocking aicpu op + Status DistributeWaitTaskForAicpuBlockingOp(); + Status CheckDeviceSupportBlockingAicpuOpProcess(bool &is_support); + Status UpdateEventIdForAicpuBlockingOp(const OpDescPtr &op_desc, + std::shared_ptr &ext_handle); + uint32_t task_id_; uint32_t stream_id_; uint32_t dump_flag_; @@ -79,6 +86,7 @@ class KernelExTaskInfo : public TaskInfo { uint32_t args_offset_ = 0; int64_t fixed_addr_offset_ = 0; int32_t topic_type_flag_ = -1; + bool is_blocking_aicpu_op_ = false; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_KERNEL_EX_TASK_INFO_H_ diff --git a/ge/graph/load/model_manager/task_info/kernel_task_info.cc b/ge/graph/load/model_manager/task_info/kernel_task_info.cc index 019a0a8b..6bbfe58e 100755 --- a/ge/graph/load/model_manager/task_info/kernel_task_info.cc +++ b/ge/graph/load/model_manager/task_info/kernel_task_info.cc @@ -28,11 +28,10 @@ #include "graph/load/model_manager/davinci_model.h" #include "graph/load/model_manager/model_manager.h" #include "graph/load/model_manager/model_utils.h" -#include "runtime/kernel.h" +#include "runtime/rt.h" #include "graph/load/model_manager/task_info/super_kernel/super_kernel.h" #include "graph/load/model_manager/task_info/super_kernel/super_kernel_factory.h" #include "cce/aicpu_engine_struct.h" -#include "hybrid/node_executor/aicpu/aicpu_ext_info.h" #include "framework/common/debug/log.h" namespace { @@ -474,6 +473,12 @@ Status KernelTaskInfo::Distribute() { } // set for task_id_ UpdateTaskId(); + if (is_blocking_aicpu_op_) { + if (DistributeWaitTaskForAicpuBlockingOp() != SUCCESS) { + GELOGE(FAILED, "[Call][DistributeWaitTaskForAicpuBlockingOp] Call DistributeWaitTaskForAicpuBlockingOp failed"); + return FAILED; + } + } GELOGD( "KernelTaskInfo Distribute Success. sktenable:%d taskid:%d sktid:%d stubfunc_name:%s stubfunc:%p " "blockdim:%d stream:%p", @@ -482,6 +487,91 @@ Status KernelTaskInfo::Distribute() { return SUCCESS; } +Status KernelTaskInfo::CheckDeviceSupportBlockingAicpuOpProcess(bool &is_support) { + int32_t device_id = 0; + auto rt_ret = rtGetDevice(&device_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetDevice failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetDevice] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + int32_t value = 0; + rt_ret = rtGetDeviceCapability(device_id, FEATURE_TYPE_BLOCKING_OPERATOR, RT_MODULE_TYPE_AICPU, &value); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetDeviceCapability failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetDeviceCapability] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + if (value != RT_AICPU_BLOCKING_OP_NOT_SUPPORT && value != RT_AICPU_BLOCKING_OP_SUPPORT) { + REPORT_INNER_ERROR("E19999", "Value should be %d or %d but %d", + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, RT_AICPU_BLOCKING_OP_SUPPORT, value); + GELOGE(FAILED, "[Check][Value] Value should be %d or %d but %d", + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, RT_AICPU_BLOCKING_OP_SUPPORT, value); + return FAILED; + } + is_support = (value == RT_AICPU_BLOCKING_OP_SUPPORT ? true : false); + return SUCCESS; +} + +Status KernelTaskInfo::UpdateEventIdForAicpuBlockingOp(std::shared_ptr &ext_handle) { + if (is_blocking_aicpu_op_) { + bool is_support = false; + if (CheckDeviceSupportBlockingAicpuOpProcess(is_support) != SUCCESS) { + GELOGE(FAILED, "[Call][CheckDeviceSupportBlockingAicpuOpProcess] Call CheckDeviceSupportBlockingAicpuOpProcess failed"); + return FAILED; + } + if (!is_support) { + GELOGD("Device not support blocking aicpu op process"); + return SUCCESS; + } + uint32_t event_id = 0; + if (davinci_model_->GetEventIdForBlockingAicpuOp(op_desc_, stream_, event_id) != SUCCESS) { + GELOGE(FAILED, "[Get][EventId] Get event id failed for op:%s(%s)", op_desc_->GetName().c_str(), + op_desc_->GetType().c_str()); + return FAILED; + } + if (ext_handle->UpdateEventId(event_id) != SUCCESS) { + GELOGE(FAILED, "[Update][EventId] Update event id failed for op:%s(%s)", op_desc_->GetName().c_str(), + op_desc_->GetType().c_str()); + return FAILED; + } + GELOGI("Update event_id=%u success", event_id); + } + return SUCCESS; +} + +Status KernelTaskInfo::DistributeWaitTaskForAicpuBlockingOp() { + bool is_support = false; + if (CheckDeviceSupportBlockingAicpuOpProcess(is_support) != SUCCESS) { + GELOGE(FAILED, "[Call][CheckDeviceSupportBlockingAicpuOpProcess] Call CheckDeviceSupportBlockingAicpuOpProcess failed"); + return FAILED; + } + if (!is_support) { + GELOGD("device not support blocking aicpu op process."); + return SUCCESS; + } + GELOGD("Distribute wait task begin"); + rtEvent_t rt_event = nullptr; + if (davinci_model_->GetEventByStream(stream_, rt_event) != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Call GetEventByStream failed"); + GELOGE(FAILED, "[Call][GetEventByStream] Call GetEventByStream failed"); + return FAILED; + } + auto rt_ret = rtStreamWaitEvent(stream_, rt_event); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtStreamWaitEvent failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtApi] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + rt_ret = rtEventReset(rt_event, stream_); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtEventReset failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtApi] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + return SUCCESS; +} + void KernelTaskInfo::SetIoAddrs(const OpDescPtr &op_desc) { const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); vector input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); @@ -1109,7 +1199,7 @@ Status KernelTaskInfo::InitAicpuTaskExtInfo(const std::string &ext_info) { UnknowShapeOpType unknown_type = static_cast(unknown_shape_type_val); uint32_t num_inputs = op_desc_->GetInputsSize(); uint32_t num_outputs = op_desc_->GetOutputsSize(); - std::unique_ptr ext_handle( + std::shared_ptr ext_handle( new(std::nothrow) ::ge::hybrid::AicpuExtInfoHandler(op_desc_->GetName(), num_inputs, num_outputs, @@ -1145,6 +1235,16 @@ Status KernelTaskInfo::InitAicpuTaskExtInfo(const std::string &ext_info) { j, op_desc_->GetName().c_str()); } } + + AttrUtils::GetBool(op_desc_, ATTR_NAME_IS_BLOCKING_OP, is_blocking_aicpu_op_); + GELOGD("Get op:%s attribute(is_blocking_op), value:%d", op_desc_->GetName().c_str(), is_blocking_aicpu_op_); + + if (UpdateEventIdForAicpuBlockingOp(ext_handle) != SUCCESS) { + GELOGE(FAILED, "[Call][UpdateEventIdForAicpuBlockingOp] failed for op:%s(%s)", + op_desc_->GetName().c_str(), op_desc_->GetType().c_str()); + return FAILED; + } + auto rt_ret = rtMalloc(&aicpu_ext_info_addr_, ext_handle->GetExtInfoLen(), RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { REPORT_CALL_ERROR("E19999", "Call rtMalloc failed for op:%s(%s), size:%zu, ret:0x%X", diff --git a/ge/graph/load/model_manager/task_info/kernel_task_info.h b/ge/graph/load/model_manager/task_info/kernel_task_info.h index d9dd30bb..59a91aee 100644 --- a/ge/graph/load/model_manager/task_info/kernel_task_info.h +++ b/ge/graph/load/model_manager/task_info/kernel_task_info.h @@ -24,6 +24,8 @@ #include "graph/load/model_manager/task_info/task_info.h" #include "graph/op_desc.h" +#include "hybrid/node_executor/aicpu/aicpu_ext_info.h" + namespace ge { class KernelTaskInfo : public TaskInfo { public: @@ -148,6 +150,11 @@ class KernelTaskInfo : public TaskInfo { bool DoubleCallSKTSaveCheck(); void SetArgs(); + // for blocking aicpu op + Status DistributeWaitTaskForAicpuBlockingOp(); + Status CheckDeviceSupportBlockingAicpuOpProcess(bool &is_support); + Status UpdateEventIdForAicpuBlockingOp(std::shared_ptr &ext_handle); + void *stub_func_; void *args_; void *sm_desc_; @@ -187,6 +194,7 @@ class KernelTaskInfo : public TaskInfo { uint32_t skt_dump_flag_ = RT_KERNEL_DEFAULT; void *superkernel_device_args_addr_ = nullptr; void *superkernel_dev_nav_table_ = nullptr; + bool is_blocking_aicpu_op_ = false; struct AICPUCustomInfo { void *input_descs = nullptr; diff --git a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc index c607a43e..6e8841b9 100644 --- a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc +++ b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc @@ -81,6 +81,9 @@ Status AicpuExtInfoHandler::Parse(const std::string &ext_info) { case aicpu::FWKAdapter::FWK_ADPT_EXT_TOPIC_TYPE: GE_CHK_STATUS_RET(ParseExtTopicType(aicpu_ext_info), "[Parse][ExtTopicType] failed."); break; + case aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT: + GE_CHK_STATUS_RET(ParseExtAsyncWait(aicpu_ext_info), "[Parse][ExtAsyncWait] failed."); + break; default: GELOGD("Node[%s] ignore infoType=%d, infoLen=%u.", node_name_.c_str(), aicpu_ext_info->infoType, aicpu_ext_info->infoLen); @@ -101,6 +104,22 @@ Status AicpuExtInfoHandler::Parse(const std::string &ext_info) { return SUCCESS; } +Status AicpuExtInfoHandler::ParseExtAsyncWait(AicpuExtInfo *aicpu_ext_info) { + if (aicpu_ext_info->infoLen != sizeof(AsyncWaitInfo)) { + REPORT_INNER_ERROR("E19999", + "Node[%s] parse ext async wait info failed as infoLen must be %zu but %u.", + node_name_.c_str(), sizeof(AsyncWaitInfo), aicpu_ext_info->infoLen); + GELOGE(ACL_ERROR_GE_PARAM_INVALID, + "[Check][DataLen]Node[%s] parse ext async wait info failed as infoLen must be %zu but %u.", + node_name_.c_str(), sizeof(AsyncWaitInfo), aicpu_ext_info->infoLen); + return ACL_ERROR_GE_PARAM_INVALID; + } + + async_wait_ = reinterpret_cast(aicpu_ext_info->infoMsg); + GELOGI("Node[%s] parse async wait info success infoLen=%u.", node_name_.c_str(), aicpu_ext_info->infoLen); + return SUCCESS; +} + Status AicpuExtInfoHandler::ParseExtShapeType(AicpuExtInfo *aicpu_ext_info) { GE_IF_BOOL_EXEC(aicpu_ext_info->infoLen != sizeof(int32_t), REPORT_INNER_ERROR("E19999", "Node[%s] parse ext shape type failed as infoLen must be %zu but %u.", @@ -280,6 +299,17 @@ Status AicpuExtInfoHandler::UpdateSessionInfo(uint64_t session_id, uint64_t kern return SUCCESS; } +Status AicpuExtInfoHandler::UpdateEventId(uint32_t event_id) { + if (async_wait_ == nullptr) { + REPORT_INNER_ERROR("E19999", "async_wait_ is nullptr."); + GELOGE(FAILED, "[Check][async_wait_] async_wait_ is nullptr."); + return FAILED; + } + async_wait_->waitType = 1; + async_wait_->waitId = event_id; + return SUCCESS; +} + Status AicpuExtInfoHandler::UpdateSessionInfoSessionId(uint64_t session_id) { if (session_info_ == nullptr) { GELOGD("There is no session info in ext_info, no need update."); diff --git a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h index 46fb7c05..80e3bb92 100644 --- a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h +++ b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h @@ -27,6 +27,7 @@ namespace ge { namespace hybrid { using AicpuShapeAndType = aicpu::FWKAdapter::ShapeAndType; using AicpuExtInfo = aicpu::FWKAdapter::ExtInfo; +using AsyncWaitInfo = aicpu::FWKAdapter::AsyncWait; using AicpuSessionInfo = SessionInfo; class AicpuExtInfoHandler { @@ -59,6 +60,8 @@ class AicpuExtInfoHandler { Status UpdateExecuteMode(bool flag); + Status UpdateEventId(uint32_t event_id); + Status GetOutputShapeAndType(uint32_t output_index, GeShape &shape, DataType &data_type); bool IsNeedRefreshIOAddr(); @@ -73,6 +76,7 @@ class AicpuExtInfoHandler { Status ParseExtBitMap(AicpuExtInfo *aicpu_ext_info); Status ParseExtUpdateAddr(AicpuExtInfo *aicpu_ext_info); Status ParseExtTopicType(AicpuExtInfo *aicpu_ext_info); + Status ParseExtAsyncWait(AicpuExtInfo *aicpu_ext_info); static Status UpdateShapeAndType(const GeShape &shape, DataType data_type, @@ -90,6 +94,7 @@ class AicpuExtInfoHandler { const uint32_t output_num_; UnknowShapeOpType unknown_type_; AicpuSessionInfo *session_info_ = nullptr; + AsyncWaitInfo *async_wait_ = nullptr; uint64_t *bit_map_ = nullptr; uint32_t *update_addr_ = nullptr; int32_t topic_type_flag_ = -1; diff --git a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc index cf20303c..f309ebd0 100755 --- a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc +++ b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc @@ -22,6 +22,7 @@ #include "graph/utils/node_utils.h" #include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/model/hybrid_model.h" +#include "runtime/rt.h" namespace ge { namespace hybrid { @@ -33,6 +34,12 @@ const char *const kAicpuAllshape = "_AllShape"; REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICPU_TF, AiCpuNodeExecutor); REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICPU_CUSTOM, AiCpuNodeExecutor); +AicpuNodeTaskBase::~AicpuNodeTaskBase() { + if (rt_event_ != nullptr) { + (void)rtEventDestroy(rt_event_); + } +} + Status AicpuNodeTaskBase::AllocTensorBuffer(size_t size, std::unique_ptr &tensor_buffer) { auto allocator = NpuMemoryAllocator::GetAllocator(); GE_CHECK_NOTNULL(allocator); @@ -64,6 +71,13 @@ Status AicpuNodeTaskBase::InitExtInfo(const std::string &kernel_ext_info, int64_ GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateSessionInfoSessionId(session_id), "[Update][SessionInfoSessionId] failed, session_id:%ld.", session_id); + if (is_blocking_aicpu_op_) { + if (UpdateEventIdForBlockingAicpuOp() != SUCCESS) { + GELOGE(FAILED, "[Call][UpdateEventIdForBlockingAicpuOp] Call UpdateEventIdForBlockingAicpuOp failed"); + return FAILED; + } + } + // copy task args buf GE_CHK_STATUS_RET(AllocTensorBuffer(aicpu_ext_handle_.GetExtInfoLen(), ext_info_addr_dev_), "[Invoke][AllocTensorBuffer]Node[%s] alloc kernel_ext_info buf failed, size=%zu", @@ -230,6 +244,96 @@ Status AicpuNodeTaskBase::ExecuteAsync(TaskContext &context, std::functionnum_outputs == 0)) { GELOGD("Node[%s] type[%s] unknown_type is %d, output num is %d.", @@ -325,6 +429,9 @@ Status AicpuTfNodeTask::Init(const HybridModel &model) { // init ext info uint64_t ext_session_id = model.GetSessionId(); + const OpDescPtr op_desc = node_item_->GetOpDesc(); + AttrUtils::GetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, is_blocking_aicpu_op_); + GELOGD("Get op:%s attribute(is_blocking_op), value:%d", op_desc->GetName().c_str(), is_blocking_aicpu_op_); GE_CHK_STATUS_RET(InitExtInfo(kernel_ext_info, ext_session_id), "[Init][ExtInfo] failed for Node[%s].", node_name_.c_str()); GE_CHK_STATUS_RET(InitForDependComputeTask(), "[Init][DependComputeTask] failed for Node[%s].", node_name_.c_str()); @@ -642,6 +749,12 @@ Status AicpuTfNodeTask::LaunchTask(TaskContext &context) { kernel_buf_->GetSize(), flag, context.GetStream())); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[AicpuTfNodertKernelLaunchEx] End"); GELOGD("Node[%s] launch end.", node_name_.c_str()); + if (is_blocking_aicpu_op_) { + if (DistributeWaitTaskForAicpuBlockingOp(context.GetStream()) != SUCCESS) { + GELOGE(FAILED, "[Call][DistributeWaitTaskForAicpuBlockingOp] Call DistributeWaitTaskForAicpuBlockingOp failed"); + return FAILED; + } + } if (need_sync_) { GELOGD("[%s] Task needs sync", node_name_.c_str()); GE_CHK_STATUS_RET_NOLOG(context.Synchronize()); @@ -760,6 +873,8 @@ Status AicpuNodeTask::Init(const HybridModel &model) { return FAILED;); uint64_t ext_session_id = model.GetSessionId(); + AttrUtils::GetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, is_blocking_aicpu_op_); + GELOGD("Get op:%s attribute(is_blocking_op), value:%d", op_desc->GetName().c_str(), is_blocking_aicpu_op_); GE_CHK_STATUS_RET(InitExtInfo(kernel_ext_info, ext_session_id), "[Init][ExtInfo] failed for Node[%s].", node_name.c_str()); @@ -826,6 +941,12 @@ Status AicpuNodeTask::LaunchTask(TaskContext &context) { args_.get(), args_size_, nullptr, context.GetStream(), flag); GE_CHK_RT_RET(rt_ret); + if (is_blocking_aicpu_op_) { + if (DistributeWaitTaskForAicpuBlockingOp(context.GetStream()) != SUCCESS) { + GELOGE(FAILED, "[Call][DistributeWaitTaskForAicpuBlockingOp] Call DistributeWaitTaskForAicpuBlockingOp failed"); + return FAILED; + } + } GELOGD("Node[%s] launch task end.", node_name_.c_str()); return SUCCESS; } diff --git a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h index 14bc8fcc..3911e090 100644 --- a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h +++ b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h @@ -35,7 +35,7 @@ class AicpuNodeTaskBase : public NodeTask { node_item->num_outputs, node_item->shape_inference_type) {} - ~AicpuNodeTaskBase() override = default; + ~AicpuNodeTaskBase() override; using NodeTask::Init; @@ -61,6 +61,10 @@ class AicpuNodeTaskBase : public NodeTask { static Status AllocTensorBuffer(size_t size, std::unique_ptr &tensor_buffer); + Status DistributeWaitTaskForAicpuBlockingOp(rtStream_t stream); + Status CheckDeviceSupportBlockingAicpuOpProcess(bool &is_support); + Status UpdateEventIdForBlockingAicpuOp(); + protected: const NodeItem *node_item_; // just reference. @@ -78,6 +82,10 @@ class AicpuNodeTaskBase : public NodeTask { // ext info addr, device mem std::unique_ptr ext_info_addr_dev_; + + // for blocking aicpu op + bool is_blocking_aicpu_op_ = false; + rtEvent_t rt_event_ = nullptr; }; class AicpuTfNodeTask : public AicpuNodeTaskBase { diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc index dbc90ac5..83cb0529 100755 --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -564,6 +564,41 @@ AiCpuBaseTask::~AiCpuBaseTask() { if (ext_info_addr_dev_ != nullptr) { (void)rtFree(ext_info_addr_dev_); } + if (rt_event_ != nullptr) { + (void)rtEventDestroy(rt_event_); + } +} + +Status AiCpuBaseTask::UpdateEventIdForBlockingAicpuOp() { + bool is_support = false; + if (CheckDeviceSupportBlockingAicpuOpProcess(is_support) != SUCCESS) { + GELOGE(FAILED, "[Call][CheckDeviceSupportBlockingAicpuOpProcess] Call CheckDeviceSupportBlockingAicpuOpProcess failed"); + return FAILED; + } + if (!is_support) { + GELOGD("Device not support blocking aicpu op process"); + return SUCCESS; + } + uint32_t event_id = 0; + auto rt_ret = rtEventCreateWithFlag(&rt_event_, RT_EVENT_WITH_FLAG); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtEventCreateWithFlag failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtEventCreateWithFlag] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + rt_ret = rtGetEventID(rt_event_, &event_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetEventID failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetEventID] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + if (aicpu_ext_handle_->UpdateEventId(event_id) != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Update event id=%u failed.", event_id); + GELOGE(FAILED, "[Update][EventId] Update event id failed", event_id); + return FAILED; + } + GELOGI("Update event_id=%u success", event_id); + return SUCCESS; } Status AiCpuBaseTask::SetExtInfoAndType(const std::string &kernel_ext_info, uint64_t kernel_id) { @@ -577,6 +612,9 @@ Status AiCpuBaseTask::SetExtInfoAndType(const std::string &kernel_ext_info, uint GELOGD("Get unknown_type is %d.", unknown_shape_type_val); unknown_type_ = static_cast(unknown_shape_type_val); + AttrUtils::GetBool(op_desc_, ATTR_NAME_IS_BLOCKING_OP, is_blocking_aicpu_op_); + GELOGD("Get op:%s attribute(is_blocking_op), value:%d", op_desc_->GetName().c_str(), is_blocking_aicpu_op_); + aicpu_ext_handle_.reset(new(std::nothrow) ::ge::hybrid::AicpuExtInfoHandler(op_desc_->GetName(), num_inputs_, num_outputs_, @@ -595,6 +633,13 @@ Status AiCpuBaseTask::SetExtInfoAndType(const std::string &kernel_ext_info, uint GE_CHK_STATUS_RET(aicpu_ext_handle_->UpdateSessionInfo(ULLONG_MAX, kernel_id, false), "[Update][SessionInfo] failed."); + if (is_blocking_aicpu_op_) { + if (UpdateEventIdForBlockingAicpuOp() != SUCCESS) { + GELOGE(FAILED, "[Call][UpdateEventIdForBlockingAicpuOp] Call UpdateEventIdForBlockingAicpuOp failed"); + return FAILED; + } + } + GE_CHK_RT_RET(rtMalloc(&ext_info_addr_dev_, aicpu_ext_handle_->GetExtInfoLen(), RT_MEMORY_HBM)); GE_CHK_RT_RET(rtMemcpy(ext_info_addr_dev_, aicpu_ext_handle_->GetExtInfoLen(), aicpu_ext_handle_->GetExtInfo(), aicpu_ext_handle_->GetExtInfoLen(), @@ -770,6 +815,63 @@ Status AiCpuBaseTask::UpdateIoAddr(const vector &inputs, const vecto return SUCCESS; } +Status AiCpuBaseTask::CheckDeviceSupportBlockingAicpuOpProcess(bool &is_support) { + int32_t device_id = 0; + auto rt_ret = rtGetDevice(&device_id); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetDevice failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetDevice] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + int32_t value = 0; + rt_ret = rtGetDeviceCapability(device_id, FEATURE_TYPE_BLOCKING_OPERATOR, RT_MODULE_TYPE_AICPU, &value); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtGetDeviceCapability failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][rtGetDeviceCapability] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + if (value != RT_AICPU_BLOCKING_OP_NOT_SUPPORT && value != RT_AICPU_BLOCKING_OP_SUPPORT) { + REPORT_INNER_ERROR("E19999", "Value should be %d or %d but %d", + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, RT_AICPU_BLOCKING_OP_SUPPORT, value); + GELOGE(FAILED, "[Check][Value] Value should be %d or %d but %d", + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, RT_AICPU_BLOCKING_OP_SUPPORT, value); + return FAILED; + } + is_support = (value == RT_AICPU_BLOCKING_OP_SUPPORT ? true : false); + return SUCCESS; +} + +Status AiCpuBaseTask::DistributeWaitTaskForAicpuBlockingOp(rtStream_t stream) { + bool is_support = false; + if (CheckDeviceSupportBlockingAicpuOpProcess(is_support) != SUCCESS) { + GELOGE(FAILED, "[Call][CheckDeviceSupportBlockingAicpuOpProcess] Call CheckDeviceSupportBlockingAicpuOpProcess failed"); + return FAILED; + } + if (!is_support) { + GELOGD("Device not support blocking aicpu op process."); + return SUCCESS; + } + GELOGI("Distribute queue task begin"); + if (rt_event_ == nullptr) { + REPORT_INNER_ERROR("E19999", "rt_event_ is nullptr"); + GELOGE(FAILED, "[Check][rt_event_] rt_event_ is nullptr"); + return FAILED; + } + auto rt_ret = rtStreamWaitEvent(stream, rt_event_); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtStreamWaitEvent failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtApi] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + rt_ret = rtEventReset(rt_event_, stream); + if (rt_ret != RT_ERROR_NONE) { + REPORT_CALL_ERROR("E19999", "Call rtEventReset failed, ret:0x%X", rt_ret); + GELOGE(RT_FAILED, "[Call][RtApi] failed, ret:0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + return SUCCESS; +} + AiCpuTask::~AiCpuTask() { FreeHbm(args_); FreeHbm(io_addr_); @@ -813,6 +915,14 @@ Status AiCpuTask::LaunchKernel(rtStream_t stream) { GELOGI("[TASK_INFO] %lu/%s", kernel_id_, op_type_.c_str()); GELOGD("Done launch kernel successfully. task = %s", this->op_type_.c_str()); + + if (is_blocking_aicpu_op_) { + if (DistributeWaitTaskForAicpuBlockingOp(stream) != SUCCESS) { + GELOGE(FAILED, "[Call][DistributeWaitTaskForAicpuBlockingOp] Call DistributeWaitTaskForAicpuBlockingOp failed"); + return FAILED; + } + } + return SUCCESS; } @@ -1089,6 +1199,13 @@ Status AiCpuCCTask::LaunchKernel(rtStream_t stream) { } GELOGI("[TASK_INFO] %lu/%s", kernel_id_, op_type_.c_str()); GELOGD("Invoke rtCpuKernelLaunch succeeded"); + + if (is_blocking_aicpu_op_) { + if (DistributeWaitTaskForAicpuBlockingOp(stream) != SUCCESS) { + GELOGE(FAILED, "[Call][DistributeWaitTaskForAicpuBlockingOp] Call DistributeWaitTaskForAicpuBlockingOp failed"); + return FAILED; + } + } return SUCCESS; } diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index 132672b0..adf51dba 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -178,6 +178,10 @@ class AiCpuBaseTask : public OpTask { rtStream_t stream); Status UpdateOutputShape(vector &output_desc); Status UpdateShapeToOutputDesc(const GeShape &shape_new, GeTensorDesc &output_desc); + // for blocking aicpu op + Status DistributeWaitTaskForAicpuBlockingOp(rtStream_t stream); + Status UpdateEventIdForBlockingAicpuOp(); + Status CheckDeviceSupportBlockingAicpuOpProcess(bool &is_support); protected: size_t num_inputs_ = 0; @@ -186,6 +190,9 @@ class AiCpuBaseTask : public OpTask { std::unique_ptr aicpu_ext_handle_; void *ext_info_addr_dev_ = nullptr; vector input_is_const_; + // for blocking aicpu op + bool is_blocking_aicpu_op_ = false; + rtEvent_t rt_event_ = nullptr; }; class AiCpuTask : public AiCpuBaseTask { diff --git a/tests/depends/runtime/src/runtime_stub.cc b/tests/depends/runtime/src/runtime_stub.cc index 25d6c2d3..32df7552 100644 --- a/tests/depends/runtime/src/runtime_stub.cc +++ b/tests/depends/runtime/src/runtime_stub.cc @@ -16,12 +16,94 @@ #include #include +#include "runtime_stub.h" +#include "runtime/rt.h" + +#define ADD_STUB_RETURN_VALUE(FUNC, TYPE) std::vector g_Stub_##FUNC##_RETURN + +#define GET_STUB_RETURN_VALUE(FUNC, TYPE, DEFAULT) ({ \ + TYPE result = DEFAULT; \ + if (!g_Stub_##FUNC##_RETURN.empty()) { \ + result = g_Stub_##FUNC##_RETURN.back(); \ + g_Stub_##FUNC##_RETURN.pop_back(); \ + } \ + result; \ +}) + +#define DEL_STUB_RETURN_VALUE(FUNC, TYPE) \ +do { \ + extern std::vector g_Stub_##FUNC##_RETURN; \ + g_Stub_##FUNC##_RETURN.clear(); \ +} while (0) + + +#define ADD_STUB_OUTBOUND_VALUE(FUNC, TYPE, NAME) std::vector g_Stub_##FUNC##_OUT_##NAME + +#define GET_STUB_OUTBOUND_VALUE(FUNC, TYPE, NAME, DEFAULT) ({ \ + TYPE value; \ + if (!g_Stub_##FUNC##_OUT_##NAME.empty()) { \ + value = g_Stub_##FUNC##_OUT_##NAME.back(); \ + g_Stub_##FUNC##_OUT_##NAME.pop_back(); \ + } else { \ + value = DEFAULT; \ + } \ + value; \ +}) + +#define DEL_STUB_OUTBOUND_VALUE(FUNC, TYPE, NAME) \ +do { \ + extern std::vector g_Stub_##FUNC##_OUT_##NAME; \ + g_Stub_##FUNC##_OUT_##NAME.clear(); \ +} while (0) #ifdef __cplusplus extern "C" { #endif #define EVENT_LENTH 10 +void rtStubTearDown() { + DEL_STUB_RETURN_VALUE(rtGetDevice, rtError_t); + DEL_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t); + DEL_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t); + DEL_STUB_RETURN_VALUE(rtEventReset, rtError_t); + DEL_STUB_RETURN_VALUE(rtEventCreate, rtError_t); + DEL_STUB_RETURN_VALUE(rtGetEventID, rtError_t); +} + +ADD_STUB_RETURN_VALUE(rtGetDevice, rtError_t); +rtError_t rtGetDevice(int32_t *device) { + return GET_STUB_RETURN_VALUE(rtGetDevice, rtError_t, RT_ERROR_NONE); +} + +ADD_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t); +ADD_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value); +rtError_t rtGetDeviceCapability(int32_t device, int32_t moduleType, int32_t featureType, int32_t *value) { + *value = GET_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_SUPPORT); + return GET_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); +} + +ADD_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t); +rtError_t rtStreamWaitEvent(rtStream_t stream, rtEvent_t event) { + return GET_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t, RT_ERROR_NONE); +} + +ADD_STUB_RETURN_VALUE(rtEventReset, rtError_t); +rtError_t rtEventReset(rtEvent_t event, rtStream_t stream) { + return GET_STUB_RETURN_VALUE(rtEventReset, rtError_t, RT_ERROR_NONE); +} + +ADD_STUB_RETURN_VALUE(rtEventCreate, rtError_t); +rtError_t rtEventCreate(rtEvent_t *event) { + *event = new int[EVENT_LENTH]; + return GET_STUB_RETURN_VALUE(rtEventCreate, rtError_t, RT_ERROR_NONE); +} + +ADD_STUB_RETURN_VALUE(rtGetEventID, rtError_t); +rtError_t rtGetEventID(rtEvent_t event, uint32_t *event_id) { + *event_id = 0; + return GET_STUB_RETURN_VALUE(rtEventCreate, rtError_t, RT_ERROR_NONE); +} + rtError_t rtCtxSetCurrent(rtContext_t ctx) { return RT_ERROR_NONE; } rtError_t rtGetStreamId(rtStream_t stream, int32_t *stream_id) { @@ -42,11 +124,6 @@ rtError_t rtEventGetTimeStamp(uint64_t *time, rtEvent_t event) { return RT_ERROR_NONE; } -rtError_t rtEventCreate(rtEvent_t *event) { - *event = new int[EVENT_LENTH]; - return RT_ERROR_NONE; -} - rtError_t rtEventCreateWithFlag(rtEvent_t *event, uint32_t flag) { return rtEventCreate(event); } @@ -112,8 +189,6 @@ rtError_t rtMemcpyAsync(void *dst, uint64_t dest_max, const void *src, uint64_t return RT_ERROR_NONE; } -rtError_t rtStreamWaitEvent(rtStream_t stream, rtEvent_t event) { return RT_ERROR_NONE; } - rtError_t rtSetTSDevice(uint32_t tsId) { return RT_ERROR_NONE; } @@ -347,10 +422,6 @@ rtError_t rtStreamSwitchEx(void *ptr, rtCondition_t condition, void *value_ptr, rtError_t rtStreamActive(rtStream_t active_stream, rtStream_t stream) { return RT_ERROR_NONE; } -rtError_t rtEventReset(rtEvent_t event, rtStream_t stream) { return RT_ERROR_NONE; } - -rtError_t rtGetDevice(int32_t *device) { return RT_ERROR_NONE; } - rtError_t rtDatadumpInfoLoad(const void *dump_info, uint32_t length) { return RT_ERROR_NONE; } rtError_t rtKernelLaunchWithFlag(const void *stub_func, uint32_t block_dim, void *args, uint32_t args_size, diff --git a/tests/depends/runtime/src/runtime_stub.h b/tests/depends/runtime/src/runtime_stub.h new file mode 100644 index 00000000..b693b9ea --- /dev/null +++ b/tests/depends/runtime/src/runtime_stub.h @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __INC_LLT_RUNTIME_STUB_H +#define __INC_LLT_RUNTIME_STUB_H + +#include + +#ifdef __cplusplus +extern "C" { +#endif +void rtStubTearDown(); + +#define RTS_STUB_SETUP() \ +do { \ + rtStubTearDown(); \ +} while (0) + +#define RTS_STUB_TEARDOWN() \ +do { \ + rtStubTearDown(); \ +} while (0) + +#define RTS_STUB_RETURN_VALUE(FUNC, TYPE, VALUE) \ +do { \ + g_Stub_##FUNC##_RETURN.emplace(g_Stub_##FUNC##_RETURN.begin(), VALUE); \ +} while (0) + +#define RTS_STUB_OUTBOUND_VALUE(FUNC, TYPE, NAME, VALUE) \ +do { \ + g_Stub_##FUNC##_OUT_##NAME.emplace(g_Stub_##FUNC##_OUT_##NAME.begin(), VALUE); \ +} while (0) + + +#define RTS_STUB_RETURN_EXTERN(FUNC, TYPE) extern std::vector g_Stub_##FUNC##_RETURN; +#define RTS_STUB_OUTBOUND_EXTERN(FUNC, TYPE, NAME) extern std::vector g_Stub_##FUNC##_OUT_##NAME; + +RTS_STUB_RETURN_EXTERN(rtGetDevice, rtError_t); +RTS_STUB_OUTBOUND_EXTERN(rtGetDevice, int32_t, device) + +RTS_STUB_RETURN_EXTERN(rtGetDeviceCapability, rtError_t); +RTS_STUB_OUTBOUND_EXTERN(rtGetDeviceCapability, int32_t, value); + +RTS_STUB_RETURN_EXTERN(rtStreamWaitEvent, rtError_t); + +RTS_STUB_RETURN_EXTERN(rtEventReset, rtError_t); + +RTS_STUB_RETURN_EXTERN(rtEventCreate, rtError_t); +RTS_STUB_OUTBOUND_EXTERN(rtEventCreate, rtEvent_t, event); + +RTS_STUB_RETURN_EXTERN(rtGetEventID, rtError_t); +RTS_STUB_OUTBOUND_EXTERN(rtEventCreate, uint32_t, event_id); + +#ifdef __cplusplus +} +#endif +#endif // __INC_LLT_RUNTIME_STUB_H diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index a0790cf2..f9d9e857 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -935,6 +935,7 @@ target_link_libraries(ge_single_op PRIVATE ascend_protobuf json c_sec + runtime_stub ) # ut binary diff --git a/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc b/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc index 327dd248..86569789 100644 --- a/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc +++ b/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc @@ -23,15 +23,20 @@ #include "graph/load/model_manager/task_info/kernel_ex_task_info.h" #include "cce/aicpu_engine_struct.h" +#include "tests/depends/runtime/src/runtime_stub.h" namespace ge { extern OpDescPtr CreateOpDesc(string name, string type); class UtestKernelExTaskInfo : public testing::Test { protected: - void SetUp() {} + void SetUp() { + RTS_STUB_SETUP(); + } - void TearDown() {} + void TearDown() { + RTS_STUB_TEARDOWN(); + } }; // test kernel_ex_task_Release @@ -209,4 +214,136 @@ TEST_F(UtestKernelExTaskInfo, parse_topic_type_failed_2) { KernelExTaskInfo kernel_ex_task_info; EXPECT_NE(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); } + +TEST_F(UtestKernelExTaskInfo, blocking_aicpu_op) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::TaskDef task_def; + domi::KernelExDef kernel_ex_def; + kernel_ex_def.set_kernel_ext_info(buf, len); + kernel_ex_def.set_kernel_ext_info_size(len); + domi::KernelExDef *kernel_ex_def_tmp = task_def.mutable_kernel_ex(); + *kernel_ex_def_tmp = kernel_ex_def; + + const OpDescPtr op_desc = CreateOpDesc("deque", "Deque"); + ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + + KernelExTaskInfo kernel_ex_task_info; + kernel_ex_task_info.op_desc_ = op_desc; + DavinciModel davinci_model(0, nullptr); + kernel_ex_task_info.davinci_model_ = &davinci_model; + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), SUCCESS); + EXPECT_EQ(kernel_ex_task_info.Distribute(), SUCCESS); + kernel_ex_task_info.op_desc_ = op_desc; + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), SUCCESS); + EXPECT_EQ(kernel_ex_task_info.Distribute(), SUCCESS); +} + +TEST_F(UtestKernelExTaskInfo, blocking_aicpu_op_fail_01) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::TaskDef task_def; + domi::KernelExDef kernel_ex_def; + kernel_ex_def.set_kernel_ext_info(buf, len); + kernel_ex_def.set_kernel_ext_info_size(len); + domi::KernelExDef *kernel_ex_def_tmp = task_def.mutable_kernel_ex(); + *kernel_ex_def_tmp = kernel_ex_def; + + const OpDescPtr op_desc = CreateOpDesc("deque", "Deque"); + + KernelExTaskInfo kernel_ex_task_info; + kernel_ex_task_info.op_desc_ = op_desc; + DavinciModel davinci_model(0, nullptr); + kernel_ex_task_info.davinci_model_ = &davinci_model; + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), SUCCESS); + + kernel_ex_task_info.is_blocking_aicpu_op_ = true; + EXPECT_EQ(kernel_ex_task_info.Distribute(), FAILED); +} + +TEST_F(UtestKernelExTaskInfo, blocking_aicpu_op_fail_02) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::TaskDef task_def; + domi::KernelExDef kernel_ex_def; + kernel_ex_def.set_kernel_ext_info(buf, len); + kernel_ex_def.set_kernel_ext_info_size(len); + domi::KernelExDef *kernel_ex_def_tmp = task_def.mutable_kernel_ex(); + *kernel_ex_def_tmp = kernel_ex_def; + + const OpDescPtr op_desc = CreateOpDesc("deque", "Deque"); + ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + KernelExTaskInfo kernel_ex_task_info; + kernel_ex_task_info.op_desc_ = op_desc; + DavinciModel davinci_model(0, nullptr); + kernel_ex_task_info.davinci_model_ = &davinci_model; + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_SUPPORT + 1); + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + EXPECT_EQ(kernel_ex_task_info.Distribute(), FAILED); + + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), SUCCESS); + RTS_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t, 0x78000001); + EXPECT_EQ(kernel_ex_task_info.Distribute(), FAILED); + + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), SUCCESS); + RTS_STUB_RETURN_VALUE(rtEventReset, rtError_t, 0x78000001); + EXPECT_EQ(kernel_ex_task_info.Distribute(), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(kernel_ex_def.kernel_ext_info(), op_desc), SUCCESS); + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(kernel_ex_task_info.Distribute(), SUCCESS); +} + } // namespace ge diff --git a/tests/ut/ge/graph/load/kernel_task_info_unittest.cc b/tests/ut/ge/graph/load/kernel_task_info_unittest.cc index 0c8da4b5..45ae7853 100644 --- a/tests/ut/ge/graph/load/kernel_task_info_unittest.cc +++ b/tests/ut/ge/graph/load/kernel_task_info_unittest.cc @@ -22,15 +22,20 @@ #include "graph/load/model_manager/davinci_model.h" #include "graph/load/model_manager/task_info/kernel_task_info.h" #include "graph/load/model_manager/task_info/hccl_task_info.h" +#include "tests/depends/runtime/src/runtime_stub.h" namespace ge { extern OpDescPtr CreateOpDesc(string name, string type); class UtestKernelTaskInfo : public testing::Test { protected: - void SetUp() {} + void SetUp() { + RTS_STUB_SETUP(); + } - void TearDown() {} + void TearDown() { + RTS_STUB_TEARDOWN(); + } }; // test KernelTaskInfo Init. @@ -1240,4 +1245,135 @@ TEST_F(UtestKernelTaskInfo, kernel_task_info_super_kernel_info) { EXPECT_EQ(kernel_task_info.SKTFinalize(), SUCCESS); } +TEST_F(UtestKernelTaskInfo, blocking_aicpu_op) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::TaskDef task_def; + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + + const OpDescPtr op_desc = CreateOpDesc("deque", "Deque"); + op_desc->SetId(0); + ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + DavinciModel davinci_model(0, nullptr); + davinci_model.op_list_.emplace(0, op_desc); + + KernelTaskInfo kernel_task_info; + kernel_task_info.op_desc_ = op_desc; + kernel_task_info.davinci_model_ = &davinci_model; + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), SUCCESS); + EXPECT_EQ(kernel_task_info.Distribute(), SUCCESS); + kernel_task_info.op_desc_ = op_desc; + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), SUCCESS); + EXPECT_EQ(kernel_task_info.Distribute(), SUCCESS); +} + +TEST_F(UtestKernelTaskInfo, blocking_aicpu_op_fail_01) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + + const OpDescPtr op_desc = CreateOpDesc("deque", "Deque"); + op_desc->SetId(0); + DavinciModel davinci_model(0, nullptr); + davinci_model.op_list_.emplace(0, op_desc); + + KernelTaskInfo kernel_task_info; + kernel_task_info.davinci_model_ = &davinci_model; + kernel_task_info.op_desc_ = op_desc; + + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), SUCCESS); + + kernel_task_info.is_blocking_aicpu_op_ = true; + EXPECT_EQ(kernel_task_info.Distribute(), FAILED); +} + +TEST_F(UtestKernelTaskInfo, blocking_aicpu_op_fail_02) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + + const OpDescPtr op_desc = CreateOpDesc("deque", "Deque"); + ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + op_desc->SetId(0); + DavinciModel davinci_model(0, nullptr); + davinci_model.op_list_.emplace(0, op_desc); + + KernelTaskInfo kernel_task_info; + kernel_task_info.davinci_model_ = &davinci_model; + kernel_task_info.op_desc_ = op_desc; + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_SUPPORT + 1); + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + EXPECT_EQ(kernel_task_info.Distribute(), FAILED); + + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), SUCCESS); + RTS_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t, 0x78000001); + EXPECT_EQ(kernel_task_info.Distribute(), FAILED); + + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), SUCCESS); + RTS_STUB_RETURN_VALUE(rtEventReset, rtError_t, 0x78000001); + EXPECT_EQ(kernel_task_info.Distribute(), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(kernel_task_info.InitAicpuTaskExtInfo(kernel_def.kernel_ext_info()), SUCCESS); + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(kernel_task_info.Distribute(), SUCCESS); +} + } // namespace ge diff --git a/tests/ut/ge/hybrid/node_executor/aicpu/aicpu_node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/aicpu/aicpu_node_executor_unittest.cc index b225949b..034b3f47 100644 --- a/tests/ut/ge/hybrid/node_executor/aicpu/aicpu_node_executor_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/aicpu/aicpu_node_executor_unittest.cc @@ -27,7 +27,7 @@ #include "hybrid/node_executor/aicpu/aicpu_node_executor.h" #undef protected #undef private - +#include "tests/depends/runtime/src/runtime_stub.h" using namespace std; using namespace testing; @@ -43,8 +43,12 @@ using namespace hybrid; class UtestAicpuNodeExecutor : public testing::Test { protected: - void SetUp() {} - void TearDown() {} + void SetUp() { + RTS_STUB_SETUP(); + } + void TearDown() { + RTS_STUB_TEARDOWN(); + } }; static NodePtr CreateNode(ComputeGraphPtr graph, const string &name, const string &type, int in_num, int out_num) { @@ -164,5 +168,222 @@ TEST_F(UtestAicpuNodeExecutor, aicpu_tf_node_task) { } +TEST_F(UtestAicpuNodeExecutor, aicpu_blocking_node_task) { + ComputeGraphPtr graph = std::make_shared("test"); + GeRootModelPtr ge_root_model = std::make_shared(graph); + ge_root_model->SetModelName("test_name"); + HybridModel hybrid_model(ge_root_model); + + NodePtr node = CreateNode(graph, "deque", FRAMEWORK_OP_TYPE, 1, 1); + ge::AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_IS_BLOCKING_OP, true); + std::unique_ptr new_node; + ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); + NodeItem *node_item = new_node.get(); + node_item->input_start = 0; + node_item->output_start = 0; + node_item->is_dynamic = true; + node_item->shape_inference_type = DEPEND_SHAPE_RANGE; + + GraphItem graph_item; + graph_item.node_items_.emplace_back(node_item); + graph_item.total_inputs_ = 1; + graph_item.total_outputs_ = 1; + + GraphExecutionContext graph_execution_context; + SubgraphContext subgraph_context(&graph_item, &graph_execution_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); + graph_execution_context.callback_manager = std::unique_ptr(new CallbackManager()); + + auto node_state = subgraph_context.GetOrCreateNodeState(node_item); + ASSERT_NE(node_state, nullptr); + + uint64_t value_0 = 512; + TensorValue in_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetInput(*node_item, 0, in_tensor0); + + TensorValue out_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetOutput(*node_item, 0, out_tensor0); + + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + domi::TaskDef task_def; + + AicpuTaskStruct args; + args.head.length = sizeof(args); + args.head.ioAddrNum = 2; + + kernel_def.set_args(reinterpret_cast(&args), args.head.length); + kernel_def.set_args_size(args.head.length); + domi::KernelDef *kernel_def_tmp = task_def.mutable_kernel(); + *kernel_def_tmp = kernel_def; + + AicpuNodeTask aicpu_node_task(node_item, task_def); + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), SUCCESS); + ASSERT_EQ(aicpu_node_task.LaunchTask(*node_state->GetTaskContext()), SUCCESS); + + node_item->shape_inference_type = DEPEND_COMPUTE; + domi::KernelExDef kernel_ex_def; + kernel_ex_def.set_kernel_ext_info(buf, len); + kernel_ex_def.set_kernel_ext_info_size(len); + kernel_ex_def.set_args(reinterpret_cast(&args), args.head.length); + kernel_ex_def.set_args_size(args.head.length); + domi::KernelExDef *kernel_ex_def_tmp = task_def.mutable_kernel_ex(); + *kernel_ex_def_tmp = kernel_ex_def; + hybrid_model.task_defs_[node] = std::vector({task_def, task_def}); + + AicpuTfNodeTask aicpu_tf_node_task(node_item, task_def); + ASSERT_EQ(aicpu_tf_node_task.Init(hybrid_model), SUCCESS); + ASSERT_EQ(aicpu_tf_node_task.LaunchTask(*node_state->GetTaskContext()), SUCCESS); +} + +TEST_F(UtestAicpuNodeExecutor, aicpu_blocking_node_task_fail) { + ComputeGraphPtr graph = std::make_shared("test"); + GeRootModelPtr ge_root_model = std::make_shared(graph); + ge_root_model->SetModelName("test_name"); + HybridModel hybrid_model(ge_root_model); + + NodePtr node = CreateNode(graph, "deque", FRAMEWORK_OP_TYPE, 1, 1); + ge::AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_IS_BLOCKING_OP, true); + std::unique_ptr new_node; + ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); + NodeItem *node_item = new_node.get(); + node_item->input_start = 0; + node_item->output_start = 0; + node_item->is_dynamic = true; + node_item->shape_inference_type = DEPEND_SHAPE_RANGE; + + GraphItem graph_item; + graph_item.node_items_.emplace_back(node_item); + graph_item.total_inputs_ = 1; + graph_item.total_outputs_ = 1; + + GraphExecutionContext graph_execution_context; + SubgraphContext subgraph_context(&graph_item, &graph_execution_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); + graph_execution_context.callback_manager = std::unique_ptr(new CallbackManager()); + + auto node_state = subgraph_context.GetOrCreateNodeState(node_item); + ASSERT_NE(node_state, nullptr); + + uint64_t value_0 = 512; + TensorValue in_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetInput(*node_item, 0, in_tensor0); + + TensorValue out_tensor0(&value_0, sizeof(value_0)); + subgraph_context.SetOutput(*node_item, 0, out_tensor0); + + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + domi::TaskDef task_def; + + AicpuTaskStruct args; + args.head.length = sizeof(args); + args.head.ioAddrNum = 2; + + kernel_def.set_args(reinterpret_cast(&args), args.head.length); + kernel_def.set_args_size(args.head.length); + domi::KernelDef *kernel_def_tmp = task_def.mutable_kernel(); + *kernel_def_tmp = kernel_def; + + AicpuNodeTask aicpu_node_task(node_item, task_def); + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_SUPPORT + 1); + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_node_task.LaunchTask(*node_state->GetTaskContext()), FAILED); + + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), SUCCESS); + RTS_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_node_task.LaunchTask(*node_state->GetTaskContext()), FAILED); + + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), SUCCESS); + RTS_STUB_RETURN_VALUE(rtEventReset, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_node_task.LaunchTask(*node_state->GetTaskContext()), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + ASSERT_EQ(aicpu_node_task.Init(hybrid_model), SUCCESS); + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + ASSERT_EQ(aicpu_node_task.LaunchTask(*node_state->GetTaskContext()), SUCCESS); + + node_item->shape_inference_type = DEPEND_COMPUTE; + domi::KernelExDef kernel_ex_def; + kernel_ex_def.set_kernel_ext_info(buf, len); + kernel_ex_def.set_kernel_ext_info_size(len); + kernel_ex_def.set_args(reinterpret_cast(&args), args.head.length); + kernel_ex_def.set_args_size(args.head.length); + domi::KernelExDef *kernel_ex_def_tmp = task_def.mutable_kernel_ex(); + *kernel_ex_def_tmp = kernel_ex_def; + hybrid_model.task_defs_[node] = std::vector({task_def, task_def}); + + AicpuTfNodeTask aicpu_tf_node_task(node_item, task_def); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_tf_node_task.Init(hybrid_model), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_tf_node_task.Init(hybrid_model), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_SUPPORT + 1); + ASSERT_EQ(aicpu_tf_node_task.Init(hybrid_model), FAILED); + + ASSERT_EQ(aicpu_tf_node_task.Init(hybrid_model), SUCCESS); + RTS_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_tf_node_task.LaunchTask(*node_state->GetTaskContext()), FAILED); + + ASSERT_EQ(aicpu_tf_node_task.Init(hybrid_model), SUCCESS); + RTS_STUB_RETURN_VALUE(rtEventReset, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_tf_node_task.LaunchTask(*node_state->GetTaskContext()), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(aicpu_tf_node_task.Init(hybrid_model), SUCCESS); + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(aicpu_tf_node_task.LaunchTask(*node_state->GetTaskContext()), SUCCESS); +} } // namespace ge diff --git a/tests/ut/ge/single_op/single_op_task_unittest.cc b/tests/ut/ge/single_op/single_op_task_unittest.cc index 8964df74..52091856 100644 --- a/tests/ut/ge/single_op/single_op_task_unittest.cc +++ b/tests/ut/ge/single_op/single_op_task_unittest.cc @@ -19,6 +19,7 @@ #include "graph/load/model_manager/model_utils.h" #include "graph/utils/graph_utils.h" +#include "hybrid/node_executor/aicpu/aicpu_ext_info.h" #include "runtime/rt.h" #define protected public @@ -30,6 +31,7 @@ #include "external/register/op_tiling_registry.h" #undef private #undef protected +#include "tests/depends/runtime/src/runtime_stub.h" using namespace std; using namespace testing; @@ -38,9 +40,13 @@ using namespace optiling; class UtestSingleOpTask : public testing::Test { protected: - void SetUp() {} + void SetUp() { + RTS_STUB_SETUP(); + } - void TearDown() {} + void TearDown() { + RTS_STUB_TEARDOWN(); + } }; TEST_F(UtestSingleOpTask, test_build_kernel_task) { @@ -237,3 +243,124 @@ TEST_F(UtestSingleOpTask, test_aicpu_task_update_io_addr) { ASSERT_EQ(ret, PARAM_INVALID); } } + +TEST_F(UtestSingleOpTask, test_blocking_aicpu_op_01) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + + auto op_desc = make_shared("deque", "Deque"); + ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + AiCpuCCTask aicpu_task; + aicpu_task.SetOpDesc(op_desc); + rtStream_t stream; + ASSERT_EQ(rtStreamCreate(&stream, 0), RT_ERROR_NONE); + + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), SUCCESS); + ASSERT_EQ(aicpu_task.LaunchKernel(stream), SUCCESS); +} + +TEST_F(UtestSingleOpTask, test_blocking_aicpu_op_02) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + + auto op_desc = make_shared("deque", "Deque"); + ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + AiCpuTask aicpu_task; + aicpu_task.SetOpDesc(op_desc); + rtStream_t stream; + ASSERT_EQ(rtStreamCreate(&stream, 0), RT_ERROR_NONE); + + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), SUCCESS); + ASSERT_EQ(aicpu_task.LaunchKernel(stream), SUCCESS); +} + +TEST_F(UtestSingleOpTask, test_blocking_aicpu_op_fail) { + int len = sizeof(hybrid::AicpuExtInfo) + sizeof(hybrid::AsyncWaitInfo); + vector aicpu_ext_info(len, 0); + char *buf = aicpu_ext_info.data(); + int offset = 0; + hybrid::AicpuExtInfo *ext_info = reinterpret_cast(buf + offset); + ext_info->infoType = aicpu::FWKAdapter::FWK_ADPT_EXT_ASYNCWAIT; + ext_info->infoLen = sizeof(hybrid::AsyncWaitInfo); + offset += sizeof(hybrid::AicpuExtInfo); + hybrid::AsyncWaitInfo *async_wait_info = reinterpret_cast(buf + offset); + async_wait_info->waitType = 0; + async_wait_info->waitId = 0; + async_wait_info->timeOut = 0; + async_wait_info->reserved = 0; + + domi::KernelDef kernel_def; + kernel_def.set_kernel_ext_info(buf, len); + kernel_def.set_kernel_ext_info_size(len); + + auto op_desc = make_shared("deque", "Deque"); + ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_BLOCKING_OP, true); + AiCpuTask aicpu_task; + aicpu_task.SetOpDesc(op_desc); + rtStream_t stream; + ASSERT_EQ(rtStreamCreate(&stream, 0), RT_ERROR_NONE); + + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), SUCCESS); + ASSERT_EQ(aicpu_task.LaunchKernel(stream), SUCCESS); + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_SUPPORT + 1); + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDevice, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_task.LaunchKernel(stream), FAILED); + + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), SUCCESS); + RTS_STUB_RETURN_VALUE(rtStreamWaitEvent, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_task.LaunchKernel(stream), FAILED); + + ASSERT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), SUCCESS); + RTS_STUB_RETURN_VALUE(rtEventReset, rtError_t, 0x78000001); + ASSERT_EQ(aicpu_task.LaunchKernel(stream), FAILED); + + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(aicpu_task.SetExtInfoAndType(kernel_def.kernel_ext_info(), 0), SUCCESS); + RTS_STUB_RETURN_VALUE(rtGetDeviceCapability, rtError_t, RT_ERROR_NONE); + RTS_STUB_OUTBOUND_VALUE(rtGetDeviceCapability, int32_t, value, RT_AICPU_BLOCKING_OP_NOT_SUPPORT); + EXPECT_EQ(aicpu_task.LaunchKernel(stream), SUCCESS); +} diff --git a/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h b/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h index df57c82e..5733d68f 100644 --- a/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h +++ b/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h @@ -62,6 +62,7 @@ enum FWKTaskExtInfoType { FWK_ADPT_EXT_SESSION_INFO, FWK_ADPT_EXT_BITMAP, FWK_ADPT_EXT_TOPIC_TYPE, + FWK_ADPT_EXT_ASYNCWAIT, FWK_ADPT_EXT_INVALID }; @@ -80,6 +81,12 @@ enum FWKExtUpdateAddrType { FWK_ADPT_UPDATE_INPUT_OUTPUT }; +enum FWKExtWaitType { + FWK_ADPT_WAIT_TYPE_NULL = 0, + FWK_ADPT_WAIT_TYPE_EVENT, + FWK_ADPT_WAIT_TYPE_INVALID +}; + #pragma pack(push, 1) // API Parameter Structure struct StrFWKKernel { @@ -133,6 +140,15 @@ struct ResultSummary { uint64_t raw_data_size; // size of raw data }; #pragma pack(pop) + +#pragma pack(push, 1) +struct AsyncWait { + uint8_t waitType; // wait type, FWK_ADPT_WAIT_TYPE_EVENT: event wait + uint32_t waitId; // wait id, GE refresh + uint32_t timeOut; // reserved + uint64_t reserved; +}; +#pragma pack(pop) } // end namespace FWKAdapter } // namespace aicpu diff --git a/third_party/fwkacllib/inc/runtime/config.h b/third_party/fwkacllib/inc/runtime/config.h index c1327c45..a244c793 100644 --- a/third_party/fwkacllib/inc/runtime/config.h +++ b/third_party/fwkacllib/inc/runtime/config.h @@ -52,6 +52,14 @@ typedef enum tagRtAicpuScheType { SCHEDULE_HARDWARE, /* HWTS Schedule */ } rtAicpuScheType; +typedef enum tagRtDeviceCapabilityType { + RT_SCHEDULE_SOFTWARE = 0, // SoftWare Schedule + RT_SCHEDULE_SOFTWARE_OPT, + RT_SCHEDULE_HARDWARE, // HWTS Schedule + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, + RT_AICPU_BLOCKING_OP_SUPPORT, // 1910/1980/1951 ts support AICPU blocking operation +} rtDeviceCapabilityType; + typedef enum tagRtVersion { VER_BEGIN = 0, VER_NA = VER_BEGIN, diff --git a/third_party/fwkacllib/inc/runtime/dev.h b/third_party/fwkacllib/inc/runtime/dev.h index 2cf6712f..18d837eb 100644 --- a/third_party/fwkacllib/inc/runtime/dev.h +++ b/third_party/fwkacllib/inc/runtime/dev.h @@ -65,6 +65,7 @@ typedef enum tagRtFeatureType { typedef enum tagRtDeviceFeatureType { FEATURE_TYPE_SCHE, + FEATURE_TYPE_BLOCKING_OPERATOR, FEATURE_TYPE_END, } rtDeviceFeatureType_t; @@ -78,6 +79,17 @@ typedef enum tagMemoryInfo { MEMORY_INFO_RSV } rtMemoryInfo_t; +typedef enum tagRtDeviceModuleType { + RT_MODULE_TYPE_SYSTEM = 0, + RT_MODULE_TYPE_AICPU, + RT_MODULE_TYPE_CCPU, + RT_MODULE_TYPE_DCPU, + RT_MODULE_TYPE_AICORE, + RT_MODULE_TYPE_TSCPU, + RT_MODULE_TYPE_PCIE, + RT_MODULE_TYPE_VECTOR_CORE +} tagRtDeviceModuleType_t; + /** * @ingroup dvrt_dev * @brief get total device number.