From e0032f656f4aac1f29cba177c1e400b2545a9842 Mon Sep 17 00:00:00 2001 From: lujiale Date: Sat, 31 Oct 2020 11:32:10 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!187?= =?UTF-8?q?=20:=20sync=20src=20code=20and=20update=20the=20source=20folder?= =?UTF-8?q?=20location'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitmodules | 8 - CMakeLists.txt | 236 +- build.sh | 124 +- cmake/FindModule.cmake | 23 - cmake/external_libs/eigen.cmake | 22 + cmake/external_libs/gflags.cmake | 39 - cmake/external_libs/gtest.cmake | 24 + cmake/external_libs/json.cmake | 40 +- cmake/external_libs/onnx.cmake | 42 +- cmake/external_libs/protobuf.cmake | 63 + cmake/external_libs/protobuf_shared.cmake | 59 - cmake/external_libs/protobuf_static.cmake | 43 - cmake/external_libs/protoc.cmake | 103 - cmake/external_libs/securec.cmake | 73 +- cmake/ge_utils.cmake | 371 +++ cmake/intf_pub_android.cmake | 52 - cmake/intf_pub_linux.cmake | 33 - cmake/intf_pub_windows.cmake | 24 - ge/CMakeLists.txt | 910 ------- ge/README.md | 0 ge/common/CMakeLists.txt | 171 -- ge/common/proto/ge_ir.proto | 206 -- ge/common/proto/insert_op.proto | 152 -- ge/common/proto/om.proto | 401 --- ge/common/proto/task.proto | 170 -- ge/common/proto/tensorflow/attr_value.proto | 62 - ge/common/proto/tensorflow/function.proto | 100 - ge/common/proto/tensorflow/graph.proto | 56 - ge/common/proto/tensorflow/graph_library.proto | 14 - ge/common/proto/tensorflow/node_def.proto | 63 - ge/common/proto/tensorflow/op_def.proto | 164 -- ge/common/proto/tensorflow/resource_handle.proto | 29 - ge/common/proto/tensorflow/tensor.proto | 94 - ge/common/proto/tensorflow/tensor_shape.proto | 45 - ge/common/proto/tensorflow/types.proto | 74 - ge/common/proto/tensorflow/versions.proto | 31 - ge/executor/CMakeLists.txt | 115 - ge/executor/proto/ge_ir.proto | 206 -- ge/executor/proto/insert_op.proto | 152 -- ge/executor/proto/om.proto | 401 --- ge/executor/proto/op_mapping_info.proto | 89 - ge/executor/proto/task.proto | 170 -- ge/ge_local_engine/CMakeLists.txt | 225 -- ge/ge_local_engine/proto/task.proto | 170 -- ge/ge_runtime/CMakeLists.txt | 65 - ge/ge_runtime/module.mk | 66 - ge/graph/build/memory/CMakeLists.txt | 38 - ge/host_cpu_engine/CMakeLists.txt | 214 -- ge/host_cpu_engine/proto/task.proto | 1 - ge/offline/CMakeLists.txt | 81 - ge/offline/main.cc | 1334 ---------- ge/offline/module.mk | 52 - ge/offline/proto/ge_ir.proto | 1 - ge/offline/proto/insert_op.proto | 1 - ge/offline/proto/om.proto | 1 - ge/offline/proto/task.proto | 1 - ge/offline/single_op_parser.cc | 503 ---- ge/offline/single_op_parser.h | 78 - ge/plugin/engine/CMakeLists.txt | 49 - ge/proto/caffe/caffe.proto | 1821 ------------- ge/proto/dump_task.proto | 127 - ge/proto/ge_api.proto | 104 - ge/proto/ge_ir.proto | 206 -- ge/proto/insert_op.proto | 152 -- ge/proto/om.proto | 401 --- ge/proto/op_mapping_info.proto | 89 - ge/proto/task.proto | 170 -- ge/proto/tensorflow/attr_value.proto | 62 - ge/proto/tensorflow/function.proto | 100 - ge/proto/tensorflow/graph.proto | 56 - ge/proto/tensorflow/graph_library.proto | 14 - ge/proto/tensorflow/node_def.proto | 63 - ge/proto/tensorflow/op_def.proto | 164 -- ge/proto/tensorflow/resource_handle.proto | 29 - ge/proto/tensorflow/tensor.proto | 94 - ge/proto/tensorflow/tensor_shape.proto | 45 - ge/proto/tensorflow/types.proto | 74 - ge/proto/tensorflow/versions.proto | 31 - ge/session/readme.txt | 3 - inc/common/blocking_queue.h | 141 + inc/common/dynamic_aipp.h | 104 + inc/common/npu_error_define.h | 94 + inc/common/opskernel/ge_task_info.h | 74 + inc/common/opskernel/ops_kernel_info_store.h | 88 + inc/common/opskernel/ops_kernel_info_types.h | 66 + inc/common/optimizer/graph_optimizer.h | 71 + inc/common/optimizer/graph_optimizer_types.h | 34 + .../util/ai_core/common/aicore_util_attr_define.h | 41 + inc/common/util/ai_core/common/aicore_util_types.h | 118 + inc/common/util/ai_core/common/graph_comm.h | 107 + inc/common/util/ai_core/common/scope_allocator.h | 43 + .../param_calculate/aicore_param_calculator.h | 33 + .../param_calculate/tensorsize_calculator.h | 45 + inc/common/util/compress/compress.h | 37 + inc/common/util/compress/compress_weight.h | 33 + inc/common/util/error_manager/error_manager.h | 94 + inc/common/util/platform_info.h | 101 + inc/common/util/platform_info_def.h | 140 + inc/external/graph/attr_value.h | 75 + inc/external/graph/ge_error_codes.h | 38 + inc/external/graph/graph.h | 81 + inc/external/graph/inference_context.h | 76 + inc/external/graph/operator.h | 289 ++ inc/external/graph/operator_factory.h | 68 + inc/external/graph/operator_reg.h | 376 +++ inc/external/graph/tensor.h | 131 + inc/external/graph/types.h | 240 ++ inc/external/register/register.h | 163 ++ inc/external/register/register_error_codes.h | 39 + inc/external/register/register_fmk_types.h | 37 + inc/external/register/register_types.h | 59 + .../register/scope/scope_fusion_pass_register.h | 334 +++ inc/framework/omg/parser/model_parser.h | 111 - inc/framework/omg/parser/op_parser.h | 92 - inc/framework/omg/parser/parser_api.h | 31 - inc/framework/omg/parser/parser_factory.h | 138 - inc/framework/omg/parser/parser_inner_ctx.h | 43 - inc/framework/omg/parser/weights_parser.h | 74 - inc/graph/anchor.h | 284 ++ inc/graph/attr_value_serializable.h | 191 ++ inc/graph/buffer.h | 82 + inc/graph/compute_graph.h | 308 +++ inc/graph/debug/ge_attr_define.h | 1130 ++++++++ inc/graph/def_types.h | 195 ++ inc/graph/detail/any_map.h | 120 + inc/graph/detail/attributes_holder.h | 165 ++ inc/graph/detail/model_serialize_imp.h | 93 + inc/graph/ge_attr_value.h | 343 +++ inc/graph/ge_context.h | 46 + inc/graph/ge_global_options.h | 26 + inc/graph/ge_local_context.h | 44 + inc/graph/ge_tensor.h | 193 ++ inc/graph/graph_util.h | 134 + inc/graph/model.h | 94 + inc/graph/model_serialize.h | 52 + inc/graph/node.h | 213 ++ inc/graph/op_desc.h | 329 +++ inc/graph/op_kernel_bin.h | 48 + inc/graph/operator_factory_impl.h | 56 + inc/graph/opsproto_manager.h | 46 + inc/graph/range_vistor.h | 57 + inc/graph/ref_relation.h | 79 + inc/graph/runtime_inference_context.h | 49 + inc/graph/shape_refiner.h | 40 + inc/graph/tuning_utils.h | 130 + inc/graph/usr_types.h | 133 + inc/graph/utils/anchor_utils.h | 45 + inc/graph/utils/attr_utils.h | 150 ++ inc/graph/utils/graph_utils.h | 771 ++++++ inc/graph/utils/node_utils.h | 170 ++ inc/graph/utils/op_desc_utils.h | 182 ++ inc/graph/utils/tensor_adapter.h | 43 + inc/graph/utils/tensor_utils.h | 77 + inc/graph/utils/type_utils.h | 53 + metadef | 1 - parser | 1 - src/common/graph/CMakeLists.txt | 77 + src/common/graph/anchor.cc | 371 +++ src/common/graph/attr_value.cc | 38 + src/common/graph/buffer.cc | 112 + src/common/graph/compute_graph.cc | 1314 ++++++++++ src/common/graph/debug/ge_log.h | 147 ++ src/common/graph/debug/ge_op_types.h | 69 + src/common/graph/debug/ge_util.h | 274 ++ src/common/graph/debug/graph_debug.cc | 246 ++ src/common/graph/debug/graph_debug.h | 48 + src/common/graph/detail/attributes_holder.cc | 241 ++ src/common/graph/format_refiner.cc | 508 ++++ src/common/graph/format_refiner.h | 50 + src/common/graph/ge_attr_define.cc | 1086 ++++++++ src/common/graph/ge_attr_value.cc | 1289 +++++++++ src/common/graph/ge_tensor.cc | 1021 ++++++++ src/common/graph/graph.cc | 384 +++ src/common/graph/graph.mk | 294 +++ src/common/graph/inference_context.cc | 112 + src/common/graph/model.cc | 190 ++ src/common/graph/model_serialize.cc | 763 ++++++ src/common/graph/module.mk | 3 + src/common/graph/node.cc | 878 +++++++ src/common/graph/op_desc.cc | 1410 ++++++++++ src/common/graph/op_imp.cc | 79 + src/common/graph/operator.cc | 1587 +++++++++++ src/common/graph/operator_factory.cc | 48 + src/common/graph/operator_factory_impl.cc | 149 ++ src/common/graph/opsproto/opsproto_manager.cc | 187 ++ src/common/graph/option/ge_context.cc | 104 + src/common/graph/option/ge_local_context.cc | 60 + src/common/graph/ref_relation.cc | 455 ++++ src/common/graph/runtime_inference_context.cc | 129 + src/common/graph/shape_refiner.cc | 688 +++++ src/common/graph/stub/Makefile | 6 + src/common/graph/stub/gen_stubapi.py | 578 ++++ src/common/graph/tensor.cc | 704 +++++ src/common/graph/utils/anchor_utils.cc | 102 + src/common/graph/utils/ge_ir_utils.cc | 1178 +++++++++ src/common/graph/utils/ge_ir_utils.h | 206 ++ src/common/graph/utils/graph_utils.cc | 2767 ++++++++++++++++++++ src/common/graph/utils/mem_utils.h | 32 + src/common/graph/utils/node_utils.cc | 1005 +++++++ src/common/graph/utils/op_desc_utils.cc | 825 ++++++ src/common/graph/utils/string_utils.h | 68 + src/common/graph/utils/tensor_utils.cc | 401 +++ src/common/graph/utils/tuning_utils.cc | 684 +++++ src/common/graph/utils/type_utils.cc | 448 ++++ src/ge/CMakeLists.txt | 380 +++ {ge => src/ge}/analyzer/analyzer.cc | 0 {ge => src/ge}/analyzer/analyzer.h | 0 src/ge/client/CMakeLists.txt | 74 + {ge => src/ge}/client/ge_api.cc | 0 {ge => src/ge}/client/ge_prof.cc | 0 {ge => src/ge}/client/module.mk | 0 src/ge/common/CMakeLists.txt | 103 + {ge => src/ge}/common/auth/file_saver.cc | 0 {ge => src/ge}/common/auth/file_saver.h | 0 {ge => src/ge}/common/base64.h | 0 {ge => src/ge}/common/context/ctx.cc | 0 src/ge/common/convert/pb2json.cc | 248 ++ src/ge/common/convert/pb2json.h | 68 + {ge => src/ge}/common/cust_aicpu_kernel_store.cc | 0 {ge => src/ge}/common/cust_aicpu_kernel_store.h | 0 {ge => src/ge}/common/debug/memory_dumper.cc | 0 {ge => src/ge}/common/debug/memory_dumper.h | 0 {ge => src/ge}/common/dump/dump_manager.cc | 0 {ge => src/ge}/common/dump/dump_manager.h | 0 {ge => src/ge}/common/dump/dump_op.cc | 0 {ge => src/ge}/common/dump/dump_op.h | 0 {ge => src/ge}/common/dump/dump_properties.cc | 0 {ge => src/ge}/common/dump/dump_properties.h | 0 {ge => src/ge}/common/dump/dump_server.cc | 0 {ge => src/ge}/common/fmk_error_codes.cc | 0 .../formats/format_transfers/datatype_transfer.cc | 0 .../formats/format_transfers/datatype_transfer.h | 0 .../format_transfer_c1hwncoc0_hwcn.cc | 0 .../format_transfer_c1hwncoc0_hwcn.h | 0 .../format_transfer_dhwcn_fracz3D.cc | 0 .../format_transfer_dhwcn_fracz3D.h | 0 .../format_transfer_dhwnc_fracz3D_transpose.cc | 0 .../format_transfer_dhwnc_fracz3D_transpose.h | 0 .../format_transfers/format_transfer_fractal_nz.cc | 0 .../format_transfers/format_transfer_fractal_nz.h | 0 .../format_transfers/format_transfer_fractal_z.cc | 0 .../format_transfers/format_transfer_fractal_z.h | 0 .../format_transfers/format_transfer_fractal_zz.cc | 0 .../format_transfers/format_transfer_fractal_zz.h | 0 .../format_transfers/format_transfer_fracz_hwcn.cc | 0 .../format_transfers/format_transfer_fracz_hwcn.h | 0 .../format_transfers/format_transfer_fracz_nchw.cc | 0 .../format_transfers/format_transfer_fracz_nchw.h | 0 .../format_transfers/format_transfer_fracz_nhwc.cc | 0 .../format_transfers/format_transfer_fracz_nhwc.h | 0 .../format_transfer_hwcn_c1hwncoc0.cc | 0 .../format_transfer_hwcn_c1hwncoc0.h | 0 .../format_transfer_nc1hwc0_nchw.cc | 0 .../format_transfer_nc1hwc0_nchw.h | 0 .../format_transfer_nc1hwc0_nhwc.cc | 0 .../format_transfer_nc1hwc0_nhwc.h | 0 .../format_transfer_nchw_fz_c04.cc | 0 .../format_transfers/format_transfer_nchw_fz_c04.h | 0 .../format_transfer_nchw_nc1hwc0.cc | 0 .../format_transfer_nchw_nc1hwc0.h | 0 .../format_transfer_nhwc_nc1hwc0.cc | 0 .../format_transfer_nhwc_nc1hwc0.h | 0 .../format_transfers/format_transfer_transpose.cc | 0 .../format_transfers/format_transfer_transpose.h | 0 {ge => src/ge}/common/formats/formats.cc | 0 {ge => src/ge}/common/formats/formats.h | 0 .../ge}/common/formats/utils/formats_definitions.h | 0 .../common/formats/utils/formats_trans_utils.cc | 0 .../ge}/common/formats/utils/formats_trans_utils.h | 0 {ge => src/ge}/common/fp16_t.cc | 0 {ge => src/ge}/common/fp16_t.h | 0 {ge => src/ge}/common/ge/datatype_util.cc | 0 {ge => src/ge}/common/ge/datatype_util.h | 0 {ge => src/ge}/common/ge/ge_util.h | 0 {ge => src/ge}/common/ge/op_tiling_manager.cc | 0 {ge => src/ge}/common/ge/op_tiling_manager.h | 0 {ge => src/ge}/common/ge/plugin_manager.cc | 0 {ge => src/ge}/common/ge/plugin_manager.h | 0 {ge => src/ge}/common/ge/tbe_plugin_manager.cc | 0 {ge => src/ge}/common/ge/tbe_plugin_manager.h | 0 {ge => src/ge}/common/ge_common.mk | 0 {ge => src/ge}/common/ge_format_util.cc | 0 {ge => src/ge}/common/helper/model_cache_helper.cc | 0 {ge => src/ge}/common/helper/model_cache_helper.h | 0 {ge => src/ge}/common/helper/model_helper.cc | 0 {ge => src/ge}/common/helper/om_file_helper.cc | 0 {ge => src/ge}/common/kernel_store.cc | 0 {ge => src/ge}/common/kernel_store.h | 0 {ge => src/ge}/common/math/fp16_math.cc | 0 {ge => src/ge}/common/math/fp16_math.h | 0 {ge => src/ge}/common/math/math_util.h | 0 {ge => src/ge}/common/math_util.h | 0 {ge => src/ge}/common/model_parser/base.cc | 0 {ge => src/ge}/common/model_parser/base.h | 0 {ge => src/ge}/common/model_saver.cc | 0 {ge => src/ge}/common/model_saver.h | 0 {ge => src/ge}/common/module.mk | 0 {ge => src/ge}/common/op/attr_value_util.cc | 0 {ge => src/ge}/common/op/ge_op_utils.cc | 0 .../ge}/common/profiling/profiling_manager.cc | 0 .../ge}/common/profiling/profiling_manager.h | 0 {ge => src/ge}/common/properties_manager.cc | 0 {ge => src/ge}/common/properties_manager.h | 0 {ge => src/ge}/common/singleton.h | 0 {ge => src/ge}/common/tbe_kernel_store.cc | 0 {ge => src/ge}/common/tbe_kernel_store.h | 0 {ge => src/ge}/common/thread_pool.cc | 0 {ge => src/ge}/common/thread_pool.h | 0 {ge => src/ge}/common/types.cc | 0 {ge => src/ge}/common/util.cc | 0 {ge => src/ge}/engine_manager/dnnengine_manager.cc | 0 {ge => src/ge}/engine_manager/dnnengine_manager.h | 0 {ge => src/ge}/engine_manager/engine_conf.json | 0 src/ge/executor/CMakeLists.txt | 126 + {ge => src/ge}/executor/ge_executor.cc | 0 {ge => src/ge}/executor/module.mk | 0 {ge => src/ge}/ge_inference.mk | 0 src/ge/ge_local_engine/CMakeLists.txt | 52 + .../ge}/ge_local_engine/common/constant/constant.h | 0 .../ge}/ge_local_engine/engine/ge_local_engine.cc | 0 .../ge}/ge_local_engine/engine/ge_local_engine.h | 0 .../ge}/ge_local_engine/engine/host_cpu_engine.cc | 0 .../ge}/ge_local_engine/engine/host_cpu_engine.h | 0 {ge => src/ge}/ge_local_engine/module.mk | 0 .../ops_kernel_store/ge_local_ops_kernel_info.cc | 0 .../ops_kernel_store/ge_local_ops_kernel_info.h | 0 .../ops_kernel_store/op/ge_deleted_op.cc | 0 .../ops_kernel_store/op/ge_deleted_op.h | 0 .../ge_local_engine/ops_kernel_store/op/no_op.cc | 0 .../ge_local_engine/ops_kernel_store/op/no_op.h | 0 .../ge}/ge_local_engine/ops_kernel_store/op/op.cc | 0 .../ge}/ge_local_engine/ops_kernel_store/op/op.h | 0 .../ops_kernel_store/op/op_factory.cc | 0 .../ops_kernel_store/op/op_factory.h | 0 {ge => src/ge}/ge_runner.mk | 0 src/ge/ge_runtime/CMakeLists.txt | 51 + {ge => src/ge}/ge_runtime/model_context.h | 0 {ge => src/ge}/ge_runtime/model_runner.cc | 0 {ge => src/ge}/ge_runtime/output.cc | 0 {ge => src/ge}/ge_runtime/output.h | 0 {ge => src/ge}/ge_runtime/runtime_model.cc | 0 {ge => src/ge}/ge_runtime/runtime_model.h | 0 {ge => src/ge}/ge_runtime/task/aicpu_task.cc | 0 {ge => src/ge}/ge_runtime/task/aicpu_task.h | 0 {ge => src/ge}/ge_runtime/task/cce_task.cc | 0 {ge => src/ge}/ge_runtime/task/cce_task.h | 0 .../ge}/ge_runtime/task/event_record_task.cc | 0 {ge => src/ge}/ge_runtime/task/event_record_task.h | 0 {ge => src/ge}/ge_runtime/task/event_wait_task.cc | 0 {ge => src/ge}/ge_runtime/task/event_wait_task.h | 0 {ge => src/ge}/ge_runtime/task/hccl_task.cc | 0 {ge => src/ge}/ge_runtime/task/hccl_task.h | 0 {ge => src/ge}/ge_runtime/task/label_goto_task.cc | 0 {ge => src/ge}/ge_runtime/task/label_goto_task.h | 0 {ge => src/ge}/ge_runtime/task/label_set_task.cc | 0 {ge => src/ge}/ge_runtime/task/label_set_task.h | 0 .../ge}/ge_runtime/task/label_switch_task.cc | 0 {ge => src/ge}/ge_runtime/task/label_switch_task.h | 0 .../ge}/ge_runtime/task/memcpy_async_task.cc | 0 {ge => src/ge}/ge_runtime/task/memcpy_async_task.h | 0 {ge => src/ge}/ge_runtime/task/profiler_task.cc | 0 {ge => src/ge}/ge_runtime/task/profiler_task.h | 0 .../ge}/ge_runtime/task/stream_active_task.cc | 0 .../ge}/ge_runtime/task/stream_active_task.h | 0 .../ge}/ge_runtime/task/stream_switch_task.cc | 0 .../ge}/ge_runtime/task/stream_switch_task.h | 0 {ge => src/ge}/ge_runtime/task/task.h | 0 {ge => src/ge}/ge_runtime/task/task_factory.h | 0 {ge => src/ge}/ge_runtime/task/tbe_task.cc | 0 {ge => src/ge}/ge_runtime/task/tbe_task.h | 0 {ge => src/ge}/generator/ge_generator.cc | 0 {ge => src/ge}/generator/generator_api.cc | 0 {ge => src/ge}/graph/build/graph_builder.cc | 0 {ge => src/ge}/graph/build/graph_builder.h | 0 {ge => src/ge}/graph/build/label_allocator.cc | 0 {ge => src/ge}/graph/build/label_allocator.h | 0 .../ge}/graph/build/logical_stream_allocator.cc | 0 .../ge}/graph/build/logical_stream_allocator.h | 0 src/ge/graph/build/memory/CMakeLists.txt | 51 + .../build/memory/binary_block_mem_assigner.cc | 0 .../graph/build/memory/binary_block_mem_assigner.h | 0 .../ge}/graph/build/memory/block_mem_assigner.cc | 0 .../ge}/graph/build/memory/block_mem_assigner.h | 0 .../ge}/graph/build/memory/graph_mem_assigner.cc | 0 .../ge}/graph/build/memory/graph_mem_assigner.h | 0 .../ge}/graph/build/memory/hybrid_mem_assigner.cc | 0 .../ge}/graph/build/memory/hybrid_mem_assigner.h | 0 .../graph/build/memory/max_block_mem_assigner.cc | 0 .../graph/build/memory/max_block_mem_assigner.h | 0 {ge => src/ge}/graph/build/memory/mem_assigner.h | 0 .../ge}/graph/build/memory/memory_assigner.cc | 0 {ge => src/ge}/graph/build/memory/module.mk | 0 .../ge}/graph/build/memory/var_mem_assign_util.cc | 0 .../ge}/graph/build/memory/var_mem_assign_util.h | 0 {ge => src/ge}/graph/build/model_builder.cc | 0 {ge => src/ge}/graph/build/model_builder.h | 0 {ge => src/ge}/graph/build/run_context.cc | 0 {ge => src/ge}/graph/build/run_context.h | 0 {ge => src/ge}/graph/build/stream_allocator.cc | 0 {ge => src/ge}/graph/build/stream_allocator.h | 0 .../ge}/graph/build/stream_graph_optimizer.cc | 0 .../ge}/graph/build/stream_graph_optimizer.h | 0 {ge => src/ge}/graph/build/task_generator.cc | 0 {ge => src/ge}/graph/build/task_generator.h | 0 {ge => src/ge}/graph/common/bcast.cc | 0 {ge => src/ge}/graph/common/bcast.h | 0 {ge => src/ge}/graph/common/ge_call_wrapper.h | 0 {ge => src/ge}/graph/common/local_context.cc | 0 {ge => src/ge}/graph/common/local_context.h | 0 {ge => src/ge}/graph/common/omg_util.cc | 0 {ge => src/ge}/graph/common/omg_util.h | 0 {ge => src/ge}/graph/common/transop_util.cc | 0 {ge => src/ge}/graph/common/transop_util.h | 0 {ge => src/ge}/graph/execute/graph_execute.cc | 0 {ge => src/ge}/graph/execute/graph_execute.h | 0 {ge => src/ge}/graph/label/case_label_maker.cc | 0 {ge => src/ge}/graph/label/case_label_maker.h | 0 {ge => src/ge}/graph/label/if_label_maker.cc | 0 {ge => src/ge}/graph/label/if_label_maker.h | 0 {ge => src/ge}/graph/label/label_maker.cc | 0 {ge => src/ge}/graph/label/label_maker.h | 0 {ge => src/ge}/graph/label/label_maker_factory.h | 0 .../graph/label/partitioned_call_label_maker.cc | 0 .../ge}/graph/label/partitioned_call_label_maker.h | 0 {ge => src/ge}/graph/label/while_label_maker.cc | 0 {ge => src/ge}/graph/label/while_label_maker.h | 0 {ge => src/ge}/graph/load/graph_loader.cc | 0 {ge => src/ge}/graph/load/graph_loader.h | 0 .../ge}/graph/load/new_model_manager/aipp_utils.cc | 0 .../ge}/graph/load/new_model_manager/aipp_utils.h | 0 .../load/new_model_manager/cpu_queue_schedule.cc | 0 .../load/new_model_manager/cpu_queue_schedule.h | 0 .../graph/load/new_model_manager/data_dumper.cc | 0 .../ge}/graph/load/new_model_manager/data_dumper.h | 0 .../graph/load/new_model_manager/data_inputer.cc | 0 .../graph/load/new_model_manager/data_inputer.h | 0 .../graph/load/new_model_manager/davinci_model.cc | 0 .../graph/load/new_model_manager/davinci_model.h | 0 .../load/new_model_manager/davinci_model_parser.cc | 0 .../load/new_model_manager/davinci_model_parser.h | 0 .../graph/load/new_model_manager/model_manager.cc | 0 .../graph/load/new_model_manager/model_manager.h | 0 .../graph/load/new_model_manager/model_utils.cc | 0 .../ge}/graph/load/new_model_manager/model_utils.h | 0 .../task_info/end_graph_task_info.cc | 0 .../task_info/end_graph_task_info.h | 0 .../task_info/event_record_task_info.cc | 0 .../task_info/event_record_task_info.h | 0 .../task_info/event_wait_task_info.cc | 0 .../task_info/event_wait_task_info.h | 0 .../task_info/fusion_start_task_info.cc | 0 .../task_info/fusion_start_task_info.h | 0 .../task_info/fusion_stop_task_info.cc | 0 .../task_info/fusion_stop_task_info.h | 0 .../new_model_manager/task_info/hccl_task_info.cc | 0 .../new_model_manager/task_info/hccl_task_info.h | 0 .../task_info/kernel_ex_task_info.cc | 0 .../task_info/kernel_ex_task_info.h | 0 .../task_info/kernel_task_info.cc | 0 .../new_model_manager/task_info/kernel_task_info.h | 0 .../task_info/label_goto_ex_task_info.cc | 0 .../task_info/label_goto_ex_task_info.h | 0 .../task_info/label_set_task_info.cc | 0 .../task_info/label_set_task_info.h | 0 .../task_info/label_switch_by_index_task_info.cc | 0 .../task_info/label_switch_by_index_task_info.h | 0 .../task_info/memcpy_addr_async_task_info.cc | 0 .../task_info/memcpy_addr_async_task_info.h | 0 .../task_info/memcpy_async_task_info.cc | 0 .../task_info/memcpy_async_task_info.h | 0 .../task_info/profiler_trace_task_info.cc | 0 .../task_info/profiler_trace_task_info.h | 0 .../task_info/stream_active_task_info.cc | 0 .../task_info/stream_active_task_info.h | 0 .../task_info/stream_switch_task_info.cc | 0 .../task_info/stream_switch_task_info.h | 0 .../task_info/stream_switchn_task_info.cc | 0 .../task_info/stream_switchn_task_info.h | 0 .../task_info/super_kernel/super_kernel.cc | 0 .../task_info/super_kernel/super_kernel.h | 0 .../task_info/super_kernel/super_kernel_factory.cc | 0 .../task_info/super_kernel/super_kernel_factory.h | 0 .../load/new_model_manager/task_info/task_info.cc | 0 .../load/new_model_manager/task_info/task_info.h | 0 .../task_info/task_info_factory.h | 0 .../load/new_model_manager/tbe_handle_store.cc | 0 .../load/new_model_manager/tbe_handle_store.h | 0 .../load/new_model_manager/zero_copy_offset.cc | 0 .../load/new_model_manager/zero_copy_offset.h | 0 .../graph/load/new_model_manager/zero_copy_task.cc | 0 .../graph/load/new_model_manager/zero_copy_task.h | 0 {ge => src/ge}/graph/manager/block_memory.h | 0 .../ge}/graph/manager/graph_caching_allocator.cc | 0 .../ge}/graph/manager/graph_caching_allocator.h | 0 {ge => src/ge}/graph/manager/graph_context.cc | 0 {ge => src/ge}/graph/manager/graph_context.h | 0 {ge => src/ge}/graph/manager/graph_manager.cc | 0 {ge => src/ge}/graph/manager/graph_manager.h | 0 .../ge}/graph/manager/graph_manager_utils.cc | 0 {ge => src/ge}/graph/manager/graph_manager_utils.h | 0 .../ge}/graph/manager/graph_mem_allocator.cc | 0 {ge => src/ge}/graph/manager/graph_mem_allocator.h | 0 {ge => src/ge}/graph/manager/graph_var_manager.cc | 0 {ge => src/ge}/graph/manager/graph_var_manager.h | 0 {ge => src/ge}/graph/manager/host_mem_manager.cc | 0 {ge => src/ge}/graph/manager/host_mem_manager.h | 0 {ge => src/ge}/graph/manager/memory_api.cc | 0 .../graph/manager/model_manager/event_manager.cc | 0 .../graph/manager/model_manager/event_manager.h | 0 .../ge}/graph/manager/rdma_pool_allocator.cc | 0 {ge => src/ge}/graph/manager/rdma_pool_allocator.h | 0 .../ge}/graph/manager/trans_var_data_utils.cc | 0 .../ge}/graph/manager/trans_var_data_utils.h | 0 {ge => src/ge}/graph/manager/util/debug.cc | 0 {ge => src/ge}/graph/manager/util/debug.h | 0 {ge => src/ge}/graph/manager/util/hcom_util.cc | 0 {ge => src/ge}/graph/manager/util/hcom_util.h | 0 .../ge}/graph/manager/util/rt_context_util.cc | 0 .../ge}/graph/manager/util/rt_context_util.h | 0 .../graph/manager/util/variable_accelerate_ctrl.cc | 0 .../graph/manager/util/variable_accelerate_ctrl.h | 0 {ge => src/ge}/graph/optimize/common/params.h | 0 {ge => src/ge}/graph/optimize/graph_optimize.cc | 0 {ge => src/ge}/graph/optimize/graph_optimize.h | 0 .../ge}/graph/optimize/mem_rw_conflict_optimize.cc | 0 .../optimize/optimizer/allreduce_fusion_pass.cc | 0 .../optimize/optimizer/allreduce_fusion_pass.h | 0 {ge => src/ge}/graph/optimize/summary_optimize.cc | 0 .../ge}/graph/partition/dynamic_shape_partition.cc | 0 .../ge}/graph/partition/dynamic_shape_partition.h | 0 {ge => src/ge}/graph/partition/engine_place.cc | 0 {ge => src/ge}/graph/partition/engine_place.h | 0 {ge => src/ge}/graph/partition/graph_partition.cc | 0 {ge => src/ge}/graph/partition/graph_partition.h | 0 {ge => src/ge}/graph/passes/addn_pass.cc | 0 {ge => src/ge}/graph/passes/addn_pass.h | 0 .../graph/passes/aicpu_constant_folding_pass.cc | 0 .../ge}/graph/passes/aicpu_constant_folding_pass.h | 0 {ge => src/ge}/graph/passes/assert_pass.cc | 0 {ge => src/ge}/graph/passes/assert_pass.h | 0 {ge => src/ge}/graph/passes/assign_pass.cc | 0 {ge => src/ge}/graph/passes/assign_pass.h | 0 .../ge}/graph/passes/atomic_addr_clean_pass.cc | 0 .../ge}/graph/passes/atomic_addr_clean_pass.h | 0 .../ge}/graph/passes/attach_stream_label_pass.cc | 0 .../ge}/graph/passes/attach_stream_label_pass.h | 0 {ge => src/ge}/graph/passes/base_pass.cc | 0 {ge => src/ge}/graph/passes/base_pass.h | 0 {ge => src/ge}/graph/passes/bitcast_pass.cc | 0 {ge => src/ge}/graph/passes/bitcast_pass.h | 0 {ge => src/ge}/graph/passes/cast_remove_pass.cc | 0 {ge => src/ge}/graph/passes/cast_remove_pass.h | 0 {ge => src/ge}/graph/passes/cast_translate_pass.cc | 0 {ge => src/ge}/graph/passes/cast_translate_pass.h | 0 .../common_subexpression_elimination_pass.cc | 0 .../passes/common_subexpression_elimination_pass.h | 0 {ge => src/ge}/graph/passes/compile_nodes_pass.cc | 0 {ge => src/ge}/graph/passes/compile_nodes_pass.h | 0 {ge => src/ge}/graph/passes/cond_pass.cc | 0 {ge => src/ge}/graph/passes/cond_pass.h | 0 {ge => src/ge}/graph/passes/cond_remove_pass.cc | 0 {ge => src/ge}/graph/passes/cond_remove_pass.h | 0 .../ge}/graph/passes/constant_folding_pass.cc | 0 .../ge}/graph/passes/constant_folding_pass.h | 0 .../ge}/graph/passes/constant_fuse_same_pass.cc | 0 .../ge}/graph/passes/constant_fuse_same_pass.h | 0 .../ge}/graph/passes/control_trigger_pass.cc | 0 {ge => src/ge}/graph/passes/control_trigger_pass.h | 0 .../ge}/graph/passes/ctrl_edge_transfer_pass.cc | 0 .../ge}/graph/passes/ctrl_edge_transfer_pass.h | 0 {ge => src/ge}/graph/passes/data_pass.cc | 0 {ge => src/ge}/graph/passes/data_pass.h | 0 .../ge}/graph/passes/dimension_adjust_pass.cc | 0 .../ge}/graph/passes/dimension_adjust_pass.h | 0 .../ge}/graph/passes/dimension_compute_pass.cc | 0 .../ge}/graph/passes/dimension_compute_pass.h | 0 {ge => src/ge}/graph/passes/dropout_pass.cc | 0 {ge => src/ge}/graph/passes/dropout_pass.h | 0 .../passes/end_of_sequence_add_control_pass.cc | 0 .../passes/end_of_sequence_add_control_pass.h | 0 {ge => src/ge}/graph/passes/enter_pass.cc | 0 {ge => src/ge}/graph/passes/enter_pass.h | 0 {ge => src/ge}/graph/passes/flow_ctrl_pass.cc | 0 {ge => src/ge}/graph/passes/flow_ctrl_pass.h | 0 {ge => src/ge}/graph/passes/folding_pass.cc | 0 {ge => src/ge}/graph/passes/folding_pass.h | 0 {ge => src/ge}/graph/passes/for_pass.cc | 0 {ge => src/ge}/graph/passes/for_pass.h | 0 .../ge}/graph/passes/get_original_format_pass.cc | 0 .../ge}/graph/passes/get_original_format_pass.h | 0 .../ge}/graph/passes/global_step_insert_pass.cc | 0 .../ge}/graph/passes/global_step_insert_pass.h | 0 .../ge}/graph/passes/guarantee_const_pass.cc | 0 {ge => src/ge}/graph/passes/guarantee_const_pass.h | 0 {ge => src/ge}/graph/passes/hccl_group_pass.cc | 0 {ge => src/ge}/graph/passes/hccl_group_pass.h | 0 {ge => src/ge}/graph/passes/hccl_memcpy_pass.cc | 0 {ge => src/ge}/graph/passes/hccl_memcpy_pass.h | 0 {ge => src/ge}/graph/passes/identity_pass.cc | 0 {ge => src/ge}/graph/passes/identity_pass.h | 0 {ge => src/ge}/graph/passes/infershape_pass.cc | 0 {ge => src/ge}/graph/passes/infershape_pass.h | 0 .../input_output_connection_identify_pass.cc | 0 .../passes/input_output_connection_identify_pass.h | 0 .../ge}/graph/passes/isolated_op_remove_pass.cc | 0 .../ge}/graph/passes/isolated_op_remove_pass.h | 0 {ge => src/ge}/graph/passes/iterator_op_pass.cc | 0 {ge => src/ge}/graph/passes/iterator_op_pass.h | 0 .../ge}/graph/passes/link_gen_mask_nodes_pass.cc | 0 .../ge}/graph/passes/link_gen_mask_nodes_pass.h | 0 {ge => src/ge}/graph/passes/mark_agnostic_pass.cc | 0 {ge => src/ge}/graph/passes/mark_agnostic_pass.h | 0 .../graph/passes/mark_graph_unknown_status_pass.cc | 0 .../graph/passes/mark_graph_unknown_status_pass.h | 0 {ge => src/ge}/graph/passes/mark_same_addr_pass.cc | 0 {ge => src/ge}/graph/passes/mark_same_addr_pass.h | 0 .../ge}/graph/passes/memcpy_addr_async_pass.cc | 0 .../ge}/graph/passes/memcpy_addr_async_pass.h | 0 {ge => src/ge}/graph/passes/merge_pass.cc | 0 {ge => src/ge}/graph/passes/merge_pass.h | 0 .../ge}/graph/passes/merge_to_stream_merge_pass.cc | 0 .../ge}/graph/passes/merge_to_stream_merge_pass.h | 0 .../ge}/graph/passes/multi_batch_clone_pass.cc | 0 .../ge}/graph/passes/multi_batch_clone_pass.h | 0 {ge => src/ge}/graph/passes/multi_batch_pass.cc | 0 {ge => src/ge}/graph/passes/multi_batch_pass.h | 0 {ge => src/ge}/graph/passes/net_output_pass.cc | 0 {ge => src/ge}/graph/passes/net_output_pass.h | 0 {ge => src/ge}/graph/passes/next_iteration_pass.cc | 0 {ge => src/ge}/graph/passes/next_iteration_pass.h | 0 .../ge}/graph/passes/no_use_reshape_remove_pass.cc | 0 .../ge}/graph/passes/no_use_reshape_remove_pass.h | 0 .../graph/passes/parallel_concat_start_op_pass.cc | 0 .../graph/passes/parallel_concat_start_op_pass.h | 0 {ge => src/ge}/graph/passes/pass_manager.cc | 0 {ge => src/ge}/graph/passes/pass_utils.cc | 0 {ge => src/ge}/graph/passes/pass_utils.h | 0 {ge => src/ge}/graph/passes/permute_pass.cc | 0 {ge => src/ge}/graph/passes/permute_pass.h | 0 .../graph/passes/placeholder_with_default_pass.cc | 0 .../graph/passes/placeholder_with_default_pass.h | 0 .../ge}/graph/passes/prevent_gradient_pass.cc | 0 .../ge}/graph/passes/prevent_gradient_pass.h | 0 {ge => src/ge}/graph/passes/print_op_pass.cc | 0 {ge => src/ge}/graph/passes/print_op_pass.h | 0 {ge => src/ge}/graph/passes/prune_pass.cc | 0 {ge => src/ge}/graph/passes/prune_pass.h | 0 .../graph/passes/ref_identity_delete_op_pass.cc | 0 .../ge}/graph/passes/ref_identity_delete_op_pass.h | 0 {ge => src/ge}/graph/passes/remove_nodes_pass.cc | 0 {ge => src/ge}/graph/passes/remove_nodes_pass.h | 0 .../ge}/graph/passes/replace_transshape_pass.cc | 0 .../ge}/graph/passes/replace_transshape_pass.h | 0 .../graph/passes/replace_with_empty_const_pass.cc | 0 .../graph/passes/replace_with_empty_const_pass.h | 0 .../ge}/graph/passes/reshape_recovery_pass.cc | 0 .../ge}/graph/passes/reshape_recovery_pass.h | 0 {ge => src/ge}/graph/passes/reshape_remove_pass.cc | 0 {ge => src/ge}/graph/passes/reshape_remove_pass.h | 0 .../graph/passes/resource_pair_add_control_pass.cc | 0 .../graph/passes/resource_pair_add_control_pass.h | 0 .../passes/resource_pair_remove_control_pass.cc | 0 .../passes/resource_pair_remove_control_pass.h | 0 .../passes/same_transdata_breadth_fusion_pass.cc | 0 .../passes/same_transdata_breadth_fusion_pass.h | 0 {ge => src/ge}/graph/passes/save_pass.cc | 0 {ge => src/ge}/graph/passes/save_pass.h | 0 .../graph/passes/set_input_output_offset_pass.cc | 0 .../graph/passes/set_input_output_offset_pass.h | 0 .../graph/passes/shape_operate_op_remove_pass.cc | 0 .../graph/passes/shape_operate_op_remove_pass.h | 0 {ge => src/ge}/graph/passes/snapshot_pass.cc | 0 {ge => src/ge}/graph/passes/snapshot_pass.h | 0 {ge => src/ge}/graph/passes/stop_gradient_pass.cc | 0 {ge => src/ge}/graph/passes/stop_gradient_pass.h | 0 .../graph/passes/subexpression_migration_pass.cc | 0 .../graph/passes/subexpression_migration_pass.h | 0 {ge => src/ge}/graph/passes/subgraph_pass.cc | 0 {ge => src/ge}/graph/passes/subgraph_pass.h | 0 .../ge}/graph/passes/switch_data_edges_bypass.cc | 0 .../ge}/graph/passes/switch_data_edges_bypass.h | 0 .../graph/passes/switch_dead_branch_elimination.cc | 0 .../graph/passes/switch_dead_branch_elimination.h | 0 .../ge}/graph/passes/switch_logic_remove_pass.cc | 0 .../ge}/graph/passes/switch_logic_remove_pass.h | 0 .../graph/passes/switch_to_stream_switch_pass.cc | 0 .../graph/passes/switch_to_stream_switch_pass.h | 0 .../graph/passes/transop_breadth_fusion_pass.cc | 0 .../ge}/graph/passes/transop_breadth_fusion_pass.h | 0 .../ge}/graph/passes/transop_depth_fusion_pass.cc | 0 .../ge}/graph/passes/transop_depth_fusion_pass.h | 0 .../passes/transop_nearby_allreduce_fusion_pass.cc | 0 .../passes/transop_nearby_allreduce_fusion_pass.h | 0 .../passes/transop_symmetry_elimination_pass.cc | 0 .../passes/transop_symmetry_elimination_pass.h | 0 .../passes/transop_without_reshape_fusion_pass.cc | 0 .../passes/transop_without_reshape_fusion_pass.h | 0 .../ge}/graph/passes/transpose_transdata_pass.cc | 0 .../ge}/graph/passes/transpose_transdata_pass.h | 0 .../ge}/graph/passes/unused_args_clean_pass.cc | 0 .../ge}/graph/passes/unused_args_clean_pass.h | 0 {ge => src/ge}/graph/passes/unused_const_pass.cc | 0 {ge => src/ge}/graph/passes/unused_const_pass.h | 0 .../ge}/graph/passes/unused_op_remove_pass.cc | 0 .../ge}/graph/passes/unused_op_remove_pass.h | 0 .../ge}/graph/passes/var_is_initialized_op_pass.cc | 0 .../ge}/graph/passes/var_is_initialized_op_pass.h | 0 .../ge}/graph/passes/variable_format_pass.cc | 0 {ge => src/ge}/graph/passes/variable_format_pass.h | 0 {ge => src/ge}/graph/passes/variable_op_pass.cc | 0 {ge => src/ge}/graph/passes/variable_op_pass.h | 0 .../ge}/graph/passes/variable_prepare_op_pass.cc | 0 .../ge}/graph/passes/variable_prepare_op_pass.h | 0 .../graph/passes/variable_ref_delete_op_pass.cc | 0 .../ge}/graph/passes/variable_ref_delete_op_pass.h | 0 ...variable_ref_useless_control_out_delete_pass.cc | 0 .../variable_ref_useless_control_out_delete_pass.h | 0 .../ge}/graph/preprocess/graph_preprocess.cc | 0 {ge => src/ge}/graph/preprocess/graph_preprocess.h | 0 .../graph/preprocess/insert_op/base_insert_op.h | 0 .../ge}/graph/preprocess/insert_op/ge_aipp_op.cc | 0 .../ge}/graph/preprocess/insert_op/ge_aipp_op.h | 0 .../preprocess/insert_op/util_insert_aipp_op.cc | 0 .../preprocess/insert_op/util_insert_aipp_op.h | 0 .../ge}/graph/preprocess/multi_batch_copy_graph.cc | 0 .../ge}/graph/preprocess/multi_batch_copy_graph.h | 0 .../ge}/graph/preprocess/multi_batch_options.cc | 0 .../ge}/graph/preprocess/multi_batch_options.h | 0 .../ge}/host_cpu_engine/common/constant/constant.h | 0 .../ge}/host_cpu_engine/engine/host_cpu_engine.cc | 0 .../ge}/host_cpu_engine/engine/host_cpu_engine.h | 0 {ge => src/ge}/host_cpu_engine/module.mk | 0 .../ops_kernel_store/host_cpu_ops_kernel_info.cc | 0 .../ops_kernel_store/host_cpu_ops_kernel_info.h | 0 .../host_cpu_engine/ops_kernel_store/op/host_op.cc | 0 .../host_cpu_engine/ops_kernel_store/op/host_op.h | 0 .../ge}/host_cpu_engine/ops_kernel_store/op/op.h | 0 .../ops_kernel_store/op/op_factory.cc | 0 .../ops_kernel_store/op/op_factory.h | 0 src/ge/host_cpu_engine/proto/task.proto | 1 + {ge => src/ge}/host_kernels/add_kernel.cc | 0 {ge => src/ge}/host_kernels/add_kernel.h | 0 .../ge}/host_kernels/broadcast_args_kernel.cc | 0 .../ge}/host_kernels/broadcast_args_kernel.h | 0 .../host_kernels/broadcast_gradient_args_kernel.cc | 0 .../host_kernels/broadcast_gradient_args_kernel.h | 0 {ge => src/ge}/host_kernels/cast_kernel.cc | 0 {ge => src/ge}/host_kernels/cast_kernel.h | 0 .../ge}/host_kernels/concat_offset_kernel.cc | 0 {ge => src/ge}/host_kernels/concat_offset_kernel.h | 0 {ge => src/ge}/host_kernels/concat_v2_kernel.cc | 0 {ge => src/ge}/host_kernels/concat_v2_kernel.h | 0 .../ge}/host_kernels/dynamic_stitch_kernel.cc | 0 .../ge}/host_kernels/dynamic_stitch_kernel.h | 0 {ge => src/ge}/host_kernels/empty_kernel.cc | 0 {ge => src/ge}/host_kernels/empty_kernel.h | 0 {ge => src/ge}/host_kernels/expanddims_kernel.cc | 0 {ge => src/ge}/host_kernels/expanddims_kernel.h | 0 {ge => src/ge}/host_kernels/fill_kernel.cc | 0 {ge => src/ge}/host_kernels/fill_kernel.h | 0 {ge => src/ge}/host_kernels/floordiv_kernel.cc | 0 {ge => src/ge}/host_kernels/floordiv_kernel.h | 0 {ge => src/ge}/host_kernels/floormod_kernel.cc | 0 {ge => src/ge}/host_kernels/floormod_kernel.h | 0 {ge => src/ge}/host_kernels/gather_v2_kernel.cc | 0 {ge => src/ge}/host_kernels/gather_v2_kernel.h | 0 {ge => src/ge}/host_kernels/greater_kernel.cc | 0 {ge => src/ge}/host_kernels/greater_kernel.h | 0 {ge => src/ge}/host_kernels/identity_kernel.cc | 0 {ge => src/ge}/host_kernels/identity_kernel.h | 0 {ge => src/ge}/host_kernels/kernel_utils.cc | 0 {ge => src/ge}/host_kernels/kernel_utils.h | 0 {ge => src/ge}/host_kernels/maximum_kernel.cc | 0 {ge => src/ge}/host_kernels/maximum_kernel.h | 0 {ge => src/ge}/host_kernels/mul_kernel.cc | 0 {ge => src/ge}/host_kernels/mul_kernel.h | 0 {ge => src/ge}/host_kernels/pack_kernel.cc | 0 {ge => src/ge}/host_kernels/pack_kernel.h | 0 {ge => src/ge}/host_kernels/permute_kernel.cc | 0 {ge => src/ge}/host_kernels/permute_kernel.h | 0 {ge => src/ge}/host_kernels/range_kernel.cc | 0 {ge => src/ge}/host_kernels/range_kernel.h | 0 {ge => src/ge}/host_kernels/rank_kernel.cc | 0 {ge => src/ge}/host_kernels/rank_kernel.h | 0 {ge => src/ge}/host_kernels/reduce_prod_kernel.cc | 0 {ge => src/ge}/host_kernels/reduce_prod_kernel.h | 0 {ge => src/ge}/host_kernels/reformat_kernel.cc | 0 {ge => src/ge}/host_kernels/reformat_kernel.h | 0 {ge => src/ge}/host_kernels/reshape_kernel.cc | 0 {ge => src/ge}/host_kernels/reshape_kernel.h | 0 {ge => src/ge}/host_kernels/rsqrt_kernel.cc | 0 {ge => src/ge}/host_kernels/rsqrt_kernel.h | 0 {ge => src/ge}/host_kernels/shape_kernel.cc | 0 {ge => src/ge}/host_kernels/shape_kernel.h | 0 {ge => src/ge}/host_kernels/shape_n_kernel.cc | 0 {ge => src/ge}/host_kernels/shape_n_kernel.h | 0 {ge => src/ge}/host_kernels/size_kernel.cc | 0 {ge => src/ge}/host_kernels/size_kernel.h | 0 {ge => src/ge}/host_kernels/slice_d_kernel.cc | 0 {ge => src/ge}/host_kernels/slice_d_kernel.h | 0 {ge => src/ge}/host_kernels/slice_kernel.cc | 0 {ge => src/ge}/host_kernels/slice_kernel.h | 0 {ge => src/ge}/host_kernels/squeeze_kernel.cc | 0 {ge => src/ge}/host_kernels/squeeze_kernel.h | 0 .../ge}/host_kernels/ssd_prior_box_kernel.cc | 0 {ge => src/ge}/host_kernels/ssd_prior_box_kernel.h | 0 .../ge}/host_kernels/strided_slice_kernel.cc | 0 {ge => src/ge}/host_kernels/strided_slice_kernel.h | 0 {ge => src/ge}/host_kernels/sub_kernel.cc | 0 {ge => src/ge}/host_kernels/sub_kernel.h | 0 {ge => src/ge}/host_kernels/transdata_kernel.cc | 0 {ge => src/ge}/host_kernels/transdata_kernel.h | 0 {ge => src/ge}/host_kernels/transpose_kernel.cc | 0 {ge => src/ge}/host_kernels/transpose_kernel.h | 0 {ge => src/ge}/host_kernels/unpack_kernel.cc | 0 {ge => src/ge}/host_kernels/unpack_kernel.h | 0 {ge => src/ge}/host_kernels/unsqueeze_kernel.cc | 0 {ge => src/ge}/host_kernels/unsqueeze_kernel.h | 0 .../ge}/hybrid/common/npu_memory_allocator.cc | 0 .../ge}/hybrid/common/npu_memory_allocator.h | 0 {ge => src/ge}/hybrid/common/tensor_value.cc | 0 {ge => src/ge}/hybrid/common/tensor_value.h | 0 .../hybrid/executor/hybrid_execution_context.cc | 0 .../ge}/hybrid/executor/hybrid_execution_context.h | 0 .../hybrid/executor/hybrid_model_async_executor.cc | 0 .../hybrid/executor/hybrid_model_async_executor.h | 0 .../ge}/hybrid/executor/hybrid_model_executor.cc | 0 .../ge}/hybrid/executor/hybrid_model_executor.h | 0 {ge => src/ge}/hybrid/executor/hybrid_profiler.cc | 0 {ge => src/ge}/hybrid/executor/hybrid_profiler.h | 0 .../ge}/hybrid/executor/node_done_manager.cc | 0 {ge => src/ge}/hybrid/executor/node_done_manager.h | 0 {ge => src/ge}/hybrid/executor/node_state.cc | 0 {ge => src/ge}/hybrid/executor/node_state.h | 0 .../ge}/hybrid/executor/rt_callback_manager.cc | 0 .../ge}/hybrid/executor/rt_callback_manager.h | 0 {ge => src/ge}/hybrid/executor/subgraph_context.cc | 0 {ge => src/ge}/hybrid/executor/subgraph_context.h | 0 .../ge}/hybrid/executor/subgraph_executor.cc | 0 {ge => src/ge}/hybrid/executor/subgraph_executor.h | 0 .../ge}/hybrid/executor/worker/execution_engine.cc | 0 .../ge}/hybrid/executor/worker/execution_engine.h | 0 .../executor/worker/shape_inference_engine.cc | 0 .../executor/worker/shape_inference_engine.h | 0 .../hybrid/executor/worker/task_compile_engine.cc | 0 .../hybrid/executor/worker/task_compile_engine.h | 0 {ge => src/ge}/hybrid/hybrid_davinci_model.cc | 0 {ge => src/ge}/hybrid/hybrid_davinci_model.h | 0 {ge => src/ge}/hybrid/hybrid_davinci_model_stub.cc | 0 {ge => src/ge}/hybrid/model/graph_item.cc | 0 {ge => src/ge}/hybrid/model/graph_item.h | 0 {ge => src/ge}/hybrid/model/hybrid_model.cc | 0 {ge => src/ge}/hybrid/model/hybrid_model.h | 0 .../ge}/hybrid/model/hybrid_model_builder.cc | 0 {ge => src/ge}/hybrid/model/hybrid_model_builder.h | 0 {ge => src/ge}/hybrid/model/node_item.cc | 0 {ge => src/ge}/hybrid/model/node_item.h | 0 .../node_executor/aicore/aicore_node_executor.cc | 0 .../node_executor/aicore/aicore_node_executor.h | 0 .../hybrid/node_executor/aicore/aicore_op_task.cc | 0 .../hybrid/node_executor/aicore/aicore_op_task.h | 0 .../node_executor/aicore/aicore_task_builder.cc | 0 .../node_executor/aicore/aicore_task_builder.h | 0 .../node_executor/aicore/aicore_task_compiler.cc | 0 .../node_executor/aicore/aicore_task_compiler.h | 0 .../hybrid/node_executor/aicpu/aicpu_ext_info.cc | 0 .../hybrid/node_executor/aicpu/aicpu_ext_info.h | 0 .../node_executor/aicpu/aicpu_node_executor.cc | 0 .../node_executor/aicpu/aicpu_node_executor.h | 0 .../compiledsubgraph/known_node_executor.cc | 0 .../compiledsubgraph/known_node_executor.h | 0 .../node_executor/controlop/control_op_executor.cc | 0 .../node_executor/controlop/control_op_executor.h | 0 .../ge_local/ge_local_node_executor.cc | 0 .../ge_local/ge_local_node_executor.h | 0 .../node_executor/hccl/hccl_node_executor.cc | 0 .../hybrid/node_executor/hccl/hccl_node_executor.h | 0 .../host_cpu/host_cpu_node_executor.cc | 0 .../host_cpu/host_cpu_node_executor.h | 0 .../node_executor/host_cpu/kernel/assign_kernel.cc | 0 .../node_executor/host_cpu/kernel/assign_kernel.h | 0 .../hybrid/node_executor/host_cpu/kernel/kernel.h | 0 .../node_executor/host_cpu/kernel/no_op_kernel.cc | 0 .../node_executor/host_cpu/kernel/no_op_kernel.h | 0 .../host_cpu/kernel/random_uniform_kernel.cc | 0 .../host_cpu/kernel/random_uniform_kernel.h | 0 .../host_cpu/kernel/variable_kernel.cc | 0 .../host_cpu/kernel/variable_kernel.h | 0 .../node_executor/host_cpu/kernel_factory.cc | 0 .../hybrid/node_executor/host_cpu/kernel_factory.h | 0 .../ge}/hybrid/node_executor/node_executor.cc | 0 .../ge}/hybrid/node_executor/node_executor.h | 0 .../partitioned_call_node_executor.cc | 0 .../partitioned_call_node_executor.h | 0 .../hybrid/node_executor/rts/rts_node_executor.cc | 0 .../hybrid/node_executor/rts/rts_node_executor.h | 0 .../ge}/hybrid/node_executor/task_context.cc | 0 {ge => src/ge}/hybrid/node_executor/task_context.h | 0 {ge => src/ge}/inc/graph_pass.h | 0 {ge => src/ge}/inc/kernel.h | 0 {ge => src/ge}/inc/kernel_factory.h | 0 {ge => src/ge}/inc/pass.h | 0 {ge => src/ge}/inc/pass_manager.h | 0 {ge => src/ge}/init/gelib.cc | 0 {ge => src/ge}/init/gelib.h | 0 {ge => src/ge}/ir_build/atc_ir_common.cc | 0 {ge => src/ge}/ir_build/atc_ir_common.h | 0 {ge => src/ge}/ir_build/ge_ir_build.cc | 0 {ge => src/ge}/model/ge_model.cc | 0 {ge => src/ge}/model/ge_model.h | 0 {ge => src/ge}/model/ge_root_model.cc | 0 {ge => src/ge}/model/ge_root_model.h | 0 {ge => src/ge}/module.mk | 0 {ge => src/ge}/omm/csa_interact.cc | 0 {ge => src/ge}/omm/csa_interact.h | 0 .../ge}/opskernel_manager/ops_kernel_manager.cc | 0 .../ge}/opskernel_manager/ops_kernel_manager.h | 0 .../ge}/opskernel_manager/optimizer_priority.pbtxt | 0 src/ge/plugin/engine/CMakeLists.txt | 45 + {ge => src/ge}/plugin/engine/dnnengines.cc | 0 {ge => src/ge}/plugin/engine/dnnengines.h | 0 {ge => src/ge}/plugin/engine/engine_manage.cc | 0 {ge => src/ge}/plugin/engine/engine_manage.h | 0 {ge => src/ge}/plugin/engine/module.mk | 0 {ge => src/ge}/session/inner_session.cc | 0 {ge => src/ge}/session/inner_session.h | 0 {ge => src/ge}/session/omg.cc | 0 {ge => src/ge}/session/session_manager.cc | 0 {ge => src/ge}/session/session_manager.h | 0 {ge => src/ge}/single_op/single_op.cc | 0 {ge => src/ge}/single_op/single_op.h | 0 {ge => src/ge}/single_op/single_op_manager.cc | 0 {ge => src/ge}/single_op/single_op_manager.h | 0 {ge => src/ge}/single_op/single_op_model.cc | 0 {ge => src/ge}/single_op/single_op_model.h | 0 {ge => src/ge}/single_op/stream_resource.cc | 0 {ge => src/ge}/single_op/stream_resource.h | 0 .../single_op/task/aicpu_kernel_task_builder.cc | 0 .../ge}/single_op/task/aicpu_kernel_task_builder.h | 0 .../ge}/single_op/task/aicpu_task_builder.cc | 0 {ge => src/ge}/single_op/task/aicpu_task_builder.h | 0 {ge => src/ge}/single_op/task/build_task_utils.cc | 0 {ge => src/ge}/single_op/task/build_task_utils.h | 0 {ge => src/ge}/single_op/task/op_task.cc | 0 {ge => src/ge}/single_op/task/op_task.h | 0 {ge => src/ge}/single_op/task/tbe_task_builder.cc | 0 {ge => src/ge}/single_op/task/tbe_task_builder.h | 0 {ge => src/ge}/stub/Makefile | 0 {ge => src/ge}/stub/README | 0 {ge => src/ge}/stub/README.md | 0 {ge => src/ge}/stub/gen_stubapi.py | 0 {ge/executor => src}/proto/dump_task.proto | 0 {ge => src}/proto/fusion_model.proto | 0 {ge => src}/proto/fwk_adapter.proto | 0 {ge/client => src}/proto/ge_api.proto | 0 {ge/client => src}/proto/ge_ir.proto | 0 {ge/client => src}/proto/insert_op.proto | 0 {ge/client => src}/proto/om.proto | 0 {ge/common => src}/proto/op_mapping_info.proto | 0 {ge => src}/proto/optimizer_priority.proto | 0 {ge/client => src}/proto/task.proto | 0 tests/depends/cce/src/cce_stub.cc | 1 + tests/st/resnet50/resnet50_train.cc | 2 + .../graph/testcase/ge_graph/ge_model_unittest.cc | 2 + tests/ut/ge/CMakeLists.txt | 1 - .../new_model_manager_davinci_model_unittest.cc | 16 + .../new_model_manager_model_manager_unittest.cc | 95 + tests/ut/ge/graph/load/new_op_test_utils.h | 5 + .../graph/passes/dimension_adjust_pass_unittest.cc | 1 + .../strided_slice_kernel_unittest.cc | 15 + .../graph/passes/guarantee_const_pass_unittest.cc | 1 + tests/ut/ge/graph/passes/identity_pass_unittest.cc | 1 + .../ut/ge/graph/passes/net_output_pass_unittest.cc | 1 + .../placeholder_with_default_pass_unittest.cc | 1 + .../graph/passes/prevent_gradient_pass_unittest.cc | 1 + .../graph/passes/reshape_remove_pass_unittest.cc | 1 + tests/ut/ge/graph/passes/snapshot_pass_unittest.cc | 1 + .../ge/graph/passes/stop_gradient_pass_unittest.cc | 1 + tests/ut/ge/graph/passes/switch_pass_unittest.cc | 2 + .../unused_and_isolated_op_remove_pass_unittest.cc | 1 + .../ge/graph/passes/variable_op_pass_unittest.cc | 4 + .../securec/0001-add-securec-cmake-script.patch | 105 - 982 files changed, 35072 insertions(+), 11871 deletions(-) delete mode 100644 .gitmodules delete mode 100644 cmake/FindModule.cmake create mode 100644 cmake/external_libs/eigen.cmake delete mode 100755 cmake/external_libs/gflags.cmake create mode 100644 cmake/external_libs/gtest.cmake mode change 100755 => 100644 cmake/external_libs/json.cmake mode change 100755 => 100644 cmake/external_libs/onnx.cmake create mode 100644 cmake/external_libs/protobuf.cmake delete mode 100755 cmake/external_libs/protobuf_shared.cmake delete mode 100755 cmake/external_libs/protobuf_static.cmake delete mode 100755 cmake/external_libs/protoc.cmake mode change 100755 => 100644 cmake/external_libs/securec.cmake create mode 100644 cmake/ge_utils.cmake delete mode 100755 cmake/intf_pub_android.cmake delete mode 100755 cmake/intf_pub_linux.cmake delete mode 100755 cmake/intf_pub_windows.cmake delete mode 100755 ge/CMakeLists.txt delete mode 100755 ge/README.md delete mode 100755 ge/common/CMakeLists.txt delete mode 100644 ge/common/proto/ge_ir.proto delete mode 100644 ge/common/proto/insert_op.proto delete mode 100644 ge/common/proto/om.proto delete mode 100644 ge/common/proto/task.proto delete mode 100644 ge/common/proto/tensorflow/attr_value.proto delete mode 100644 ge/common/proto/tensorflow/function.proto delete mode 100644 ge/common/proto/tensorflow/graph.proto delete mode 100644 ge/common/proto/tensorflow/graph_library.proto delete mode 100644 ge/common/proto/tensorflow/node_def.proto delete mode 100644 ge/common/proto/tensorflow/op_def.proto delete mode 100644 ge/common/proto/tensorflow/resource_handle.proto delete mode 100644 ge/common/proto/tensorflow/tensor.proto delete mode 100644 ge/common/proto/tensorflow/tensor_shape.proto delete mode 100644 ge/common/proto/tensorflow/types.proto delete mode 100644 ge/common/proto/tensorflow/versions.proto delete mode 100755 ge/executor/CMakeLists.txt delete mode 100644 ge/executor/proto/ge_ir.proto delete mode 100644 ge/executor/proto/insert_op.proto delete mode 100644 ge/executor/proto/om.proto delete mode 100644 ge/executor/proto/op_mapping_info.proto delete mode 100644 ge/executor/proto/task.proto delete mode 100755 ge/ge_local_engine/CMakeLists.txt delete mode 100644 ge/ge_local_engine/proto/task.proto delete mode 100644 ge/ge_runtime/CMakeLists.txt delete mode 100755 ge/ge_runtime/module.mk delete mode 100644 ge/graph/build/memory/CMakeLists.txt delete mode 100644 ge/host_cpu_engine/CMakeLists.txt delete mode 100644 ge/host_cpu_engine/proto/task.proto delete mode 100644 ge/offline/CMakeLists.txt delete mode 100755 ge/offline/main.cc delete mode 100755 ge/offline/module.mk delete mode 100644 ge/offline/proto/ge_ir.proto delete mode 100644 ge/offline/proto/insert_op.proto delete mode 100644 ge/offline/proto/om.proto delete mode 100644 ge/offline/proto/task.proto delete mode 100644 ge/offline/single_op_parser.cc delete mode 100644 ge/offline/single_op_parser.h delete mode 100644 ge/plugin/engine/CMakeLists.txt delete mode 100644 ge/proto/caffe/caffe.proto delete mode 100644 ge/proto/dump_task.proto delete mode 100755 ge/proto/ge_api.proto delete mode 100644 ge/proto/ge_ir.proto delete mode 100644 ge/proto/insert_op.proto delete mode 100644 ge/proto/om.proto delete mode 100644 ge/proto/op_mapping_info.proto delete mode 100644 ge/proto/task.proto delete mode 100644 ge/proto/tensorflow/attr_value.proto delete mode 100644 ge/proto/tensorflow/function.proto delete mode 100644 ge/proto/tensorflow/graph.proto delete mode 100644 ge/proto/tensorflow/graph_library.proto delete mode 100644 ge/proto/tensorflow/node_def.proto delete mode 100644 ge/proto/tensorflow/op_def.proto delete mode 100644 ge/proto/tensorflow/resource_handle.proto delete mode 100644 ge/proto/tensorflow/tensor.proto delete mode 100644 ge/proto/tensorflow/tensor_shape.proto delete mode 100644 ge/proto/tensorflow/types.proto delete mode 100644 ge/proto/tensorflow/versions.proto delete mode 100644 ge/session/readme.txt create mode 100644 inc/common/blocking_queue.h create mode 100644 inc/common/dynamic_aipp.h create mode 100644 inc/common/npu_error_define.h create mode 100644 inc/common/opskernel/ge_task_info.h create mode 100644 inc/common/opskernel/ops_kernel_info_store.h create mode 100644 inc/common/opskernel/ops_kernel_info_types.h create mode 100644 inc/common/optimizer/graph_optimizer.h create mode 100644 inc/common/optimizer/graph_optimizer_types.h create mode 100644 inc/common/util/ai_core/common/aicore_util_attr_define.h create mode 100644 inc/common/util/ai_core/common/aicore_util_types.h create mode 100644 inc/common/util/ai_core/common/graph_comm.h create mode 100644 inc/common/util/ai_core/common/scope_allocator.h create mode 100644 inc/common/util/ai_core/param_calculate/aicore_param_calculator.h create mode 100644 inc/common/util/ai_core/param_calculate/tensorsize_calculator.h create mode 100644 inc/common/util/compress/compress.h create mode 100644 inc/common/util/compress/compress_weight.h create mode 100644 inc/common/util/error_manager/error_manager.h create mode 100644 inc/common/util/platform_info.h create mode 100644 inc/common/util/platform_info_def.h create mode 100644 inc/external/graph/attr_value.h create mode 100644 inc/external/graph/ge_error_codes.h create mode 100644 inc/external/graph/graph.h create mode 100644 inc/external/graph/inference_context.h create mode 100644 inc/external/graph/operator.h create mode 100644 inc/external/graph/operator_factory.h create mode 100644 inc/external/graph/operator_reg.h create mode 100644 inc/external/graph/tensor.h create mode 100644 inc/external/graph/types.h create mode 100644 inc/external/register/register.h create mode 100644 inc/external/register/register_error_codes.h create mode 100644 inc/external/register/register_fmk_types.h create mode 100644 inc/external/register/register_types.h create mode 100644 inc/external/register/scope/scope_fusion_pass_register.h delete mode 100644 inc/framework/omg/parser/model_parser.h delete mode 100644 inc/framework/omg/parser/op_parser.h delete mode 100644 inc/framework/omg/parser/parser_api.h delete mode 100644 inc/framework/omg/parser/parser_factory.h delete mode 100644 inc/framework/omg/parser/parser_inner_ctx.h delete mode 100644 inc/framework/omg/parser/weights_parser.h create mode 100644 inc/graph/anchor.h create mode 100644 inc/graph/attr_value_serializable.h create mode 100644 inc/graph/buffer.h create mode 100644 inc/graph/compute_graph.h create mode 100644 inc/graph/debug/ge_attr_define.h create mode 100644 inc/graph/def_types.h create mode 100644 inc/graph/detail/any_map.h create mode 100644 inc/graph/detail/attributes_holder.h create mode 100644 inc/graph/detail/model_serialize_imp.h create mode 100644 inc/graph/ge_attr_value.h create mode 100644 inc/graph/ge_context.h create mode 100644 inc/graph/ge_global_options.h create mode 100644 inc/graph/ge_local_context.h create mode 100644 inc/graph/ge_tensor.h create mode 100644 inc/graph/graph_util.h create mode 100644 inc/graph/model.h create mode 100644 inc/graph/model_serialize.h create mode 100644 inc/graph/node.h create mode 100644 inc/graph/op_desc.h create mode 100644 inc/graph/op_kernel_bin.h create mode 100644 inc/graph/operator_factory_impl.h create mode 100644 inc/graph/opsproto_manager.h create mode 100644 inc/graph/range_vistor.h create mode 100644 inc/graph/ref_relation.h create mode 100644 inc/graph/runtime_inference_context.h create mode 100644 inc/graph/shape_refiner.h create mode 100644 inc/graph/tuning_utils.h create mode 100644 inc/graph/usr_types.h create mode 100644 inc/graph/utils/anchor_utils.h create mode 100644 inc/graph/utils/attr_utils.h create mode 100644 inc/graph/utils/graph_utils.h create mode 100644 inc/graph/utils/node_utils.h create mode 100644 inc/graph/utils/op_desc_utils.h create mode 100644 inc/graph/utils/tensor_adapter.h create mode 100644 inc/graph/utils/tensor_utils.h create mode 100644 inc/graph/utils/type_utils.h delete mode 160000 metadef delete mode 160000 parser create mode 100755 src/common/graph/CMakeLists.txt create mode 100644 src/common/graph/anchor.cc create mode 100644 src/common/graph/attr_value.cc create mode 100644 src/common/graph/buffer.cc create mode 100644 src/common/graph/compute_graph.cc create mode 100644 src/common/graph/debug/ge_log.h create mode 100644 src/common/graph/debug/ge_op_types.h create mode 100644 src/common/graph/debug/ge_util.h create mode 100644 src/common/graph/debug/graph_debug.cc create mode 100644 src/common/graph/debug/graph_debug.h create mode 100644 src/common/graph/detail/attributes_holder.cc create mode 100644 src/common/graph/format_refiner.cc create mode 100644 src/common/graph/format_refiner.h create mode 100644 src/common/graph/ge_attr_define.cc create mode 100644 src/common/graph/ge_attr_value.cc create mode 100644 src/common/graph/ge_tensor.cc create mode 100644 src/common/graph/graph.cc create mode 100644 src/common/graph/graph.mk create mode 100644 src/common/graph/inference_context.cc create mode 100644 src/common/graph/model.cc create mode 100644 src/common/graph/model_serialize.cc create mode 100644 src/common/graph/module.mk create mode 100644 src/common/graph/node.cc create mode 100644 src/common/graph/op_desc.cc create mode 100644 src/common/graph/op_imp.cc create mode 100644 src/common/graph/operator.cc create mode 100644 src/common/graph/operator_factory.cc create mode 100644 src/common/graph/operator_factory_impl.cc create mode 100644 src/common/graph/opsproto/opsproto_manager.cc create mode 100644 src/common/graph/option/ge_context.cc create mode 100644 src/common/graph/option/ge_local_context.cc create mode 100644 src/common/graph/ref_relation.cc create mode 100644 src/common/graph/runtime_inference_context.cc create mode 100644 src/common/graph/shape_refiner.cc create mode 100644 src/common/graph/stub/Makefile create mode 100644 src/common/graph/stub/gen_stubapi.py create mode 100644 src/common/graph/tensor.cc create mode 100644 src/common/graph/utils/anchor_utils.cc create mode 100644 src/common/graph/utils/ge_ir_utils.cc create mode 100644 src/common/graph/utils/ge_ir_utils.h create mode 100644 src/common/graph/utils/graph_utils.cc create mode 100644 src/common/graph/utils/mem_utils.h create mode 100644 src/common/graph/utils/node_utils.cc create mode 100644 src/common/graph/utils/op_desc_utils.cc create mode 100644 src/common/graph/utils/string_utils.h create mode 100644 src/common/graph/utils/tensor_utils.cc create mode 100644 src/common/graph/utils/tuning_utils.cc create mode 100644 src/common/graph/utils/type_utils.cc create mode 100755 src/ge/CMakeLists.txt rename {ge => src/ge}/analyzer/analyzer.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/analyzer/analyzer.h (100%) mode change 100755 => 100644 create mode 100755 src/ge/client/CMakeLists.txt rename {ge => src/ge}/client/ge_api.cc (100%) rename {ge => src/ge}/client/ge_prof.cc (100%) rename {ge => src/ge}/client/module.mk (100%) create mode 100755 src/ge/common/CMakeLists.txt rename {ge => src/ge}/common/auth/file_saver.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/auth/file_saver.h (100%) rename {ge => src/ge}/common/base64.h (100%) rename {ge => src/ge}/common/context/ctx.cc (100%) mode change 100755 => 100644 create mode 100644 src/ge/common/convert/pb2json.cc create mode 100644 src/ge/common/convert/pb2json.h rename {ge => src/ge}/common/cust_aicpu_kernel_store.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/cust_aicpu_kernel_store.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/debug/memory_dumper.cc (100%) rename {ge => src/ge}/common/debug/memory_dumper.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/dump/dump_manager.cc (100%) rename {ge => src/ge}/common/dump/dump_manager.h (100%) rename {ge => src/ge}/common/dump/dump_op.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/dump/dump_op.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/dump/dump_properties.cc (100%) rename {ge => src/ge}/common/dump/dump_properties.h (100%) rename {ge => src/ge}/common/dump/dump_server.cc (100%) rename {ge => src/ge}/common/fmk_error_codes.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/datatype_transfer.cc (100%) rename {ge => src/ge}/common/formats/format_transfers/datatype_transfer.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc (100%) rename {ge => src/ge}/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h (100%) rename {ge => src/ge}/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc (100%) rename {ge => src/ge}/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc (100%) rename {ge => src/ge}/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_fractal_nz.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_fractal_nz.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_fractal_z.cc (100%) rename {ge => src/ge}/common/formats/format_transfers/format_transfer_fractal_z.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_fractal_zz.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_fractal_zz.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_fracz_hwcn.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_fracz_hwcn.h (100%) rename {ge => src/ge}/common/formats/format_transfers/format_transfer_fracz_nchw.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_fracz_nchw.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_fracz_nhwc.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_fracz_nhwc.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h (100%) rename {ge => src/ge}/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc (100%) rename {ge => src/ge}/common/formats/format_transfers/format_transfer_nchw_fz_c04.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h (100%) rename {ge => src/ge}/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_transpose.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/format_transfers/format_transfer_transpose.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/formats.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/formats.h (100%) rename {ge => src/ge}/common/formats/utils/formats_definitions.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/utils/formats_trans_utils.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/formats/utils/formats_trans_utils.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/fp16_t.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/fp16_t.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/ge/datatype_util.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/ge/datatype_util.h (100%) rename {ge => src/ge}/common/ge/ge_util.h (100%) rename {ge => src/ge}/common/ge/op_tiling_manager.cc (100%) rename {ge => src/ge}/common/ge/op_tiling_manager.h (100%) rename {ge => src/ge}/common/ge/plugin_manager.cc (100%) rename {ge => src/ge}/common/ge/plugin_manager.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/ge/tbe_plugin_manager.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/ge/tbe_plugin_manager.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/ge_common.mk (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/ge_format_util.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/helper/model_cache_helper.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/helper/model_cache_helper.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/helper/model_helper.cc (100%) rename {ge => src/ge}/common/helper/om_file_helper.cc (100%) rename {ge => src/ge}/common/kernel_store.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/kernel_store.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/math/fp16_math.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/math/fp16_math.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/math/math_util.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/math_util.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/model_parser/base.cc (100%) rename {ge => src/ge}/common/model_parser/base.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/model_saver.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/model_saver.h (100%) rename {ge => src/ge}/common/module.mk (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/op/attr_value_util.cc (100%) rename {ge => src/ge}/common/op/ge_op_utils.cc (100%) rename {ge => src/ge}/common/profiling/profiling_manager.cc (100%) rename {ge => src/ge}/common/profiling/profiling_manager.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/properties_manager.cc (100%) rename {ge => src/ge}/common/properties_manager.h (100%) rename {ge => src/ge}/common/singleton.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/tbe_kernel_store.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/tbe_kernel_store.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/thread_pool.cc (100%) rename {ge => src/ge}/common/thread_pool.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/types.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/common/util.cc (100%) rename {ge => src/ge}/engine_manager/dnnengine_manager.cc (100%) rename {ge => src/ge}/engine_manager/dnnengine_manager.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/engine_manager/engine_conf.json (100%) create mode 100755 src/ge/executor/CMakeLists.txt rename {ge => src/ge}/executor/ge_executor.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/executor/module.mk (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_inference.mk (100%) mode change 100755 => 100644 create mode 100755 src/ge/ge_local_engine/CMakeLists.txt rename {ge => src/ge}/ge_local_engine/common/constant/constant.h (100%) rename {ge => src/ge}/ge_local_engine/engine/ge_local_engine.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_local_engine/engine/ge_local_engine.h (100%) rename {ge => src/ge}/ge_local_engine/engine/host_cpu_engine.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_local_engine/engine/host_cpu_engine.h (100%) rename {ge => src/ge}/ge_local_engine/module.mk (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_local_engine/ops_kernel_store/op/ge_deleted_op.h (100%) rename {ge => src/ge}/ge_local_engine/ops_kernel_store/op/no_op.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_local_engine/ops_kernel_store/op/no_op.h (100%) rename {ge => src/ge}/ge_local_engine/ops_kernel_store/op/op.cc (100%) rename {ge => src/ge}/ge_local_engine/ops_kernel_store/op/op.h (100%) rename {ge => src/ge}/ge_local_engine/ops_kernel_store/op/op_factory.cc (100%) rename {ge => src/ge}/ge_local_engine/ops_kernel_store/op/op_factory.h (100%) rename {ge => src/ge}/ge_runner.mk (100%) create mode 100755 src/ge/ge_runtime/CMakeLists.txt rename {ge => src/ge}/ge_runtime/model_context.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/model_runner.cc (100%) rename {ge => src/ge}/ge_runtime/output.cc (100%) rename {ge => src/ge}/ge_runtime/output.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/runtime_model.cc (100%) rename {ge => src/ge}/ge_runtime/runtime_model.h (100%) rename {ge => src/ge}/ge_runtime/task/aicpu_task.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/task/aicpu_task.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/task/cce_task.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/task/cce_task.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/task/event_record_task.cc (100%) rename {ge => src/ge}/ge_runtime/task/event_record_task.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/task/event_wait_task.cc (100%) rename {ge => src/ge}/ge_runtime/task/event_wait_task.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/task/hccl_task.cc (100%) rename {ge => src/ge}/ge_runtime/task/hccl_task.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/task/label_goto_task.cc (100%) rename {ge => src/ge}/ge_runtime/task/label_goto_task.h (100%) rename {ge => src/ge}/ge_runtime/task/label_set_task.cc (100%) rename {ge => src/ge}/ge_runtime/task/label_set_task.h (100%) rename {ge => src/ge}/ge_runtime/task/label_switch_task.cc (100%) rename {ge => src/ge}/ge_runtime/task/label_switch_task.h (100%) rename {ge => src/ge}/ge_runtime/task/memcpy_async_task.cc (100%) rename {ge => src/ge}/ge_runtime/task/memcpy_async_task.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/task/profiler_task.cc (100%) rename {ge => src/ge}/ge_runtime/task/profiler_task.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/task/stream_active_task.cc (100%) rename {ge => src/ge}/ge_runtime/task/stream_active_task.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/task/stream_switch_task.cc (100%) rename {ge => src/ge}/ge_runtime/task/stream_switch_task.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/task/task.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/task/task_factory.h (100%) rename {ge => src/ge}/ge_runtime/task/tbe_task.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/ge_runtime/task/tbe_task.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/generator/ge_generator.cc (100%) rename {ge => src/ge}/generator/generator_api.cc (100%) rename {ge => src/ge}/graph/build/graph_builder.cc (100%) rename {ge => src/ge}/graph/build/graph_builder.h (100%) rename {ge => src/ge}/graph/build/label_allocator.cc (100%) rename {ge => src/ge}/graph/build/label_allocator.h (100%) rename {ge => src/ge}/graph/build/logical_stream_allocator.cc (100%) rename {ge => src/ge}/graph/build/logical_stream_allocator.h (100%) create mode 100644 src/ge/graph/build/memory/CMakeLists.txt rename {ge => src/ge}/graph/build/memory/binary_block_mem_assigner.cc (100%) rename {ge => src/ge}/graph/build/memory/binary_block_mem_assigner.h (100%) rename {ge => src/ge}/graph/build/memory/block_mem_assigner.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/build/memory/block_mem_assigner.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/build/memory/graph_mem_assigner.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/build/memory/graph_mem_assigner.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/build/memory/hybrid_mem_assigner.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/build/memory/hybrid_mem_assigner.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/build/memory/max_block_mem_assigner.cc (100%) rename {ge => src/ge}/graph/build/memory/max_block_mem_assigner.h (100%) rename {ge => src/ge}/graph/build/memory/mem_assigner.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/build/memory/memory_assigner.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/build/memory/module.mk (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/build/memory/var_mem_assign_util.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/build/memory/var_mem_assign_util.h (100%) rename {ge => src/ge}/graph/build/model_builder.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/build/model_builder.h (100%) rename {ge => src/ge}/graph/build/run_context.cc (100%) rename {ge => src/ge}/graph/build/run_context.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/build/stream_allocator.cc (100%) rename {ge => src/ge}/graph/build/stream_allocator.h (100%) rename {ge => src/ge}/graph/build/stream_graph_optimizer.cc (100%) rename {ge => src/ge}/graph/build/stream_graph_optimizer.h (100%) rename {ge => src/ge}/graph/build/task_generator.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/build/task_generator.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/common/bcast.cc (100%) rename {ge => src/ge}/graph/common/bcast.h (100%) rename {ge => src/ge}/graph/common/ge_call_wrapper.h (100%) rename {ge => src/ge}/graph/common/local_context.cc (100%) rename {ge => src/ge}/graph/common/local_context.h (100%) rename {ge => src/ge}/graph/common/omg_util.cc (100%) rename {ge => src/ge}/graph/common/omg_util.h (100%) rename {ge => src/ge}/graph/common/transop_util.cc (100%) rename {ge => src/ge}/graph/common/transop_util.h (100%) rename {ge => src/ge}/graph/execute/graph_execute.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/execute/graph_execute.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/label/case_label_maker.cc (100%) rename {ge => src/ge}/graph/label/case_label_maker.h (100%) rename {ge => src/ge}/graph/label/if_label_maker.cc (100%) rename {ge => src/ge}/graph/label/if_label_maker.h (100%) rename {ge => src/ge}/graph/label/label_maker.cc (100%) rename {ge => src/ge}/graph/label/label_maker.h (100%) rename {ge => src/ge}/graph/label/label_maker_factory.h (100%) rename {ge => src/ge}/graph/label/partitioned_call_label_maker.cc (100%) rename {ge => src/ge}/graph/label/partitioned_call_label_maker.h (100%) rename {ge => src/ge}/graph/label/while_label_maker.cc (100%) rename {ge => src/ge}/graph/label/while_label_maker.h (100%) rename {ge => src/ge}/graph/load/graph_loader.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/graph_loader.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/aipp_utils.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/aipp_utils.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/cpu_queue_schedule.cc (100%) rename {ge => src/ge}/graph/load/new_model_manager/cpu_queue_schedule.h (100%) rename {ge => src/ge}/graph/load/new_model_manager/data_dumper.cc (100%) rename {ge => src/ge}/graph/load/new_model_manager/data_dumper.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/data_inputer.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/data_inputer.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/davinci_model.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/davinci_model.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/davinci_model_parser.cc (100%) rename {ge => src/ge}/graph/load/new_model_manager/davinci_model_parser.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/model_manager.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/model_manager.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/model_utils.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/model_utils.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/end_graph_task_info.cc (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/end_graph_task_info.h (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/event_record_task_info.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/event_record_task_info.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/event_wait_task_info.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/event_wait_task_info.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/fusion_start_task_info.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/fusion_start_task_info.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/fusion_stop_task_info.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/hccl_task_info.cc (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/hccl_task_info.h (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/kernel_ex_task_info.h (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/kernel_task_info.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/kernel_task_info.h (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/label_set_task_info.cc (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/label_set_task_info.h (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/memcpy_async_task_info.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/profiler_trace_task_info.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/stream_active_task_info.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/stream_active_task_info.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/stream_switch_task_info.cc (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/stream_switch_task_info.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/stream_switchn_task_info.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/task_info.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/task_info/task_info.h (100%) rename {ge => src/ge}/graph/load/new_model_manager/task_info/task_info_factory.h (100%) rename {ge => src/ge}/graph/load/new_model_manager/tbe_handle_store.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/tbe_handle_store.h (100%) rename {ge => src/ge}/graph/load/new_model_manager/zero_copy_offset.cc (100%) rename {ge => src/ge}/graph/load/new_model_manager/zero_copy_offset.h (100%) rename {ge => src/ge}/graph/load/new_model_manager/zero_copy_task.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/load/new_model_manager/zero_copy_task.h (100%) rename {ge => src/ge}/graph/manager/block_memory.h (100%) rename {ge => src/ge}/graph/manager/graph_caching_allocator.cc (100%) rename {ge => src/ge}/graph/manager/graph_caching_allocator.h (100%) rename {ge => src/ge}/graph/manager/graph_context.cc (100%) rename {ge => src/ge}/graph/manager/graph_context.h (100%) rename {ge => src/ge}/graph/manager/graph_manager.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/manager/graph_manager.h (100%) rename {ge => src/ge}/graph/manager/graph_manager_utils.cc (100%) rename {ge => src/ge}/graph/manager/graph_manager_utils.h (100%) rename {ge => src/ge}/graph/manager/graph_mem_allocator.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/manager/graph_mem_allocator.h (100%) rename {ge => src/ge}/graph/manager/graph_var_manager.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/manager/graph_var_manager.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/manager/host_mem_manager.cc (100%) rename {ge => src/ge}/graph/manager/host_mem_manager.h (100%) rename {ge => src/ge}/graph/manager/memory_api.cc (100%) rename {ge => src/ge}/graph/manager/model_manager/event_manager.cc (100%) rename {ge => src/ge}/graph/manager/model_manager/event_manager.h (100%) rename {ge => src/ge}/graph/manager/rdma_pool_allocator.cc (100%) rename {ge => src/ge}/graph/manager/rdma_pool_allocator.h (100%) rename {ge => src/ge}/graph/manager/trans_var_data_utils.cc (100%) rename {ge => src/ge}/graph/manager/trans_var_data_utils.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/manager/util/debug.cc (100%) rename {ge => src/ge}/graph/manager/util/debug.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/manager/util/hcom_util.cc (100%) rename {ge => src/ge}/graph/manager/util/hcom_util.h (100%) rename {ge => src/ge}/graph/manager/util/rt_context_util.cc (100%) rename {ge => src/ge}/graph/manager/util/rt_context_util.h (100%) rename {ge => src/ge}/graph/manager/util/variable_accelerate_ctrl.cc (100%) rename {ge => src/ge}/graph/manager/util/variable_accelerate_ctrl.h (100%) rename {ge => src/ge}/graph/optimize/common/params.h (100%) rename {ge => src/ge}/graph/optimize/graph_optimize.cc (100%) rename {ge => src/ge}/graph/optimize/graph_optimize.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/optimize/mem_rw_conflict_optimize.cc (100%) rename {ge => src/ge}/graph/optimize/optimizer/allreduce_fusion_pass.cc (100%) rename {ge => src/ge}/graph/optimize/optimizer/allreduce_fusion_pass.h (100%) rename {ge => src/ge}/graph/optimize/summary_optimize.cc (100%) rename {ge => src/ge}/graph/partition/dynamic_shape_partition.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/partition/dynamic_shape_partition.h (100%) rename {ge => src/ge}/graph/partition/engine_place.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/partition/engine_place.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/partition/graph_partition.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/partition/graph_partition.h (100%) rename {ge => src/ge}/graph/passes/addn_pass.cc (100%) rename {ge => src/ge}/graph/passes/addn_pass.h (100%) rename {ge => src/ge}/graph/passes/aicpu_constant_folding_pass.cc (100%) rename {ge => src/ge}/graph/passes/aicpu_constant_folding_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/assert_pass.cc (100%) rename {ge => src/ge}/graph/passes/assert_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/assign_pass.cc (100%) rename {ge => src/ge}/graph/passes/assign_pass.h (100%) rename {ge => src/ge}/graph/passes/atomic_addr_clean_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/atomic_addr_clean_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/attach_stream_label_pass.cc (100%) rename {ge => src/ge}/graph/passes/attach_stream_label_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/base_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/base_pass.h (100%) rename {ge => src/ge}/graph/passes/bitcast_pass.cc (100%) rename {ge => src/ge}/graph/passes/bitcast_pass.h (100%) rename {ge => src/ge}/graph/passes/cast_remove_pass.cc (100%) rename {ge => src/ge}/graph/passes/cast_remove_pass.h (100%) rename {ge => src/ge}/graph/passes/cast_translate_pass.cc (100%) rename {ge => src/ge}/graph/passes/cast_translate_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/common_subexpression_elimination_pass.cc (100%) rename {ge => src/ge}/graph/passes/common_subexpression_elimination_pass.h (100%) rename {ge => src/ge}/graph/passes/compile_nodes_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/compile_nodes_pass.h (100%) rename {ge => src/ge}/graph/passes/cond_pass.cc (100%) rename {ge => src/ge}/graph/passes/cond_pass.h (100%) rename {ge => src/ge}/graph/passes/cond_remove_pass.cc (100%) rename {ge => src/ge}/graph/passes/cond_remove_pass.h (100%) rename {ge => src/ge}/graph/passes/constant_folding_pass.cc (100%) rename {ge => src/ge}/graph/passes/constant_folding_pass.h (100%) rename {ge => src/ge}/graph/passes/constant_fuse_same_pass.cc (100%) rename {ge => src/ge}/graph/passes/constant_fuse_same_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/control_trigger_pass.cc (100%) rename {ge => src/ge}/graph/passes/control_trigger_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/ctrl_edge_transfer_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/ctrl_edge_transfer_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/data_pass.cc (100%) rename {ge => src/ge}/graph/passes/data_pass.h (100%) rename {ge => src/ge}/graph/passes/dimension_adjust_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/dimension_adjust_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/dimension_compute_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/dimension_compute_pass.h (100%) rename {ge => src/ge}/graph/passes/dropout_pass.cc (100%) rename {ge => src/ge}/graph/passes/dropout_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/end_of_sequence_add_control_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/end_of_sequence_add_control_pass.h (100%) rename {ge => src/ge}/graph/passes/enter_pass.cc (100%) rename {ge => src/ge}/graph/passes/enter_pass.h (100%) rename {ge => src/ge}/graph/passes/flow_ctrl_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/flow_ctrl_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/folding_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/folding_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/for_pass.cc (100%) rename {ge => src/ge}/graph/passes/for_pass.h (100%) rename {ge => src/ge}/graph/passes/get_original_format_pass.cc (100%) rename {ge => src/ge}/graph/passes/get_original_format_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/global_step_insert_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/global_step_insert_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/guarantee_const_pass.cc (100%) rename {ge => src/ge}/graph/passes/guarantee_const_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/hccl_group_pass.cc (100%) rename {ge => src/ge}/graph/passes/hccl_group_pass.h (100%) rename {ge => src/ge}/graph/passes/hccl_memcpy_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/hccl_memcpy_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/identity_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/identity_pass.h (100%) rename {ge => src/ge}/graph/passes/infershape_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/infershape_pass.h (100%) rename {ge => src/ge}/graph/passes/input_output_connection_identify_pass.cc (100%) rename {ge => src/ge}/graph/passes/input_output_connection_identify_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/isolated_op_remove_pass.cc (100%) rename {ge => src/ge}/graph/passes/isolated_op_remove_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/iterator_op_pass.cc (100%) rename {ge => src/ge}/graph/passes/iterator_op_pass.h (100%) rename {ge => src/ge}/graph/passes/link_gen_mask_nodes_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/link_gen_mask_nodes_pass.h (100%) rename {ge => src/ge}/graph/passes/mark_agnostic_pass.cc (100%) rename {ge => src/ge}/graph/passes/mark_agnostic_pass.h (100%) rename {ge => src/ge}/graph/passes/mark_graph_unknown_status_pass.cc (100%) rename {ge => src/ge}/graph/passes/mark_graph_unknown_status_pass.h (100%) rename {ge => src/ge}/graph/passes/mark_same_addr_pass.cc (100%) rename {ge => src/ge}/graph/passes/mark_same_addr_pass.h (100%) rename {ge => src/ge}/graph/passes/memcpy_addr_async_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/memcpy_addr_async_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/merge_pass.cc (100%) rename {ge => src/ge}/graph/passes/merge_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/merge_to_stream_merge_pass.cc (100%) rename {ge => src/ge}/graph/passes/merge_to_stream_merge_pass.h (100%) rename {ge => src/ge}/graph/passes/multi_batch_clone_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/multi_batch_clone_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/multi_batch_pass.cc (100%) rename {ge => src/ge}/graph/passes/multi_batch_pass.h (100%) rename {ge => src/ge}/graph/passes/net_output_pass.cc (100%) rename {ge => src/ge}/graph/passes/net_output_pass.h (100%) rename {ge => src/ge}/graph/passes/next_iteration_pass.cc (100%) rename {ge => src/ge}/graph/passes/next_iteration_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/no_use_reshape_remove_pass.cc (100%) rename {ge => src/ge}/graph/passes/no_use_reshape_remove_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/parallel_concat_start_op_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/parallel_concat_start_op_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/pass_manager.cc (100%) rename {ge => src/ge}/graph/passes/pass_utils.cc (100%) rename {ge => src/ge}/graph/passes/pass_utils.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/permute_pass.cc (100%) rename {ge => src/ge}/graph/passes/permute_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/placeholder_with_default_pass.cc (100%) rename {ge => src/ge}/graph/passes/placeholder_with_default_pass.h (100%) rename {ge => src/ge}/graph/passes/prevent_gradient_pass.cc (100%) rename {ge => src/ge}/graph/passes/prevent_gradient_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/print_op_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/print_op_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/prune_pass.cc (100%) rename {ge => src/ge}/graph/passes/prune_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/ref_identity_delete_op_pass.cc (100%) rename {ge => src/ge}/graph/passes/ref_identity_delete_op_pass.h (100%) rename {ge => src/ge}/graph/passes/remove_nodes_pass.cc (100%) rename {ge => src/ge}/graph/passes/remove_nodes_pass.h (100%) rename {ge => src/ge}/graph/passes/replace_transshape_pass.cc (100%) rename {ge => src/ge}/graph/passes/replace_transshape_pass.h (100%) rename {ge => src/ge}/graph/passes/replace_with_empty_const_pass.cc (100%) rename {ge => src/ge}/graph/passes/replace_with_empty_const_pass.h (100%) rename {ge => src/ge}/graph/passes/reshape_recovery_pass.cc (100%) rename {ge => src/ge}/graph/passes/reshape_recovery_pass.h (100%) rename {ge => src/ge}/graph/passes/reshape_remove_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/reshape_remove_pass.h (100%) rename {ge => src/ge}/graph/passes/resource_pair_add_control_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/resource_pair_add_control_pass.h (100%) rename {ge => src/ge}/graph/passes/resource_pair_remove_control_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/resource_pair_remove_control_pass.h (100%) rename {ge => src/ge}/graph/passes/same_transdata_breadth_fusion_pass.cc (100%) rename {ge => src/ge}/graph/passes/same_transdata_breadth_fusion_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/save_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/save_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/set_input_output_offset_pass.cc (100%) rename {ge => src/ge}/graph/passes/set_input_output_offset_pass.h (100%) rename {ge => src/ge}/graph/passes/shape_operate_op_remove_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/shape_operate_op_remove_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/snapshot_pass.cc (100%) rename {ge => src/ge}/graph/passes/snapshot_pass.h (100%) rename {ge => src/ge}/graph/passes/stop_gradient_pass.cc (100%) rename {ge => src/ge}/graph/passes/stop_gradient_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/subexpression_migration_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/subexpression_migration_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/subgraph_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/subgraph_pass.h (100%) rename {ge => src/ge}/graph/passes/switch_data_edges_bypass.cc (100%) rename {ge => src/ge}/graph/passes/switch_data_edges_bypass.h (100%) rename {ge => src/ge}/graph/passes/switch_dead_branch_elimination.cc (100%) rename {ge => src/ge}/graph/passes/switch_dead_branch_elimination.h (100%) rename {ge => src/ge}/graph/passes/switch_logic_remove_pass.cc (100%) rename {ge => src/ge}/graph/passes/switch_logic_remove_pass.h (100%) rename {ge => src/ge}/graph/passes/switch_to_stream_switch_pass.cc (100%) rename {ge => src/ge}/graph/passes/switch_to_stream_switch_pass.h (100%) rename {ge => src/ge}/graph/passes/transop_breadth_fusion_pass.cc (100%) rename {ge => src/ge}/graph/passes/transop_breadth_fusion_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/transop_depth_fusion_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/transop_depth_fusion_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/transop_nearby_allreduce_fusion_pass.cc (100%) rename {ge => src/ge}/graph/passes/transop_nearby_allreduce_fusion_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/transop_symmetry_elimination_pass.cc (100%) rename {ge => src/ge}/graph/passes/transop_symmetry_elimination_pass.h (100%) rename {ge => src/ge}/graph/passes/transop_without_reshape_fusion_pass.cc (100%) rename {ge => src/ge}/graph/passes/transop_without_reshape_fusion_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/transpose_transdata_pass.cc (100%) rename {ge => src/ge}/graph/passes/transpose_transdata_pass.h (100%) rename {ge => src/ge}/graph/passes/unused_args_clean_pass.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/unused_args_clean_pass.h (100%) rename {ge => src/ge}/graph/passes/unused_const_pass.cc (100%) rename {ge => src/ge}/graph/passes/unused_const_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/unused_op_remove_pass.cc (100%) rename {ge => src/ge}/graph/passes/unused_op_remove_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/var_is_initialized_op_pass.cc (100%) rename {ge => src/ge}/graph/passes/var_is_initialized_op_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/variable_format_pass.cc (100%) rename {ge => src/ge}/graph/passes/variable_format_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/variable_op_pass.cc (100%) rename {ge => src/ge}/graph/passes/variable_op_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/variable_prepare_op_pass.cc (100%) rename {ge => src/ge}/graph/passes/variable_prepare_op_pass.h (100%) rename {ge => src/ge}/graph/passes/variable_ref_delete_op_pass.cc (100%) rename {ge => src/ge}/graph/passes/variable_ref_delete_op_pass.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/passes/variable_ref_useless_control_out_delete_pass.cc (100%) rename {ge => src/ge}/graph/passes/variable_ref_useless_control_out_delete_pass.h (100%) rename {ge => src/ge}/graph/preprocess/graph_preprocess.cc (100%) rename {ge => src/ge}/graph/preprocess/graph_preprocess.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/preprocess/insert_op/base_insert_op.h (100%) rename {ge => src/ge}/graph/preprocess/insert_op/ge_aipp_op.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/preprocess/insert_op/ge_aipp_op.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/preprocess/insert_op/util_insert_aipp_op.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/graph/preprocess/insert_op/util_insert_aipp_op.h (100%) rename {ge => src/ge}/graph/preprocess/multi_batch_copy_graph.cc (100%) rename {ge => src/ge}/graph/preprocess/multi_batch_copy_graph.h (100%) rename {ge => src/ge}/graph/preprocess/multi_batch_options.cc (100%) rename {ge => src/ge}/graph/preprocess/multi_batch_options.h (100%) rename {ge => src/ge}/host_cpu_engine/common/constant/constant.h (100%) rename {ge => src/ge}/host_cpu_engine/engine/host_cpu_engine.cc (100%) rename {ge => src/ge}/host_cpu_engine/engine/host_cpu_engine.h (100%) rename {ge => src/ge}/host_cpu_engine/module.mk (100%) rename {ge => src/ge}/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc (100%) rename {ge => src/ge}/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h (100%) rename {ge => src/ge}/host_cpu_engine/ops_kernel_store/op/host_op.cc (100%) rename {ge => src/ge}/host_cpu_engine/ops_kernel_store/op/host_op.h (100%) rename {ge => src/ge}/host_cpu_engine/ops_kernel_store/op/op.h (100%) rename {ge => src/ge}/host_cpu_engine/ops_kernel_store/op/op_factory.cc (100%) rename {ge => src/ge}/host_cpu_engine/ops_kernel_store/op/op_factory.h (100%) create mode 120000 src/ge/host_cpu_engine/proto/task.proto rename {ge => src/ge}/host_kernels/add_kernel.cc (100%) rename {ge => src/ge}/host_kernels/add_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/broadcast_args_kernel.cc (100%) rename {ge => src/ge}/host_kernels/broadcast_args_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/broadcast_gradient_args_kernel.cc (100%) rename {ge => src/ge}/host_kernels/broadcast_gradient_args_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/cast_kernel.cc (100%) rename {ge => src/ge}/host_kernels/cast_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/concat_offset_kernel.cc (100%) rename {ge => src/ge}/host_kernels/concat_offset_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/concat_v2_kernel.cc (100%) rename {ge => src/ge}/host_kernels/concat_v2_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/dynamic_stitch_kernel.cc (100%) rename {ge => src/ge}/host_kernels/dynamic_stitch_kernel.h (100%) rename {ge => src/ge}/host_kernels/empty_kernel.cc (100%) rename {ge => src/ge}/host_kernels/empty_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/expanddims_kernel.cc (100%) rename {ge => src/ge}/host_kernels/expanddims_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/fill_kernel.cc (100%) rename {ge => src/ge}/host_kernels/fill_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/floordiv_kernel.cc (100%) rename {ge => src/ge}/host_kernels/floordiv_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/floormod_kernel.cc (100%) rename {ge => src/ge}/host_kernels/floormod_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/gather_v2_kernel.cc (100%) rename {ge => src/ge}/host_kernels/gather_v2_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/greater_kernel.cc (100%) rename {ge => src/ge}/host_kernels/greater_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/identity_kernel.cc (100%) rename {ge => src/ge}/host_kernels/identity_kernel.h (100%) rename {ge => src/ge}/host_kernels/kernel_utils.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/kernel_utils.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/maximum_kernel.cc (100%) rename {ge => src/ge}/host_kernels/maximum_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/mul_kernel.cc (100%) rename {ge => src/ge}/host_kernels/mul_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/pack_kernel.cc (100%) rename {ge => src/ge}/host_kernels/pack_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/permute_kernel.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/permute_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/range_kernel.cc (100%) rename {ge => src/ge}/host_kernels/range_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/rank_kernel.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/rank_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/reduce_prod_kernel.cc (100%) rename {ge => src/ge}/host_kernels/reduce_prod_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/reformat_kernel.cc (100%) rename {ge => src/ge}/host_kernels/reformat_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/reshape_kernel.cc (100%) rename {ge => src/ge}/host_kernels/reshape_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/rsqrt_kernel.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/rsqrt_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/shape_kernel.cc (100%) rename {ge => src/ge}/host_kernels/shape_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/shape_n_kernel.cc (100%) rename {ge => src/ge}/host_kernels/shape_n_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/size_kernel.cc (100%) rename {ge => src/ge}/host_kernels/size_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/slice_d_kernel.cc (100%) rename {ge => src/ge}/host_kernels/slice_d_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/slice_kernel.cc (100%) rename {ge => src/ge}/host_kernels/slice_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/squeeze_kernel.cc (100%) rename {ge => src/ge}/host_kernels/squeeze_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/ssd_prior_box_kernel.cc (100%) rename {ge => src/ge}/host_kernels/ssd_prior_box_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/strided_slice_kernel.cc (100%) rename {ge => src/ge}/host_kernels/strided_slice_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/sub_kernel.cc (100%) rename {ge => src/ge}/host_kernels/sub_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/transdata_kernel.cc (100%) rename {ge => src/ge}/host_kernels/transdata_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/transpose_kernel.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/transpose_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/unpack_kernel.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/unpack_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/host_kernels/unsqueeze_kernel.cc (100%) rename {ge => src/ge}/host_kernels/unsqueeze_kernel.h (100%) rename {ge => src/ge}/hybrid/common/npu_memory_allocator.cc (100%) rename {ge => src/ge}/hybrid/common/npu_memory_allocator.h (100%) rename {ge => src/ge}/hybrid/common/tensor_value.cc (100%) rename {ge => src/ge}/hybrid/common/tensor_value.h (100%) rename {ge => src/ge}/hybrid/executor/hybrid_execution_context.cc (100%) rename {ge => src/ge}/hybrid/executor/hybrid_execution_context.h (100%) rename {ge => src/ge}/hybrid/executor/hybrid_model_async_executor.cc (100%) rename {ge => src/ge}/hybrid/executor/hybrid_model_async_executor.h (100%) rename {ge => src/ge}/hybrid/executor/hybrid_model_executor.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/executor/hybrid_model_executor.h (100%) rename {ge => src/ge}/hybrid/executor/hybrid_profiler.cc (100%) rename {ge => src/ge}/hybrid/executor/hybrid_profiler.h (100%) rename {ge => src/ge}/hybrid/executor/node_done_manager.cc (100%) rename {ge => src/ge}/hybrid/executor/node_done_manager.h (100%) rename {ge => src/ge}/hybrid/executor/node_state.cc (100%) rename {ge => src/ge}/hybrid/executor/node_state.h (100%) rename {ge => src/ge}/hybrid/executor/rt_callback_manager.cc (100%) rename {ge => src/ge}/hybrid/executor/rt_callback_manager.h (100%) rename {ge => src/ge}/hybrid/executor/subgraph_context.cc (100%) rename {ge => src/ge}/hybrid/executor/subgraph_context.h (100%) rename {ge => src/ge}/hybrid/executor/subgraph_executor.cc (100%) rename {ge => src/ge}/hybrid/executor/subgraph_executor.h (100%) rename {ge => src/ge}/hybrid/executor/worker/execution_engine.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/executor/worker/execution_engine.h (100%) rename {ge => src/ge}/hybrid/executor/worker/shape_inference_engine.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/executor/worker/shape_inference_engine.h (100%) rename {ge => src/ge}/hybrid/executor/worker/task_compile_engine.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/executor/worker/task_compile_engine.h (100%) rename {ge => src/ge}/hybrid/hybrid_davinci_model.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/hybrid_davinci_model.h (100%) rename {ge => src/ge}/hybrid/hybrid_davinci_model_stub.cc (100%) rename {ge => src/ge}/hybrid/model/graph_item.cc (100%) rename {ge => src/ge}/hybrid/model/graph_item.h (100%) rename {ge => src/ge}/hybrid/model/hybrid_model.cc (100%) rename {ge => src/ge}/hybrid/model/hybrid_model.h (100%) rename {ge => src/ge}/hybrid/model/hybrid_model_builder.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/model/hybrid_model_builder.h (100%) rename {ge => src/ge}/hybrid/model/node_item.cc (100%) rename {ge => src/ge}/hybrid/model/node_item.h (100%) rename {ge => src/ge}/hybrid/node_executor/aicore/aicore_node_executor.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/aicore/aicore_node_executor.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/aicore/aicore_op_task.cc (100%) rename {ge => src/ge}/hybrid/node_executor/aicore/aicore_op_task.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/aicore/aicore_task_builder.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/aicore/aicore_task_builder.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/aicore/aicore_task_compiler.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/aicore/aicore_task_compiler.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/aicpu/aicpu_ext_info.cc (100%) rename {ge => src/ge}/hybrid/node_executor/aicpu/aicpu_ext_info.h (100%) rename {ge => src/ge}/hybrid/node_executor/aicpu/aicpu_node_executor.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/aicpu/aicpu_node_executor.h (100%) rename {ge => src/ge}/hybrid/node_executor/compiledsubgraph/known_node_executor.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/compiledsubgraph/known_node_executor.h (100%) rename {ge => src/ge}/hybrid/node_executor/controlop/control_op_executor.cc (100%) rename {ge => src/ge}/hybrid/node_executor/controlop/control_op_executor.h (100%) rename {ge => src/ge}/hybrid/node_executor/ge_local/ge_local_node_executor.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/ge_local/ge_local_node_executor.h (100%) rename {ge => src/ge}/hybrid/node_executor/hccl/hccl_node_executor.cc (100%) rename {ge => src/ge}/hybrid/node_executor/hccl/hccl_node_executor.h (100%) rename {ge => src/ge}/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/host_cpu/host_cpu_node_executor.h (100%) rename {ge => src/ge}/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc (100%) rename {ge => src/ge}/hybrid/node_executor/host_cpu/kernel/assign_kernel.h (100%) rename {ge => src/ge}/hybrid/node_executor/host_cpu/kernel/kernel.h (100%) rename {ge => src/ge}/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc (100%) rename {ge => src/ge}/hybrid/node_executor/host_cpu/kernel/no_op_kernel.h (100%) rename {ge => src/ge}/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc (100%) rename {ge => src/ge}/hybrid/node_executor/host_cpu/kernel/variable_kernel.h (100%) rename {ge => src/ge}/hybrid/node_executor/host_cpu/kernel_factory.cc (100%) rename {ge => src/ge}/hybrid/node_executor/host_cpu/kernel_factory.h (100%) rename {ge => src/ge}/hybrid/node_executor/node_executor.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/node_executor.h (100%) rename {ge => src/ge}/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h (100%) rename {ge => src/ge}/hybrid/node_executor/rts/rts_node_executor.cc (100%) rename {ge => src/ge}/hybrid/node_executor/rts/rts_node_executor.h (100%) rename {ge => src/ge}/hybrid/node_executor/task_context.cc (100%) rename {ge => src/ge}/hybrid/node_executor/task_context.h (100%) rename {ge => src/ge}/inc/graph_pass.h (100%) rename {ge => src/ge}/inc/kernel.h (100%) rename {ge => src/ge}/inc/kernel_factory.h (100%) rename {ge => src/ge}/inc/pass.h (100%) rename {ge => src/ge}/inc/pass_manager.h (100%) rename {ge => src/ge}/init/gelib.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/init/gelib.h (100%) rename {ge => src/ge}/ir_build/atc_ir_common.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/ir_build/atc_ir_common.h (100%) rename {ge => src/ge}/ir_build/ge_ir_build.cc (100%) rename {ge => src/ge}/model/ge_model.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/model/ge_model.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/model/ge_root_model.cc (100%) rename {ge => src/ge}/model/ge_root_model.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/module.mk (100%) mode change 100755 => 100644 rename {ge => src/ge}/omm/csa_interact.cc (100%) rename {ge => src/ge}/omm/csa_interact.h (100%) rename {ge => src/ge}/opskernel_manager/ops_kernel_manager.cc (100%) rename {ge => src/ge}/opskernel_manager/ops_kernel_manager.h (100%) rename {ge => src/ge}/opskernel_manager/optimizer_priority.pbtxt (100%) create mode 100644 src/ge/plugin/engine/CMakeLists.txt rename {ge => src/ge}/plugin/engine/dnnengines.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/plugin/engine/dnnengines.h (100%) rename {ge => src/ge}/plugin/engine/engine_manage.cc (100%) rename {ge => src/ge}/plugin/engine/engine_manage.h (100%) rename {ge => src/ge}/plugin/engine/module.mk (100%) mode change 100755 => 100644 rename {ge => src/ge}/session/inner_session.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/session/inner_session.h (100%) rename {ge => src/ge}/session/omg.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/session/session_manager.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/session/session_manager.h (100%) rename {ge => src/ge}/single_op/single_op.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/single_op/single_op.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/single_op/single_op_manager.cc (100%) rename {ge => src/ge}/single_op/single_op_manager.h (100%) rename {ge => src/ge}/single_op/single_op_model.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/single_op/single_op_model.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/single_op/stream_resource.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/single_op/stream_resource.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/single_op/task/aicpu_kernel_task_builder.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/single_op/task/aicpu_kernel_task_builder.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/single_op/task/aicpu_task_builder.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/single_op/task/aicpu_task_builder.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/single_op/task/build_task_utils.cc (100%) rename {ge => src/ge}/single_op/task/build_task_utils.h (100%) rename {ge => src/ge}/single_op/task/op_task.cc (100%) mode change 100755 => 100644 rename {ge => src/ge}/single_op/task/op_task.h (100%) rename {ge => src/ge}/single_op/task/tbe_task_builder.cc (100%) rename {ge => src/ge}/single_op/task/tbe_task_builder.h (100%) mode change 100755 => 100644 rename {ge => src/ge}/stub/Makefile (100%) rename {ge => src/ge}/stub/README (100%) rename {ge => src/ge}/stub/README.md (100%) rename {ge => src/ge}/stub/gen_stubapi.py (100%) rename {ge/executor => src}/proto/dump_task.proto (100%) rename {ge => src}/proto/fusion_model.proto (100%) mode change 100755 => 100644 rename {ge => src}/proto/fwk_adapter.proto (100%) rename {ge/client => src}/proto/ge_api.proto (100%) rename {ge/client => src}/proto/ge_ir.proto (100%) rename {ge/client => src}/proto/insert_op.proto (100%) rename {ge/client => src}/proto/om.proto (100%) mode change 100755 => 100644 rename {ge/common => src}/proto/op_mapping_info.proto (100%) rename {ge => src}/proto/optimizer_priority.proto (100%) rename {ge/client => src}/proto/task.proto (100%) delete mode 100644 third_party/patch/securec/0001-add-securec-cmake-script.patch diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index c6b7cc8e..00000000 --- a/.gitmodules +++ /dev/null @@ -1,8 +0,0 @@ -[submodule "parser"] - path = parser - url = https://gitee.com/ascend/parser.git - branch = master -[submodule "metadef"] - path = metadef - url = https://gitee.com/ascend/metadef.git - branch = master diff --git a/CMakeLists.txt b/CMakeLists.txt index 9a9a7a9d..457fa086 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,133 +1,135 @@ +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + cmake_minimum_required(VERSION 3.14) project (GraphEngine[CXX]) +set(CMAKE_CXX_STANDARD 17) +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) -set(GE_CODE_DIR ${CMAKE_CURRENT_LIST_DIR}) -set(CMAKE_SKIP_INSTALL_ALL_DEPENDENCY TRUE) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) +set(GE_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) +set(GE_PROTO_DIR ${GE_SOURCE_DIR}/src) if (NOT BUILD_PATH) set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build") endif() +# architecture: aarch64 or x86_64 +message(STATUS "System architecture: ${CMAKE_HOST_SYSTEM_PROCESSOR}") +# system: euleros or ubuntu +if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + execute_process( + COMMAND bash "-c" "cat /etc/os-release | grep ^ID= | awk -F '=' '{print $2}'" + OUTPUT_VARIABLE SYSTEM_TYPE + ) + MESSAGE(STATUS "System type: ${SYSTEM_TYPE}.") +endif() -option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) - -if (ENABLE_OPEN_SRC) - set(HI_PYTHON python3.7) - - include(cmake/external_libs/protobuf_shared.cmake) - include(cmake/external_libs/protobuf_static.cmake) - include(cmake/external_libs/protoc.cmake) - include(cmake/external_libs/gflags.cmake) - include(cmake/external_libs/securec.cmake) - include(cmake/external_libs/json.cmake) - include(cmake/FindModule.cmake) - include(cmake/intf_pub_linux.cmake) +# download json headers, rather than whole repository +include(${GE_SOURCE_DIR}/cmake/ge_utils.cmake) +include(${GE_SOURCE_DIR}/cmake/external_libs/json.cmake) +include(${GE_SOURCE_DIR}/cmake/external_libs/eigen.cmake) +include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake) +include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) +include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake) +include(${GE_SOURCE_DIR}/cmake/external_libs/securec.cmake) +set(CMAKE_SKIP_RPATH TRUE) - # for CPU/GPU mode, find c_sec and slog from local prebuild - #if(NOT ENABLE_D AND NOT GE_ONLY) - # set(GE_PREBUILD_PATH ${GE_CODE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}) - # find_module(slog libslog.so ${GE_PREBUILD_PATH}) - # if D_LINK_PATH is set in environment variables, search libraries in given path - if(DEFINED ENV{D_LINK_PATH}) - # D_LINK_PATH is set - set(GE_LIB_PATH $ENV{D_LINK_PATH}) - set(GE_SYS_ARCH "") - if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") - # x86 ubuntu - set(GE_SYS_ARCH "x86_64") - elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") - # arm euleros - set(GE_SYS_ARCH "aarch64") - else() - message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") - endif() - set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) - set(STATIC_ACL_LIB ${GE_LIB_PATH}) - find_module(slog libslog.so ${GE_LIB_PATH}) - find_module(mmpa libmmpa.so ${GE_LIB_PATH}) - find_module(msprof libmsprof.so ${GE_LIB_PATH}) - find_module(hccl libhccl.so ${GE_LIB_PATH}) - find_module(adump_server libadump_server.a ${GE_LIB_PATH}) - find_module(runtime libruntime.so ${GE_LIB_PATH}) - find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) - find_module(resource libresource.so ${GE_LIB_PATH}) - find_module(error_manager liberror_manager.so ${GE_LIB_PATH}) - find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) - find_module(error_manager_static liberror_manager.a ${GE_LIB_PATH}) - find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH}) +# for CPU/GPU mode, find c_sec and slog from local prebuild +if(NOT ENABLE_D AND NOT GE_ONLY) + set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}) + find_library(slog libslog.so ${GE_PREBUILD_PATH}) +# if D_LINK_PATH is set in environment variables, search libraries in given path +elseif(DEFINED ENV{D_LINK_PATH}) + # D_LINK_PATH is set + set(GE_LIB_PATH $ENV{D_LINK_PATH}) + set(GE_SYS_ARCH "") + if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") + # x86 ubuntu + set(GE_SYS_ARCH "x86_64") + elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") + # arm euleros + set(GE_SYS_ARCH "aarch64") + else() + message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") + endif() + set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) + find_library(slog libslog.so ${GE_LIB_PATH}) + find_library(mmpa libmmpa.so ${GE_LIB_PATH}) + find_library(runtime libruntime.so ${GE_LIB_PATH}) + find_library(msprof libmsprofiler.a ${GE_LIB_PATH}) + find_library(register libregister.so ${GE_LIB_PATH}) + find_library(hccl libhccl.so ${GE_LIB_PATH}) + find_library(resource libresource.so ${GE_LIB_PATH}) + find_library(error_manager liberror_manager.so ${GE_LIB_PATH}) + find_library(adump_server libadump_server.a ${GE_LIB_PATH}) +else() + # Ascend mode + if(DEFINED ENV{ASCEND_CUSTOM_PATH}) + set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) else() - if(DEFINED ENV{ASCEND_CUSTOM_PATH}) - set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) - else() - set(ASCEND_DIR /usr/local/Ascend) - endif() - set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) - set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) - set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) - set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) - set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) - set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) - set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) - find_module(slog libslog.so ${ASCEND_ATC_DIR}) - find_module(mmpa libmmpa.so ${ASCEND_ATC_DIR}) - if(PLATFORM STREQUAL "train") - find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) - find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) - find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) - find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) - find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) - find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) - find_module(msprofiler libmsprofiler.a ${ASCEND_RUNTIME_DIR}) - find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) - if(PRODUCT STREQUAL "flr3") - message(FATAL_ERROR "This platform is not supported in train mode, build terminated") - endif() - elseif(PLATFORM STREQUAL "inference") - find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) - find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) - find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) - find_module(resource libresource.so ${ASCEND_ATC_DIR}) - find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) - find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) - find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) - if(PRODUCT STREQUAL "flr3") - find_module(msprof libmsprof.so ${ASCEND_DRIVER_SHARE_DIR}) - elseif(PRODUCT STREQUAL "flr1") - find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) - find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) - elseif(PRODUCT STREQUAL "flr2") - # flr2 ascend_hal_stub limsprof ? - else() - find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) - find_module(msprof libmsprof.so ${ASCEND_DRIVER_DIR}) - endif() - elseif(PLATFORM STREQUAL "all") - find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) - find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) - find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) - find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) - find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) - find_module(resource libresource.so ${ASCEND_ATC_DIR}) - find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) - find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) - find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) - find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) - else() - message(FATAL_ERROR "PLATFORM param is invalid, should be train or inference, build terminated") - endif() + set(ASCEND_DIR /usr/local/Ascend) endif() + set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64/common) + set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) + find_library(slog libslog.so ${ASCEND_DRIVER_DIR}) + find_library(mmpa libmmpa.so ${ASCEND_DRIVER_DIR}) + find_library(msprof libmsprofiler.a ${ASCEND_RUNTIME_DIR}) - set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) - set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser) - set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) + find_library(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) + find_library(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) + find_library(register libregister.so ${ASCEND_RUNTIME_DIR}) + find_library(resource libresource.so ${ASCEND_RUNTIME_DIR}) + find_library(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) + find_library(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) +endif() - add_subdirectory(metadef) - add_subdirectory(parser) - #add_subdirectory(metadef/graph) - #add_subdirectory(metadef/register) +# add compile flags +if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") + message("Build in Debug mode") + set(CMAKE_C_FLAGS "-O0 -g -Wall -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -pipe -fPIC ${CMAKE_C_FLAGS}") + set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -pipe -fPIC ${CMAKE_CXX_FLAGS}") + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -rdynamic") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic") + endif() else() - set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) - set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) - set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) + set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -pipe ${CMAKE_C_FLAGS}") + set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -pipe ${CMAKE_CXX_FLAGS}") +endif () + +# force __FILE__ to show relative path of file, from source directory, as cmake project makes __FILE__ absolute directory +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst $(realpath ${CMAKE_SOURCE_DIR})/,,$(abspath $<))\"' -Wno-builtin-macro-redefined") + +# compile libraries from following directories +# libgraph is compiled in any situation +add_subdirectory(${GE_SOURCE_DIR}/src/common/graph) +if(ENABLE_D) + # if MindSpore compiles in D mode, compile the following libraries + add_subdirectory(${GE_SOURCE_DIR}/src/ge/common) + add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_runtime) +elseif(GE_ONLY) + # standalone GraphEngine compiles all following libraries + add_subdirectory(${GE_SOURCE_DIR}/src/ge/common) + add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_runtime) + add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_local_engine) + add_subdirectory(${GE_SOURCE_DIR}/src/ge/graph/build/memory) + add_subdirectory(${GE_SOURCE_DIR}/src/ge/) + add_subdirectory(${GE_SOURCE_DIR}/src/ge/plugin/engine) endif() -add_subdirectory(ge) +# if (ENABLE_GE_COV OR ENABLE_GE_UT OR ENABLE_GE_ST) +# add_subdirectory(tests) +# endif() + diff --git a/build.sh b/build.sh index 572b1c3c..5227f21f 100644 --- a/build.sh +++ b/build.sh @@ -23,7 +23,7 @@ export BUILD_PATH="${BASEPATH}/build/" usage() { echo "Usage:" - echo "sh build.sh [-j[n]] [-h] [-v] [-s] [-t] [-u] [-c] [-p]" + echo "sh build.sh [-j[n]] [-h] [-v] [-s] [-t] [-u] [-c]" echo "" echo "Options:" echo " -h Print usage" @@ -32,7 +32,6 @@ usage() echo " -j[n] Set the number of threads used for building GraphEngine, default is 8" echo " -t Build and execute ut" echo " -c Build ut with coverage tag" - echo " -p Build inference or train" echo " -v Display build command" echo "to be continued ..." } @@ -47,10 +46,8 @@ checkopts() ENABLE_GE_ST="off" ENABLE_GE_COV="off" GE_ONLY="on" - PLATFORM="inference" - PRODUCT="normal" # Process the options - while getopts 'ustchj:p:g:v' opt + while getopts 'ustchj:v' opt do OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') case "${opt}" in @@ -80,12 +77,6 @@ checkopts() v) VERBOSE="VERBOSE=1" ;; - p) - PLATFORM=$OPTARG - ;; - g) - PRODUCT=$OPTARG - ;; *) echo "Undefined option: ${opt}" usage @@ -95,9 +86,6 @@ checkopts() } checkopts "$@" -git submodule update --init metadef -git submodule update --init parser - mk_dir() { local create_dir="$1" # the target to make @@ -112,8 +100,8 @@ echo "---------------- GraphEngine build start ----------------" build_graphengine() { echo "create build directory and build GraphEngine"; - mk_dir "${BUILD_PATH}" - cd "${BUILD_PATH}" + mk_dir "${BUILD_PATH}/graphengine" + cd "${BUILD_PATH}/graphengine" CMAKE_ARGS="-DBUILD_PATH=$BUILD_PATH -DGE_ONLY=$GE_ONLY" if [[ "X$ENABLE_GE_COV" = "Xon" ]]; then @@ -129,42 +117,17 @@ build_graphengine() CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_ST=ON" fi - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_OPEN_SRC=True -DCMAKE_INSTALL_PREFIX=${OUTPUT_PATH} -DPLATFORM=${PLATFORM} -DPRODUCT=${PRODUCT}" echo "${CMAKE_ARGS}" - cmake ${CMAKE_ARGS} .. - if [ $? -ne 0 ] - then - echo "execute command: cmake ${CMAKE_ARGS} .. failed." - return 1 - fi - COMMON_TARGET="ge_common engine fmk_parser parser_common _caffe_parser fmk_onnx_parser graph register engine_conf.json optimizer_priority.pbtxt " - TARGET=${COMMON_TARGET} - if [ "x${PLATFORM}" = "xtrain" ] - then - TARGET="ge_runner ge_local_engine ge_local_opskernel_builder host_cpu_engine host_cpu_opskernel_builder ${TARGET}" - elif [ "x${PLATFORM}" = "xinference" ] - then - TARGET="ge_compiler atc_ge_local_engine atc_ge_local_opskernel_builder atc_host_cpu_engine atc_host_cpu_opskernel_builder atc opensrc_ascendcl ${TARGET}" - elif [ "x${PLATFORM}" = "xall" ] - then - # build all the target - TARGET="" - fi - - make ${VERBOSE} ${TARGET} -j${THREAD_NUM} && make install - if [ $? -ne 0 ] - then - echo "execute command: make ${VERBOSE} -j${THREAD_NUM} && make install failed." - return 1 - fi + cmake ${CMAKE_ARGS} ../.. + make ${VERBOSE} -j${THREAD_NUM} echo "GraphEngine build success!" } g++ -v -mk_dir ${OUTPUT_PATH} -build_graphengine || { echo "GraphEngine build failed."; return; } +build_graphengine echo "---------------- GraphEngine build finished ----------------" -#cp -rf "${BUILD_PATH}/graphengine/"*.so "${OUTPUT_PATH}" -#rm -rf "${OUTPUT_PATH}/"libproto* +mk_dir ${OUTPUT_PATH} +cp -rf "${BUILD_PATH}/graphengine/"*.so "${OUTPUT_PATH}" +rm -rf "${OUTPUT_PATH}/"libproto* rm -f ${OUTPUT_PATH}/libgmock*.so rm -f ${OUTPUT_PATH}/libgtest*.so rm -f ${OUTPUT_PATH}/lib*_stub.so @@ -212,82 +175,43 @@ echo "---------------- GraphEngine output generated ----------------" generate_package() { cd "${BASEPATH}" - - GRAPHENGINE_LIB_PATH="lib" - ACL_PATH="acllib/lib64" FWK_PATH="fwkacllib/lib64" ATC_PATH="atc/lib64" - ATC_BIN_PATH="atc/bin" NNENGINE_PATH="plugin/nnengine/ge_config" OPSKERNEL_PATH="plugin/opskernel" - ATC_LIB=("libc_sec.so" "libge_common.so" "libge_compiler.so" "libgraph.so" "libregister.so") - FWK_LIB=("libge_common.so" "libge_runner.so" "libgraph.so" "libregister.so") - PLUGIN_OPSKERNEL=("libge_local_engine.so" "libge_local_opskernel_builder.so" "libhost_cpu_engine.so" "libhost_cpu_opskernel_builder.so" "optimizer_priority.pbtxt") - PARSER_LIB=("lib_caffe_parser.so" "libfmk_onnx_parser.so" "libfmk_parser.so" "libparser_common.so") + ATC_LIB=("libc_sec.so" "libge_common.so" "libge_compiler.so" "libgraph.so") + FWK_LIB=("libge_common.so" "libge_runner.so" "libgraph.so") rm -rf ${OUTPUT_PATH:?}/${FWK_PATH}/ - rm -rf ${OUTPUT_PATH:?}/${ACL_PATH}/ rm -rf ${OUTPUT_PATH:?}/${ATC_PATH}/ - rm -rf ${OUTPUT_PATH:?}/${ATC_BIN_PATH}/ - mk_dir "${OUTPUT_PATH}/${FWK_PATH}/${NNENGINE_PATH}" mk_dir "${OUTPUT_PATH}/${FWK_PATH}/${OPSKERNEL_PATH}" mk_dir "${OUTPUT_PATH}/${ATC_PATH}/${NNENGINE_PATH}" mk_dir "${OUTPUT_PATH}/${ATC_PATH}/${OPSKERNEL_PATH}" - mk_dir "${OUTPUT_PATH}/${ACL_PATH}" - mk_dir "${OUTPUT_PATH}/${ATC_BIN_PATH}" - - cd "${OUTPUT_PATH}" - find ./ -name graphengine_lib.tar -exec rm {} \; + find output/ -name graphengine_lib.tar -exec rm {} \; + cp src/ge/engine_manager/engine_conf.json ${OUTPUT_PATH}/${FWK_PATH}/${NNENGINE_PATH} + cp src/ge/engine_manager/engine_conf.json ${OUTPUT_PATH}/${ATC_PATH}/${NNENGINE_PATH} - cp ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH}/engine_conf.json ${OUTPUT_PATH}/${FWK_PATH}/${NNENGINE_PATH} - cp ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH}/engine_conf.json ${OUTPUT_PATH}/${ATC_PATH}/${NNENGINE_PATH} + find output/ -maxdepth 1 -name libengine.so -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH}/${NNENGINE_PATH}/../ \; + find output/ -maxdepth 1 -name libengine.so -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH}/${NNENGINE_PATH}/../ \; - find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name libengine.so -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH}/${NNENGINE_PATH}/../ \; - find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name libengine.so -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH}/${NNENGINE_PATH}/../ \; + find output/ -maxdepth 1 -name libge_local_engine.so -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH}/${OPSKERNEL_PATH} \; + find output/ -maxdepth 1 -name libge_local_engine.so -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH}/${OPSKERNEL_PATH} \; - MAX_DEPTH=1 - if [ "x${PLATFORM}" = "xall" ] || [ "x${PLATFORM}" = "xinference" ] - then - MAX_DEPTH=2 - fi - for lib in "${PLUGIN_OPSKERNEL[@]}"; - do - find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth ${MAX_DEPTH} -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH}/${OPSKERNEL_PATH} \; - find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth ${MAX_DEPTH} -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH}/${OPSKERNEL_PATH} \; - done - - for lib in "${PARSER_LIB[@]}"; + cd "${OUTPUT_PATH}" + for lib in "${ATC_LIB[@]}"; do - find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH} \; - find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \; + cp "$lib" "${OUTPUT_PATH}/${ATC_PATH}" done for lib in "${FWK_LIB[@]}"; do - find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH} \; + cp "$lib" "${OUTPUT_PATH}/${FWK_PATH}" done - for lib in "${ATC_LIB[@]}"; - do - find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \; - done - - find ./bin -name atc -exec cp {} "${OUTPUT_PATH}/${ATC_BIN_PATH}" \; - find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "libascendcl.so" -exec cp -f {} ${OUTPUT_PATH}/${ACL_PATH} \; - - if [ "x${PLATFORM}" = "xtrain" ] - then - tar -cf graphengine_lib.tar fwkacllib - elif [ "x${PLATFORM}" = "xinference" ] - then - tar -cf graphengine_lib.tar acllib atc - elif [ "x${PLATFORM}" = "xall" ] - then - tar -cf graphengine_lib.tar fwkacllib acllib atc - fi + tar -cf graphengine_lib.tar fwkacllib/ atc/ } if [[ "X$ENABLE_GE_UT" = "Xoff" ]]; then diff --git a/cmake/FindModule.cmake b/cmake/FindModule.cmake deleted file mode 100644 index 74a63634..00000000 --- a/cmake/FindModule.cmake +++ /dev/null @@ -1,23 +0,0 @@ -#[[ - module - the name of export imported target - name - find the library name - path - find the library path -#]] -function(find_module module name path) - if (TARGET ${module}) - return() - endif() - find_library(${module}_LIBRARY_DIR NAMES ${name} NAMES_PER_DIR PATHS ${path} - PATH_SUFFIXES lib - ) - - message(STATUS "find ${name} location ${${module}_LIBRARY_DIR}") - if ("${${module}_LIBRARY_DIR}" STREQUAL "${module}_LIBRARY_DIR-NOTFOUND") - message(FATAL_ERROR "${name} not found in ${path}") - endif() - - add_library(${module} SHARED IMPORTED) - set_target_properties(${module} PROPERTIES - IMPORTED_LOCATION ${${module}_LIBRARY_DIR} - ) -endfunction() diff --git a/cmake/external_libs/eigen.cmake b/cmake/external_libs/eigen.cmake new file mode 100644 index 00000000..5cdfc346 --- /dev/null +++ b/cmake/external_libs/eigen.cmake @@ -0,0 +1,22 @@ +set(Eigen3_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") +set(Eigen3_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") +set(Eigen3_NS "ge_") + +if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/eigen-git-mirrorsource/repository/archive/3.3.7.tar.gz") + set(MD5 "cf6552a5d90c1aca4b5e0b011f65ea93") +else() + set(REQ_URL "https://gitlab.com/libeigen/eigen/-/archive/3.3.7/eigen-3.3.7.tar.gz") + set(MD5 "9e30f67e8531477de4117506fe44669b") +endif () + +graphengine_add_pkg(Eigen3 + VER 3.3.7 + URL ${REQ_URL} + MD5 ${MD5} + CMAKE_OPTION -DBUILD_TESTING=OFF) + +find_package(Eigen3 3.3.7 REQUIRED ${GE_FIND_NO_DEFAULT_PATH}) +set_property(TARGET Eigen3::Eigen PROPERTY IMPORTED_GLOBAL TRUE) +add_library(graphengine::eigen ALIAS Eigen3::Eigen) +include_directories(${EIGEN3_INCLUDE_DIRS}) diff --git a/cmake/external_libs/gflags.cmake b/cmake/external_libs/gflags.cmake deleted file mode 100755 index 5a4c5338..00000000 --- a/cmake/external_libs/gflags.cmake +++ /dev/null @@ -1,39 +0,0 @@ -if (HAVE_GFLAGS) - return() -endif() - -include(ExternalProject) -#set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output) - -if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR - (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) - set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) - message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") -endif() - -ExternalProject_Add(gflags_build - URL https://github.com/gflags/gflags/archive/v2.2.2.tar.gz - #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz - #SOURCE_DIR ${GE_CODE_DIR}/../third_party/gflags/src/gflags-2.2.2 - CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags - BUILD_COMMAND $(MAKE) - INSTALL_COMMAND $(MAKE) install - EXCLUDE_FROM_ALL TRUE -) - -set(GFLAGS_PKG_DIR ${CMAKE_INSTALL_PREFIX}/gflags) - -add_library(gflags_static STATIC IMPORTED) - -set_target_properties(gflags_static PROPERTIES - IMPORTED_LOCATION ${GFLAGS_PKG_DIR}/lib/libgflags.a -) - -add_library(gflags INTERFACE) -target_include_directories(gflags INTERFACE ${GFLAGS_PKG_DIR}/include) -target_link_libraries(gflags INTERFACE gflags_static) - -add_dependencies(gflags gflags_build) - -#set(HAVE_GFLAGS TRUE CACHE BOOL "gflags build add") -set(HAVE_GFLAGS TRUE) diff --git a/cmake/external_libs/gtest.cmake b/cmake/external_libs/gtest.cmake new file mode 100644 index 00000000..5e175fd2 --- /dev/null +++ b/cmake/external_libs/gtest.cmake @@ -0,0 +1,24 @@ +set(ge_gtest_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") +set(ge_gtest_CFLAGS "-D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") + +if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.0.tar.gz") + set(MD5 "89e13ca1aa48d370719d58010b83f62c") +else() + set(REQ_URL "https://github.com/google/googletest/archive/release-1.8.0.tar.gz") + set(MD5 "16877098823401d1bf2ed7891d7dce36") +endif () + +graphengine_add_pkg(ge_gtest + VER 1.8.0 + LIBS gtest gtest_main + URL ${REQ_URL} + MD5 ${MD5} + CMAKE_OPTION -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON + -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON) + +add_library(graphengine::gtest ALIAS ge_gtest::gtest) +add_library(graphengine::gtest_main ALIAS ge_gtest::gtest_main) +include_directories(${ge_gtest_INC}) +file(COPY ${ge_gtest_INC}/../lib/libgtest.so DESTINATION ${CMAKE_SOURCE_DIR}/build/graphengine) +file(COPY ${ge_gtest_INC}/../lib/libgtest_main.so DESTINATION ${CMAKE_SOURCE_DIR}/build/graphengine) diff --git a/cmake/external_libs/json.cmake b/cmake/external_libs/json.cmake old mode 100755 new mode 100644 index cf020b40..f2ae5310 --- a/cmake/external_libs/json.cmake +++ b/cmake/external_libs/json.cmake @@ -1,24 +1,20 @@ -if (HAVE_JSON) - return() -endif() +set(nlohmann_json_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") +set(nlohmann_json_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") -include(ExternalProject) +if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") + set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") + set(INCLUDE "./include") +else() + set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") + set(MD5 "0dc903888211db3a0f170304cd9f3a89") + set(INCLUDE "./") +endif () -set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include) -ExternalProject_Add(json_build - URL https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip - #URL /home/txd/workspace/cloud_code/pkg/include.zip - SOURCE_DIR ${JSON_SRC_DIR} - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - INSTALL_COMMAND "" - EXCLUDE_FROM_ALL TRUE -) - - -add_library(json INTERFACE) -target_include_directories(json INTERFACE ${JSON_SRC_DIR}) -add_dependencies(json json_build) - -#set(HAVE_JSON TRUE CACHE BOOL "json build add") -set(HAVE_JSON TRUE) +graphengine_add_pkg(ge_nlohmann_json + VER 3.6.1 + HEAD_ONLY ${INCLUDE} + URL ${REQ_URL} + MD5 ${MD5}) +include_directories(${ge_nlohmann_json_INC}) +add_library(graphengine::json ALIAS ge_nlohmann_json) \ No newline at end of file diff --git a/cmake/external_libs/onnx.cmake b/cmake/external_libs/onnx.cmake old mode 100755 new mode 100644 index 889c95c3..a092f964 --- a/cmake/external_libs/onnx.cmake +++ b/cmake/external_libs/onnx.cmake @@ -1,29 +1,13 @@ -include(ExternalProject) - -#set(ONNX_SRC_DIR /home/txd/workspace/cloud_code/graphengine/build/graphengine/open_source/onnx) -#set(ONNX_PROTO ${ONNX_SRC_DIR}/onnx/onnx.proto) -set(ONNX_PROTO_DIR ${CMAKE_BINARY_DIR}/onnx) -set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto) -file(MAKE_DIRECTORY ${ONNX_PROTO_DIR}) - -ExternalProject_Add(onnx - URL https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz - #URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz - #URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345 - #SOURCE_DIR ${ONNX_SRC_DIR} - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - #INSTALL_COMMAND "" - INSTALL_COMMAND ${CMAKE_COMMAND} -E copy /onnx/onnx.proto ${ONNX_PROTO_FILE} - #BUILD_ALWAYS TRUE - EXCLUDE_FROM_ALL TRUE -) - -macro(onnx_protobuf_generate comp c_var h_var) - add_custom_command(OUTPUT ${ONNX_PROTO_FILE} - DEPENDS onnx - ) - ge_protobuf_generate(${comp} ${c_var} ${h_var} ${ONNX_PROTO_FILE}) -endmacro() - - +if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz") + set(MD5 "1bdbcecdd68ea8392630467646776e02") +else() + set(REQ_URL "https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz") + set(MD5 "512f2779d6215d4a36f366b6b9acdf1e") +endif () + +graphengine_add_pkg(onnx + VER 1.6.0 + HEAD_ONLY ./ + URL ${REQ_URL} + MD5 ${MD5}) diff --git a/cmake/external_libs/protobuf.cmake b/cmake/external_libs/protobuf.cmake new file mode 100644 index 00000000..8be594c7 --- /dev/null +++ b/cmake/external_libs/protobuf.cmake @@ -0,0 +1,63 @@ +if (NOT TARGET protobuf::protobuf) +set(protobuf_USE_STATIC_LIBS ON) +set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -O2") +set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") +set(_ge_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) +string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") +string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + +if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") + set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") +else() + set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") + set(MD5 "3d9e32700639618a4d2d342c99d4507a") +endif () + +graphengine_add_pkg(protobuf + VER 3.8.0 + LIBS protobuf + EXE protoc + URL ${REQ_URL} + MD5 ${MD5} + CMAKE_PATH ../cmake/ + CMAKE_OPTION -Dprotobuf_BUILD_TESTS=OFF -Dprotobuf_BUILD_SHARED_LIBS=OFF) +set(CMAKE_CXX_FLAGS ${_ge_tmp_CMAKE_CXX_FLAGS}) +endif() +add_library(graphengine::protobuf ALIAS protobuf::protobuf) +set(PROTOBUF_LIBRARY protobuf::protobuf) +include_directories(${protobuf_INC}) +include_directories(${protobuf_DIRPATH}/src) + +function(ge_protobuf_generate comp c_var h_var) + if(NOT ARGN) + message(SEND_ERROR "Error: ge_protobuf_generate() called without any proto files") + return() + endif() + + set(${c_var}) + set(${h_var}) + + foreach(file ${ARGN}) + get_filename_component(abs_file ${file} ABSOLUTE) + get_filename_component(file_name ${file} NAME_WE) + get_filename_component(file_dir ${abs_file} PATH) + + list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/proto/${comp}/proto/${file_name}.pb.cc") + list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/proto/${comp}/proto/${file_name}.pb.h") + + add_custom_command( + OUTPUT "${CMAKE_BINARY_DIR}/proto/${comp}/proto/${file_name}.pb.cc" + "${CMAKE_BINARY_DIR}/proto/${comp}/proto/${file_name}.pb.h" + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/proto/${comp}/proto" + COMMAND protobuf::protoc -I${file_dir} --cpp_out=${CMAKE_BINARY_DIR}/proto/${comp}/proto ${abs_file} + DEPENDS protobuf::protoc ${abs_file} + COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM ) + endforeach() + + set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE) + set(${c_var} ${${c_var}} PARENT_SCOPE) + set(${h_var} ${${h_var}} PARENT_SCOPE) + +endfunction() diff --git a/cmake/external_libs/protobuf_shared.cmake b/cmake/external_libs/protobuf_shared.cmake deleted file mode 100755 index b9c4c105..00000000 --- a/cmake/external_libs/protobuf_shared.cmake +++ /dev/null @@ -1,59 +0,0 @@ -if (HAVE_PROTOBUF) - return() -endif() - -include(ExternalProject) -include(GNUInstallDirs) - -if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR - (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) - set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) - message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") -endif() -set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") -set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") -ExternalProject_Add(protobuf_build - URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz - #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz - #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 - #DOWNLOAD_COMMAND ${CMAKE_COMMAND} -E copy_directory ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 - #CONFIGURE_COMMAND ${CMAKE_COMMAND} - #-DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} - #-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} - #-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} - #-DCMAKE_LINKER=${CMAKE_LINKER} - #-DCMAKE_AR=${CMAKE_AR} - #-DCMAKE_RANLIB=${CMAKE_RANLIB} - #-Dprotobuf_WITH_ZLIB=OFF - #-Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=ON -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protobuf /cmake - CONFIGURE_COMMAND cd - && ./autogen.sh && cd && /configure --prefix=${CMAKE_INSTALL_PREFIX}/protobuf --with-zlib=no CC=${CMAKE_C_COMPILER} CXX=${CMAKE_CXX_COMPILER} CXXFLAGS=${protobuf_CXXFLAGS} LDFLAGS=${protobuf_LDFLAGS} - && bash -c "sed -i 's|^hardcode_libdir_flag_spec=.*|hardcode_libdir_flag_spec=\"\"|g' libtool && sed -i 's|^runpath_var=LD_RUN_PATH|runpath_var=DIE_RPATH_DIE|g' libtool" - BUILD_COMMAND $(MAKE) - INSTALL_COMMAND $(MAKE) install - EXCLUDE_FROM_ALL TRUE -) -include(GNUInstallDirs) - -set(PROTOBUF_SHARED_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf) - -add_library(protobuf SHARED IMPORTED) - -file(MAKE_DIRECTORY ${PROTOBUF_SHARED_PKG_DIR}/include) - -set_target_properties(protobuf PROPERTIES - IMPORTED_LOCATION ${PROTOBUF_SHARED_PKG_DIR}/lib/libprotobuf.so -) - -target_include_directories(protobuf INTERFACE ${PROTOBUF_SHARED_PKG_DIR}/include) - -set(INSTALL_BASE_DIR "") -set(INSTALL_LIBRARY_DIR lib) - -install(FILES ${PROTOBUF_SHARED_PKG_DIR}/lib/libprotobuf.so ${PROTOBUF_SHARED_PKG_DIR}/lib/libprotobuf.so.19.0.0 OPTIONAL - DESTINATION ${INSTALL_LIBRARY_DIR}) - -add_dependencies(protobuf protobuf_build) - -#set(HAVE_PROTOBUF TRUE CACHE BOOL "protobuf build add") -set(HAVE_PROTOBUF TRUE) diff --git a/cmake/external_libs/protobuf_static.cmake b/cmake/external_libs/protobuf_static.cmake deleted file mode 100755 index 81535f21..00000000 --- a/cmake/external_libs/protobuf_static.cmake +++ /dev/null @@ -1,43 +0,0 @@ -include(ExternalProject) -include(GNUInstallDirs) -#set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output) - -if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR - (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) - set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) - message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") -endif() - -set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") -set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") -set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) -ExternalProject_Add(protobuf_static_build - URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz - #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz - #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 - CONFIGURE_COMMAND ${CMAKE_COMMAND} - -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} - -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} - -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} - -DCMAKE_LINKER=${CMAKE_LINKER} - -DCMAKE_AR=${CMAKE_AR} - -DCMAKE_RANLIB=${CMAKE_RANLIB} - -Dprotobuf_WITH_ZLIB=OFF - -Dprotobuf_BUILD_TESTS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${PROTOBUF_STATIC_PKG_DIR} /cmake - BUILD_COMMAND $(MAKE) - INSTALL_COMMAND $(MAKE) install - EXCLUDE_FROM_ALL TRUE -) -include(GNUInstallDirs) - -add_library(protobuf_static_lib STATIC IMPORTED) - -set_target_properties(protobuf_static_lib PROPERTIES - IMPORTED_LOCATION ${PROTOBUF_STATIC_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/libprotobuf.a -) - -add_library(protobuf_static INTERFACE) -target_include_directories(protobuf_static INTERFACE ${PROTOBUF_STATIC_PKG_DIR}/include) -target_link_libraries(protobuf_static INTERFACE protobuf_static_lib) - -add_dependencies(protobuf_static protobuf_static_build) diff --git a/cmake/external_libs/protoc.cmake b/cmake/external_libs/protoc.cmake deleted file mode 100755 index 74ef785a..00000000 --- a/cmake/external_libs/protoc.cmake +++ /dev/null @@ -1,103 +0,0 @@ -if (HAVE_PROTOC) - return() -endif() - -include(ExternalProject) -include(GNUInstallDirs) -#set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output) - -if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR - (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) - set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) - message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") -endif() - -set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") -set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") -ExternalProject_Add(protoc_build - URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz - #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz - #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 - CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc /cmake - BUILD_COMMAND $(MAKE) - INSTALL_COMMAND $(MAKE) install - EXCLUDE_FROM_ALL TRUE -) - -set(PROTOC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protoc) - -set(protoc_EXECUTABLE ${PROTOC_PKG_DIR}/${CMAKE_INSTALL_BINDIR}/protoc) - -function(protobuf_generate comp c_var h_var) - if(NOT ARGN) - message(SEND_ERROR "Error: protobuf_generate() called without any proto files") - return() - endif() - set(${c_var}) - set(${h_var}) - - foreach(file ${ARGN}) - get_filename_component(abs_file ${file} ABSOLUTE) - get_filename_component(file_name ${file} NAME_WE) - get_filename_component(file_dir ${abs_file} PATH) - get_filename_component(parent_subdir ${file_dir} NAME) - - if("${parent_subdir}" STREQUAL "proto") - set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto) - else() - set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto/${parent_subdir}) - endif() - list(APPEND ${c_var} "${proto_output_path}/${file_name}.pb.cc") - list(APPEND ${h_var} "${proto_output_path}/${file_name}.pb.h") - - add_custom_command( - OUTPUT "${proto_output_path}/${file_name}.pb.cc" "${proto_output_path}/${file_name}.pb.h" - WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} - COMMAND ${CMAKE_COMMAND} -E make_directory "${proto_output_path}" - COMMAND ${protoc_EXECUTABLE} -I${file_dir} --cpp_out=${proto_output_path} ${abs_file} - DEPENDS protoc_build ${abs_file} - COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM ) - endforeach() - - set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE) - set(${c_var} ${${c_var}} PARENT_SCOPE) - set(${h_var} ${${h_var}} PARENT_SCOPE) - -endfunction() - -function(protobuf_generate_py comp py_var) - if(NOT ARGN) - message(SEND_ERROR "Error: protobuf_generate_py() called without any proto files") - return() - endif() - set(${py_var}) - - foreach(file ${ARGN}) - get_filename_component(abs_file ${file} ABSOLUTE) - get_filename_component(file_name ${file} NAME_WE) - get_filename_component(file_dir ${abs_file} PATH) - get_filename_component(parent_subdir ${file_dir} NAME) - - if("${parent_subdir}" STREQUAL "proto") - set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto) - else() - set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto/${parent_subdir}) - endif() - list(APPEND ${py_var} "${proto_output_path}/${file_name}_pb2.py") - - add_custom_command( - OUTPUT "${proto_output_path}/${file_name}_pb2.py" - WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} - COMMAND ${CMAKE_COMMAND} -E make_directory "${proto_output_path}" - COMMAND ${protoc_EXECUTABLE} -I${file_dir} --python_out=${proto_output_path} ${abs_file} - DEPENDS protoc_build ${abs_file} - COMMENT "Running PYTHON protocol buffer compiler on ${file}" VERBATIM ) - endforeach() - - set_source_files_properties(${${py_var}} PROPERTIES GENERATED TRUE) - set(${py_var} ${${py_var}} PARENT_SCOPE) - -endfunction() - -#set(HAVE_PROTOC TRUE CACHE BOOL "protoc build add") -set(HAVE_PROTOC TRUE) diff --git a/cmake/external_libs/securec.cmake b/cmake/external_libs/securec.cmake old mode 100755 new mode 100644 index 0bd62ab2..2fbf8b80 --- a/cmake/external_libs/securec.cmake +++ b/cmake/external_libs/securec.cmake @@ -1,62 +1,11 @@ -if (HAVE_C_SEC) - return() -endif() - -include(ExternalProject) - -if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR - (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) - set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) - message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") -endif() - -ExternalProject_Add(c_sec_build - URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz - #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz - #SOURCE_DIR ${GE_CODE_DIR}/../libc_sec - PATCH_COMMAND patch -p1 < ${GE_CODE_DIR}/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch - CONFIGURE_COMMAND ${CMAKE_COMMAND} - -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} - -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} - -DCMAKE_LINKER=${CMAKE_LINKER} - -DCMAKE_AR=${CMAKE_AR} - -DCMAKE_RANLIB=${CMAKE_RANLIB} - -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/c_sec - BUILD_COMMAND $(MAKE) - INSTALL_COMMAND $(MAKE) install - EXCLUDE_FROM_ALL TRUE -) - -set(C_SEC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/c_sec) - -add_library(c_sec SHARED IMPORTED) - -file(MAKE_DIRECTORY ${C_SEC_PKG_DIR}/include) - -set_target_properties(c_sec PROPERTIES - IMPORTED_LOCATION ${C_SEC_PKG_DIR}/lib/libc_sec.so -) - -target_include_directories(c_sec INTERFACE ${C_SEC_PKG_DIR}/include) - -add_dependencies(c_sec c_sec_build) - -set(INSTALL_BASE_DIR "") -set(INSTALL_LIBRARY_DIR lib) - -install(FILES ${C_SEC_PKG_DIR}/lib/libc_sec.so OPTIONAL - DESTINATION ${INSTALL_LIBRARY_DIR}) - -add_library(c_sec_static_lib STATIC IMPORTED) -set_target_properties(c_sec_static_lib PROPERTIES - IMPORTED_LOCATION ${C_SEC_PKG_DIR}/lib/libc_sec.a -) - -add_library(c_sec_static INTERFACE) -target_include_directories(c_sec_static INTERFACE ${C_SEC_PKG_DIR}/include) -target_link_libraries(c_sec_static INTERFACE c_sec_static_lib) - -add_dependencies(c_sec_static c_sec_build) - -#set(HAVE_C_SEC TRUE CACHE BOOL "c_sec build add") -set(HAVE_C_SEC TRUE) +graphengine_add_pkg(securec + VER 1.1.10 + URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz + MD5 193f0ca5246c1dd84920db34d2d8249f + LIBS c_sec + PATCHES ${GE_SOURCE_DIR}/third_party/patch/securec/securec.patch001 + CMAKE_OPTION "-DCMAKE_BUILD_TYPE=Release" + ) +include_directories(${securec_INC}) +file(COPY ${securec_INC}/../lib/libc_sec.so DESTINATION ${CMAKE_SOURCE_DIR}/build/graphengine) +add_library(graphengine::securec ALIAS securec::c_sec) \ No newline at end of file diff --git a/cmake/ge_utils.cmake b/cmake/ge_utils.cmake new file mode 100644 index 00000000..75480ded --- /dev/null +++ b/cmake/ge_utils.cmake @@ -0,0 +1,371 @@ +include(FetchContent) +set(FETCHCONTENT_QUIET OFF) + +function(graphengine_add_submodule_obj des_submodule_objs sub_dir submodule_name_obj) + + add_subdirectory(${sub_dir}) + + if(NOT TARGET ${submodule_name_obj}) + message(FATAL_ERROR "Can not find submodule '${submodule_name_obj}'. in ${CMAKE_CURRENT_LIST_FILE}") + endif() + if("$" IN_LIST ${des_submodule_objs}) + message(FATAL_ERROR "submodule '${submodule_name_obj}' added more than once. in ${CMAKE_CURRENT_LIST_FILE}") + endif() + + set(${des_submodule_objs} ${${des_submodule_objs}} $ PARENT_SCOPE) + +endfunction() + +if (DEFINED ENV{MSLIBS_CACHE_PATH}) + set(_MS_LIB_CACHE $ENV{MSLIBS_CACHE_PATH}) +else() + set(_MS_LIB_CACHE ${CMAKE_BINARY_DIR}/.mslib) +endif () +message("MS LIBS CACHE PATH: ${_MS_LIB_CACHE}") + +if (NOT EXISTS ${_MS_LIB_CACHE}) + file(MAKE_DIRECTORY ${_MS_LIB_CACHE}) +endif () + +if (DEFINED ENV{MSLIBS_SERVER}) + set(LOCAL_LIBS_SERVER $ENV{MSLIBS_SERVER}) + message("LOCAL_LIBS_SERVER: ${LOCAL_LIBS_SERVER}") +endif () + +include(ProcessorCount) +ProcessorCount(N) +if (JOBS) + set(THNUM ${JOBS}) +else() + set(JOBS 8) + if (${JOBS} GREATER ${N}) + set(THNUM ${N}) + endif() +endif () +message("set make thread num: ${THNUM}") + +if(LOCAL_LIBS_SERVER) + if (NOT ENV{no_proxy}) + set(ENV{no_proxy} "${LOCAL_LIBS_SERVER}") + else() + string(FIND $ENV{no_proxy} ${LOCAL_LIBS_SERVER} IP_POS) + if (${IP_POS} EQUAL -1) + set(ENV{no_proxy} "$ENV{no_proxy},${LOCAL_LIBS_SERVER}") + endif () + endif () +endif() + +function(__download_pkg pkg_name pkg_url pkg_md5) + + if(LOCAL_LIBS_SERVER) + get_filename_component(_URL_FILE_NAME ${pkg_url} NAME) + set(pkg_url "http://${LOCAL_LIBS_SERVER}:8081/libs/${pkg_name}/${_URL_FILE_NAME}" ${pkg_url}) + endif() + + FetchContent_Declare( + ${pkg_name} + URL ${pkg_url} + URL_HASH MD5=${pkg_md5} + ) + FetchContent_GetProperties(${pkg_name}) + message("download: ${${pkg_name}_SOURCE_DIR} , ${pkg_name} , ${pkg_url}") + if(NOT ${pkg_name}_POPULATED) + FetchContent_Populate(${pkg_name}) + set(${pkg_name}_SOURCE_DIR ${${pkg_name}_SOURCE_DIR} PARENT_SCOPE) + endif() + +endfunction() + +function(__download_pkg_with_git pkg_name pkg_url pkg_git_commit pkg_md5) + + if(LOCAL_LIBS_SERVER) + set(pkg_url "http://${LOCAL_LIBS_SERVER}:8081/libs/${pkg_name}/${pkg_git_commit}") + FetchContent_Declare( + ${pkg_name} + URL ${pkg_url} + URL_HASH MD5=${pkg_md5} + ) + else() + FetchContent_Declare( + ${pkg_name} + GIT_REPOSITORY ${pkg_url} + GIT_TAG ${pkg_git_commit}) + endif() + FetchContent_GetProperties(${pkg_name}) + message("download: ${${pkg_name}_SOURCE_DIR} , ${pkg_name} , ${pkg_url}") + if(NOT ${pkg_name}_POPULATED) + FetchContent_Populate(${pkg_name}) + set(${pkg_name}_SOURCE_DIR ${${pkg_name}_SOURCE_DIR} PARENT_SCOPE) + endif() + +endfunction() + + +function(__find_pkg_then_add_target pkg_name pkg_exe) + + unset(${pkg_name}_LIBS) + + message("_FIND:${${pkg_name}_BASE_DIR}") + + if(pkg_exe) + find_program(${pkg_exe}_EXE ${pkg_exe} PATHS ${${pkg_name}_BASE_DIR}/bin NO_DEFAULT_PATH) + if(NOT ${pkg_exe}_EXE) + return() + endif() + add_executable(${pkg_name}::${pkg_exe} IMPORTED GLOBAL) + set_target_properties(${pkg_name}::${pkg_exe} PROPERTIES + IMPORTED_LOCATION ${${pkg_exe}_EXE} + ) + message("found ${${pkg_exe}_EXE}") + endif() + + foreach(_LIB_NAME ${ARGN}) + set(_LIB_SEARCH_NAME ${_LIB_NAME}) + set(_LIB_TYPE SHARED) + if (${pkg_name}_USE_STATIC_LIBS) + set(_LIB_SEARCH_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}") + set(_LIB_TYPE STATIC) + endif () + set(${_LIB_NAME}_LIB ${_LIB_NAME}_LIB-NOTFOUND) + find_library(${_LIB_NAME}_LIB ${_LIB_SEARCH_NAME} PATHS ${${pkg_name}_BASE_DIR}/lib NO_DEFAULT_PATH) + if(NOT ${_LIB_NAME}_LIB) + return() + endif() + add_library(${pkg_name}::${_LIB_NAME} ${_LIB_TYPE} IMPORTED GLOBAL) + set_target_properties(${pkg_name}::${_LIB_NAME} PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${${pkg_name}_BASE_DIR}/include" + IMPORTED_LOCATION ${${_LIB_NAME}_LIB} + ) + list(APPEND ${pkg_name}_LIBS ${pkg_name}::${_LIB_NAME}) + message("found ${${_LIB_NAME}_LIB}") + STRING( REGEX REPLACE "(.+)/(.+)" "\\1" LIBPATH ${${_LIB_NAME}_LIB}) + set(${pkg_name}_LIBPATH ${LIBPATH} CACHE STRING INTERNAL) + endforeach(_LIB_NAME) + + set(${pkg_name}_LIBS ${${pkg_name}_LIBS} PARENT_SCOPE) +endfunction() + +function(__exec_cmd) + set(options ) + set(oneValueArgs WORKING_DIRECTORY) + set(multiValueArgs COMMAND) + + cmake_parse_arguments(EXEC "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) + + execute_process(COMMAND ${EXEC_COMMAND} + WORKING_DIRECTORY ${EXEC_WORKING_DIRECTORY} + RESULT_VARIABLE RESULT) + if(NOT RESULT EQUAL "0") + message(FATAL_ERROR "error! when ${EXEC_COMMAND} in ${EXEC_WORKING_DIRECTORY}") + endif() +endfunction() + +function(__check_patches pkg_patches) + # check patches + if (PKG_PATCHES) + file(TOUCH ${_MS_LIB_CACHE}/${pkg_name}_patch.md5) + file(READ ${_MS_LIB_CACHE}/${pkg_name}_patch.md5 ${pkg_name}_PATCHES_MD5) + + message("patches md5:${${pkg_name}_PATCHES_MD5}") + + set(${pkg_name}_PATCHES_NEW_MD5 ) + foreach(_PATCH ${PKG_PATCHES}) + file(MD5 ${_PATCH} _PF_MD5) + set(${pkg_name}_PATCHES_NEW_MD5 "${${pkg_name}_PATCHES_NEW_MD5},${_PF_MD5}") + endforeach(_PATCH) + + if (NOT ${pkg_name}_PATCHES_MD5 STREQUAL ${pkg_name}_PATCHES_NEW_MD5) + set(${pkg_name}_PATCHES ${PKG_PATCHES}) + file(REMOVE_RECURSE "${_MS_LIB_CACHE}/${pkg_name}-subbuild") + file(WRITE ${_MS_LIB_CACHE}/${pkg_name}_patch.md5 ${${pkg_name}_PATCHES_NEW_MD5}) + message("patches changed : ${${pkg_name}_PATCHES_NEW_MD5}") + endif () + endif () +endfunction() + +set(GE_FIND_NO_DEFAULT_PATH NO_CMAKE_PATH NO_CMAKE_ENVIRONMENT_PATH NO_SYSTEM_ENVIRONMENT_PATH + NO_CMAKE_BUILDS_PATH NO_CMAKE_PACKAGE_REGISTRY NO_CMAKE_SYSTEM_PATH + NO_CMAKE_SYSTEM_PACKAGE_REGISTRY) +set(GE_FIND_NO_DEFAULT_PATH ${GE_FIND_NO_DEFAULT_PATH} PARENT_SCOPE) + +function(graphengine_add_pkg pkg_name ) + set(options ) + set(oneValueArgs URL MD5 GIT_REPOSITORY GIT_TAG VER EXE DIR HEAD_ONLY CMAKE_PATH) + set(multiValueArgs CMAKE_OPTION LIBS PRE_CONFIGURE_COMMAND CONFIGURE_COMMAND BUILD_OPTION INSTALL_INCS INSTALL_LIBS PATCHES) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) + + if (NOT PKG_CMAKE_PATH) + set(PKG_CMAKE_PATH ..) + endif () + + set(__FIND_PKG_NAME ${pkg_name}) + string(TOLOWER ${pkg_name} pkg_name) + message("pkg name:${__FIND_PKG_NAME},${pkg_name}") + + set(${pkg_name}_PATCHES_HASH ) + foreach(_PATCH ${PKG_PATCHES}) + file(MD5 ${_PATCH} _PF_MD5) + set(${pkg_name}_PATCHES_HASH "${${pkg_name}_PATCHES_HASH},${_PF_MD5}") + endforeach(_PATCH) + + # check options + set(${pkg_name}_CONFIG_TXT + "${CMAKE_CXX_COMPILER_VERSION}-${CMAKE_C_COMPILER_VERSION} + ${ARGN} - ${${pkg_name}_USE_STATIC_LIBS}- ${${pkg_name}_PATCHES_HASH} + ${${pkg_name}_CXXFLAGS}--${${pkg_name}_CFLAGS}--${${pkg_name}_LDFLAGS}") + string(REPLACE ";" "-" ${pkg_name}_CONFIG_TXT ${${pkg_name}_CONFIG_TXT}) + string(MD5 ${pkg_name}_CONFIG_HASH ${${pkg_name}_CONFIG_TXT}) + + message("${pkg_name} config hash: ${${pkg_name}_CONFIG_HASH}") + + set(${pkg_name}_BASE_DIR ${_MS_LIB_CACHE}/${pkg_name}_${${pkg_name}_CONFIG_HASH}) + set(${pkg_name}_DIRPATH ${${pkg_name}_BASE_DIR} CACHE STRING INTERNAL) + + if(EXISTS ${${pkg_name}_BASE_DIR}/options.txt AND PKG_HEAD_ONLY) + set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/${PKG_HEAD_ONLY} PARENT_SCOPE) + add_library(${pkg_name} INTERFACE) + target_include_directories(${pkg_name} INTERFACE ${${pkg_name}_INC}) + return() + endif () + + if(NOT PKG_EXE) + set(PKG_EXE 0) + endif() + + set(${__FIND_PKG_NAME}_ROOT ${${pkg_name}_BASE_DIR}) + set(${__FIND_PKG_NAME}_ROOT ${${pkg_name}_BASE_DIR} PARENT_SCOPE) + + if (PKG_LIBS) + __find_pkg_then_add_target(${pkg_name} ${PKG_EXE} ${PKG_LIBS}) + if(${pkg_name}_LIBS) + set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) + message("Found libs: ${${pkg_name}_LIBS}") + return() + endif() + elseif(NOT PKG_HEAD_ONLY) + find_package(${__FIND_PKG_NAME} ${PKG_VER} ${GE_FIND_NO_DEFAULT_PATH}) + if (${__FIND_PKG_NAME}_FOUND) + set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) + message("Found pkg: ${__FIND_PKG_NAME}") + return() + endif () + endif () + + if (NOT PKG_DIR) + if (PKG_GIT_REPOSITORY) + __download_pkg_with_git(${pkg_name} ${PKG_GIT_REPOSITORY} ${PKG_GIT_TAG} ${PKG_MD5}) + else() + __download_pkg(${pkg_name} ${PKG_URL} ${PKG_MD5}) + endif() + else() + set(${pkg_name}_SOURCE_DIR ${PKG_DIR}) + endif () + file(WRITE ${${pkg_name}_BASE_DIR}/options.txt ${${pkg_name}_CONFIG_TXT}) + message("${pkg_name}_SOURCE_DIR : ${${pkg_name}_SOURCE_DIR}") + + foreach(_PATCH_FILE ${PKG_PATCHES}) + message("patching ${${pkg_name}_SOURCE_DIR} -p1 < ${_PATCH_FILE}") + execute_process(COMMAND patch -p1 INPUT_FILE ${_PATCH_FILE} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR} + RESULT_VARIABLE Result) + if(NOT Result EQUAL "0") + message(FATAL_ERROR "Failed patch: ${_PATCH_FILE}") + endif() + endforeach(_PATCH_FILE) + + file(LOCK ${${pkg_name}_BASE_DIR} DIRECTORY GUARD FUNCTION RESULT_VARIABLE ${pkg_name}_LOCK_RET TIMEOUT 600) + if(NOT ${pkg_name}_LOCK_RET EQUAL "0") + message(FATAL_ERROR "error! when try lock ${${pkg_name}_BASE_DIR} : ${${pkg_name}_LOCK_RET}") + endif() + + if(${pkg_name}_SOURCE_DIR) + if (PKG_HEAD_ONLY) + file(GLOB ${pkg_name}_SOURCE_SUBDIRS ${${pkg_name}_SOURCE_DIR}/*) + file(COPY ${${pkg_name}_SOURCE_SUBDIRS} DESTINATION ${${pkg_name}_BASE_DIR}) + set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/${PKG_HEAD_ONLY} PARENT_SCOPE) + add_library(${pkg_name} INTERFACE) + target_include_directories(${pkg_name} INTERFACE ${${pkg_name}_INC}) + + elseif (PKG_CMAKE_OPTION) + # in cmake + file(MAKE_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) + if (${pkg_name}_CFLAGS) + set(${pkg_name}_CMAKE_CFLAGS "-DCMAKE_C_FLAGS=${${pkg_name}_CFLAGS}") + endif () + if (${pkg_name}_CXXFLAGS) + set(${pkg_name}_CMAKE_CXXFLAGS "-DCMAKE_CXX_FLAGS=${${pkg_name}_CXXFLAGS}") + endif () + + if (${pkg_name}_LDFLAGS) + if (${pkg_name}_USE_STATIC_LIBS) + #set(${pkg_name}_CMAKE_LDFLAGS "-DCMAKE_STATIC_LINKER_FLAGS=${${pkg_name}_LDFLAGS}") + else() + set(${pkg_name}_CMAKE_LDFLAGS "-DCMAKE_SHARED_LINKER_FLAGS=${${pkg_name}_LDFLAGS}") + endif () + endif () + + __exec_cmd(COMMAND ${CMAKE_COMMAND} ${PKG_CMAKE_OPTION} -G ${CMAKE_GENERATOR} + ${${pkg_name}_CMAKE_CFLAGS} ${${pkg_name}_CMAKE_CXXFLAGS} ${${pkg_name}_CMAKE_LDFLAGS} + -DCMAKE_INSTALL_PREFIX=${${pkg_name}_BASE_DIR} ${PKG_CMAKE_PATH} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) + + __exec_cmd(COMMAND ${CMAKE_COMMAND} --build . --target install -- -j${THNUM} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) + + else() + if (${pkg_name}_CFLAGS) + set(${pkg_name}_MAKE_CFLAGS "CFLAGS=${${pkg_name}_CFLAGS}") + endif () + if (${pkg_name}_CXXFLAGS) + set(${pkg_name}_MAKE_CXXFLAGS "CXXFLAGS=${${pkg_name}_CXXFLAGS}") + endif () + if (${pkg_name}_LDFLAGS) + set(${pkg_name}_MAKE_LDFLAGS "LDFLAGS=${${pkg_name}_LDFLAGS}") + endif () + # in configure && make + if (PKG_PRE_CONFIGURE_COMMAND) + __exec_cmd(COMMAND ${PKG_PRE_CONFIGURE_COMMAND} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) + endif () + + if (PKG_CONFIGURE_COMMAND) + __exec_cmd(COMMAND ${PKG_CONFIGURE_COMMAND} + ${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS} + --prefix=${${pkg_name}_BASE_DIR} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) + endif () + set(${pkg_name}_BUILD_OPTION ${PKG_BUILD_OPTION}) + if (NOT PKG_CONFIGURE_COMMAND) + set(${pkg_name}_BUILD_OPTION ${${pkg_name}_BUILD_OPTION} + ${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS}) + endif () + # build + __exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} ${${pkg_name}_BUILD_OPTION} -j${THNUM} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) + + if (PKG_INSTALL_INCS OR PKG_INSTALL_LIBS) + file(GLOB ${pkg_name}_INSTALL_INCS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_INCS}) + file(GLOB ${pkg_name}_INSTALL_LIBS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_LIBS}) + file(COPY ${${pkg_name}_INSTALL_INCS} DESTINATION ${${pkg_name}_BASE_DIR}/include) + file(COPY ${${pkg_name}_INSTALL_LIBS} DESTINATION ${${pkg_name}_BASE_DIR}/lib) + else() + __exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} install WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) + endif () + endif () + endif() + + if (PKG_LIBS) + __find_pkg_then_add_target(${pkg_name} ${PKG_EXE} ${PKG_LIBS}) + set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) + if(NOT ${pkg_name}_LIBS) + message(FATAL_ERROR "Can not find pkg: ${pkg_name}") + endif() + else() + find_package(${__FIND_PKG_NAME} ${PKG_VER} QUIET) + if (${__FIND_PKG_NAME}_FOUND) + set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) + message("Found pkg: ${${__FIND_PKG_NAME}_LIBRARIES}") + return() + endif () + endif () +endfunction() diff --git a/cmake/intf_pub_android.cmake b/cmake/intf_pub_android.cmake deleted file mode 100755 index 153d5764..00000000 --- a/cmake/intf_pub_android.cmake +++ /dev/null @@ -1,52 +0,0 @@ - -add_library(intf_pub INTERFACE) - -target_compile_options(intf_pub INTERFACE - -Wall - -fPIC - -fstack-protector-strong -) -target_compile_definitions(intf_pub INTERFACE - $<$:_GLIBCXX_USE_CXX11_ABI=0> - $<$:CFG_BUILD_NDEBUG> - $<$:CFG_BUILD_DEBUG> - WIN64=1 - LINUX=0 -) -target_link_options(intf_pub INTERFACE - -Wl,-z,relro - -Wl,-z,now - -Wl,-z,noexecstack - $<$:-Wl,--build-id=none> -) -target_link_directories(intf_pub INTERFACE -) - -add_library(intf_ccec INTERFACE) -target_compile_options(intf_ccec INTERFACE - -mcpu=cortex-a73 - --target=aarch64-linux-android29 - --sysroot=${HCC_PATH}/../sysroot - -L${HCC_PATH}/../lib/gcc/aarch64-linux-android/4.9.x - -Wall - -fPIC - -fstack-protector-strong -) -target_compile_definitions(intf_ccec INTERFACE - $<$:_GLIBCXX_USE_CXX11_ABI=0> - $<$:CFG_BUILD_NDEBUG> - $<$:CFG_BUILD_DEBUG> -) - -target_link_options(intf_ccec INTERFACE - -mcpu=cortex-a73 - --target=aarch64-linux-android29 - --sysroot=${HCC_PATH}/../sysroot - -L${HCC_PATH}/../lib/gcc/aarch64-linux-android/4.9.x - -Wl,-cce-host-android - -Wl,-z,relro - -Wl,-z,now - -Wl,-z,noexecstack - $<$:-Wl,--build-id=none> -) - diff --git a/cmake/intf_pub_linux.cmake b/cmake/intf_pub_linux.cmake deleted file mode 100755 index 40c6bca9..00000000 --- a/cmake/intf_pub_linux.cmake +++ /dev/null @@ -1,33 +0,0 @@ -if (HAVE_PUB) - return() -endif() - -add_library(intf_pub INTERFACE) - -target_compile_options(intf_pub INTERFACE - -Wall - -fPIC - $,-fstack-protector-all,-fstack-protector-strong> - $<$:-std=c++11> -) -target_compile_definitions(intf_pub INTERFACE - _GLIBCXX_USE_CXX11_ABI=0 - $<$:CFG_BUILD_NDEBUG> - $<$:CFG_BUILD_DEBUG> - WIN64=1 - LINUX=0 -) -target_link_options(intf_pub INTERFACE - -Wl,-z,relro - -Wl,-z,now - -Wl,-z,noexecstack - $<$:-Wl,--build-id=none> -) -target_link_directories(intf_pub INTERFACE -) -target_link_libraries(intf_pub INTERFACE - -lpthread -) - -#set(HAVE_PUB TRUE CACHE BOOL "pub add") -set(HAVE_PUB TRUE) diff --git a/cmake/intf_pub_windows.cmake b/cmake/intf_pub_windows.cmake deleted file mode 100755 index 19e37283..00000000 --- a/cmake/intf_pub_windows.cmake +++ /dev/null @@ -1,24 +0,0 @@ - -add_library(intf_pub INTERFACE) - -target_compile_options(intf_pub INTERFACE - -Wall - -fPIC - $,-fstack-protector-all,-fstack-protector-strong> - $<$:-std=c++11> -) -target_compile_definitions(intf_pub INTERFACE - $<$:_GLIBCXX_USE_CXX11_ABI=0> - OS_TYPE=WIN64 - WIN64=1 - LINUX=0 - $<$:CFG_BUILD_NDEBUG> - $<$:CFG_BUILD_DEBUG> -) -target_link_options(intf_pub INTERFACE - $<$:-Wl,--build-id=none> -) -target_link_directories(intf_pub INTERFACE -) -target_link_libraries(intf_pub INTERFACE -) diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt deleted file mode 100755 index 4f162fd3..00000000 --- a/ge/CMakeLists.txt +++ /dev/null @@ -1,910 +0,0 @@ -add_subdirectory(common) -add_subdirectory(plugin/engine) -add_subdirectory(graph/build/memory) -add_subdirectory(ge_local_engine) -add_subdirectory(host_cpu_engine) -add_subdirectory(executor) -add_subdirectory(offline) - -set(PROTO_LIST - "${METADEF_DIR}/proto/fusion_model.proto" - "${GE_CODE_DIR}/ge/proto/optimizer_priority.proto" -) - -set(PROTO_CLIENT_LIST - "${METADEF_DIR}/proto/ge_api.proto" -) - -set(PROTO_HEADER_LIST - "${METADEF_DIR}/proto/om.proto" - "${METADEF_DIR}/proto/task.proto" - "${METADEF_DIR}/proto/insert_op.proto" - "${METADEF_DIR}/proto/ge_ir.proto" - "${METADEF_DIR}/proto/fwk_adapter.proto" - "${METADEF_DIR}/proto/op_mapping_info.proto" -) - -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) -protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) - -############ libge_runner.so ############ -set(TRAIN_SRC_LIST - "common/formats/format_transfers/datatype_transfer.cc" - "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" - "common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" - "common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" - "common/formats/format_transfers/format_transfer_fractal_nz.cc" - "common/formats/format_transfers/format_transfer_fractal_z.cc" - "common/formats/format_transfers/format_transfer_fractal_zz.cc" - "common/formats/format_transfers/format_transfer_fracz_hwcn.cc" - "common/formats/format_transfers/format_transfer_fracz_nchw.cc" - "common/formats/format_transfers/format_transfer_fracz_nhwc.cc" - "common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" - "common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" - "common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" - "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" - "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" - "common/formats/format_transfers/format_transfer_transpose.cc" - "common/formats/formats.cc" - "common/formats/utils/formats_trans_utils.cc" - "common/fp16_t.cc" - "common/ge/plugin_manager.cc" - "common/ge/op_tiling_manager.cc" - "common/helper/model_cache_helper.cc" - "common/profiling/profiling_manager.cc" - "common/dump/dump_manager.cc" - "common/dump/dump_properties.cc" - "common/dump/dump_op.cc" - "engine_manager/dnnengine_manager.cc" - "ge_local_engine/engine/host_cpu_engine.cc" - "generator/ge_generator.cc" - "generator/generator_api.cc" - "graph/build/graph_builder.cc" - "graph/build/label_allocator.cc" - "graph/build/logical_stream_allocator.cc" - "graph/build/model_builder.cc" - "graph/build/run_context.cc" - "graph/build/stream_allocator.cc" - "graph/build/stream_graph_optimizer.cc" - "graph/build/task_generator.cc" - "graph/common/bcast.cc" - "graph/common/local_context.cc" - "graph/common/omg_util.cc" - "graph/common/transop_util.cc" - "graph/execute/graph_execute.cc" - "graph/label/case_label_maker.cc" - "graph/label/if_label_maker.cc" - "graph/label/label_maker.cc" - "graph/label/partitioned_call_label_maker.cc" - "graph/label/while_label_maker.cc" - "graph/load/graph_loader.cc" - "graph/load/new_model_manager/cpu_queue_schedule.cc" - "graph/load/new_model_manager/data_dumper.cc" - "graph/load/new_model_manager/data_inputer.cc" - "graph/load/new_model_manager/davinci_model.cc" - "graph/load/new_model_manager/davinci_model_parser.cc" - "graph/load/new_model_manager/model_manager.cc" - "graph/load/new_model_manager/model_utils.cc" - "graph/load/new_model_manager/aipp_utils.cc" - "graph/load/new_model_manager/task_info/end_graph_task_info.cc" - "graph/load/new_model_manager/task_info/model_exit_task_info.cc" - "graph/load/new_model_manager/task_info/event_record_task_info.cc" - "graph/load/new_model_manager/task_info/event_wait_task_info.cc" - "graph/load/new_model_manager/task_info/fusion_start_task_info.cc" - "graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" - "graph/load/new_model_manager/task_info/hccl_task_info.cc" - "graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" - "graph/load/new_model_manager/task_info/kernel_task_info.cc" - "graph/load/new_model_manager/task_info/label_set_task_info.cc" - "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" - "graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" - "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" - "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" - "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" - "graph/load/new_model_manager/task_info/stream_active_task_info.cc" - "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" - "graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" - "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" - "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" - "graph/load/new_model_manager/task_info/task_info.cc" - "graph/load/new_model_manager/tbe_handle_store.cc" - "graph/load/new_model_manager/zero_copy_task.cc" - "graph/load/new_model_manager/zero_copy_offset.cc" - "graph/manager/graph_context.cc" - "graph/manager/graph_manager.cc" - "graph/manager/graph_manager_utils.cc" - "graph/manager/graph_mem_allocator.cc" - "graph/manager/graph_caching_allocator.cc" - "graph/manager/graph_var_manager.cc" - "graph/manager/host_mem_manager.cc" - "graph/manager/rdma_pool_allocator.cc" - "graph/manager/memory_api.cc" - "graph/manager/model_manager/event_manager.cc" - "graph/manager/trans_var_data_utils.cc" - "graph/manager/util/debug.cc" - "graph/manager/util/hcom_util.cc" - "graph/manager/util/rt_context_util.cc" - "graph/manager/util/variable_accelerate_ctrl.cc" - "graph/optimize/graph_optimize.cc" - "graph/optimize/mem_rw_conflict_optimize.cc" - "graph/optimize/summary_optimize.cc" - "graph/partition/engine_place.cc" - "graph/partition/graph_partition.cc" - "graph/passes/addn_pass.cc" - "graph/passes/aicpu_constant_folding_pass.cc" - "graph/passes/assert_pass.cc" - "graph/passes/input_output_connection_identify_pass.cc" - "graph/passes/atomic_addr_clean_pass.cc" - "graph/passes/mark_same_addr_pass.cc" - "graph/passes/mark_graph_unknown_status_pass.cc" - "graph/passes/mark_agnostic_pass.cc" - "graph/partition/dynamic_shape_partition.cc" - "graph/partition/stage_partition.cc" - "graph/passes/base_pass.cc" - "graph/passes/bitcast_pass.cc" - "graph/passes/cast_remove_pass.cc" - "graph/passes/cast_translate_pass.cc" - "graph/passes/common_subexpression_elimination_pass.cc" - "graph/passes/transop_symmetry_elimination_pass.cc" - "graph/passes/compile_nodes_pass.cc" - "graph/passes/constant_folding_pass.cc" - "graph/passes/constant_fuse_same_pass.cc" - "graph/passes/control_trigger_pass.cc" - "graph/passes/dimension_adjust_pass.cc" - "graph/passes/dimension_compute_pass.cc" - "graph/passes/dropout_pass.cc" - "graph/passes/hccl_group_pass.cc" - "graph/passes/enter_pass.cc" - "graph/passes/assign_pass.cc" - "graph/passes/flow_ctrl_pass.cc" - "graph/passes/global_step_insert_pass.cc" - "host_kernels/transpose_kernel.cc" - "host_kernels/add_kernel.cc" - "host_kernels/broadcast_args_kernel.cc" - "host_kernels/broadcast_gradient_args_kernel.cc" - "host_kernels/cast_kernel.cc" - "host_kernels/concat_offset_kernel.cc" - "host_kernels/concat_v2_kernel.cc" - "host_kernels/dynamic_stitch_kernel.cc" - "host_kernels/identity_kernel.cc" - "host_kernels/empty_kernel.cc" - "host_kernels/expanddims_kernel.cc" - "host_kernels/fill_kernel.cc" - "host_kernels/floordiv_kernel.cc" - "host_kernels/floormod_kernel.cc" - "host_kernels/gather_v2_kernel.cc" - "host_kernels/greater_kernel.cc" - "host_kernels/kernel_utils.cc" - "host_kernels/maximum_kernel.cc" - "host_kernels/mul_kernel.cc" - "host_kernels/pack_kernel.cc" - "host_kernels/permute_kernel.cc" - "host_kernels/range_kernel.cc" - "host_kernels/rank_kernel.cc" - "host_kernels/reduce_prod_kernel.cc" - "host_kernels/reshape_kernel.cc" - "host_kernels/rsqrt_kernel.cc" - "host_kernels/shape_kernel.cc" - "host_kernels/shape_n_kernel.cc" - "host_kernels/size_kernel.cc" - "host_kernels/slice_d_kernel.cc" - "host_kernels/slice_kernel.cc" - "host_kernels/squeeze_kernel.cc" - "host_kernels/unsqueeze_kernel.cc" - "host_kernels/ssd_prior_box_kernel.cc" - "host_kernels/strided_slice_kernel.cc" - "host_kernels/sub_kernel.cc" - "host_kernels/transdata_kernel.cc" - "host_kernels/unpack_kernel.cc" - "graph/passes/folding_pass.cc" - "graph/passes/get_original_format_pass.cc" - "graph/passes/guarantee_const_pass.cc" - "graph/passes/hccl_memcpy_pass.cc" - "graph/passes/identity_pass.cc" - "graph/passes/ref_identity_delete_op_pass.cc" - "graph/passes/infershape_pass.cc" - "graph/passes/isolated_op_remove_pass.cc" - "graph/passes/iterator_op_pass.cc" - "graph/passes/link_gen_mask_nodes_pass.cc" - "graph/passes/merge_pass.cc" - "graph/passes/multi_batch_pass.cc" - "graph/passes/multi_batch_clone_pass.cc" - "graph/passes/subexpression_migration_pass.cc" - "graph/passes/subgraph_const_migration_pass.cc" - "graph/passes/unused_args_clean_pass.cc" - "graph/passes/net_output_pass.cc" - "graph/passes/next_iteration_pass.cc" - "graph/passes/no_use_reshape_remove_pass.cc" - "graph/passes/pass_manager.cc" - "graph/passes/pass_utils.cc" - "graph/passes/permute_pass.cc" - "graph/passes/placeholder_with_default_pass.cc" - "graph/passes/prevent_gradient_pass.cc" - "graph/passes/print_op_pass.cc" - "graph/passes/prune_pass.cc" - "graph/passes/ctrl_edge_transfer_pass.cc" - "graph/passes/replace_with_empty_const_pass.cc" - "graph/passes/reshape_remove_pass.cc" - "graph/passes/reshape_recovery_pass.cc" - "graph/passes/resource_pair_add_control_pass.cc" - "graph/passes/resource_pair_remove_control_pass.cc" - "graph/passes/same_transdata_breadth_fusion_pass.cc" - "graph/passes/save_pass.cc" - "graph/passes/shape_operate_op_remove_pass.cc" - "graph/passes/snapshot_pass.cc" - "graph/passes/stop_gradient_pass.cc" - "graph/passes/subgraph_pass.cc" - "graph/passes/data_pass.cc" - "graph/passes/switch_data_edges_bypass.cc" - "graph/passes/switch_logic_remove_pass.cc" - "graph/passes/merge_to_stream_merge_pass.cc" - "graph/passes/switch_to_stream_switch_pass.cc" - "graph/passes/attach_stream_label_pass.cc" - "graph/passes/switch_dead_branch_elimination.cc" - "graph/passes/replace_transshape_pass.cc" - "graph/passes/transop_breadth_fusion_pass.cc" - "graph/passes/transop_depth_fusion_pass.cc" - "graph/passes/transop_nearby_allreduce_fusion_pass.cc" - "graph/passes/transop_without_reshape_fusion_pass.cc" - "graph/passes/transpose_transdata_pass.cc" - "graph/passes/unused_const_pass.cc" - "graph/passes/unused_op_remove_pass.cc" - "graph/passes/var_is_initialized_op_pass.cc" - "graph/passes/parallel_concat_start_op_pass.cc" - "graph/passes/cond_pass.cc" - "graph/passes/cond_remove_pass.cc" - "graph/passes/for_pass.cc" - "graph/passes/variable_format_pass.cc" - "graph/passes/variable_op_pass.cc" - "graph/passes/variable_prepare_op_pass.cc" - "graph/passes/variable_ref_delete_op_pass.cc" - "graph/passes/variable_ref_useless_control_out_delete_pass.cc" - "graph/passes/end_of_sequence_add_control_pass.cc" - "graph/passes/memcpy_addr_async_pass.cc" - "graph/passes/set_input_output_offset_pass.cc" - "graph/preprocess/graph_preprocess.cc" - "graph/preprocess/insert_op/ge_aipp_op.cc" - "graph/preprocess/insert_op/util_insert_aipp_op.cc" - "graph/preprocess/multi_batch_options.cc" - "graph/preprocess/multi_batch_copy_graph.cc" - "init/gelib.cc" - "model/ge_model.cc" - "model/ge_root_model.cc" - "omm/csa_interact.cc" - "opskernel_manager/ops_kernel_manager.cc" - "opskernel_manager/ops_kernel_builder_manager.cc" - "session/inner_session.cc" - "session/session_manager.cc" - "single_op/single_op.cc" - "single_op/single_op_manager.cc" - "single_op/single_op_model.cc" - "single_op/stream_resource.cc" - "single_op/task/build_task_utils.cc" - "single_op/task/op_task.cc" - "single_op/task/tbe_task_builder.cc" - "single_op/task/aicpu_task_builder.cc" - "single_op/task/aicpu_kernel_task_builder.cc" - "hybrid/common/tensor_value.cc" - "hybrid/common/npu_memory_allocator.cc" - "hybrid/executor/rt_callback_manager.cc" - "hybrid/executor/node_state.cc" - "hybrid/executor/node_done_manager.cc" - "hybrid/executor/hybrid_profiler.cc" - "hybrid/executor/hybrid_model_executor.cc" - "hybrid/executor/hybrid_model_async_executor.cc" - "hybrid/executor/hybrid_execution_context.cc" - "hybrid/executor/subgraph_context.cc" - "hybrid/executor/subgraph_executor.cc" - "hybrid/executor/worker/task_compile_engine.cc" - "hybrid/executor/worker/shape_inference_engine.cc" - "hybrid/executor/worker/execution_engine.cc" - "hybrid/model/hybrid_model.cc" - "hybrid/model/hybrid_model_builder.cc" - "hybrid/model/node_item.cc" - "hybrid/model/graph_item.cc" - "hybrid/node_executor/aicore/aicore_node_executor.cc" - "hybrid/node_executor/aicore/aicore_op_task.cc" - "hybrid/node_executor/aicore/aicore_task_builder.cc" - "hybrid/node_executor/aicore/aicore_task_compiler.cc" - "hybrid/node_executor/aicpu/aicpu_ext_info.cc" - "hybrid/node_executor/aicpu/aicpu_node_executor.cc" - "hybrid/node_executor/compiledsubgraph/known_node_executor.cc" - "hybrid/node_executor/ge_local/ge_local_node_executor.cc" - "hybrid/node_executor/host_cpu/host_cpu_node_executor.cc" - "hybrid/node_executor/host_cpu/kernel_factory.cc" - "hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc" - "hybrid/node_executor/host_cpu/kernel/variable_kernel.cc" - "hybrid/node_executor/host_cpu/kernel/assign_kernel.cc" - "hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc" - "hybrid/node_executor/controlop/control_op_executor.cc" - "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" - "hybrid/node_executor/hccl/hccl_node_executor.cc" - "hybrid/node_executor/rts/rts_node_executor.cc" - "hybrid/node_executor/node_executor.cc" - "hybrid/node_executor/task_context.cc" - "hybrid/hybrid_davinci_model.cc" - "executor/ge_executor.cc" - "client/ge_api.cc" - "client/ge_prof.cc" - "analyzer/analyzer.cc" -) - -add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) - -target_compile_definitions(ge_runner PRIVATE - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - DAVINCI_SUPPORT_PROFILING - REUSE_MEMORY=1 - FMK_SUPPORT_DUMP - DAVINCI_CLOUD -) - -target_compile_options(ge_runner PRIVATE - -O2 -) - -target_include_directories(ge_runner PRIVATE - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/ge/analyzer - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${GE_CODE_DIR}/inc/framework/common - ${METADEF_DIR} - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - ${GE_CODE_DIR}/../inc/external - ${GE_CODE_DIR}/../inc/cce - ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external - #### blue zone - ${ASCEND_DIR}/driver/include - ${ASCEND_DIR}/fwkacllib/include - ${GE_CODE_DIR}/third_party/fwkacllib/inc - ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain -) - -target_link_libraries(ge_runner - $ - ge_memory - adump_server - msprofiler - -Wl,--no-as-needed - graph - ge_common - protobuf - register - c_sec - slog - mmpa - msprof - runtime - resource - error_manager - ascend_hal_stub - -Wl,--as-needed - json - -lrt - -ldl -) - -############ libge_compiler.so ############ -set(INFER_SRC_LIST - "graph/manager/trans_var_data_utils.cc" - "omm/csa_interact.cc" - "common/fp16_t.cc" - "common/formats/utils/formats_trans_utils.cc" - "common/formats/format_transfers/datatype_transfer.cc" - "common/formats/format_transfers/format_transfer_transpose.cc" - "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" - "common/formats/format_transfers/format_transfer_fractal_z.cc" - "common/formats/format_transfers/format_transfer_fractal_nz.cc" - "common/formats/format_transfers/format_transfer_fractal_zz.cc" - "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" - "common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" - "common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" - "common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" - "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" - "common/formats/format_transfers/format_transfer_fracz_nchw.cc" - "common/formats/format_transfers/format_transfer_fracz_nhwc.cc" - "common/formats/format_transfers/format_transfer_fracz_hwcn.cc" - "common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" - "common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" - "common/formats/format_transfers/format_transfer_nchw_fz_c04.cc" - "common/formats/formats.cc" - "common/profiling/profiling_manager.cc" - "common/dump/dump_properties.cc" - "common/dump/dump_manager.cc" - "common/dump/dump_op.cc" - "common/dump/dump_server.cc" - "common/helper/model_cache_helper.cc" - "ge_local_engine/engine/host_cpu_engine.cc" - "common/ge/plugin_manager.cc" - "common/ge/op_tiling_manager.cc" - "init/gelib.cc" - "session/inner_session.cc" - "session/session_manager.cc" - "engine_manager/dnnengine_manager.cc" - "opskernel_manager/ops_kernel_manager.cc" - "opskernel_manager/ops_kernel_builder_manager.cc" - "graph/manager/graph_manager.cc" - "graph/manager/graph_manager_utils.cc" - "graph/manager/graph_context.cc" - "graph/preprocess/graph_preprocess.cc" - "graph/preprocess/multi_batch_options.cc" - "graph/preprocess/multi_batch_copy_graph.cc" - "graph/execute/graph_execute.cc" - "graph/load/graph_loader.cc" - "graph/optimize/graph_optimize.cc" - "graph/optimize/mem_rw_conflict_optimize.cc" - "graph/optimize/summary_optimize.cc" - "graph/build/graph_builder.cc" - "graph/partition/engine_place.cc" - "graph/partition/graph_partition.cc" - "graph/partition/dynamic_shape_partition.cc" - "graph/partition/stage_partition.cc" - "generator/ge_generator.cc" - "generator/generator_api.cc" - "graph/manager/graph_var_manager.cc" - "graph/manager/host_mem_manager.cc" - "graph/manager/rdma_pool_allocator.cc" - "graph/manager/graph_mem_allocator.cc" - "graph/manager/graph_caching_allocator.cc" - "model/ge_model.cc" - "model/ge_root_model.cc" - "graph/common/transop_util.cc" - "graph/passes/pass_manager.cc" - "graph/passes/resource_pair_add_control_pass.cc" - "graph/passes/resource_pair_remove_control_pass.cc" - "graph/passes/pass_utils.cc" - "graph/passes/base_pass.cc" - "graph/passes/bitcast_pass.cc" - "graph/passes/constant_folding_pass.cc" - "graph/passes/aicpu_constant_folding_pass.cc" - "graph/passes/reshape_remove_pass.cc" - "graph/passes/reshape_recovery_pass.cc" - "graph/passes/transop_breadth_fusion_pass.cc" - "graph/passes/transop_depth_fusion_pass.cc" - "graph/passes/transop_nearby_allreduce_fusion_pass.cc" - "graph/passes/same_transdata_breadth_fusion_pass.cc" - "graph/passes/transop_without_reshape_fusion_pass.cc" - "graph/passes/compile_nodes_pass.cc" - "graph/passes/variable_prepare_op_pass.cc" - "graph/passes/variable_ref_delete_op_pass.cc" - "graph/passes/variable_ref_useless_control_out_delete_pass.cc" - "graph/passes/subgraph_pass.cc" - "graph/passes/data_pass.cc" - "graph/passes/net_output_pass.cc" - "graph/passes/replace_transshape_pass.cc" - "graph/passes/constant_fuse_same_pass.cc" - "graph/passes/print_op_pass.cc" - "graph/passes/no_use_reshape_remove_pass.cc" - "graph/passes/iterator_op_pass.cc" - "graph/passes/input_output_connection_identify_pass.cc" - "graph/passes/atomic_addr_clean_pass.cc" - "graph/passes/mark_same_addr_pass.cc" - "graph/passes/mark_graph_unknown_status_pass.cc" - "graph/passes/mark_agnostic_pass.cc" - "graph/common/omg_util.cc" - "graph/common/bcast.cc" - "graph/common/local_context.cc" - "graph/passes/dimension_compute_pass.cc" - "graph/passes/dimension_adjust_pass.cc" - "graph/passes/get_original_format_pass.cc" - "graph/passes/shape_operate_op_remove_pass.cc" - "graph/passes/unused_op_remove_pass.cc" - "graph/passes/assert_pass.cc" - "graph/passes/dropout_pass.cc" - "graph/passes/infershape_pass.cc" - "graph/passes/unused_const_pass.cc" - "graph/passes/isolated_op_remove_pass.cc" - "graph/passes/permute_pass.cc" - "graph/passes/ctrl_edge_transfer_pass.cc" - "graph/passes/end_of_sequence_add_control_pass.cc" - "host_kernels/broadcast_gradient_args_kernel.cc" - "host_kernels/greater_kernel.cc" - "host_kernels/gather_v2_kernel.cc" - "host_kernels/maximum_kernel.cc" - "host_kernels/floormod_kernel.cc" - "host_kernels/floordiv_kernel.cc" - "host_kernels/range_kernel.cc" - "host_kernels/shape_kernel.cc" - "host_kernels/size_kernel.cc" - "host_kernels/shape_n_kernel.cc" - "host_kernels/rank_kernel.cc" - "host_kernels/broadcast_args_kernel.cc" - "host_kernels/fill_kernel.cc" - "host_kernels/empty_kernel.cc" - "host_kernels/expanddims_kernel.cc" - "host_kernels/reshape_kernel.cc" - "host_kernels/squeeze_kernel.cc" - "host_kernels/unsqueeze_kernel.cc" - "host_kernels/kernel_utils.cc" - "host_kernels/cast_kernel.cc" - "host_kernels/transdata_kernel.cc" - "host_kernels/unpack_kernel.cc" - "host_kernels/transpose_kernel.cc" - "host_kernels/permute_kernel.cc" - "host_kernels/pack_kernel.cc" - "host_kernels/concat_v2_kernel.cc" - "host_kernels/concat_offset_kernel.cc" - "host_kernels/strided_slice_kernel.cc" - "host_kernels/ssd_prior_box_kernel.cc" - "host_kernels/add_kernel.cc" - "host_kernels/sub_kernel.cc" - "host_kernels/mul_kernel.cc" - "host_kernels/reduce_prod_kernel.cc" - "host_kernels/rsqrt_kernel.cc" - "host_kernels/slice_kernel.cc" - "host_kernels/slice_d_kernel.cc" - "host_kernels/dynamic_stitch_kernel.cc" - "host_kernels/identity_kernel.cc" - "graph/passes/stop_gradient_pass.cc" - "graph/passes/prevent_gradient_pass.cc" - "graph/passes/identity_pass.cc" - "graph/passes/ref_identity_delete_op_pass.cc" - "graph/passes/placeholder_with_default_pass.cc" - "graph/passes/snapshot_pass.cc" - "graph/passes/guarantee_const_pass.cc" - "graph/passes/var_is_initialized_op_pass.cc" - "graph/passes/parallel_concat_start_op_pass.cc" - "graph/passes/folding_pass.cc" - "graph/passes/cast_translate_pass.cc" - "graph/passes/prune_pass.cc" - "graph/passes/merge_to_stream_merge_pass.cc" - "graph/passes/switch_to_stream_switch_pass.cc" - "graph/passes/attach_stream_label_pass.cc" - "graph/passes/multi_batch_pass.cc" - "graph/passes/multi_batch_clone_pass.cc" - "graph/passes/subexpression_migration_pass.cc" - "graph/passes/subgraph_const_migration_pass.cc" - "graph/passes/unused_args_clean_pass.cc" - "graph/passes/next_iteration_pass.cc" - "graph/passes/control_trigger_pass.cc" - "graph/passes/cond_pass.cc" - "graph/passes/cond_remove_pass.cc" - "graph/passes/for_pass.cc" - "graph/passes/enter_pass.cc" - "graph/passes/assign_pass.cc" - "graph/passes/addn_pass.cc" - "graph/passes/common_subexpression_elimination_pass.cc" - "graph/passes/transop_symmetry_elimination_pass.cc" - "graph/passes/save_pass.cc" - "graph/passes/switch_dead_branch_elimination.cc" - "graph/passes/switch_logic_remove_pass.cc" - "graph/passes/switch_data_edges_bypass.cc" - "graph/passes/merge_pass.cc" - "graph/passes/variable_format_pass.cc" - "graph/passes/variable_op_pass.cc" - "graph/passes/cast_remove_pass.cc" - "graph/passes/transpose_transdata_pass.cc" - "graph/passes/hccl_memcpy_pass.cc" - "graph/passes/flow_ctrl_pass.cc" - "graph/passes/global_step_insert_pass.cc" - "graph/passes/link_gen_mask_nodes_pass.cc" - "graph/passes/replace_with_empty_const_pass.cc" - "graph/passes/hccl_group_pass.cc" - "graph/passes/memcpy_addr_async_pass.cc" - "graph/passes/set_input_output_offset_pass.cc" - "graph/manager/model_manager/event_manager.cc" - "graph/manager/util/rt_context_util.cc" - "graph/manager/util/variable_accelerate_ctrl.cc" - "graph/manager/util/debug.cc" - "graph/load/new_model_manager/model_manager.cc" - "graph/load/new_model_manager/data_inputer.cc" - "graph/load/new_model_manager/davinci_model.cc" - "graph/load/new_model_manager/davinci_model_parser.cc" - "graph/load/new_model_manager/model_utils.cc" - "graph/load/new_model_manager/aipp_utils.cc" - "graph/load/new_model_manager/tbe_handle_store.cc" - "graph/load/new_model_manager/cpu_queue_schedule.cc" - "graph/load/new_model_manager/zero_copy_task.cc" - "graph/load/new_model_manager/zero_copy_offset.cc" - "graph/load/new_model_manager/data_dumper.cc" - "graph/load/new_model_manager/task_info/task_info.cc" - "graph/load/new_model_manager/task_info/event_record_task_info.cc" - "graph/load/new_model_manager/task_info/event_wait_task_info.cc" - "graph/load/new_model_manager/task_info/fusion_start_task_info.cc" - "graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" - "graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" - "graph/load/new_model_manager/task_info/kernel_task_info.cc" - "graph/load/new_model_manager/task_info/label_set_task_info.cc" - "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" - "graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" - "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" - "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" - "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" - "graph/load/new_model_manager/task_info/stream_active_task_info.cc" - "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" - "graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" - "graph/load/new_model_manager/task_info/end_graph_task_info.cc" - "graph/load/new_model_manager/task_info/model_exit_task_info.cc" - "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" - "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" - "single_op/task/op_task.cc" - "single_op/task/build_task_utils.cc" - "single_op/task/tbe_task_builder.cc" - "single_op/task/aicpu_task_builder.cc" - "single_op/task/aicpu_kernel_task_builder.cc" - "single_op/single_op.cc" - "single_op/single_op_model.cc" - "single_op/stream_resource.cc" - "single_op/single_op_manager.cc" - "hybrid/hybrid_davinci_model_stub.cc" - "ir_build/ge_ir_build.cc" - "ir_build/atc_ir_common.cc" - "graph/preprocess/insert_op/ge_aipp_op.cc" - "graph/preprocess/insert_op/util_insert_aipp_op.cc" - "hybrid/node_executor/aicpu/aicpu_ext_info.cc" - "graph/build/model_builder.cc" - "graph/build/task_generator.cc" - "graph/build/stream_allocator.cc" - "graph/build/logical_stream_allocator.cc" - "graph/build/stream_graph_optimizer.cc" - "graph/build/run_context.cc" - "graph/build/label_allocator.cc" - "graph/label/label_maker.cc" - "graph/label/if_label_maker.cc" - "graph/label/case_label_maker.cc" - "graph/label/while_label_maker.cc" - "graph/label/partitioned_call_label_maker.cc" - "analyzer/analyzer.cc" -) - -add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) - -target_compile_definitions(ge_compiler PRIVATE - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - REUSE_MEMORY=1 - FMK_SUPPORT_DUMP - FMK_HOST_INFER - COMPILE_OMG_PACKAGE -) - -target_compile_options(ge_compiler PRIVATE - -O2 -) - -target_include_directories(ge_compiler PRIVATE - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/ge/analyzer - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${GE_CODE_DIR}/inc/framework/common - ${METADEF_DIR} - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - ${GE_CODE_DIR}/../inc/external - ${GE_CODE_DIR}/../inc/cce - ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external - #### blue zone #### - ${ASCEND_DIR}/driver/include - ${ASCEND_DIR}/fwkacllib/include - ${GE_CODE_DIR}/third_party/fwkacllib/inc - ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain -) - -target_link_libraries(ge_compiler - $ - ge_memory - -Wl,--no-as-needed - graph - ge_common - protobuf - register - c_sec - error_manager - slog - mmpa - runtime_compile - resource - -Wl,--as-needed - json - -lrt - -ldl -) - -############ libascendcl.so ############ -file(GENERATE OUTPUT ${CMAKE_BINARY_DIR}/dummy.c CONTENT "") -#add_library(dummy_obj OBJECT ${CMAKE_BINARY_DIR}/dummy.c) -#set(DUMMY_OBJ $) - -file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ascendcl_object) - -if(EXISTS ${STATIC_ACL_LIB}/libascendcl.a) - execute_process( - COMMAND ar x ${STATIC_ACL_LIB}/libascendcl.a - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ascendcl_object - ) - file(GLOB OBJECT_LIST ${CMAKE_CURRENT_BINARY_DIR}/ascendcl_object/*.o) -else() - set(OBJECT_LIST ${CMAKE_BINARY_DIR}/dummy.c) -endif() - -add_library(opensrc_ascendcl SHARED - ${OBJECT_LIST} -) -target_compile_options(opensrc_ascendcl PRIVATE - -O2 - -fvisibility=hidden -) -target_link_options(opensrc_ascendcl PRIVATE - -rdynamic - -Wl,--allow-multiple-definition - -Wl,-z,muldefs - -Wl,-Bsymbolic - -Wl,--exclude-libs,ALL -) -target_link_libraries(opensrc_ascendcl PRIVATE - -Wl,--whole-archive - ge_executor - ascendcl_static - ge_common_static - graph_static - protobuf_static - register_static - error_manager_static - adump_server - msprofiler - -Wl,--no-whole-archive - -Wl,--no-as-needed - c_sec - runtime - mmpa - slog - msprof - ascend_hal_stub - -Wl,--as-needed - -ldl - json -) - -set_target_properties(opensrc_ascendcl PROPERTIES - OUTPUT_NAME ascendcl -) - -################################################################## -add_custom_command( - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_ir_build.cc - ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_api.cc - ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_prof.cc - COMMAND echo "Generating stub files." - && ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/stub/gen_stubapi.py ${GE_CODE_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} - && mv ge_ir_build.cc stub_ge_ir_build.cc - && mv ge_api.cc stub_ge_api.cc - && mv ge_prof.cc stub_ge_prof.cc - && echo "Generating stub files end." - #WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} - #DEPENDS stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} -) - -add_custom_target(ge_stub - DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_ir_build.cc - ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_api.cc - ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_prof.cc -) - -################################################################## -############ stub/libge_compiler.so ############ -add_library(atc_stub_ge_compiler SHARED - stub_ge_ir_build.cc -) - -add_dependencies(atc_stub_ge_compiler ge_stub) - -target_link_libraries(atc_stub_ge_compiler PRIVATE - $ -) - -set_target_properties(atc_stub_ge_compiler PROPERTIES - OUTPUT_NAME ge_compiler - LIBRARY_OUTPUT_DIRECTORY atc_stub -) - -target_include_directories(atc_stub_ge_compiler PRIVATE - ${GE_CODE_DIR} - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/ge/analyzer - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/framework - ${GE_CODE_DIR}/inc/framework/common - ${GE_CODE_DIR}/inc/external - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph - #### yellow zone #### - ${GE_CODE_DIR}/../inc/cce - ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external - #### blue zone #### - ${ASCEND_DIR}/driver/include - ${ASCEND_DIR}/fwkacllib/include -) - -############ stub/libge_runner.so ############ -add_library(fwk_stub_ge_runner SHARED - stub_ge_api.cc - stub_ge_prof.cc -) - -add_dependencies(fwk_stub_ge_runner ge_stub) - -target_link_libraries(fwk_stub_ge_runner PRIVATE - $ -) - -set_target_properties(fwk_stub_ge_runner PROPERTIES - OUTPUT_NAME ge_runner - LIBRARY_OUTPUT_DIRECTORY fwk_stub -) - -target_include_directories(fwk_stub_ge_runner PRIVATE - ${GE_CODE_DIR} - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/ge/analyzer - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${GE_CODE_DIR}/inc/framework/common - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph - #### yellow zone #### - ${GE_CODE_DIR}/../inc/cce - ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external - #### blue zone #### - ${ASCEND_DIR}/driver/include - ${ASCEND_DIR}/fwkacllib/include -) - -############################################################### -add_custom_target( - engine_conf.json ALL - DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/engine_conf.json -) -add_custom_command( - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/engine_conf.json - COMMAND cp ${CMAKE_CURRENT_LIST_DIR}/engine_manager/engine_conf.json ${CMAKE_CURRENT_BINARY_DIR}/ -) - -############################################################### -add_custom_target( - optimizer_priority.pbtxt ALL - DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt -) -add_custom_command( - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt - COMMAND cp ${CMAKE_CURRENT_LIST_DIR}/opskernel_manager/optimizer_priority.pbtxt ${CMAKE_CURRENT_BINARY_DIR}/ -) - -############################################################### - -############ install ############ -set(INSTALL_BASE_DIR "") -set(INSTALL_LIBRARY_DIR lib) - -install(TARGETS ge_runner ge_compiler opensrc_ascendcl OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} -) - -install(TARGETS atc_stub_ge_compiler fwk_stub_ge_runner OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/stub -) - -install(FILES - ${CMAKE_CURRENT_BINARY_DIR}/engine_conf.json - ${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt OPTIONAL - DESTINATION ${INSTALL_LIBRARY_DIR} -) diff --git a/ge/README.md b/ge/README.md deleted file mode 100755 index e69de29b..00000000 diff --git a/ge/common/CMakeLists.txt b/ge/common/CMakeLists.txt deleted file mode 100755 index 685a6fe2..00000000 --- a/ge/common/CMakeLists.txt +++ /dev/null @@ -1,171 +0,0 @@ -set(PROTO_LIST - "${METADEF_DIR}/proto/om.proto" - "${METADEF_DIR}/proto/ge_ir.proto" - "${METADEF_DIR}/proto/insert_op.proto" - "${METADEF_DIR}/proto/task.proto" - "${METADEF_DIR}/proto/tensorflow/attr_value.proto" - "${METADEF_DIR}/proto/tensorflow/function.proto" - "${METADEF_DIR}/proto/tensorflow/graph.proto" - "${METADEF_DIR}/proto/tensorflow/node_def.proto" - "${METADEF_DIR}/proto/tensorflow/op_def.proto" - "${METADEF_DIR}/proto/tensorflow/resource_handle.proto" - "${METADEF_DIR}/proto/tensorflow/tensor.proto" - "${METADEF_DIR}/proto/tensorflow/tensor_shape.proto" - "${METADEF_DIR}/proto/tensorflow/types.proto" - "${METADEF_DIR}/proto/tensorflow/versions.proto" -) - -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) - -set(SRC_LIST - "context/ctx.cc" - "model_saver.cc" - "ge/datatype_util.cc" - "helper/om_file_helper.cc" - "helper/model_helper.cc" - "../model/ge_model.cc" - "auth/file_saver.cc" - "fp16_t.cc" - "math/fp16_math.cc" - "debug/memory_dumper.cc" - "formats/utils/formats_trans_utils.cc" - "dump/dump_properties.cc" - "formats/format_transfers/datatype_transfer.cc" - "formats/format_transfers/format_transfer_transpose.cc" - "formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" - "formats/format_transfers/format_transfer_fractal_z.cc" - "formats/format_transfers/format_transfer_fractal_nz.cc" - "formats/format_transfers/format_transfer_fractal_zz.cc" - "formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" - "formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" - "formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" - "formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" - "formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" - "formats/format_transfers/format_transfer_fracz_nchw.cc" - "formats/format_transfers/format_transfer_fracz_nhwc.cc" - "formats/format_transfers/format_transfer_fracz_hwcn.cc" - "formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" - "formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" - "formats/format_transfers/format_transfer_nchw_fz_c04.cc" - "formats/formats.cc" - "ge_format_util.cc" - "fmk_error_codes.cc" - "util.cc" - "properties_manager.cc" - "types.cc" - "model_parser/base.cc" - "kernel_store.cc" - "tbe_kernel_store.cc" - "cust_aicpu_kernel_store.cc" - "op/attr_value_util.cc" - "op/ge_op_utils.cc" - "thread_pool.cc" - "ge/tbe_plugin_manager.cc" -) - -############ libge_common.so ############ -add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) -target_compile_definitions(ge_common PRIVATE - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - HOST_VISIBILITY - FMK_SUPPORT_DUMP - OS_CENTOS -) - -target_compile_options(ge_common PRIVATE - -fvisibility=hidden - -O2 - -Werror -) - -target_include_directories(ge_common PRIVATE - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/ge/common - ${GE_CODE_DIR}/ge/common/op - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_DEPEND_DIR}/inc - ${GE_DEPEND_DIR}/inc/cce - #### blue zone #### - #${GE_DEPEND_DIR}/include - ${GE_CODE_DIR}/third_party/fwkacllib/inc - ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain -) - -target_link_libraries(ge_common PRIVATE - $ - -Wl,--no-as-needed - graph - protobuf - register - c_sec - error_manager - slog - mmpa - -Wl,--as-needed - json - -lrt - -ldl -) - -############ libge_common.a ############ -add_library(ge_common_static STATIC ${SRC_LIST} ${PROTO_HDRS}) -target_compile_definitions(ge_common_static PRIVATE - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - HOST_VISIBILITY - FMK_SUPPORT_DUMP - OS_CENTOS -) - -target_compile_options(ge_common_static PRIVATE - -fvisibility=hidden - -O2 - -Werror -) - -target_include_directories(ge_common_static PRIVATE - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/ge/common - ${GE_CODE_DIR}/ge/common/op - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_DEPEND_DIR}/inc - ${GE_DEPEND_DIR}/inc/cce - #### blue zone #### - #${GE_DEPEND_DIR}/include - ${GE_CODE_DIR}/third_party/fwkacllib/inc - ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain -) - -target_link_libraries(ge_common_static PRIVATE - $ - protobuf - json - c_sec - -lrt - -ldl -) - -############ install ############ -set(INSTALL_BASE_DIR "") -set(INSTALL_LIBRARY_DIR lib) - -install(TARGETS ge_common OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} -) diff --git a/ge/common/proto/ge_ir.proto b/ge/common/proto/ge_ir.proto deleted file mode 100644 index 87886c84..00000000 --- a/ge/common/proto/ge_ir.proto +++ /dev/null @@ -1,206 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package ge.proto; - -enum DataType -{ - DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. - DT_FLOAT = 1; // float type - DT_FLOAT16 = 2; // fp16 type - DT_INT8 = 3; // int8 type - DT_UINT8 = 4; // uint8 type - DT_INT16 = 5; // int16 type - DT_UINT16 = 6; // uint16 type - DT_INT32 = 7; // - DT_INT64 = 8; // int64 type - DT_UINT32 = 9; // unsigned int32 - DT_UINT64 = 10; // unsigned int64 - DT_BOOL = 11; // bool type - DT_DOUBLE = 12; // double type - DT_STRING = 13; // string type - DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ - DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ - DT_COMPLEX64 = 16; // complex64 type - DT_COMPLEX128 = 17; // complex128 type - DT_QINT8 = 18; // qint8 type - DT_QINT16 = 19; // qint16 type - DT_QINT32 = 20; // qint32 type - DT_QUINT8 = 21; // quint8 type - DT_QUINT16 = 22; // quint16 type - DT_RESOURCE = 23; // resource type - DT_STRING_REF = 24; // string_ref type - DT_DUAL = 25; /**< dual output type */ -} - -message AttrDef -{ - message ListValue - { - enum ListValueType{ - VT_LIST_NONE = 0; - VT_LIST_STRING = 1; - VT_LIST_INT = 2; - VT_LIST_FLOAT = 3; - VT_LIST_BOOL = 4; - VT_LIST_BYTES = 5; - VT_LIST_TENSOR_DESC = 6; - VT_LIST_TENSOR = 7; - VT_LIST_GRAPH = 8; - VT_LIST_NAMED_ATTRS = 9; - VT_LIST_DATA_TYPE = 10; - } - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3; // "list(int)" - repeated float f = 4; // "list(float)" - repeated bool b = 5; // "list(bool)" - repeated bytes bt = 7; - repeated TensorDescriptor td = 8; - repeated TensorDef t = 9; - repeated GraphDef g = 10; - repeated NamedAttrs na = 11; - repeated int64 dt = 12; // list ge::DataType - - ListValueType val_type = 20; - } - - message ListListInt{ - message ListInt{ - repeated int64 list_i = 1; // list int - } - repeated ListInt list_list_i = 1; // list list int - } - - oneof value - { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; // Used to support attr nesting - TensorDescriptor td = 11; // GeTensorDesc type - TensorDef t = 12; // GeTensor type - GraphDef g = 13; // Graph type - ListListInt list_list_int = 14; // List List Int type - int64 dt = 15; // ge::DataType - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs -{ - string name = 1; - map attr = 2; -} - -// Shape / dimension description, using row-major order -message ShapeDef -{ - repeated int64 dim = 1; // Size of each dimension -} - -// Multidimensional data description -message TensorDescriptor -{ - string name = 1; // Optional parameter, tensor name - - DataType dtype = 2; // tensor datatype - ShapeDef shape = 3; // Shape / dimension - string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" - - bool has_out_attr = 9; - int64 size = 10; - int64 weight_size = 11; - bool reuse_input = 12; - bool output_tensor = 13; - string device_type = 14; - bool input_tensor =15; - int64 real_dim_cnt = 16; - int64 reuse_input_index = 17; - int64 data_offset = 18; - int64 cmps_size = 19; - string cmps_tab = 20; - int64 cmps_tab_offset = 21; - - map attr = 5; // Set of extra parameter fields -} - -// GeTensor definition -message TensorDef -{ - TensorDescriptor desc = 1; // Tensor description - bytes data = 2; // Tensor data -} - - -// Operator description -message OpDef -{ - string name = 1; // name - string type = 2; // type - - repeated string input = 5; // input original op name + outgoing index. op_name:index - - map attr = 10; // Set of operator parameter fields - - bool has_out_attr = 20; - int64 id = 21; - int64 stream_id =22; - repeated string input_name = 23; - repeated string src_name = 24; - repeated int64 src_index = 25; - repeated string dst_name = 26; - repeated int64 dst_index = 27; - repeated int64 input_i = 28; - repeated int64 output_i = 29; - repeated int64 workspace = 30; - repeated int64 workspace_bytes = 31; - repeated bool is_input_const = 32; - repeated TensorDescriptor input_desc = 33; - repeated TensorDescriptor output_desc = 34; - repeated string subgraph_name = 35; -} - -// Graph definition -message GraphDef -{ - string name = 1; // name - - repeated string input = 4; // Graph input - repeated string output = 5; // Graph output - - repeated OpDef op = 6; // List of operators - - map attr = 11; // Extended field -} - -// model definition -message ModelDef -{ - string name = 1; // name - uint32 version = 2; // IR Proto verion - string custom_version = 3; // User model version number, passed in by user - - repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef - - map attr = 11; // Extended field -} - diff --git a/ge/common/proto/insert_op.proto b/ge/common/proto/insert_op.proto deleted file mode 100644 index a059e122..00000000 --- a/ge/common/proto/insert_op.proto +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package domi; - -message InsertNewOps { - repeated AippOpParams aipp_op = 1; - repeated MultiShapeOpParams multi_shape_op = 2; -} - -message AippOpParams { - enum InputFormat { - UNDEFINED = 0; - YUV420SP_U8 = 1; - XRGB8888_U8 = 2; - RGB888_U8 = 3; - YUV400_U8 = 4; - NC1HWC0DI_FP16 = 5; - NC1HWC0DI_S8 = 6; - ARGB8888_U8 = 7; - YUYV_U8 = 8; - YUV422SP_U8 = 9; - AYUV444_U8 = 10; - RAW10 = 11; - RAW12 = 12; - RAW16 = 13; - RAW24 = 14; - RGB16 = 15; - RGB20 = 16; - RGB24 = 17; - RGB8_IR = 18; - RGB16_IR = 19; - RGB24_IR = 20; - } - - enum AippMode { - undefined = 0; - static = 1; - dynamic = 2; - } - - // AIPPģʽ£¬Çø·Ö¾²Ì¬AIPPºÍ¶¯Ì¬AIPP - AippMode aipp_mode = 1; - - // related_input_rank²ÎÊýΪ±ØÌÀàÐÍΪÕûÐÍ£¬ÅäÖ÷¶Î§>=0, <=ÊäÈëDataËã×ӵĸöÊý£¬Ä¬ÈÏֵΪ0¡£ - // ±êʶ¶ÔÄ£Ð͵ĵڼ¸¸öÊäÈë×öAIPP´¦Àí£¬ÀýÈçÄ£ÐÍÓÐÁ½¸öÊäÈ룬ÐèÒª¶ÔµÚ2¸öÊäÈë×öAIPP£¬ÔòÅäÖÃrelated_input_rankΪ1¡£ - uint32 related_input_rank = 2; - - // input_edge_idx²ÎÊýΪ¿ÉÑ¡£¬ÀàÐÍΪÕûÐÍ£¬ÅäÖ÷¶Î§Îª>=0¡£ - // ÅäÖøòÎÊýµÄ×÷Óã¬ÔÚÓÚ¶ÔDataËã×Ó²»Í¬µÄÊä³ö×ö²»Í¬µÄAIPP´¦Àí£¬Èç¹û¸Ã²ÎÊýûÓÐÅäÖã¬Ä¬È϶Ôrelated_input_rankÖ¸¶¨µÄÄ£ÐÍÊäÈëµÄËùÓÐÊä³ö±ß×öAIPP¡£ - // ÅäÖÃÖµ <= DataËã×ÓÊä³ö±ßµÄ¸öÊý¡£ - repeated uint32 input_edge_idx = 3; - - // [Begin] ¶¯Ì¬AIPP²ÎÊý£¬ÅäÖþ²Ì¬AIPPʱÎÞЧ - uint32 max_src_image_size = 4; - - // ÊÇ·ñÖ§³ÖÐýת¡£Ä¬Èϲ»Ö§³Ö£¬¿ªÆôÖ§³ÖÐýתʱ£¬»áÓжîÍâµÄ¿Õ¼äºÍÐÔÄÜËðʧ - bool support_rotation = 5; - - // [End] ¶¯Ì¬AIPP²ÎÊý - - - // [Begin] ¾²Ì¬AIPP²ÎÊý£¬ÅäÖö¯Ì¬AIPPʱÎÞЧ - InputFormat input_format = 51; - bool csc_switch = 52; - float cpadding_value = 53; - bool rbuv_swap_switch = 54; - bool ax_swap_switch = 55; - bool single_line_mode = 56; - - int32 src_image_size_w = 57; - int32 src_image_size_h = 58; - - bool crop = 59; - int32 load_start_pos_w = 60; - int32 load_start_pos_h = 61; - int32 crop_size_w = 62; - int32 crop_size_h = 63; - - bool resize = 64; - int32 resize_output_w = 65; - int32 resize_output_h = 66; - - bool padding = 67; - int32 left_padding_size = 68; - int32 right_padding_size = 69; - int32 top_padding_size = 70; - int32 bottom_padding_size = 71; - - int32 mean_chn_0 = 10; - int32 mean_chn_1 = 11; - int32 mean_chn_2 = 12; - int32 mean_chn_3 = 19; - float min_chn_0 = 13; - float min_chn_1 = 14; - float min_chn_2 = 15; - float min_chn_3 = 20; - repeated float var_reci_chn_0 = 16; - repeated float var_reci_chn_1 = 17; - repeated float var_reci_chn_2 = 18; - repeated float var_reci_chn_3 = 21; - - repeated int32 matrix_r0c0 = 30; - repeated int32 matrix_r0c1 = 31; - repeated int32 matrix_r0c2 = 32; - repeated int32 matrix_r1c0 = 33; - repeated int32 matrix_r1c1 = 34; - repeated int32 matrix_r1c2 = 35; - repeated int32 matrix_r2c0 = 36; - repeated int32 matrix_r2c1 = 37; - repeated int32 matrix_r2c2 = 38; - repeated int32 output_bias_0 = 39; - repeated int32 output_bias_1 = 40; - repeated int32 output_bias_2 = 41; - repeated int32 input_bias_0 = 42; - repeated int32 input_bias_1 = 43; - repeated int32 input_bias_2 = 44; - - // [End] ¾²Ì¬AIPP²ÎÊý - - // The n number that is used for raw/rgbir data into f16 transformation. - // The transformation equation is x/(2^n). If set to 0, no transform is performed. - uint32 raw_rgbir_to_f16_n = 45; -} - -message MultiShapeOpParams { - enum MultiShapeMode { - batch = 0; //¶¯Ì¬batch - resolution = 1; //¶¯Ì¬·Ö±æÂÊ£¬À©Õ¹Óà - } - - MultiShapeMode mode = 1; //Ëã×Óģʽ - uint32 related_input_rank = 2; //ÐÂÔöËã×Ó²åÈëµ½ÄĸöÊäÈë - - - repeated uint32 batch_list = 11; //batch_listÖµ£¬batch_listµÄ¸öÊýÊÇ2µ½8Ö®¼ä -} diff --git a/ge/common/proto/om.proto b/ge/common/proto/om.proto deleted file mode 100644 index dd992191..00000000 --- a/ge/common/proto/om.proto +++ /dev/null @@ -1,401 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package domi; - -enum TargetType -{ - MINI = 0; - TINY = 1; - LITE = 2; -} - -// offline model -message ModelDef { - string name = 1; - uint32 version = 2; - - uint64 memory_size = 10; - uint32 stream_num = 11; - uint32 event_num = 12; - uint64 weight_size = 13; - uint32 label_num = 15; - repeated OpDef op = 20; - TargetType target_type = 23; - - map attr = 30; -}; - -// operator define -message OpDef { - string name = 1; - string type = 2; - - uint32 id = 3; - uint32 stream_id = 4; - - repeated string input_name = 5; - - repeated string src_name = 8; - repeated int32 src_index = 9; - repeated int64 input = 10; - repeated int64 output = 11; - repeated TensorDescriptor input_desc = 12; - repeated TensorDescriptor output_desc = 13; - repeated WeightDef weights = 14; - repeated string dst_name = 15; - repeated int32 dst_index = 16; - - repeated int64 workspace = 20; - repeated uint32 workspace_bytes = 21; - - repeated string weight_name = 22; - repeated bool is_input_const = 23; - - map attr = 30; - - QuantizeFactorParams quantize_factor = 31; - - oneof op_params { - // start at 100 here - SendOpParams sender_param = 100; - RecvOpParams receiver_param = 200; - ConvolutionOpParams convolution_param = 300; - PoolingOpParams pooling_param = 400; - EltwiseOpParams eltwise_param = 500; - BatchNormOpParams batchnorm_param = 600; - ScaleOpParams scale_param = 700; - FullConnectionOpParams full_connection_param = 800; - SoftmaxOpParams softmax_param = 900; - ActivationOpParams activation_param = 1000; - ReshapeOpParams reshape_param = 1100; - } -}; - -message SendOpParams { - uint32 event_id = 1; -}; - -message RecvOpParams { - uint32 event_id = 1; -}; - -enum QuantizeScaleType -{ - VECTOR_SCALE = 0; - SCALAR_SCALE = 1; -} - -enum QuantizeScaleMode -{ - NORMAL_MODE = 0; - SQRT_MODE = 1; -} - -enum QuantizeAlgorithm -{ - NON_OFFSET_ALGO = 0; - HALF_OFFSET_ALGO = 1; - ALL_OFFSET_ALGO = 2; -} -message QuantizeFactor -{ - QuantizeScaleMode scale_mode = 1; - bytes scale_value = 2; - int64 scale_offset = 3; - bytes offset_data_value = 4; - int64 offset_data_offset = 5; - bytes offset_weight_value = 6; - int64 offset_weight_offset = 7; - bytes offset_pad_value = 8; - int64 offset_pad_offset = 9; -}; - -message QuantizeCalcFactor -{ - bytes offsetw = 1; - int64 offsetw_offset = 2; - bytes offsetd = 3; - int64 offsetd_offset = 4; - bytes scalereq = 5; - int64 scaledreq_offset = 6; - bytes offsetdnext = 7; - int64 offsetdnext_offset = 8; -} - -message QuantizeFactorParams -{ - QuantizeAlgorithm quantize_algo = 1; - QuantizeScaleType scale_type = 2; - QuantizeFactor quantize_param = 3; - QuantizeFactor dequantize_param = 4; - QuantizeFactor requantize_param = 5; - QuantizeCalcFactor quantizecalc_param = 6; -}; - -message ConvolutionOpParams { - int32 mode = 1; - int32 algo = 2; - int32 pad_mode = 3; - uint32 group = 4; - uint32 num_output = 5; - - repeated uint32 pad = 10; - repeated uint32 stride = 11; - repeated uint32 dilation = 12; - repeated uint32 kernel = 13; - - float alpha = 20; - float beta = 21; - - WeightDef filter = 40; - WeightDef bias = 41; - - bool relu_flag = 62; - repeated uint32 adj = 70; - repeated uint32 target_shape = 71; - repeated uint32 before_pad = 72; -}; - -message PoolingOpParams { - int32 mode = 1; - int32 nan_opt = 2; - int32 pad_mode = 3; - bool global_pooling = 4; - - repeated uint32 window = 10; - repeated uint32 pad = 11; - repeated uint32 stride = 12; - bool ceil_mode = 13; - int32 data_mode = 14; - - float alpha = 20; - float beta = 21; - repeated uint32 before_pad = 22; -}; - -message EltwiseOpParams { - int32 mode = 1; - repeated float coeff = 2; - float alpha = 3; - float beta = 4; - repeated WeightDef weight = 5; - bool relu_flag = 6; -}; - -message ActivationOpParams { - int32 mode = 1; - float coef = 2; - float alpha = 3; - float beta = 4; -}; - -message BatchNormOpParams { - int32 mode = 1; - - float alpha = 2; - float beta = 3; - double epsilon = 4;//optinal,[default = 1e-5] - bool use_global_stats = 5; //optinal,by default true,testing mode - float moving_average_fraction = 6; //optinal,[default = .999]; - - WeightDef estimated_mean = 7; - WeightDef estimated_variance = 8; - - WeightDef scale = 9; - WeightDef bias = 10; -}; - -message ScaleOpParams { - WeightDef scale = 1; - WeightDef bias = 2; -}; - -message ReshapeOpParams { - float alpha = 1; - float beta = 2; - ShapeDef shape = 3; - int32 axis = 4; - int32 num_axes = 5; - int32 format = 6; -}; - -message SoftmaxOpParams { - int32 algo = 1; - int32 mode = 2; - float alpha = 3; - float beta = 4; -}; - -message FullConnectionOpParams { - WeightDef filter = 1; - WeightDef bias = 2; - uint32 num_output = 3; - bool relu_flag = 12; -}; - -message FlattenOpParams { - float alpha = 1; - float beta = 2; - int32 start_axis = 3; - int32 end_axis = 4; -} - -message AddLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message MulLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message AddOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message MulOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message SubOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message BiasAddOpParams { - float alpha = 1; - float beta = 2; - - WeightDef bias = 10; -}; - -message MatMulOpParams { - float alpha = 1; - float beta = 2; - bool transposeX = 3; - bool transposeW = 4; - - WeightDef filter = 10; - WeightDef bias = 12; -}; - -message RsqrtOpParams { - float alpha = 1; - float beta = 2; -}; - - -message WeightDef { - int32 format = 1; - int32 data_type = 2; - ShapeDef shape = 3; - bytes data = 4; - int64 data_offset = 5; - uint32 cmps_size = 6; - bytes cmps_tab = 7; - int64 cmps_tab_offset = 10; - CompressInfo cmps_info = 8; - AllOffsetQuantizeInfo alloffset_quantize_info = 11; -} - -message ShapeDef { - repeated int64 dim = 1; -} - -enum DeviceType { - NPU = 0; // In default, we will use NPU. - CPU = 1; // CPU -} - -message AllOffsetQuantizeInfo { - float scale = 1; - int32 offset = 2; -} - -message TensorDescriptor { - int32 format = 1; - int32 data_type = 2; - repeated int64 dim = 3; - uint32 size = 4; - bool reuse_input = 5; - bool output_tensor = 7; - DeviceType device_type = 8; - bool input_tensor = 9; - uint32 real_dim_cnt = 10; - uint32 reuse_input_index = 11; - AllOffsetQuantizeInfo alloffset_quantize_info = 12; -} - -message CompressInfo { - int32 blockRow = 1; // block row - int32 blockCol = 2; // block col - int32 fractalK = 3; // fractal K - int32 fractalN = 4; // fractal N - int32 lastFractalK = 5; // K of last fractal - int32 lastFractalN = 6; // N of last fractal - int32 cubeSize = 7; // cube's length - int32 loadDir = 8; // data load directtiono 0:col load 1:row load -} - -message AttrDef { - message ListValue { - repeated string s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated uint32 u = 6 [packed = true]; // "list(uint)" - repeated bytes bt = 7; - } - - oneof value { - string s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - uint32 u = 6; // "uint32" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs { - string name = 1; - map attr = 2; -} - diff --git a/ge/common/proto/task.proto b/ge/common/proto/task.proto deleted file mode 100644 index 50ea061b..00000000 --- a/ge/common/proto/task.proto +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package domi; - -message ModelTaskDef { - string version = 1; - - map attr = 9; // Extended field - repeated TaskDef task = 10; - - uint64 memory_size = 11; - uint32 stream_num = 12; - uint32 event_num = 13; - uint64 weight_size = 14; - - repeated bytes op = 15; // input/output opdef in bytes - - uint64 base_addr = 16; // base addr - uint64 weight_addr = 17; // weight addr - uint32 batch_num = 18; -} - - -message TaskDef { - uint32 id = 1; - uint32 type = 2; - - uint32 stream_id = 10; - uint32 event_id = 11; - - KernelDef kernel = 20; - KernelExDef kernel_ex = 21; - KernelHcclDef kernel_hccl = 25; - EventExDef event_ex = 26; - LogTimeStampDef log_timestamp = 28; - - uint32 label_id = 30; - - MemcpyAsyncDef memcpy_async = 31; - StreamSwitchDef stream_switch = 32; - StreamActiveDef stream_active = 33; - bytes private_def = 34; - uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future - StreamSwitchNDef stream_switch_n = 36; - - LabelSetDef label_set = 37; - LabelGotoExDef label_goto_ex = 38; - LabelSwitchByIndexDef label_switch_by_index = 39; -} - -message KernelDef { - KernelContext context = 1; - - string stub_func = 10; - uint32 block_dim = 11; - uint32 args_size = 12; - bytes args = 13; - bytes sm_desc = 14; - bytes flowtable = 15; - string so_name = 16; - string kernel_name = 17; - bytes kernel_ext_info = 18; - uint32 kernel_ext_info_size = 19; -} - -message KernelContext { - uint32 kernel_type = 1; - uint32 op_id = 2; // OP type in CCE - uint32 kernel_func_id = 3; - uint32 op_index = 4; // TE/Custom operator - bool is_flowtable = 5; // Identify whether args is a flowtable structure - bytes args_offset = 6; // args offset information - uint32 args_count = 7; // args count - repeated uint32 origin_op_index = 8; -} - - -message KernelExDef { - uint32 flags = 1; - - uint32 op_index = 4; - uint32 args_size = 12; - bytes args = 13; - bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput - uint32 task_info_size = 15; - bytes kernel_ext_info = 16; - uint32 kernel_ext_info_size = 17; -} - - -message KernelHcclDef { - uint32 op_index = 8; - string hccl_type = 9; -} - - -message EventExDef { - uint32 op_index = 1; - uint32 event_type = 2; -} - -message LogTimeStampDef { - uint64 logid = 1; - bool notify = 2; - uint32 flat = 3; -} - -message MemcpyAsyncDef { - uint64 dst = 1; - uint64 dst_max = 2; - uint64 src = 3; - uint64 count = 4; - uint32 kind = 5; - uint32 op_index = 6; -} - -message StreamSwitchDef { - uint32 op_index = 1; - uint32 true_stream_id = 2; - int64 value = 3; - uint64 value_ptr = 4; - uint32 data_type = 5; -} - -message StreamActiveDef { - uint32 op_index = 1; - uint32 active_stream_id = 2; -} - -message StreamSwitchNDef { - uint32 op_index = 1; - uint32 size = 2; - repeated int64 target_value = 3; - repeated uint32 true_stream_id = 4; - uint32 element_size = 5; - uint32 data_type = 6; -} - -message LabelSetDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelGotoExDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelSwitchByIndexDef { - uint32 op_index = 1; - uint32 label_max = 2; -} diff --git a/ge/common/proto/tensorflow/attr_value.proto b/ge/common/proto/tensorflow/attr_value.proto deleted file mode 100644 index 1cc67d62..00000000 --- a/ge/common/proto/tensorflow/attr_value.proto +++ /dev/null @@ -1,62 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "AttrValueProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensor.proto"; -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing the value for an attr used to configure an Op. -// Comment indicates the corresponding attr type. Only the field matching the -// attr type may be filled. -message AttrValue { - // LINT.IfChange - message ListValue { - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated DataType type = 6 [packed = true]; // "list(type)" - repeated TensorShapeProto shape = 7; // "list(shape)" - repeated TensorProto tensor = 8; // "list(tensor)" - repeated NameAttrList func = 9; // "list(attr)" - } - // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) - - oneof value { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - DataType type = 6; // "type" - TensorShapeProto shape = 7; // "shape" - TensorProto tensor = 8; // "tensor" - ListValue list = 1; // any "list(...)" - - // "func" represents a function. func.name is a function's name or - // a primitive op's name. func.attr.first is the name of an attr - // defined for that function. func.attr.second is the value for - // that attr in the instantiation. - NameAttrList func = 10; - - // This is a placeholder only used in nodes defined inside a - // function. It indicates the attr value will be supplied when - // the function is instantiated. For example, let us suppose a - // node "N" in function "FN". "N" has an attr "A" with value - // placeholder = "foo". When FN is instantiated with attr "foo" - // set to "bar", the instantiated node N's attr A will have been - // given the value "bar". - string placeholder = 9; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NameAttrList { - string name = 1; - map attr = 2; -} diff --git a/ge/common/proto/tensorflow/function.proto b/ge/common/proto/tensorflow/function.proto deleted file mode 100644 index 075897c6..00000000 --- a/ge/common/proto/tensorflow/function.proto +++ /dev/null @@ -1,100 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "FunctionProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; -import "node_def.proto"; -import "op_def.proto"; - -// A library is a set of named functions. -message FunctionDefLibrary { - repeated FunctionDef function = 1; - repeated GradientDef gradient = 2; -} - -// A function can be instantiated when the runtime can bind every attr -// with a value. When a GraphDef has a call to a function, it must -// have binding for every attr defined in the signature. -// * device spec, etc. -message FunctionDef { - // The definition of the function's name, arguments, return values, - // attrs etc. - OpDef signature = 1; - - // Attributes specific to this function definition. - map attr = 5; - - // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. - reserved 2; - - // In both of the following fields, there is the need to specify an - // output that is used as either the input to another node (in - // `node_def`) or as a return value of the function (in `ret`). - // Unlike the NodeDefs in GraphDef, we need to be able to specify a - // list in some cases (instead of just single outputs). Also, we - // need to be able to deal with lists of unknown length (so the - // output index may not be known at function definition time). So - // we use the following format instead: - // * "fun_in" where "fun_in" is the name of a function input arg in - // the `signature` field above. This represents that input, whether - // it is a single tensor or a list. - // * "fun_in:0" gives the first element of a function input arg (a - // non-list input is considered a list of length 1 for these - // purposes). - // * "node:out" where "node" is the name of a node in `node_def` and - // "out" is the name one of its op's output arguments (the name - // comes from the OpDef of the node's op). This represents that - // node's output, whether it is a single tensor or a list. - // Note: We enforce that an op's output arguments are never - // renamed in the backwards-compatibility test. - // * "node:out:0" gives the first element of a node output arg (a - // non-list output is considered a list of length 1 for these - // purposes). - // - // NOT CURRENTLY SUPPORTED (but may be in the future): - // * "node:out:-1" gives last element in a node output list - // * "node:out:1:" gives a list with all but the first element in a - // node output list - // * "node:out::-1" gives a list with all but the last element in a - // node output list - - // The body of the function. Unlike the NodeDefs in a GraphDef, attrs - // may have values of type `placeholder` and the `input` field uses - // the "output" format above. - - // By convention, "op" in node_def is resolved by consulting with a - // user-defined library first. If not resolved, "func" is assumed to - // be a builtin op. - repeated NodeDef node_def = 3; - - // A mapping from the output arg names from `signature` to the - // outputs from `node_def` that should be returned by the function. - map ret = 4; -} - -// GradientDef defines the gradient function of a function defined in -// a function library. -// -// A gradient function g (specified by gradient_func) for a function f -// (specified by function_name) must follow the following: -// -// The function 'f' must be a numerical function which takes N inputs -// and produces M outputs. Its gradient function 'g', which is a -// function taking N + M inputs and produces N outputs. -// -// I.e. if we have -// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), -// then, g is -// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, -// dL/dy1, dL/dy2, ..., dL/dy_M), -// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the -// loss function). dL/dx_i is the partial derivative of L with respect -// to x_i. -message GradientDef { - string function_name = 1; // The function name. - string gradient_func = 2; // The gradient function's name. -} diff --git a/ge/common/proto/tensorflow/graph.proto b/ge/common/proto/tensorflow/graph.proto deleted file mode 100644 index d639a7d6..00000000 --- a/ge/common/proto/tensorflow/graph.proto +++ /dev/null @@ -1,56 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "GraphProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "node_def.proto"; -import "function.proto"; -import "versions.proto"; - -// Represents the graph of operations -message GraphDef { - repeated NodeDef node = 1; - - // Compatibility versions of the graph. See core/public/version.h for version - // history. The GraphDef version is distinct from the TensorFlow version, and - // each release of TensorFlow will support a range of GraphDef versions. - VersionDef versions = 4; - - // Deprecated single version field; use versions above instead. Since all - // GraphDef changes before "versions" was introduced were forward - // compatible, this field is entirely ignored. - int32 version = 3 [deprecated = true]; - - // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. - // - // "library" provides user-defined functions. - // - // Naming: - // * library.function.name are in a flat namespace. - // NOTE: We may need to change it to be hierarchical to support - // different orgs. E.g., - // { "/google/nn", { ... }}, - // { "/google/vision", { ... }} - // { "/org_foo/module_bar", { ... }} - // map named_lib; - // * If node[i].op is the name of one function in "library", - // node[i] is deemed as a function call. Otherwise, node[i].op - // must be a primitive operation supported by the runtime. - // - // - // Function call semantics: - // - // * The callee may start execution as soon as some of its inputs - // are ready. The caller may want to use Tuple() mechanism to - // ensure all inputs are ready in the same time. - // - // * The consumer of return values may start executing as soon as - // the return values the consumer depends on are ready. The - // consumer may want to use Tuple() mechanism to ensure the - // consumer does not start until all return values of the callee - // function are ready. - FunctionDefLibrary library = 2; -}; diff --git a/ge/common/proto/tensorflow/graph_library.proto b/ge/common/proto/tensorflow/graph_library.proto deleted file mode 100644 index e393d38d..00000000 --- a/ge/common/proto/tensorflow/graph_library.proto +++ /dev/null @@ -1,14 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; - -import "graph.proto"; - -message GeGraphDef { - string name = 1; - GraphDef graph = 2; -} - -message GraphDefLibrary { - repeated GeGraphDef graph_def = 1; -}; \ No newline at end of file diff --git a/ge/common/proto/tensorflow/node_def.proto b/ge/common/proto/tensorflow/node_def.proto deleted file mode 100644 index b9bc97ee..00000000 --- a/ge/common/proto/tensorflow/node_def.proto +++ /dev/null @@ -1,63 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "NodeProto"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; - -message NodeDef { - // The name given to this operator. Used for naming inputs, - // logging, visualization, etc. Unique within a single GraphDef. - // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". - string name = 1; - - // The operation name. There may be custom parameters in attrs. - // Op names starting with an underscore are reserved for internal use. - string op = 2; - - // Each input is "node:src_output" with "node" being a string name and - // "src_output" indicating which output tensor to use from "node". If - // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs - // may optionally be followed by control inputs that have the format - // "^node". - repeated string input = 3; - - // A (possibly partial) specification for the device on which this - // node should be placed. - // The expected syntax for this string is as follows: - // - // DEVICE_SPEC ::= PARTIAL_SPEC - // - // PARTIAL_SPEC ::= ("/" CONSTRAINT) * - // CONSTRAINT ::= ("job:" JOB_NAME) - // | ("replica:" [1-9][0-9]*) - // | ("task:" [1-9][0-9]*) - // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) - // - // Valid values for this string include: - // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) - // * "/job:worker/device:GPU:3" (partial specification) - // * "" (no specification) - // - // If the constraints do not resolve to a single device (or if this - // field is empty or not present), the runtime will attempt to - // choose a device automatically. - string device = 4; - - // Operation-specific graph-construction-time configuration. - // Note that this should include all attrs defined in the - // corresponding OpDef, including those with a value matching - // the default -- this allows the default to change and makes - // NodeDefs easier to interpret on their own. However, if - // an attr with a default is not specified in this list, the - // default will be used. - // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and - // one of the names from the corresponding OpDef's attr field). - // The values must have a type matching the corresponding OpDef - // attr's type field. - // Add some examples here showing best practices. - map attr = 5; -}; diff --git a/ge/common/proto/tensorflow/op_def.proto b/ge/common/proto/tensorflow/op_def.proto deleted file mode 100644 index 3485d045..00000000 --- a/ge/common/proto/tensorflow/op_def.proto +++ /dev/null @@ -1,164 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "OpDefProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; -import "types.proto"; - -// Defines an operation. A NodeDef in a GraphDef specifies an Op by -// using the "op" field which should match the name of a OpDef. -// LINT.IfChange -message OpDef { - // Op names starting with an underscore are reserved for internal use. - // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". - string name = 1; - - // For describing inputs and outputs. - message ArgDef { - // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". - string name = 1; - - // Human readable description. - string description = 2; - - // Describes the type of one or more tensors that are accepted/produced - // by this input/output arg. The only legal combinations are: - // * For a single tensor: either the "type" field is set or the - // "type_attr" field is set to the name of an attr with type "type". - // * For a sequence of tensors with the same type: the "number_attr" - // field will be set to the name of an attr with type "int", and - // either the "type" or "type_attr" field will be set as for - // single tensors. - // * For a sequence of tensors, the "type_list_attr" field will be set - // to the name of an attr with type "list(type)". - DataType type = 3; - string type_attr = 4; // if specified, attr must have type "type" - string number_attr = 5; // if specified, attr must have type "int" - // If specified, attr must have type "list(type)", and none of - // type, type_attr, and number_attr may be specified. - string type_list_attr = 6; - - // For inputs: if true, the inputs are required to be refs. - // By default, inputs can be either refs or non-refs. - // For outputs: if true, outputs are refs, otherwise they are not. - bool is_ref = 16; - }; - - // Description of the input(s). - repeated ArgDef input_arg = 2; - - // Description of the output(s). - repeated ArgDef output_arg = 3; - - // Description of the graph-construction-time configuration of this - // Op. That is to say, this describes the attr fields that will - // be specified in the NodeDef. - message AttrDef { - // A descriptive name for the argument. May be used, e.g. by the - // Python client, as a keyword argument name, and so should match - // the regexp "[a-z][a-z0-9_]+". - string name = 1; - - // One of the type names from attr_value.proto ("string", "list(string)", - // "int", etc.). - string type = 2; - - // A reasonable default for this attribute if the user does not supply - // a value. If not specified, the user must supply a value. - AttrValue default_value = 3; - - // Human-readable description. - string description = 4; - - - // --- Constraints --- - // These constraints are only in effect if specified. Default is no - // constraints. - - // For type == "int", this is a minimum value. For "list(___)" - // types, this is the minimum length. - bool has_minimum = 5; - int64 minimum = 6; - - // The set of allowed values. Has type that is the "list" version - // of the "type" field above (uses the "list" field of AttrValue). - // If type == "type" or "list(type)" above, then the "type" field - // of "allowed_values.list" has the set of allowed DataTypes. - // If type == "string" or "list(string)", then the "s" field of - // "allowed_values.list" has the set of allowed strings. - AttrValue allowed_values = 7; - } - repeated AttrDef attr = 4; - - // Optional deprecation based on GraphDef versions. - OpDeprecation deprecation = 8; - - // One-line human-readable description of what the Op does. - string summary = 5; - - // Additional, longer human-readable description of what the Op does. - string description = 6; - - // ------------------------------------------------------------------------- - // Which optimizations this operation can participate in. - - // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) - bool is_commutative = 18; - - // If is_aggregate is true, then this operation accepts N >= 2 - // inputs and produces 1 output all of the same type. Should be - // associative and commutative, and produce output with the same - // shape as the input. The optimizer may replace an aggregate op - // taking input from multiple devices with a tree of aggregate ops - // that aggregate locally within each device (and possibly within - // groups of nearby devices) before communicating. - bool is_aggregate = 16; // for things like add - - // Other optimizations go here, like - // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. - - // ------------------------------------------------------------------------- - // Optimization constraints. - - // Ops are marked as stateful if their behavior depends on some state beyond - // their input tensors (e.g. variable reading op) or if they have - // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops - // must always produce the same output for the same input and have - // no side-effects. - // - // By default Ops may be moved between devices. Stateful ops should - // either not be moved, or should only be moved if that state can also - // be moved (e.g. via some sort of save / restore). - // Stateful ops are guaranteed to never be optimized away by Common - // Subexpression Elimination (CSE). - bool is_stateful = 17; // for things like variables, queue - - // ------------------------------------------------------------------------- - // Non-standard options. - - // By default, all inputs to an Op must be initialized Tensors. Ops - // that may initialize tensors for the first time should set this - // field to true, to allow the Op to take an uninitialized Tensor as - // input. - bool allows_uninitialized_input = 19; // for Assign, etc. -}; -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) - -// Information about version-dependent deprecation of an op -message OpDeprecation { - // First GraphDef version at which the op is disallowed. - int32 version = 1; - - // Explanation of why it was deprecated and what to use instead. - string explanation = 2; -}; - -// A collection of OpDefs -message OpList { - repeated OpDef op = 1; -}; diff --git a/ge/common/proto/tensorflow/resource_handle.proto b/ge/common/proto/tensorflow/resource_handle.proto deleted file mode 100644 index a3452351..00000000 --- a/ge/common/proto/tensorflow/resource_handle.proto +++ /dev/null @@ -1,29 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "ResourceHandle"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// Protocol buffer representing a handle to a tensorflow resource. Handles are -// not valid across executions, but can be serialized back and forth from within -// a single run. -message ResourceHandleProto { - // Unique name for the device containing the resource. - string device = 1; - - // Container in which this resource is placed. - string container = 2; - - // Unique name of this resource. - string name = 3; - - // Hash code for the type of the resource. Is only valid in the same device - // and in the same execution. - uint64 hash_code = 4; - - // For debug-only, the name of the type pointed to by this handle, if - // available. - string maybe_type_name = 5; -}; diff --git a/ge/common/proto/tensorflow/tensor.proto b/ge/common/proto/tensorflow/tensor.proto deleted file mode 100644 index d0a4d024..00000000 --- a/ge/common/proto/tensorflow/tensor.proto +++ /dev/null @@ -1,94 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TensorProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "resource_handle.proto"; -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing a tensor. -message TensorProto { - DataType dtype = 1; - - // Shape of the tensor. - TensorShapeProto tensor_shape = 2; - - // Only one of the representations below is set, one of "tensor_contents" and - // the "xxx_val" attributes. We are not using oneof because as oneofs cannot - // contain repeated fields it would require another extra set of messages. - - // Version number. - // - // In version 0, if the "repeated xxx" representations contain only one - // element, that element is repeated to fill the shape. This makes it easy - // to represent a constant Tensor with a single value. - int32 version_number = 3; - - // Serialized raw tensor content from either Tensor::AsProtoTensorContent or - // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation - // can be used for all tensor types. The purpose of this representation is to - // reduce serialization overhead during RPC call by avoiding serialization of - // many repeated small items. - bytes tensor_content = 4; - - // Type specific representations that make it easy to create tensor protos in - // all languages. Only the representation corresponding to "dtype" can - // be set. The values hold the flattened representation of the tensor in - // row major order. - - // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll - // have some pointless zero padding for each value here. - repeated int32 half_val = 13 [packed = true]; - - // DT_FLOAT. - repeated float float_val = 5 [packed = true]; - - // DT_DOUBLE. - repeated double double_val = 6 [packed = true]; - - // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. - repeated int32 int_val = 7 [packed = true]; - - // DT_STRING - repeated bytes string_val = 8; - - // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real - // and imaginary parts of i-th single precision complex. - repeated float scomplex_val = 9 [packed = true]; - - // DT_INT64 - repeated int64 int64_val = 10 [packed = true]; - - // DT_BOOL - repeated bool bool_val = 11 [packed = true]; - - // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real - // and imaginary parts of i-th double precision complex. - repeated double dcomplex_val = 12 [packed = true]; - - // DT_RESOURCE - repeated ResourceHandleProto resource_handle_val = 14; - - // DT_VARIANT - repeated VariantTensorDataProto variant_val = 15; - - // DT_UINT32 - repeated uint32 uint32_val = 16 [packed = true]; - - // DT_UINT64 - repeated uint64 uint64_val = 17 [packed = true]; -}; - -// Protocol buffer representing the serialization format of DT_VARIANT tensors. -message VariantTensorDataProto { - // Name of the type of objects being serialized. - string type_name = 1; - // Portions of the object that are not Tensors. - bytes metadata = 2; - // Tensors contained within objects being serialized. - repeated TensorProto tensors = 3; -} diff --git a/ge/common/proto/tensorflow/tensor_shape.proto b/ge/common/proto/tensorflow/tensor_shape.proto deleted file mode 100644 index 4225a2e3..00000000 --- a/ge/common/proto/tensorflow/tensor_shape.proto +++ /dev/null @@ -1,45 +0,0 @@ -// Protocol buffer representing the shape of tensors. - -syntax = "proto3"; -option cc_enable_arenas = true; -option java_outer_classname = "TensorShapeProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -package domi.tensorflow; - -// Dimensions of a tensor. -message TensorShapeProto { - // One dimension of the tensor. - message Dim { - // Size of the tensor in that dimension. - // This value must be >= -1, but values of -1 are reserved for "unknown" - // shapes (values of -1 mean "unknown" dimension). Certain wrappers - // that work with TensorShapeProto may fail at runtime when deserializing - // a TensorShapeProto containing a dim value of -1. - int64 size = 1; - - // Optional name of the tensor dimension. - string name = 2; - }; - - // Dimensions of the tensor, such as {"input", 30}, {"output", 40} - // for a 30 x 40 2D tensor. If an entry has size -1, this - // corresponds to a dimension of unknown size. The names are - // optional. - // - // The order of entries in "dim" matters: It indicates the layout of the - // values in the tensor in-memory representation. - // - // The first entry in "dim" is the outermost dimension used to layout the - // values, the last entry is the innermost dimension. This matches the - // in-memory layout of RowMajor Eigen tensors. - // - // If "dim.size()" > 0, "unknown_rank" must be false. - repeated Dim dim = 2; - - // If true, the number of dimensions in the shape is unknown. - // - // If true, "dim.size()" must be 0. - bool unknown_rank = 3; -}; diff --git a/ge/common/proto/tensorflow/types.proto b/ge/common/proto/tensorflow/types.proto deleted file mode 100644 index ba7a72b3..00000000 --- a/ge/common/proto/tensorflow/types.proto +++ /dev/null @@ -1,74 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TypesProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// LINT.IfChange -enum DataType { - // Not a legal value for DataType. Used to indicate a DataType field - // has not been set. - DT_INVALID = 0; - - // Data types that all computation devices are expected to be - // capable to support. - DT_FLOAT = 1; - DT_DOUBLE = 2; - DT_INT32 = 3; - DT_UINT8 = 4; - DT_INT16 = 5; - DT_INT8 = 6; - DT_STRING = 7; - DT_COMPLEX64 = 8; // Single-precision complex - DT_INT64 = 9; - DT_BOOL = 10; - DT_QINT8 = 11; // Quantized int8 - DT_QUINT8 = 12; // Quantized uint8 - DT_QINT32 = 13; // Quantized int32 - DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. - DT_QINT16 = 15; // Quantized int16 - DT_QUINT16 = 16; // Quantized uint16 - DT_UINT16 = 17; - DT_COMPLEX128 = 18; // Double-precision complex - DT_HALF = 19; - DT_RESOURCE = 20; - DT_VARIANT = 21; // Arbitrary C++ data types - DT_UINT32 = 22; - DT_UINT64 = 23; - - // Do not use! These are only for parameters. Every enum above - // should have a corresponding value below (verified by types_test). - DT_FLOAT_REF = 101; - DT_DOUBLE_REF = 102; - DT_INT32_REF = 103; - DT_UINT8_REF = 104; - DT_INT16_REF = 105; - DT_INT8_REF = 106; - DT_STRING_REF = 107; - DT_COMPLEX64_REF = 108; - DT_INT64_REF = 109; - DT_BOOL_REF = 110; - DT_QINT8_REF = 111; - DT_QUINT8_REF = 112; - DT_QINT32_REF = 113; - DT_BFLOAT16_REF = 114; - DT_QINT16_REF = 115; - DT_QUINT16_REF = 116; - DT_UINT16_REF = 117; - DT_COMPLEX128_REF = 118; - DT_HALF_REF = 119; - DT_RESOURCE_REF = 120; - DT_VARIANT_REF = 121; - DT_UINT32_REF = 122; - DT_UINT64_REF = 123; -} -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/c/c_api.h, -// https://www.tensorflow.org/code/tensorflow/go/tensor.go, -// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, -// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, -// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, -// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, -// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/ge/common/proto/tensorflow/versions.proto b/ge/common/proto/tensorflow/versions.proto deleted file mode 100644 index 48061218..00000000 --- a/ge/common/proto/tensorflow/versions.proto +++ /dev/null @@ -1,31 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "VersionsProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// Version information for a piece of serialized data -// -// There are different types of versions for each type of data -// (GraphDef, etc.), but they all have the same common shape -// described here. -// -// Each consumer has "consumer" and "min_producer" versions (specified -// elsewhere). A consumer is allowed to consume this data if -// -// producer >= min_producer -// consumer >= min_consumer -// consumer not in bad_consumers -// -message VersionDef { - // The version of the code that produced this data. - int32 producer = 1; - - // Any consumer below this version is not allowed to consume this data. - int32 min_consumer = 2; - - // Specific consumer versions which are disallowed (e.g. due to bugs). - repeated int32 bad_consumers = 3; -}; diff --git a/ge/executor/CMakeLists.txt b/ge/executor/CMakeLists.txt deleted file mode 100755 index b67f2fd4..00000000 --- a/ge/executor/CMakeLists.txt +++ /dev/null @@ -1,115 +0,0 @@ -set(PROTO_LIST - "${METADEF_DIR}/proto/om.proto" - "${METADEF_DIR}/proto/ge_ir.proto" - "${METADEF_DIR}/proto/insert_op.proto" - "${METADEF_DIR}/proto/task.proto" - "${METADEF_DIR}/proto/op_mapping_info.proto" - "${METADEF_DIR}/proto/dump_task.proto" -) - -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) - -set(SRC_LIST - "ge_executor.cc" - "../common/profiling/profiling_manager.cc" - "../common/ge/plugin_manager.cc" - "../common/ge/op_tiling_manager.cc" - "../common/dump/dump_properties.cc" - "../common/dump/dump_manager.cc" - "../common/dump/dump_op.cc" - "../graph/load/graph_loader.cc" - "../graph/execute/graph_execute.cc" - "../omm/csa_interact.cc" - "../graph/manager/graph_manager_utils.cc" - "../graph/manager/graph_var_manager.cc" - "../graph/manager/graph_mem_allocator.cc" - "../graph/manager/graph_caching_allocator.cc" - "../graph/manager/trans_var_data_utils.cc" - "../graph/manager/util/debug.cc" - "../graph/manager/rdma_pool_allocator.cc" - "../hybrid/node_executor/aicpu/aicpu_ext_info.cc" - "../model/ge_model.cc" - "../model/ge_root_model.cc" - "../graph/load/new_model_manager/davinci_model.cc" - "../graph/load/new_model_manager/davinci_model_parser.cc" - "../graph/load/new_model_manager/model_manager.cc" - "../graph/load/new_model_manager/tbe_handle_store.cc" - "../graph/load/new_model_manager/cpu_queue_schedule.cc" - "../graph/load/new_model_manager/model_utils.cc" - "../graph/load/new_model_manager/aipp_utils.cc" - "../graph/load/new_model_manager/data_inputer.cc" - "../graph/load/new_model_manager/data_dumper.cc" - "../graph/load/new_model_manager/zero_copy_task.cc" - "../graph/load/new_model_manager/zero_copy_offset.cc" - "../graph/load/new_model_manager/task_info/task_info.cc" - "../graph/load/new_model_manager/task_info/event_record_task_info.cc" - "../graph/load/new_model_manager/task_info/event_wait_task_info.cc" - "../graph/load/new_model_manager/task_info/fusion_start_task_info.cc" - "../graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" - "../graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" - "../graph/load/new_model_manager/task_info/kernel_task_info.cc" - "../graph/load/new_model_manager/task_info/label_set_task_info.cc" - "../graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" - "../graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" - "../graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" - "../graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" - "../graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" - "../graph/load/new_model_manager/task_info/stream_active_task_info.cc" - "../graph/load/new_model_manager/task_info/stream_switch_task_info.cc" - "../graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" - "../graph/load/new_model_manager/task_info/end_graph_task_info.cc" - "../graph/load/new_model_manager/task_info/model_exit_task_info.cc" - "../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" - "../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" - "../opskernel_manager/ops_kernel_builder_manager.cc" - "../single_op/single_op_manager.cc" - "../single_op/single_op_model.cc" - "../single_op/single_op.cc" - "../single_op/stream_resource.cc" - "../single_op/task/op_task.cc" - "../single_op/task/build_task_utils.cc" - "../single_op/task/tbe_task_builder.cc" - "../single_op/task/aicpu_task_builder.cc" - "../single_op/task/aicpu_kernel_task_builder.cc" - "../hybrid/hybrid_davinci_model_stub.cc" -) - -######## libge_executor.a ######## -add_library(ge_executor STATIC ${SRC_LIST} ${PROTO_HDRS}) - -target_compile_options(ge_executor PRIVATE - -Werror - -O2 -) - -target_compile_definitions(ge_executor PRIVATE - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - DAVINCI_SUPPORT_PROFILING -) - -target_include_directories(ge_executor PRIVATE - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - ${GE_CODE_DIR}/../inc/cce - #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc -) - -target_link_libraries(ge_executor PRIVATE - $ - json - protobuf - c_sec - -lrt - -ldl -) diff --git a/ge/executor/proto/ge_ir.proto b/ge/executor/proto/ge_ir.proto deleted file mode 100644 index 87886c84..00000000 --- a/ge/executor/proto/ge_ir.proto +++ /dev/null @@ -1,206 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package ge.proto; - -enum DataType -{ - DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. - DT_FLOAT = 1; // float type - DT_FLOAT16 = 2; // fp16 type - DT_INT8 = 3; // int8 type - DT_UINT8 = 4; // uint8 type - DT_INT16 = 5; // int16 type - DT_UINT16 = 6; // uint16 type - DT_INT32 = 7; // - DT_INT64 = 8; // int64 type - DT_UINT32 = 9; // unsigned int32 - DT_UINT64 = 10; // unsigned int64 - DT_BOOL = 11; // bool type - DT_DOUBLE = 12; // double type - DT_STRING = 13; // string type - DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ - DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ - DT_COMPLEX64 = 16; // complex64 type - DT_COMPLEX128 = 17; // complex128 type - DT_QINT8 = 18; // qint8 type - DT_QINT16 = 19; // qint16 type - DT_QINT32 = 20; // qint32 type - DT_QUINT8 = 21; // quint8 type - DT_QUINT16 = 22; // quint16 type - DT_RESOURCE = 23; // resource type - DT_STRING_REF = 24; // string_ref type - DT_DUAL = 25; /**< dual output type */ -} - -message AttrDef -{ - message ListValue - { - enum ListValueType{ - VT_LIST_NONE = 0; - VT_LIST_STRING = 1; - VT_LIST_INT = 2; - VT_LIST_FLOAT = 3; - VT_LIST_BOOL = 4; - VT_LIST_BYTES = 5; - VT_LIST_TENSOR_DESC = 6; - VT_LIST_TENSOR = 7; - VT_LIST_GRAPH = 8; - VT_LIST_NAMED_ATTRS = 9; - VT_LIST_DATA_TYPE = 10; - } - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3; // "list(int)" - repeated float f = 4; // "list(float)" - repeated bool b = 5; // "list(bool)" - repeated bytes bt = 7; - repeated TensorDescriptor td = 8; - repeated TensorDef t = 9; - repeated GraphDef g = 10; - repeated NamedAttrs na = 11; - repeated int64 dt = 12; // list ge::DataType - - ListValueType val_type = 20; - } - - message ListListInt{ - message ListInt{ - repeated int64 list_i = 1; // list int - } - repeated ListInt list_list_i = 1; // list list int - } - - oneof value - { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; // Used to support attr nesting - TensorDescriptor td = 11; // GeTensorDesc type - TensorDef t = 12; // GeTensor type - GraphDef g = 13; // Graph type - ListListInt list_list_int = 14; // List List Int type - int64 dt = 15; // ge::DataType - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs -{ - string name = 1; - map attr = 2; -} - -// Shape / dimension description, using row-major order -message ShapeDef -{ - repeated int64 dim = 1; // Size of each dimension -} - -// Multidimensional data description -message TensorDescriptor -{ - string name = 1; // Optional parameter, tensor name - - DataType dtype = 2; // tensor datatype - ShapeDef shape = 3; // Shape / dimension - string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" - - bool has_out_attr = 9; - int64 size = 10; - int64 weight_size = 11; - bool reuse_input = 12; - bool output_tensor = 13; - string device_type = 14; - bool input_tensor =15; - int64 real_dim_cnt = 16; - int64 reuse_input_index = 17; - int64 data_offset = 18; - int64 cmps_size = 19; - string cmps_tab = 20; - int64 cmps_tab_offset = 21; - - map attr = 5; // Set of extra parameter fields -} - -// GeTensor definition -message TensorDef -{ - TensorDescriptor desc = 1; // Tensor description - bytes data = 2; // Tensor data -} - - -// Operator description -message OpDef -{ - string name = 1; // name - string type = 2; // type - - repeated string input = 5; // input original op name + outgoing index. op_name:index - - map attr = 10; // Set of operator parameter fields - - bool has_out_attr = 20; - int64 id = 21; - int64 stream_id =22; - repeated string input_name = 23; - repeated string src_name = 24; - repeated int64 src_index = 25; - repeated string dst_name = 26; - repeated int64 dst_index = 27; - repeated int64 input_i = 28; - repeated int64 output_i = 29; - repeated int64 workspace = 30; - repeated int64 workspace_bytes = 31; - repeated bool is_input_const = 32; - repeated TensorDescriptor input_desc = 33; - repeated TensorDescriptor output_desc = 34; - repeated string subgraph_name = 35; -} - -// Graph definition -message GraphDef -{ - string name = 1; // name - - repeated string input = 4; // Graph input - repeated string output = 5; // Graph output - - repeated OpDef op = 6; // List of operators - - map attr = 11; // Extended field -} - -// model definition -message ModelDef -{ - string name = 1; // name - uint32 version = 2; // IR Proto verion - string custom_version = 3; // User model version number, passed in by user - - repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef - - map attr = 11; // Extended field -} - diff --git a/ge/executor/proto/insert_op.proto b/ge/executor/proto/insert_op.proto deleted file mode 100644 index a059e122..00000000 --- a/ge/executor/proto/insert_op.proto +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package domi; - -message InsertNewOps { - repeated AippOpParams aipp_op = 1; - repeated MultiShapeOpParams multi_shape_op = 2; -} - -message AippOpParams { - enum InputFormat { - UNDEFINED = 0; - YUV420SP_U8 = 1; - XRGB8888_U8 = 2; - RGB888_U8 = 3; - YUV400_U8 = 4; - NC1HWC0DI_FP16 = 5; - NC1HWC0DI_S8 = 6; - ARGB8888_U8 = 7; - YUYV_U8 = 8; - YUV422SP_U8 = 9; - AYUV444_U8 = 10; - RAW10 = 11; - RAW12 = 12; - RAW16 = 13; - RAW24 = 14; - RGB16 = 15; - RGB20 = 16; - RGB24 = 17; - RGB8_IR = 18; - RGB16_IR = 19; - RGB24_IR = 20; - } - - enum AippMode { - undefined = 0; - static = 1; - dynamic = 2; - } - - // AIPPģʽ£¬Çø·Ö¾²Ì¬AIPPºÍ¶¯Ì¬AIPP - AippMode aipp_mode = 1; - - // related_input_rank²ÎÊýΪ±ØÌÀàÐÍΪÕûÐÍ£¬ÅäÖ÷¶Î§>=0, <=ÊäÈëDataËã×ӵĸöÊý£¬Ä¬ÈÏֵΪ0¡£ - // ±êʶ¶ÔÄ£Ð͵ĵڼ¸¸öÊäÈë×öAIPP´¦Àí£¬ÀýÈçÄ£ÐÍÓÐÁ½¸öÊäÈ룬ÐèÒª¶ÔµÚ2¸öÊäÈë×öAIPP£¬ÔòÅäÖÃrelated_input_rankΪ1¡£ - uint32 related_input_rank = 2; - - // input_edge_idx²ÎÊýΪ¿ÉÑ¡£¬ÀàÐÍΪÕûÐÍ£¬ÅäÖ÷¶Î§Îª>=0¡£ - // ÅäÖøòÎÊýµÄ×÷Óã¬ÔÚÓÚ¶ÔDataËã×Ó²»Í¬µÄÊä³ö×ö²»Í¬µÄAIPP´¦Àí£¬Èç¹û¸Ã²ÎÊýûÓÐÅäÖã¬Ä¬È϶Ôrelated_input_rankÖ¸¶¨µÄÄ£ÐÍÊäÈëµÄËùÓÐÊä³ö±ß×öAIPP¡£ - // ÅäÖÃÖµ <= DataËã×ÓÊä³ö±ßµÄ¸öÊý¡£ - repeated uint32 input_edge_idx = 3; - - // [Begin] ¶¯Ì¬AIPP²ÎÊý£¬ÅäÖþ²Ì¬AIPPʱÎÞЧ - uint32 max_src_image_size = 4; - - // ÊÇ·ñÖ§³ÖÐýת¡£Ä¬Èϲ»Ö§³Ö£¬¿ªÆôÖ§³ÖÐýתʱ£¬»áÓжîÍâµÄ¿Õ¼äºÍÐÔÄÜËðʧ - bool support_rotation = 5; - - // [End] ¶¯Ì¬AIPP²ÎÊý - - - // [Begin] ¾²Ì¬AIPP²ÎÊý£¬ÅäÖö¯Ì¬AIPPʱÎÞЧ - InputFormat input_format = 51; - bool csc_switch = 52; - float cpadding_value = 53; - bool rbuv_swap_switch = 54; - bool ax_swap_switch = 55; - bool single_line_mode = 56; - - int32 src_image_size_w = 57; - int32 src_image_size_h = 58; - - bool crop = 59; - int32 load_start_pos_w = 60; - int32 load_start_pos_h = 61; - int32 crop_size_w = 62; - int32 crop_size_h = 63; - - bool resize = 64; - int32 resize_output_w = 65; - int32 resize_output_h = 66; - - bool padding = 67; - int32 left_padding_size = 68; - int32 right_padding_size = 69; - int32 top_padding_size = 70; - int32 bottom_padding_size = 71; - - int32 mean_chn_0 = 10; - int32 mean_chn_1 = 11; - int32 mean_chn_2 = 12; - int32 mean_chn_3 = 19; - float min_chn_0 = 13; - float min_chn_1 = 14; - float min_chn_2 = 15; - float min_chn_3 = 20; - repeated float var_reci_chn_0 = 16; - repeated float var_reci_chn_1 = 17; - repeated float var_reci_chn_2 = 18; - repeated float var_reci_chn_3 = 21; - - repeated int32 matrix_r0c0 = 30; - repeated int32 matrix_r0c1 = 31; - repeated int32 matrix_r0c2 = 32; - repeated int32 matrix_r1c0 = 33; - repeated int32 matrix_r1c1 = 34; - repeated int32 matrix_r1c2 = 35; - repeated int32 matrix_r2c0 = 36; - repeated int32 matrix_r2c1 = 37; - repeated int32 matrix_r2c2 = 38; - repeated int32 output_bias_0 = 39; - repeated int32 output_bias_1 = 40; - repeated int32 output_bias_2 = 41; - repeated int32 input_bias_0 = 42; - repeated int32 input_bias_1 = 43; - repeated int32 input_bias_2 = 44; - - // [End] ¾²Ì¬AIPP²ÎÊý - - // The n number that is used for raw/rgbir data into f16 transformation. - // The transformation equation is x/(2^n). If set to 0, no transform is performed. - uint32 raw_rgbir_to_f16_n = 45; -} - -message MultiShapeOpParams { - enum MultiShapeMode { - batch = 0; //¶¯Ì¬batch - resolution = 1; //¶¯Ì¬·Ö±æÂÊ£¬À©Õ¹Óà - } - - MultiShapeMode mode = 1; //Ëã×Óģʽ - uint32 related_input_rank = 2; //ÐÂÔöËã×Ó²åÈëµ½ÄĸöÊäÈë - - - repeated uint32 batch_list = 11; //batch_listÖµ£¬batch_listµÄ¸öÊýÊÇ2µ½8Ö®¼ä -} diff --git a/ge/executor/proto/om.proto b/ge/executor/proto/om.proto deleted file mode 100644 index dd992191..00000000 --- a/ge/executor/proto/om.proto +++ /dev/null @@ -1,401 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package domi; - -enum TargetType -{ - MINI = 0; - TINY = 1; - LITE = 2; -} - -// offline model -message ModelDef { - string name = 1; - uint32 version = 2; - - uint64 memory_size = 10; - uint32 stream_num = 11; - uint32 event_num = 12; - uint64 weight_size = 13; - uint32 label_num = 15; - repeated OpDef op = 20; - TargetType target_type = 23; - - map attr = 30; -}; - -// operator define -message OpDef { - string name = 1; - string type = 2; - - uint32 id = 3; - uint32 stream_id = 4; - - repeated string input_name = 5; - - repeated string src_name = 8; - repeated int32 src_index = 9; - repeated int64 input = 10; - repeated int64 output = 11; - repeated TensorDescriptor input_desc = 12; - repeated TensorDescriptor output_desc = 13; - repeated WeightDef weights = 14; - repeated string dst_name = 15; - repeated int32 dst_index = 16; - - repeated int64 workspace = 20; - repeated uint32 workspace_bytes = 21; - - repeated string weight_name = 22; - repeated bool is_input_const = 23; - - map attr = 30; - - QuantizeFactorParams quantize_factor = 31; - - oneof op_params { - // start at 100 here - SendOpParams sender_param = 100; - RecvOpParams receiver_param = 200; - ConvolutionOpParams convolution_param = 300; - PoolingOpParams pooling_param = 400; - EltwiseOpParams eltwise_param = 500; - BatchNormOpParams batchnorm_param = 600; - ScaleOpParams scale_param = 700; - FullConnectionOpParams full_connection_param = 800; - SoftmaxOpParams softmax_param = 900; - ActivationOpParams activation_param = 1000; - ReshapeOpParams reshape_param = 1100; - } -}; - -message SendOpParams { - uint32 event_id = 1; -}; - -message RecvOpParams { - uint32 event_id = 1; -}; - -enum QuantizeScaleType -{ - VECTOR_SCALE = 0; - SCALAR_SCALE = 1; -} - -enum QuantizeScaleMode -{ - NORMAL_MODE = 0; - SQRT_MODE = 1; -} - -enum QuantizeAlgorithm -{ - NON_OFFSET_ALGO = 0; - HALF_OFFSET_ALGO = 1; - ALL_OFFSET_ALGO = 2; -} -message QuantizeFactor -{ - QuantizeScaleMode scale_mode = 1; - bytes scale_value = 2; - int64 scale_offset = 3; - bytes offset_data_value = 4; - int64 offset_data_offset = 5; - bytes offset_weight_value = 6; - int64 offset_weight_offset = 7; - bytes offset_pad_value = 8; - int64 offset_pad_offset = 9; -}; - -message QuantizeCalcFactor -{ - bytes offsetw = 1; - int64 offsetw_offset = 2; - bytes offsetd = 3; - int64 offsetd_offset = 4; - bytes scalereq = 5; - int64 scaledreq_offset = 6; - bytes offsetdnext = 7; - int64 offsetdnext_offset = 8; -} - -message QuantizeFactorParams -{ - QuantizeAlgorithm quantize_algo = 1; - QuantizeScaleType scale_type = 2; - QuantizeFactor quantize_param = 3; - QuantizeFactor dequantize_param = 4; - QuantizeFactor requantize_param = 5; - QuantizeCalcFactor quantizecalc_param = 6; -}; - -message ConvolutionOpParams { - int32 mode = 1; - int32 algo = 2; - int32 pad_mode = 3; - uint32 group = 4; - uint32 num_output = 5; - - repeated uint32 pad = 10; - repeated uint32 stride = 11; - repeated uint32 dilation = 12; - repeated uint32 kernel = 13; - - float alpha = 20; - float beta = 21; - - WeightDef filter = 40; - WeightDef bias = 41; - - bool relu_flag = 62; - repeated uint32 adj = 70; - repeated uint32 target_shape = 71; - repeated uint32 before_pad = 72; -}; - -message PoolingOpParams { - int32 mode = 1; - int32 nan_opt = 2; - int32 pad_mode = 3; - bool global_pooling = 4; - - repeated uint32 window = 10; - repeated uint32 pad = 11; - repeated uint32 stride = 12; - bool ceil_mode = 13; - int32 data_mode = 14; - - float alpha = 20; - float beta = 21; - repeated uint32 before_pad = 22; -}; - -message EltwiseOpParams { - int32 mode = 1; - repeated float coeff = 2; - float alpha = 3; - float beta = 4; - repeated WeightDef weight = 5; - bool relu_flag = 6; -}; - -message ActivationOpParams { - int32 mode = 1; - float coef = 2; - float alpha = 3; - float beta = 4; -}; - -message BatchNormOpParams { - int32 mode = 1; - - float alpha = 2; - float beta = 3; - double epsilon = 4;//optinal,[default = 1e-5] - bool use_global_stats = 5; //optinal,by default true,testing mode - float moving_average_fraction = 6; //optinal,[default = .999]; - - WeightDef estimated_mean = 7; - WeightDef estimated_variance = 8; - - WeightDef scale = 9; - WeightDef bias = 10; -}; - -message ScaleOpParams { - WeightDef scale = 1; - WeightDef bias = 2; -}; - -message ReshapeOpParams { - float alpha = 1; - float beta = 2; - ShapeDef shape = 3; - int32 axis = 4; - int32 num_axes = 5; - int32 format = 6; -}; - -message SoftmaxOpParams { - int32 algo = 1; - int32 mode = 2; - float alpha = 3; - float beta = 4; -}; - -message FullConnectionOpParams { - WeightDef filter = 1; - WeightDef bias = 2; - uint32 num_output = 3; - bool relu_flag = 12; -}; - -message FlattenOpParams { - float alpha = 1; - float beta = 2; - int32 start_axis = 3; - int32 end_axis = 4; -} - -message AddLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message MulLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message AddOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message MulOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message SubOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message BiasAddOpParams { - float alpha = 1; - float beta = 2; - - WeightDef bias = 10; -}; - -message MatMulOpParams { - float alpha = 1; - float beta = 2; - bool transposeX = 3; - bool transposeW = 4; - - WeightDef filter = 10; - WeightDef bias = 12; -}; - -message RsqrtOpParams { - float alpha = 1; - float beta = 2; -}; - - -message WeightDef { - int32 format = 1; - int32 data_type = 2; - ShapeDef shape = 3; - bytes data = 4; - int64 data_offset = 5; - uint32 cmps_size = 6; - bytes cmps_tab = 7; - int64 cmps_tab_offset = 10; - CompressInfo cmps_info = 8; - AllOffsetQuantizeInfo alloffset_quantize_info = 11; -} - -message ShapeDef { - repeated int64 dim = 1; -} - -enum DeviceType { - NPU = 0; // In default, we will use NPU. - CPU = 1; // CPU -} - -message AllOffsetQuantizeInfo { - float scale = 1; - int32 offset = 2; -} - -message TensorDescriptor { - int32 format = 1; - int32 data_type = 2; - repeated int64 dim = 3; - uint32 size = 4; - bool reuse_input = 5; - bool output_tensor = 7; - DeviceType device_type = 8; - bool input_tensor = 9; - uint32 real_dim_cnt = 10; - uint32 reuse_input_index = 11; - AllOffsetQuantizeInfo alloffset_quantize_info = 12; -} - -message CompressInfo { - int32 blockRow = 1; // block row - int32 blockCol = 2; // block col - int32 fractalK = 3; // fractal K - int32 fractalN = 4; // fractal N - int32 lastFractalK = 5; // K of last fractal - int32 lastFractalN = 6; // N of last fractal - int32 cubeSize = 7; // cube's length - int32 loadDir = 8; // data load directtiono 0:col load 1:row load -} - -message AttrDef { - message ListValue { - repeated string s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated uint32 u = 6 [packed = true]; // "list(uint)" - repeated bytes bt = 7; - } - - oneof value { - string s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - uint32 u = 6; // "uint32" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs { - string name = 1; - map attr = 2; -} - diff --git a/ge/executor/proto/op_mapping_info.proto b/ge/executor/proto/op_mapping_info.proto deleted file mode 100644 index 7b84a115..00000000 --- a/ge/executor/proto/op_mapping_info.proto +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; -package aicpu.dump; - -message Shape { - repeated uint64 dim = 1; -} - -message Output { - int32 data_type = 1; - int32 format = 2; - Shape shape = 3; - uint64 address = 4; - string original_name = 5; - int32 original_output_index = 6; - int32 original_output_data_type = 7; - int32 original_output_format = 8; - uint64 size = 9; -} - -message Input { - int32 data_type =1; - int32 format = 2; - Shape shape = 3; - uint64 address = 4; - uint64 size = 5; -} - -enum BufferType { - L1 = 0; -} - -message OpBuffer { - BufferType buffer_type = 1; - uint64 address = 2; - uint64 size = 3; -} - -message Op { - string op_name = 1; - string op_type = 2; -} - -message Task { - uint32 task_id = 1; - uint32 stream_id = 2; - Op op = 3; - repeated Output output = 4; - bool end_graph = 5; - repeated Input input = 6; - repeated OpBuffer buffer = 7; -} - -message OpMappingInfo { - string dump_path = 1; - oneof model_name_param { - string model_name = 2; - } - oneof model_id_param { - uint32 model_id = 3; - } - oneof step_id { - uint64 step_id_addr = 4; - } - oneof iterations_per_loop { - uint64 iterations_per_loop_addr = 5; - } - oneof loop_cond { - uint64 loop_cond_addr = 6; - } - uint32 flag = 7; // 0x01 load, 0x00 unload - repeated Task task = 8; - string dump_step = 9; -} \ No newline at end of file diff --git a/ge/executor/proto/task.proto b/ge/executor/proto/task.proto deleted file mode 100644 index 50ea061b..00000000 --- a/ge/executor/proto/task.proto +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package domi; - -message ModelTaskDef { - string version = 1; - - map attr = 9; // Extended field - repeated TaskDef task = 10; - - uint64 memory_size = 11; - uint32 stream_num = 12; - uint32 event_num = 13; - uint64 weight_size = 14; - - repeated bytes op = 15; // input/output opdef in bytes - - uint64 base_addr = 16; // base addr - uint64 weight_addr = 17; // weight addr - uint32 batch_num = 18; -} - - -message TaskDef { - uint32 id = 1; - uint32 type = 2; - - uint32 stream_id = 10; - uint32 event_id = 11; - - KernelDef kernel = 20; - KernelExDef kernel_ex = 21; - KernelHcclDef kernel_hccl = 25; - EventExDef event_ex = 26; - LogTimeStampDef log_timestamp = 28; - - uint32 label_id = 30; - - MemcpyAsyncDef memcpy_async = 31; - StreamSwitchDef stream_switch = 32; - StreamActiveDef stream_active = 33; - bytes private_def = 34; - uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future - StreamSwitchNDef stream_switch_n = 36; - - LabelSetDef label_set = 37; - LabelGotoExDef label_goto_ex = 38; - LabelSwitchByIndexDef label_switch_by_index = 39; -} - -message KernelDef { - KernelContext context = 1; - - string stub_func = 10; - uint32 block_dim = 11; - uint32 args_size = 12; - bytes args = 13; - bytes sm_desc = 14; - bytes flowtable = 15; - string so_name = 16; - string kernel_name = 17; - bytes kernel_ext_info = 18; - uint32 kernel_ext_info_size = 19; -} - -message KernelContext { - uint32 kernel_type = 1; - uint32 op_id = 2; // OP type in CCE - uint32 kernel_func_id = 3; - uint32 op_index = 4; // TE/Custom operator - bool is_flowtable = 5; // Identify whether args is a flowtable structure - bytes args_offset = 6; // args offset information - uint32 args_count = 7; // args count - repeated uint32 origin_op_index = 8; -} - - -message KernelExDef { - uint32 flags = 1; - - uint32 op_index = 4; - uint32 args_size = 12; - bytes args = 13; - bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput - uint32 task_info_size = 15; - bytes kernel_ext_info = 16; - uint32 kernel_ext_info_size = 17; -} - - -message KernelHcclDef { - uint32 op_index = 8; - string hccl_type = 9; -} - - -message EventExDef { - uint32 op_index = 1; - uint32 event_type = 2; -} - -message LogTimeStampDef { - uint64 logid = 1; - bool notify = 2; - uint32 flat = 3; -} - -message MemcpyAsyncDef { - uint64 dst = 1; - uint64 dst_max = 2; - uint64 src = 3; - uint64 count = 4; - uint32 kind = 5; - uint32 op_index = 6; -} - -message StreamSwitchDef { - uint32 op_index = 1; - uint32 true_stream_id = 2; - int64 value = 3; - uint64 value_ptr = 4; - uint32 data_type = 5; -} - -message StreamActiveDef { - uint32 op_index = 1; - uint32 active_stream_id = 2; -} - -message StreamSwitchNDef { - uint32 op_index = 1; - uint32 size = 2; - repeated int64 target_value = 3; - repeated uint32 true_stream_id = 4; - uint32 element_size = 5; - uint32 data_type = 6; -} - -message LabelSetDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelGotoExDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelSwitchByIndexDef { - uint32 op_index = 1; - uint32 label_max = 2; -} diff --git a/ge/ge_local_engine/CMakeLists.txt b/ge/ge_local_engine/CMakeLists.txt deleted file mode 100755 index 7a858b29..00000000 --- a/ge/ge_local_engine/CMakeLists.txt +++ /dev/null @@ -1,225 +0,0 @@ -set(PROTO_LIST - "${METADEF_DIR}/proto/task.proto" -) - -set(SRC_LIST - "engine/ge_local_engine.cc" - "ops_kernel_store/ge_local_ops_kernel_info.cc" - "ops_kernel_store/op/op_factory.cc" - "ops_kernel_store/op/op.cc" - "ops_kernel_store/op/ge_deleted_op.cc" - "ops_kernel_store/op/no_op.cc" -) - -set(OPS_KERNEL_SRC_LIST - "ops_kernel_store/ge_local_ops_kernel_builder.cc" - "ops_kernel_store/op/op_factory.cc" - "ops_kernel_store/op/op.cc" - "ops_kernel_store/op/ge_deleted_op.cc" - "ops_kernel_store/op/no_op.cc" -) - -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) - -############ libge_local_engine.so ############ -add_library(ge_local_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) - -target_compile_options(ge_local_engine PRIVATE - -Werror -) - -target_include_directories(ge_local_engine PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc -) - -target_link_libraries(ge_local_engine PRIVATE - $ - -Wl,--no-as-needed - graph - protobuf - register - c_sec - slog - runtime - -Wl,--as-needed -) - -######### atclib/libge_local_engine.so ############# -add_library(atc_ge_local_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) - -target_compile_options(atc_ge_local_engine PRIVATE - -Werror -) - -target_compile_definitions(atc_ge_local_engine PRIVATE - COMPILE_OMG_PACKAGE -) - -target_include_directories(atc_ge_local_engine PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc -) - -target_link_libraries(atc_ge_local_engine PRIVATE - $ - -Wl,--no-as-needed - graph - protobuf - register - c_sec - slog - runtime_compile - -Wl,--as-needed -) - -set_target_properties(atc_ge_local_engine PROPERTIES - OUTPUT_NAME ge_local_engine - LIBRARY_OUTPUT_DIRECTORY atclib -) - -############ libge_local_opskernel_builder.so ############ -add_library(ge_local_opskernel_builder SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_HDRS}) - -target_compile_options(ge_local_opskernel_builder PRIVATE - -Werror -) - -target_include_directories(ge_local_opskernel_builder PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc -) - -target_link_libraries(ge_local_opskernel_builder PRIVATE - $ - -Wl,--no-as-needed - protobuf - c_sec - slog - register - graph - -Wl,--as-needed -) - -############ atclib/libge_local_opskernel_builder.so ############ -add_library(atc_ge_local_opskernel_builder SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_HDRS}) - -target_compile_options(atc_ge_local_opskernel_builder PRIVATE - -Werror -) - -target_include_directories(atc_ge_local_opskernel_builder PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc -) - -target_link_libraries(atc_ge_local_opskernel_builder PRIVATE - $ - -Wl,--no-as-needed - protobuf - c_sec - slog - register - graph - -Wl,--as-needed -) - -set_target_properties(atc_ge_local_opskernel_builder PROPERTIES - OUTPUT_NAME ge_local_opskernel_builder - LIBRARY_OUTPUT_DIRECTORY atclib -) - -############ libge_local_opskernel_builder.a ############ -add_library(ge_local_opskernel_builder_static SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_HDRS}) - -target_compile_options(ge_local_opskernel_builder_static PRIVATE - -Werror -) - -target_include_directories(ge_local_opskernel_builder_static PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc -) - -target_link_libraries(ge_local_opskernel_builder_static PRIVATE - $ - protobuf - c_sec -) - -############ install ############ -set(INSTALL_BASE_DIR "") -set(INSTALL_LIBRARY_DIR lib) - -install(TARGETS ge_local_engine ge_local_opskernel_builder OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} -) - -install(TARGETS atc_ge_local_engine atc_ge_local_opskernel_builder OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/atclib -) diff --git a/ge/ge_local_engine/proto/task.proto b/ge/ge_local_engine/proto/task.proto deleted file mode 100644 index 50ea061b..00000000 --- a/ge/ge_local_engine/proto/task.proto +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package domi; - -message ModelTaskDef { - string version = 1; - - map attr = 9; // Extended field - repeated TaskDef task = 10; - - uint64 memory_size = 11; - uint32 stream_num = 12; - uint32 event_num = 13; - uint64 weight_size = 14; - - repeated bytes op = 15; // input/output opdef in bytes - - uint64 base_addr = 16; // base addr - uint64 weight_addr = 17; // weight addr - uint32 batch_num = 18; -} - - -message TaskDef { - uint32 id = 1; - uint32 type = 2; - - uint32 stream_id = 10; - uint32 event_id = 11; - - KernelDef kernel = 20; - KernelExDef kernel_ex = 21; - KernelHcclDef kernel_hccl = 25; - EventExDef event_ex = 26; - LogTimeStampDef log_timestamp = 28; - - uint32 label_id = 30; - - MemcpyAsyncDef memcpy_async = 31; - StreamSwitchDef stream_switch = 32; - StreamActiveDef stream_active = 33; - bytes private_def = 34; - uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future - StreamSwitchNDef stream_switch_n = 36; - - LabelSetDef label_set = 37; - LabelGotoExDef label_goto_ex = 38; - LabelSwitchByIndexDef label_switch_by_index = 39; -} - -message KernelDef { - KernelContext context = 1; - - string stub_func = 10; - uint32 block_dim = 11; - uint32 args_size = 12; - bytes args = 13; - bytes sm_desc = 14; - bytes flowtable = 15; - string so_name = 16; - string kernel_name = 17; - bytes kernel_ext_info = 18; - uint32 kernel_ext_info_size = 19; -} - -message KernelContext { - uint32 kernel_type = 1; - uint32 op_id = 2; // OP type in CCE - uint32 kernel_func_id = 3; - uint32 op_index = 4; // TE/Custom operator - bool is_flowtable = 5; // Identify whether args is a flowtable structure - bytes args_offset = 6; // args offset information - uint32 args_count = 7; // args count - repeated uint32 origin_op_index = 8; -} - - -message KernelExDef { - uint32 flags = 1; - - uint32 op_index = 4; - uint32 args_size = 12; - bytes args = 13; - bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput - uint32 task_info_size = 15; - bytes kernel_ext_info = 16; - uint32 kernel_ext_info_size = 17; -} - - -message KernelHcclDef { - uint32 op_index = 8; - string hccl_type = 9; -} - - -message EventExDef { - uint32 op_index = 1; - uint32 event_type = 2; -} - -message LogTimeStampDef { - uint64 logid = 1; - bool notify = 2; - uint32 flat = 3; -} - -message MemcpyAsyncDef { - uint64 dst = 1; - uint64 dst_max = 2; - uint64 src = 3; - uint64 count = 4; - uint32 kind = 5; - uint32 op_index = 6; -} - -message StreamSwitchDef { - uint32 op_index = 1; - uint32 true_stream_id = 2; - int64 value = 3; - uint64 value_ptr = 4; - uint32 data_type = 5; -} - -message StreamActiveDef { - uint32 op_index = 1; - uint32 active_stream_id = 2; -} - -message StreamSwitchNDef { - uint32 op_index = 1; - uint32 size = 2; - repeated int64 target_value = 3; - repeated uint32 true_stream_id = 4; - uint32 element_size = 5; - uint32 data_type = 6; -} - -message LabelSetDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelGotoExDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelSwitchByIndexDef { - uint32 op_index = 1; - uint32 label_max = 2; -} diff --git a/ge/ge_runtime/CMakeLists.txt b/ge/ge_runtime/CMakeLists.txt deleted file mode 100644 index b4c7fe9e..00000000 --- a/ge/ge_runtime/CMakeLists.txt +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -############ libge_runtime.so ############ -set(GE_SRC_LIST - "model_runner.cc" - "runtime_model.cc" - "output.cc" - "task/*.cc" -) - -add_library(ge_runtime SHARED ${GE_SRC_LIST}) - -target_compile_options(ge_runtime PRIVATE - -Werror - -O2 -) - -target_compile_definitions(ge_runtime PUBLIC - PROTOBUF_INLINE_NOT_IN_HEADERS=0 -) - -target_include_directories(ge_runtime PRIVATE - ${TOP_DIR} - ${TOP_DIR}/inc - ${TOP_DIR}/inc/graph - ${TOP_DIR}/inc/external - ${TOP_DIR}/inc/framework - ${TOP_DIR}/inc/framework/common - ${TOP_DIR}/inc/framework/ge_runtime - ${TOP_DIR}/inc/cce - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge -) - -target_link_libraries(ge_runtime PRIVATE - $ - -Wl,--no-as-needed - graph - slog - runtime - c_sec - -Wl,--as-needed - -lrt - -ldl -) - -############ install ############ -set(INSTALL_BASE_DIR "") -set(INSTALL_LIBRARY_DIR lib) - -install(TARGETS ge_runtime OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} -) diff --git a/ge/ge_runtime/module.mk b/ge/ge_runtime/module.mk deleted file mode 100755 index 43d81bfa..00000000 --- a/ge/ge_runtime/module.mk +++ /dev/null @@ -1,66 +0,0 @@ -LOCAL_PATH := $(call my-dir) - -# task.proto is old task, add it for ops_kernel_info_store -local_ge_runtime_src_files := \ - model_runner.cc \ - runtime_model.cc \ - output.cc \ - task/aicpu_task.cc \ - task/cce_task.cc \ - task/tbe_task.cc \ - task/event_record_task.cc \ - task/event_wait_task.cc \ - task/stream_active_task.cc \ - task/stream_switch_task.cc \ - task/hccl_task.cc \ - task/memcpy_async_task.cc \ - task/profiler_task.cc \ - -local_ge_runtime_include := \ - $(LOCAL_PATH)/ \ - $(TOPDIR)libc_sec/include \ - $(TOPDIR)inc/external \ - $(TOPDIR)inc/external/graph \ - $(TOPDIR)inc/framework \ - $(TOPDIR)inc/graph \ - $(TOPDIR)inc \ - $(LOCAL_PATH)/../ \ - third_party/protobuf/include - -local_ge_runtime_shared_library := \ - libruntime \ - libslog \ - libc_sec - -local_ge_runtime_ldflags := -lrt -ldl - -# compile device libge_runtime -include $(CLEAR_VARS) - -LOCAL_MODULE := libge_runtime -LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 -LOCAL_CFLAGS += -Werror -LOCAL_SRC_FILES := $(local_ge_runtime_src_files) -LOCAL_C_INCLUDES := $(local_ge_runtime_include) -LOCAL_SHARED_LIBRARIES := $(local_ge_runtime_shared_library) -LOCAL_LDFLAGS += $(local_ge_runtime_ldflags) - -include $(BUILD_SHARED_LIBRARY) - -# compile host libge_runtime -include $(CLEAR_VARS) - -LOCAL_MODULE := libge_runtime -LOCAL_CFLAGS += -Werror -LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -ifeq ($(DEBUG), 1) - LOCAL_CFLAGS += -g -O0 -else - LOCAL_CFLAGS += -O2 -endif -LOCAL_SRC_FILES := $(local_ge_runtime_src_files) -LOCAL_C_INCLUDES := $(local_ge_runtime_include) -LOCAL_SHARED_LIBRARIES := $(local_ge_runtime_shared_library) -LOCAL_LDFLAGS += $(local_ge_runtime_ldflags) - -include $(BUILD_HOST_SHARED_LIBRARY) diff --git a/ge/graph/build/memory/CMakeLists.txt b/ge/graph/build/memory/CMakeLists.txt deleted file mode 100644 index c568f2fe..00000000 --- a/ge/graph/build/memory/CMakeLists.txt +++ /dev/null @@ -1,38 +0,0 @@ -set(SRC_LIST - "memory_assigner.cc" - "graph_mem_assigner.cc" - "binary_block_mem_assigner.cc" - "block_mem_assigner.cc" - "hybrid_mem_assigner.cc" - "max_block_mem_assigner.cc" - "var_mem_assign_util.cc" -) - -############ libge_memory.a ############ -add_library(ge_memory STATIC ${SRC_LIST}) - -target_compile_options(ge_memory PRIVATE - -Werror - -O2 -) - -target_link_libraries(ge_memory PRIVATE - $ - protobuf - c_sec -) - -target_include_directories(ge_memory PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${GE_CODE_DIR}/inc/framework - #### yellow zone #### - ${GE_CODE_DIR}/../inc - #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc -) diff --git a/ge/host_cpu_engine/CMakeLists.txt b/ge/host_cpu_engine/CMakeLists.txt deleted file mode 100644 index 63d219d0..00000000 --- a/ge/host_cpu_engine/CMakeLists.txt +++ /dev/null @@ -1,214 +0,0 @@ -set(PROTO_LIST - "${METADEF_DIR}/proto/task.proto" -) - -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) - -set(SRC_LIST - "engine/host_cpu_engine.cc" - "ops_kernel_store/host_cpu_ops_kernel_info.cc" - "ops_kernel_store/op/op_factory.cc" - "ops_kernel_store/op/host_op.cc" -) - -set(CPU_OPS_KERNEL_LIST - "ops_kernel_store/host_cpu_ops_kernel_builder.cc" -) - -############ libhost_cpu_engine.so ############ -add_library(host_cpu_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) - -target_compile_options(host_cpu_engine PRIVATE - -Werror -) - -target_include_directories(host_cpu_engine PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc -) - -target_link_libraries(host_cpu_engine PRIVATE - $ - -Wl,--no-as-needed - protobuf - c_sec - graph - register - slog - runtime - -Wl,--as-needed -) - -############ atcstub/libhost_cpu_engine.so ############ -add_library(atc_host_cpu_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) - -target_compile_options(atc_host_cpu_engine PRIVATE - -Werror -) - -target_compile_definitions(atc_host_cpu_engine PRIVATE - COMPILE_OMG_PACKAGE -) - -target_include_directories(atc_host_cpu_engine PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc -) - -target_link_libraries(atc_host_cpu_engine PRIVATE - $ - -Wl,--no-as-needed - protobuf - c_sec - graph - register - slog - runtime_compile - -Wl,--as-needed -) - -set_target_properties(atc_host_cpu_engine PROPERTIES - OUTPUT_NAME host_cpu_engine - LIBRARY_OUTPUT_DIRECTORY atclib -) - -############ libhost_cpu_opskernel_builder.so ############ -add_library(host_cpu_opskernel_builder SHARED ${CPU_OPS_KERNEL_LIST}) - -target_compile_options(host_cpu_opskernel_builder PRIVATE - -Werror -) - -target_include_directories(host_cpu_opskernel_builder PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc -) - -target_link_libraries(host_cpu_opskernel_builder PRIVATE - $ - -Wl,--no-as-needed - protobuf - c_sec - slog - graph - register - -Wl,--as-needed -) - -############ atclib/libhost_cpu_opskernel_builder.so ############ -add_library(atc_host_cpu_opskernel_builder SHARED ${CPU_OPS_KERNEL_LIST}) - -target_compile_options(atc_host_cpu_opskernel_builder PRIVATE - -Werror -) - -target_include_directories(atc_host_cpu_opskernel_builder PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc -) - -target_link_libraries(atc_host_cpu_opskernel_builder PRIVATE - $ - -Wl,--no-as-needed - protobuf - c_sec - slog - graph - register - -Wl,--as-needed -) - -set_target_properties(atc_host_cpu_opskernel_builder PROPERTIES - OUTPUT_NAME host_cpu_opskernel_builder - LIBRARY_OUTPUT_DIRECTORY atclib -) - -############ libhost_cpu_opskernel_builder.a ############ -add_library(host_cpu_opskernel_builder_static SHARED ${CPU_OPS_KERNEL_LIST}) - -target_compile_options(host_cpu_opskernel_builder_static PRIVATE - -Werror -) - -target_include_directories(host_cpu_opskernel_builder_static PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc -) - -target_link_libraries(host_cpu_opskernel_builder_static PRIVATE - $ - protobuf - c_sec -) - -############ install ############ -set(INSTALL_BASE_DIR "") -set(INSTALL_LIBRARY_DIR lib) - -install(TARGETS host_cpu_engine host_cpu_opskernel_builder OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} -) - -install(TARGETS atc_host_cpu_engine atc_host_cpu_opskernel_builder OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/atclib -) diff --git a/ge/host_cpu_engine/proto/task.proto b/ge/host_cpu_engine/proto/task.proto deleted file mode 100644 index 36ae4847..00000000 --- a/ge/host_cpu_engine/proto/task.proto +++ /dev/null @@ -1 +0,0 @@ -../../proto/task.proto \ No newline at end of file diff --git a/ge/offline/CMakeLists.txt b/ge/offline/CMakeLists.txt deleted file mode 100644 index a5a334bd..00000000 --- a/ge/offline/CMakeLists.txt +++ /dev/null @@ -1,81 +0,0 @@ -set(PROTO_LIST - "${METADEF_DIR}/proto/om.proto" - "${METADEF_DIR}/proto/ge_ir.proto" - "${METADEF_DIR}/proto/insert_op.proto" - "${METADEF_DIR}/proto/task.proto" -) - -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) - -set(SRC_LIST - "main.cc" - "single_op_parser.cc" - "../session/omg.cc" - "../ir_build/atc_ir_common.cc" -) - -############ atc ############ -add_executable(atc ${SRC_LIST} ${PROTO_HDRS}) - -target_compile_options(atc PRIVATE - -Werror - -O2 -) - -target_compile_definitions(atc PRIVATE - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - COMPILE_OMG_PACKAGE -) - -target_include_directories(atc PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${GE_CODE_DIR} - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc/external - ${GE_CODE_DIR}/common/inc/external - ${GE_CODE_DIR}/common/inc/external/graph - ${GE_CODE_DIR}/inc - ${GE_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/graph - ${METADEF_DIR}/inc/register - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/graph - ${METADEF_DIR}/inc/external/register - ${PARSER_DIR} - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - ${GE_CODE_DIR}/../inc/common - #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc - ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain -) - -target_link_libraries(atc PRIVATE - $ - protobuf - ge_common - register - c_sec - graph - error_manager - ge_compiler - parser_common - gflags - json - runtime_compile - slog - mmpa - -lrt - -ldl -) - -############ install ############ -set(INSTALL_BASE_DIR "") -set(INSTALL_LIBRARY_DIR lib) - -install(TARGETS atc OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} -) diff --git a/ge/offline/main.cc b/ge/offline/main.cc deleted file mode 100755 index 9fa2cfba..00000000 --- a/ge/offline/main.cc +++ /dev/null @@ -1,1334 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "common/gflags_util.h" -#include "common/util.h" -#include "common/util/error_manager/error_manager.h" -#include "framework/common/debug/ge_log.h" -#include "ge/ge_api.h" -#include "generator/ge_generator.h" -#include "graph/anchor.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/graph.h" -#include "graph/op_desc.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/type_utils.h" -#include "init/gelib.h" -#include "ir_build/atc_ir_common.h" -#include "omg/omg.h" -#include "omg/parser/parser_factory.h" -#include "omg/parser/parser_inner_ctx.h" -#include "parser/common/register_tbe.h" -#include "register/op_registry.h" -#include "single_op_parser.h" - -using domi::BuildMode; -using domi::OpRegistrationData; -using domi::OpRegistry; -using domi::Status; -using domi::SUCCESS; -using ge::GEN_OM_MODEL; -using ge::GflagsUtils; -using ge::MODEL_TO_JSON; -using ge::ONLY_PRE_CHECK; -using ge::ParseInputShape; -using ge::PBTXT_TO_JSON; -using std::map; -using std::pair; -using std::shared_ptr; -using std::string; -using std::vector; - -static bool is_dynamic_input = false; - -// 310 limited 8G size -const char *const kGraphMemoryManagerMallocMaxSize = "8*1024*1024*1024"; -const char *const kModeSupport = "only support 0(model to framework model), " - "1(framework model to json), 3(only pre-check), 5(pbtxt to json)"; -const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow)"; - -// limit available mem size 2G -const long kMinAvailableMem = 2 * 1024 * 1024; - -DEFINE_string(model, "", "The model file."); -DEFINE_string(output, "", "The output file path&name."); -DEFINE_int32(framework, -1, "Framework type(0:Caffe; 1:MindSpore; 3:Tensorflow)."); -DEFINE_string(weight, "", "Optional; weight file. Required when framework is Caffe."); - -DEFINE_string(input_shape, "", - "Optional; shape of input data. Required when framework is caffe " - "or TensorFLow or MindSpore." - "Format: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\""); -DEFINE_bool(h, false, "show this help message"); -DEFINE_string(cal_conf, "", "Optional; the calibration config file."); - -DEFINE_string(insert_op_conf, "", "Optional; the config file to insert new op, for example AIPP op."); -DEFINE_string(op_name_map, "", "Optional; custom op name mapping file."); - -DEFINE_string(target, "", "Optional; mini."); - -DEFINE_string(om, "", "The model file to be converted to json."); -DEFINE_string(json, "", "The output json file path&name which is converted from a model."); -DEFINE_int32(mode, 0, - "Optional; run mode, 0(default): model => framework model; 1: " - "framework model => json; 3: only pre-check; 5: pbtxt => json."); - -#if !defined(__ANDROID__) && !defined(ANDROID) -DEFINE_int32(encrypt_mode, -1, "Optional; the encrypt flag. 0: encrypt; -1(default): not encrypt"); -DEFINE_string(encrypt_key, "", "Optional; the encrypt_key file."); -DEFINE_string(certificate, "", "Optional; the certificate file."); -DEFINE_string(hardware_key, "", "Optional; the ISV key file."); -DEFINE_string(private_key, "", "Optional; the private key file."); -#endif - -DEFINE_string(out_nodes, "", - "Optional; output nodes designated by users." - "Format: \"node_name1:0;node_name1:1;node_name2:0\""); - -DEFINE_string(precision_mode, "force_fp16", - "Optional; precision mode." - "Support force_fp16, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); - -DEFINE_string(input_format, "", - "Optional; input_format, format of input data, NCHW;NHWC." - "Format:\"NHWC\""); - -DEFINE_string(check_report, "check_result.json", "Optional; the pre-checking report file."); - -DEFINE_string(input_fp16_nodes, "", - "Optional; input node datatype is fp16 and format is NC1HWC0." - "Format:\"node_name1;node_name2\""); - -DEFINE_string(is_output_adjust_hw_layout, "", - "Optional; Net output node's datatype is fp16 and format is " - "NC1HWC0, or not." - "Format:\"false,true,false,true\""); - -DEFINE_string(is_input_adjust_hw_layout, "", - "Optional; Intput node's datatype is fp16 and format is " - "NC1HWC0, or not." - "Format:\"false,true,false,true\""); - -DEFINE_string(output_type, "", - "Optional; output type! " - "Support FP32,FP16,INT8,INT16,UINT16,UINT8,INT32,INT64,UINT32,UINT64,DOUBLE."); - -DEFINE_string(op_select_implmode, "", - "Optional; op select implmode! " - "Support high_precision, high_performance."); - -DEFINE_string(optypelist_for_implmode, "", - "Optional; Nodes need use implmode selected in op_select_implmode " - "Format:\"node_name1,node_name2\""); - -DEFINE_string(singleop, "", "Optional; If set, generate single op model with the given json file."); - -DEFINE_int32(disable_reuse_memory, 0, "Optional; If set to 1, disable reuse memory when generating if."); - -DEFINE_string(auto_tune_mode, "", "Optional; Set tune mode."); - -DEFINE_string(soc_version, "", "The soc version."); - -DEFINE_string(core_type, "AiCore", "Optional; If set to VectorCore, only use vector core."); - -DEFINE_string(aicore_num, "", "Optional; Set aicore num"); - -DEFINE_string(buffer_optimize, "l2_optimize", "Optional; buffer optimize"); - -DEFINE_string(fusion_switch_file, "", "Optional; Set fusion switch file path"); - -DEFINE_string(save_original_model, "", "Optional; enable output original offline model. false(default)"); - -DEFINE_string(dynamic_batch_size, "", - "Optional; If set, generate dynamic multi batch model. " - "Different batch sizes are split by ','." - "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); - -DEFINE_string(dynamic_image_size, "", - "Optional; If set, generate dynamic multi image size model." - "Different groups of image size are split by ';'," - "while different dimensions of each group are split by ','." - "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); - -DEFINE_string(dynamic_dims, "", - "Optional; If set, generate dynamic input size model. " - "Different groups of size are split by ';', while different dimensions of each group are split by ','." - "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); - -DEFINE_string(enable_small_channel, "0", "Optional; If set to 1, small channel is enabled."); - -DEFINE_string(enable_compress_weight, "false", - "Optional; enable compress weight. true: enable; false(default): disable"); - -DEFINE_string(compress_weight_conf, "", "Optional; the config file to compress weight"); - -DEFINE_string(enable_single_stream, "", "Optional; enable single stream. true: enable; false(default): disable"); - -DEFINE_string(log, "null", "Optional; generate atc log. Support debug, info, warning, error, null"); - -DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0."); - -DEFINE_int32(op_debug_level, 0, "Optional; configure debug level of compiler. 0(default): close debug;" - "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler"); -DEFINE_string(enable_scope_fusion_passes, "", "Optional; validate the non-general scope fusion pass," - "multiple names can be set and separated by ','."); - -class GFlagUtils { - public: - /** - * @name InitGFlag - * @brief initialize gflag - * @return void - */ - static void InitGFlag(int argc, char *argv[]) { - // -help - gflags::SetUsageMessage( - "usage: ./atc \n" - "generate offline model example:\n" - "./atc --model=./alexnet.prototxt --weight=./alexnet.caffemodel \n" - "--framework=0 --output=./domi \n" - "generate offline model for single op example:\n" - "./atc --singleop=./op_list.json --output=./op_model \n" - "===== Basic Functionality =====\n" - "[General]\n" - " --h/help Show this help message\n" - " --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format " - "3: only pre-check; 5: convert pbtxt file to JSON format\n" - "\n[Input]\n" - " --model Model file\n" - " --weight Weight file. Required when framework is Caffe\n" - " --om The model file to be converted to json\n" - " --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow\n" - " --input_format Format of input data. E.g.: \"NCHW\"\n" - " --input_shape Shape of input data. Separate multiple nodes with semicolons (;)." - "Use double quotation marks (\") to enclose each argument.\n" - " E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n" - " --dynamic_batch_size Set dynamic batch size. E.g: \"batchsize1,batchsize2,batchsize3\"\n" - " --dynamic_image_size Set dynamic image size. Separate multiple nodes with semicolons (;)." - "Use double quotation marks (\") to enclose each argument.\n" - " E.g: \"imagesize1_height,imagesize1_width;imagesize2_height,imagesize2_width\"\n" - " --dynamic_dims Set dynamic dims. Separate multiple nodes with semicolons (;)." - "Use double quotation marks (\") to enclose each argument. E.g: \"dims1_n1,dims1_n2;dims2_n1,dims2_n2\"\n" - " --singleop Single op definition file. atc will generate offline " - "model(s) for single op if --singleop is set.\n" - "\n[Output]\n" - " --output Output file path&name(needn't suffix, will add " - ".om automatically). \n" - " If --singleop is set, this arg specifies the directory to " - "which the single op offline model will be generated\n" - " --output_type Set net output type. Support FP32, FP16, UINT8." - "E.g.: FP16, indicates that all out nodes are set to FP16.\n" - " \"node1:0:FP16;node2:1:FP32\", indicates setting the datatype of multiple out nodes.\n" - " --check_report The pre-checking report file. Default value is: " - "\"check_result.json\"\n" - " --json The output json file path&name which is " - "converted from a model\n" - "\n[Target]\n" - " --soc_version The soc version.\n" - " --core_type Set core type AiCore or VectorCore. VectorCore: use vector core. " - "Default value is: AiCore\n" - " --aicore_num Set aicore num\n" - "===== Advanced Functionality =====\n" - "[Feature]\n" - " --out_nodes Output nodes designated by users. Separate multiple nodes with semicolons (;)." - "Use double quotation marks (\") to enclose each argument.\n" - " E.g.: \"node_name1:0;node_name1:1;node_name2:0\"\n" - " --input_fp16_nodes Input node datatype is fp16. Separate multiple nodes with semicolons " - "(;)." - "Use double quotation marks (\") to enclose each argument." - "E.g.: \"node_name1;node_name2\"\n" - " --insert_op_conf Config file to insert new op\n" - " --op_name_map Custom op name mapping file\n" - " Note: A semicolon(;) cannot be included in each " - "path, otherwise the resolved path will not match the expected one.\n" - " --is_input_adjust_hw_layout Intput node datatype is fp16 and format is " - "NC1HWC0, used with input_fp16_nodes E.g.: \"true,true,false,true\"\n" - " --is_output_adjust_hw_layout Net output node datatype is fp16 and format is " - "NC1HWC0, used with out_nodes. E.g.: \"true,true,false,true\"\n" - "\n[Model Tuning]\n" - " --disable_reuse_memory The switch of reuse memory. Default value is : 0." - "0 means reuse memory, 1 means do not reuse memory.\n" - " --fusion_switch_file Set fusion switch file path\n" - " --enable_scope_fusion_passes validate the non-general scope fusion passes," - "multiple names can be set and separated by ','. E.g.: ScopePass1,ScopePass2,...\n" - " --enable_single_stream Enable single stream. true: enable; false(default): disable\n" - " --enable_small_channel Set enable small channel. 0(default): disable; 1: enable\n" - " --enable_compress_weight Enable compress weight. true: enable; false(default): disable\n" - " --compress_weight_conf Config file to compress weight\n" - " --buffer_optimize Set buffer optimize. \"l2_optimize\" (default). Set \"off_optimize\" to close\n" - "\n[Operator Tuning]\n" - " --precision_mode precision mode, support force_fp16(default), allow_mix_precision, " - "allow_fp32_to_fp16, must_keep_origin_dtype.\n" - " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n" - " --op_select_implmode Set op select implmode. Support high_precision, high_performance." - "default: high_performance\n" - " --optypelist_for_implmode Appoint which op to select implmode, cooperated with op_select_implmode.\n" - " Separate multiple nodes with commas (,). Use double quotation marks (\") " - " to enclose each argument. E.g.: \"node_name1,node_name2\"\n" - " --op_debug_level Debug enable for TBE operator building.\n" - " 0 (default): Disable debug; 1: Enable TBE pipe_all, " - "and generate the operator CCE file and Python-CCE mapping file (.json);\n" - " 2: Enable TBE pipe_all, generate the operator CCE file and Python-CCE mapping file " - "(.json), and enable the CCE compiler -O0-g.\n" - "\n[Debug]\n" - " --save_original_model Control whether to output original model. E.g.: true: output original model\n" - " --log Generate log with level. Support debug, info, warning, error, null\n" - " --dump_mode The switch of dump json with shape, to be used with mode 1." - "0(default): disable; 1: enable."); - - gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true); - // Using gflags to analyze input parameters - GflagsUtils::ChangeHelpFlags(FLAGS_h); - gflags::HandleCommandLineHelpFlags(); - } - - static Status CheckDumpInfershapeJsonFlags() { - Status ret = CheckFrameWorkValid(FLAGS_framework, FLAGS_weight); - GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, - "check custom aicpu run so failed!"); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"), - return domi::FAILED, "Input parameter[--weight]'s value[%s] is invalid!", - FLAGS_weight.c_str()); - return domi::SUCCESS; - } - - static Status CheckFlags() { - Status ret = ge::SUCCESS; - // No model file information passed in - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_model == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"model"}); - ret = ge::FAILED, "Input parameter[--model]'s value is empty!"); - - // check param disable_reuse_memory - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - ge::CheckDisableReuseMemoryParamValid(to_string(FLAGS_disable_reuse_memory)) != ge::SUCCESS, - ret = ge::FAILED, "check disable_reuse_memory failed!"); - - // check optypelist_for_implmode and op_select_implmode - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, - FLAGS_op_select_implmode) != ge::SUCCESS, - ret = ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!"); - // No output file information passed in - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_mode == GEN_OM_MODEL && FLAGS_output == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"output"}); - ret = ge::FAILED, "Input parameter[--output]'s value is empty!"); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - CheckFrameWorkValid(FLAGS_framework, FLAGS_weight) != ge::SUCCESS, - ret = ge::FAILED, - "CheckFrameWorkValid failed"); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - ge::CheckDynamicInputParamValid(FLAGS_dynamic_batch_size, FLAGS_dynamic_image_size, - FLAGS_dynamic_dims, FLAGS_input_shape, - FLAGS_input_format, is_dynamic_input) != ge::SUCCESS, - ret = ge::FAILED, "check dynamic size(batch size, image size or dims) failed!"); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - !FLAGS_insert_op_conf.empty() && !FLAGS_dynamic_dims.empty(), - ErrorManager::GetInstance().ATCReportErrMessage("E10001", - {"parameter", "value", "reason"}, - {"--insert_op_conf", FLAGS_insert_op_conf, - "dynamic dims function does not support aipp"}); - ret = ge::FAILED, "dynamic dims function does not support aipp"); - -#if !defined(__ANDROID__) && !defined(ANDROID) - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!CheckEncryptModeValid(FLAGS_encrypt_mode), ret = ge::FAILED, - "encrypt_mode %d not valid!!", FLAGS_encrypt_mode); - - if (FLAGS_encrypt_mode == 0) { // Encryption mode - GELOGI("ge will run with encrypt!"); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_encrypt_key), ret = ge::FAILED, - "encrypt_key file not found!!"); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_certificate), ret = ge::FAILED, - "certificate file not found!!"); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_hardware_key), ret = ge::FAILED, - "hardware_key file not found!!"); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_private_key), ret = ge::FAILED, - "private_key file not found!!"); - } else { // No encryption - GELOGI("ge will run without encrypt!"); - } -#endif - - /** - * Check the validity of the I / O file path - */ - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_model != "" && !ge::CheckInputPathValid(FLAGS_model, "--model"), ret = ge::FAILED, - "model file %s not found!!", FLAGS_model.c_str()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"), - ret = ge::FAILED, "weight file %s not found!!", - FLAGS_weight.c_str()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_cal_conf != "" && !ge::CheckInputPathValid(FLAGS_cal_conf, "--cal_conf"), - ret = ge::FAILED, "calibration config file %s not found!!", - FLAGS_cal_conf.c_str()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_op_name_map != "" && !ge::CheckInputPathValid(FLAGS_op_name_map, "--op_name_map"), - ret = ge::FAILED, "op config file %s not found!!", - FLAGS_op_name_map.c_str()); - - GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(FLAGS_insert_op_conf)) == ge::SUCCESS, - ret = ge::FAILED, "check insert op conf failed!"); - - GE_CHK_BOOL_EXEC(ge::CheckCompressWeightParamValid( - FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS, - ret = ge::FAILED, "check compress weight failed!"); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - !ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), ret = ge::FAILED, - "check_report file %s not found!!", FLAGS_check_report.c_str()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_mode == GEN_OM_MODEL && FLAGS_output != "" && - (!ge::CheckOutputPathValid(FLAGS_output, "--output") || !CheckPathWithName(FLAGS_output)), - ret = ge::FAILED, "output path %s is not valid!!", FLAGS_output.c_str()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_save_original_model != "" && - FLAGS_save_original_model != "true" && - FLAGS_save_original_model != "false", - ErrorManager::GetInstance().ATCReportErrMessage( - "E10005", {"parameter", "value"}, {"save_original_model", FLAGS_save_original_model}); - ret = ge::FAILED, - "Input parameter[--save_original_model]'s value[%s] must be true or false.", - FLAGS_save_original_model.c_str()); - GE_CHK_BOOL_EXEC(ge::CheckBufferOptimizeParamValid(FLAGS_buffer_optimize) == ge::SUCCESS, - ret = ge::FAILED, "check output type failed!"); - - GE_CHK_BOOL_EXEC( - ge::CheckEnableSingleStreamParamValid(std::string(FLAGS_enable_single_stream)) == ge::SUCCESS, - ret = ge::FAILED, "check enable single stream failed!"); - - return ret; - } - - /** - * Verifying the parameters of converting model to JSON - * 1. Fmk_model - * 2. out_json - **/ - static Status CheckConverJsonParamFlags() { - Status ret = ge::SUCCESS; - - // No model path passed in - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"}); - ret = ge::FAILED, - "Input parameter[--om]'s value is empty!!"); - - // JSON path not passed in - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_json == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"json"}); - ret = ge::FAILED, - "Input parameter[--json]'s value is empty!!"); - - // Check if the model path is valid - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_om != "" && !ge::CheckInputPathValid(FLAGS_om, "--om"), - ret = ge::FAILED, - "model file path is invalid: %s.", FLAGS_om.c_str()); - - // Check whether the JSON path is valid - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_json != "" && !ge::CheckOutputPathValid(FLAGS_json, "--json"), - ret = ge::FAILED, - "json file path is invalid: %s.", FLAGS_json.c_str()); - - return ret; - } - - /** - * Check command line parameters for explicit settings - * true: Explicit setup - * false: Not set up - * */ - static bool CheckFlagSet(string flag) { - gflags::CommandLineFlagInfo info; - return !(gflags::GetCommandLineFlagInfo(flag.c_str(), &info) && info.is_default); - } - - private: - static bool CheckEncryptModeValid(const int encrypt_mode) { -#if !defined(__ANDROID__) && !defined(ANDROID) - if (encrypt_mode != 0 && encrypt_mode != -1) { - DOMI_LOGE("encrypt mode must be 0 or -1"); - return false; - } -#else - if (encrypt_mode != -1) { - DOMI_LOGE("encrypt mode must be -1"); - return false; - } -#endif - - return true; - } - - static Status CheckFrameWorkValid(int framework, const std::string weight_file) { - if (framework != (int32_t)domi::CAFFE && framework != (int32_t)domi::TENSORFLOW && - framework != (int32_t)domi::MINDSPORE && framework != (int32_t)domi::ONNX) { - // No framework information was passed in or the entered framework is illegal - ErrorManager::GetInstance().ATCReportErrMessage( - "E10007", {"parameter", "support"}, - {"framework", "0(Caffe) or 1(MindSpore) or 3(TensorFlow)"}); - DOMI_LOGE("Input parameter[--framework] is mandatory and it's value must be: " - "0(Caffe) or 1(MindSpore) or 3(TensorFlow)."); - return domi::PARAM_INVALID; - } - - if ((framework == (int32_t)domi::CAFFE) && (weight_file == "")) { - ErrorManager::GetInstance().ATCReportErrMessage("E10008", {"parameter"}, {"weight"}); - DOMI_LOGE("Input parameter[--weight]'s value is empty when framework is 0(CAFFE)!"); - return domi::PARAM_INVALID; - } - - if ((framework == (int32_t)domi::TENSORFLOW) && (weight_file != "")) { - GELOGW("Parameter weight is ignored for TensorFlow."); - } - - if ((framework == (int32_t)domi::ONNX) && (weight_file != "")) { - GELOGW("Parameter weight is ignored for Onnx."); - } - return domi::SUCCESS; - } - - static bool CheckPathWithName(const std::string &fileName) { - // Determine file path length - if (fileName.size() > static_cast(PATH_MAX)) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E10021", {"parameter", "size"}, {"output", std::to_string(PATH_MAX)}); - GELOGE(ge::FAILED, "Input parameter[--output]'s path is too long, it must be less than %d", PATH_MAX); - return false; - } - - // Find the last separator - int slashPosition = fileName.size() - 1; - for (; slashPosition >= 0; slashPosition--) { - if (fileName[slashPosition] == '\\' || fileName[slashPosition] == '/') { - break; - } - } - - // Failure if no filename follows the path - if (slashPosition == static_cast(fileName.size() - 1)) { - ErrorManager::GetInstance().ATCReportErrMessage("E10022", {"parameter", "filename"}, {"output", fileName}); - DOMI_LOGE("Input parameter[--output]'s path[%s] not include file name", fileName.c_str()); - return false; - } - - return true; - } -}; - -void SetDynamicInputSizeOptions() { - if (!FLAGS_dynamic_batch_size.empty()) { - domi::GetContext().dynamic_batch_size = FLAGS_dynamic_batch_size; - } - if (!FLAGS_dynamic_image_size.empty()) { - domi::GetContext().dynamic_image_size = FLAGS_dynamic_image_size; - } - if (!FLAGS_dynamic_dims.empty()) { - domi::GetContext().dynamic_dims = FLAGS_dynamic_dims; - } -} - -/// Validate the non-general scope fusion pass. -/// The parameter is set to the name of the fusion rule. -/// Multiple names can be set and separated by ",". -void SetEnableScopeFusionPasses(const std::string pass_names) { - ge::GetParserContext().enable_scope_fusion_passes = pass_names; -} - -static bool CheckInputFormat() { - if (FLAGS_input_format.empty()) { - // Set default format - if (FLAGS_framework == static_cast(domi::TENSORFLOW)) { - FLAGS_input_format = "NHWC"; - } else { - FLAGS_input_format = "NCHW"; - } - return true; - } else if ((FLAGS_framework == static_cast(domi::CAFFE))) { // caffe - if (ge::caffe_support_input_format.find(FLAGS_input_format) != ge::caffe_support_input_format.end()) { - return true; - } - // only support NCHW ND - ErrorManager::GetInstance().ATCReportErrMessage( - "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, ge::kCaffeFormatSupport}); - GELOGE(ge::FAILED, - "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kCaffeFormatSupport); - return false; - } else if ((FLAGS_framework == static_cast(domi::TENSORFLOW))) { // tf - if (ge::tf_support_input_format.find(FLAGS_input_format) != ge::tf_support_input_format.end()) { - return true; - } - // only support NCHW NHWC ND NCDHW NDHWC - ErrorManager::GetInstance().ATCReportErrMessage( - "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, ge::kTFFormatSupport}); - GELOGE(ge::FAILED, - "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kTFFormatSupport); - return false; - } else if (FLAGS_framework == static_cast(domi::ONNX)) { - if (ge::onnx_support_input_format.find(FLAGS_input_format) != ge::onnx_support_input_format.end()) { - return true; - } - // only support NCHW ND - ErrorManager::GetInstance().ATCReportErrMessage( - "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, ge::kONNXFormatSupport}); - GELOGE(ge::FAILED, - "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kONNXFormatSupport); - return false; - } - return true; -} - -#if !defined(__ANDROID__) && !defined(ANDROID) -static void GetCustomOpPath(std::string &customop_path) { - GELOGI("Enter get custom op path schedule"); - std::string fmk_type = ge::TypeUtils::FmkTypeToSerialString(static_cast(FLAGS_framework)); - GELOGI("Framework type is %s.", fmk_type.c_str()); - - const char *path_env = std::getenv("ASCEND_OPP_PATH"); - if (path_env != nullptr) { - std::string path = path_env; - customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type); - GELOGI("Get custom so path from env : %s", path_env); - return; - } - std::string path_base = ge::GELib::GetPath(); - GELOGI("path_base is %s", path_base.c_str()); - path_base = path_base.substr(0, path_base.rfind('/')); - path_base = path_base.substr(0, path_base.rfind('/') + 1); - customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type); - return; -} - -void GetPluginSoFileList(const string &path, vector &fileList, string &caffe_parser_path) { - // Support to split multiple so directories by ":" - GELOGI("path is %s", path.c_str()); - vector v_path = ge::StringUtils::Split(path, ':'); - for (size_t i = 0; i < v_path.size(); ++i) { - ge::FindParserSo(v_path[i], fileList, caffe_parser_path); - GELOGI("CustomOpLib full name = %s", v_path[i].c_str()); - } -} - -void LoadModelParserLib(std::string caffe_parser_path) { - if (FLAGS_framework == static_cast(domi::TENSORFLOW)) { - void *tf_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL); - if (tf_handle == nullptr) { - GELOGW("dlopen fmk library [libfmk_parser.so] failed."); - return; - } - GELOGI("plugin load libfmk_parser.so success."); - } else if (FLAGS_framework == static_cast(domi::CAFFE)) { - // What we are dealing with here is that the user modifies the caffe.proto scenario. - // If no lib_Caffe_Parser.so is found under the plugin path, use the default lib_Caffe_Parser.so path. - caffe_parser_path = caffe_parser_path.empty() ? "lib_caffe_parser.so" : caffe_parser_path; - - void *handle = dlopen(caffe_parser_path.c_str(), RTLD_NOW | RTLD_GLOBAL); - if (handle == nullptr) { - GELOGW("dlopen failed, plugin name:%s. Message(%s).", caffe_parser_path.c_str(), dlerror()); - return; - } - GELOGI("plugin load %s success.", caffe_parser_path.c_str()); - // According to the dependency, the Caffe parsing module of the framework is loaded here( libfmk_parser.so). - // (depend on the lib_caffe_parser.so) - void *fmk_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL); - if (fmk_handle == nullptr) { - GELOGW("dlopen fmk library [libfmk_parser.so] failed."); - if (dlclose(handle) != 0) { - GELOGW("dlclose lib_caffe_parser.so failed."); - } - return; - } - GELOGI("plugin load libfmk_parser.so success."); - } else if (FLAGS_framework == static_cast(domi::ONNX)) { - void *handle = dlopen("libfmk_onnx_parser.so", RTLD_NOW | RTLD_GLOBAL); - if (handle == nullptr) { - GELOGW("dlopen fmk library [libfmk_onnx_parser.so] failed."); - return; - } - GELOGI("plugin load libfmk_onnx_parser.so success."); - } else { - GELOGW("Framework:%s is not support.", - ge::TypeUtils::FmkTypeToSerialString(static_cast(FLAGS_framework)).c_str()); - return; - } - return; -} - -void LoadCustomOpLib(bool need_load_ops_plugin) { - std::string plugin_path; - GetCustomOpPath(plugin_path); - - vector fileList; - string caffe_parser_path = ""; - - // whether there are files in the plugin so path - GetPluginSoFileList(plugin_path, fileList, caffe_parser_path); - - // no file - if (fileList.empty() && caffe_parser_path.empty()) { - GELOGW("can not find any plugin file in plugin_path: %s", plugin_path.c_str()); - } - - LoadModelParserLib(caffe_parser_path); - if (!need_load_ops_plugin) { - GELOGI("No need to load ops plugin so."); - return; - } - OpRegistry::Instance()->registrationDatas.clear(); - // load other so files except lib_caffe_parser.so in the plugin so path - for (auto elem : fileList) { - ge::StringUtils::Trim(elem); - - void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL); - if (handle == nullptr) { - GELOGW("dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); - } else { - GELOGI("plugin load %s success.", elem.c_str()); - } - } - - std::vector registrationDatas = OpRegistry::Instance()->registrationDatas; - for (OpRegistrationData reg_data : registrationDatas) { - if (reg_data.GetFrameworkType() == static_cast(FLAGS_framework)) { - (void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data); - (void)OpRegistry::Instance()->Register(reg_data); - } - } -} - -void SaveCustomCaffeProtoPath() { - GELOGI("Enter save custom caffe proto path."); - - std::string path_base = ge::GELib::GetPath(); - GELOGI("path_base is %s", path_base.c_str()); - path_base = path_base.substr(0, path_base.rfind('/')); - path_base = path_base.substr(0, path_base.rfind('/') + 1); - ge::GetParserContext().caffe_proto_path = path_base + "include/proto/"; - - string customop_path; - const char *path_env = std::getenv("ASCEND_OPP_PATH"); - if (path_env != nullptr) { - std::string path = path_env; - customop_path = path + "/framework/custom/caffe/"; - GELOGI("Get custom proto path from env : %s", path_env); - ge::GetParserContext().custom_proto_path = customop_path; - return; - } - customop_path = path_base + "ops/framework/custom/caffe/"; - ge::GetParserContext().custom_proto_path = customop_path; - return; -} - -#endif - -Status CreateInputsForInference(const ge::Graph &graph, vector &inputs) { - auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); - GE_CHECK_NOTNULL(compute_graph); - for (ge::NodePtr &input_node : compute_graph->GetAllNodes()) { - GE_CHECK_NOTNULL(input_node); - ge::OpDescPtr op = input_node->GetOpDesc(); - GE_CHECK_NOTNULL(op); - if (op->GetType() == ge::DATA) { - GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size()); - ge::GeTensorDesc tensor = op->GetInputDesc(0); - string data_op_name = op->GetName(); - GELOGI("Data op name is: %s", data_op_name.c_str()); - ge::GeShape data_shape; - auto iter = domi::GetContext().input_dims.find(data_op_name); - if (iter != domi::GetContext().input_dims.end()) { - data_shape = ge::GeShape(iter->second); - GELOGI("Data op get shape from Context."); - } else { - data_shape = tensor.GetShape(); - GELOGI("Data op get shape from InputDesc in geir graph."); - } - - ge::DataType data_type = tensor.GetDataType(); - string data_type_str = ge::TypeUtils::DataTypeToSerialString(data_type); - GELOGI("Data op get data type:%s from InputDesc in geir graph.", data_type_str.c_str()); - - ge::GeTensor input_tensor; - ge::GeTensorDesc desc(data_shape, ge::Format(domi::GetContext().format), data_type); - input_tensor.SetTensorDesc(desc); - inputs.push_back(input_tensor); - } - } - GELOGI("Build ME model, inputs size is: %zu", inputs.size()); - return ge::SUCCESS; -} - -domi::Status GenerateInfershapeJson() { - if (!CheckInputFormat()) { - GELOGE(ge::FAILED, "Check input_format failed"); - return domi::FAILED; - } - Status ret = GFlagUtils::CheckDumpInfershapeJsonFlags(); - GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check flags failed!"); - - ge::GeGenerator ge_generator; - std::map options; - ge::Status geRet = ge_generator.Initialize(options, domi::GetContext()); - if (geRet != ge::SUCCESS) { - DOMI_LOGE("GeGenerator initialize failed!"); - return domi::FAILED; - } - - ge::Graph graph; - std::map atc_params; - atc_params.insert(std::pair("input_format", FLAGS_input_format)); - ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType) FLAGS_framework, - "", FLAGS_target.c_str(), (ge::RunMode) FLAGS_mode, false); - if (ret != ge::SUCCESS) { - DOMI_LOGE("ATC Parse graph domi::FAILED"); - (void)ge_generator.Finalize(); - return domi::FAILED; - } - - geRet = ge_generator.GenerateInfershapeGraph(graph); - if (geRet != ge::SUCCESS) { - DOMI_LOGE("ATC GenerateInfershapeJson failed"); - (void)ge_generator.Finalize(); - return domi::FAILED; - } - if (DumpInfershapeJson(graph, FLAGS_json.c_str()) != SUCCESS) { - DOMI_LOGE("ATC DumpInfershapeJson failed"); - (void)ge_generator.Finalize(); - return domi::FAILED; - } - (void)ge_generator.Finalize(); - return ge::SUCCESS; -} - -static Status ConvertModelToJson(int fwk_type, const string &model_file, const string &json_file) { - Status ret = ge::SUCCESS; - if (fwk_type == -1) { - ret = ge::ConvertOmModelToJson(model_file.c_str(), json_file.c_str()); - return ret; - } - - if ((fwk_type != domi::TENSORFLOW) && (fwk_type != domi::CAFFE) && (fwk_type != domi::ONNX)) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E10001", {"parameter", "value", "reason"}, - {"--framework", std::to_string(fwk_type), kModelToJsonSupport}); - GELOGE(ge::FAILED, "Invalid value for --framework[%d], %s.", fwk_type, kModelToJsonSupport); - ret = ge::FAILED; - } - - if (FLAGS_dump_mode != "0" && FLAGS_dump_mode != "1") { - ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"dump_mode"}); - GELOGE(ge::FAILED, "Input parameter[--dump_mode]'s value must be 1 or 0."); - ret = ge::FAILED; - } - - if (ret != ge::SUCCESS) return ret; - - // Need to save caffe.proto path - SaveCustomCaffeProtoPath(); - - if (FLAGS_dump_mode == "0") { - // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so. - LoadCustomOpLib(false); - ret = ge::ConvertFwkModelToJson((domi::FrameworkType)fwk_type, model_file.c_str(), json_file.c_str()); - } else if (FLAGS_dump_mode == "1") { - // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so and ops plugin so. - LoadCustomOpLib(true); - ret = GenerateInfershapeJson(); - } - - return ret; -} - -domi::Status GenerateModel(std::map &options, std::string output) { - ge::GeGenerator ge_generator; - ge::Status geRet = ge::SUCCESS; - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - geRet = ge::GELib::Initialize(options); - if (geRet != ge::SUCCESS) { - DOMI_LOGE("GE initialize failed!"); - return domi::FAILED; - } - } - geRet = ge_generator.Initialize(options, domi::GetContext()); - if (geRet != ge::SUCCESS) { - DOMI_LOGE("GeGenerator initialize failed!"); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - - ge::Graph graph; - std::vector inputs; - if (FLAGS_framework == domi::MINDSPORE) { - // load model from file - ge::Model load_model = ge::Model("loadmodel", "version2"); - auto ret1 = load_model.LoadFromFile(FLAGS_model); - if (ret1 != ge::GRAPH_SUCCESS) { - ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"parameter"}, {FLAGS_model}); - DOMI_LOGE("Load model from %s failed, please check model file or " - "input parameter[--framework] is correct", FLAGS_model.c_str()); - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - - graph = load_model.GetGraph(); - - GE_CHK_STATUS_EXEC(ge::InitDomiOmgContext(FLAGS_input_shape, FLAGS_input_format, "", is_dynamic_input), - GELOGE(ge::FAILED, "ATC Generate call InitDomiOmgContext ret fail"); - (void)ge_generator.Finalize(); (void)ge::GELib::GetInstance()->Finalize(); return domi::FAILED); - - Status ret = CreateInputsForInference(graph, inputs); - if (ret != ge::SUCCESS) { - GELOGE(ge::FAILED, "create inputs for inference failed."); - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - - } else { - std::map atc_params; - atc_params.insert(std::pair("input_shape", FLAGS_input_shape)); - atc_params.insert(std::pair("out_nodes", FLAGS_out_nodes)); - atc_params.insert(std::pair("input_format", FLAGS_input_format)); - atc_params.insert(std::pair("check_report", FLAGS_check_report)); - atc_params.insert(std::pair("input_fp16_nodes", FLAGS_input_fp16_nodes)); - atc_params.insert(std::pair("is_input_adjust_hw_layout", FLAGS_is_input_adjust_hw_layout)); - atc_params.insert(std::pair("is_output_adjust_hw_layout", FLAGS_is_output_adjust_hw_layout)); - atc_params.insert(std::pair("compress_weight_conf", FLAGS_compress_weight_conf)); - atc_params.insert(std::pair(string(ge::OUTPUT_DATATYPE), FLAGS_output_type)); - atc_params.insert(std::pair("output", output)); - - Status ret = - ParseGraph(graph, atc_params, FLAGS_model.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType)FLAGS_framework, - FLAGS_op_name_map.c_str(), FLAGS_target.c_str(), (ge::RunMode)FLAGS_mode, is_dynamic_input); - - // in ONLY_PRE_CHECK mode, pre-checking report has already saved in ParseGraph - if (FLAGS_mode == ge::ONLY_PRE_CHECK) { - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - if (ret != ge::SUCCESS) { - DOMI_LOGE("ATC precheck fail."); - return domi::FAILED; - } - return domi::SUCCESS; - } - - if (ret != ge::SUCCESS) { - DOMI_LOGE("ATC Parse graph domi::FAILED"); - DOMI_LOGE("ATC Generate execute failed"); // Duplicate log. (for test case - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - if (ge::SetOutputNodeInfo(graph, FLAGS_output_type, "") != domi::SUCCESS) { - DOMI_LOGE("Set output node info fail."); - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - } - - geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); - if (geRet != ge::SUCCESS) { - DOMI_LOGE("GE GenerateOfflineModel execute failed"); - DOMI_LOGE("ATC Generate execute failed"); // Duplicate log. (for test case - // checking error log) - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - (void)ge_generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return ge::SUCCESS; -} - -static void SetEnvForSingleOp(std::map &options) { - string flag_on = "1"; - string flag_off = "0"; - options.emplace(ge::GE_FE_FLAG, flag_on); - options.emplace(ge::STREAM_NUM, "1"); // single op only use one stream - options.emplace(ge::RUN_FLAG, flag_off); - options.emplace(ge::OPTION_GRAPH_RUN_MODE, flag_off); - options.emplace(ge::SINGLE_OP_FLAG, flag_on); - options.emplace(ge::PRECISION_MODE, FLAGS_precision_mode); - options.emplace(ge::SOC_VERSION, FLAGS_soc_version); - options.emplace(ge::CORE_TYPE, FLAGS_core_type); - options.emplace(ge::AICORE_NUM, FLAGS_aicore_num); - options.emplace(ge::OP_SELECT_IMPL_MODE, FLAGS_op_select_implmode); - options.emplace(ge::OPTYPELIST_FOR_IMPLMODE, FLAGS_optypelist_for_implmode); - options.emplace(ge::AUTO_TUNE_MODE, FLAGS_auto_tune_mode); - options.emplace(ge::GRAPH_MEMORY_MAX_SIZE, kGraphMemoryManagerMallocMaxSize); - options.emplace(ge::OP_DEBUG_LEVEL, to_string(FLAGS_op_debug_level)); -} - -domi::Status GenerateSingleOp(const std::string& json_file_path) { - if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output, "--output")) { - DOMI_LOGE("output path %s is not valid!", FLAGS_output.c_str()); - return domi::FAILED; - } - // check optypelist_for_implmode and op_select_implmode - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS, - return ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!"); - - std::map options; - // need to be changed when ge.ini plan is done - SetEnvForSingleOp(options); - - auto ret = ge::GELib::Initialize(options); - if (ret != ge::SUCCESS) { - DOMI_LOGE("GE initialize failed!"); - return domi::FAILED; - } - - ge::GeGenerator generator; - ret = generator.Initialize(options, domi::GetContext()); - if (ret != SUCCESS) { - DOMI_LOGE("GeGenerator initialize failed!"); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - - vector build_params; - if (ge::SingleOpParser::ParseSingleOpList(json_file_path, build_params) != ge::SUCCESS) { - DOMI_LOGE("parse single op json file failed"); - (void)generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return domi::FAILED; - } - - int index = 0; - for (auto ¶m : build_params) { - string output_path; - if (!FLAGS_output.empty()) { - output_path = FLAGS_output + "/"; - } - output_path += param.file_name; - ret = generator.BuildSingleOpModel(param.op_desc, param.inputs, param.outputs, output_path); - if (ret != SUCCESS) { - DOMI_LOGE("Compile op failed. ge ret = %u, op index = %d", ret, index); - ret = domi::FAILED; - break; - } - GELOGI("Compile op success. op index = %d, output = %s", index, output_path.c_str()); - index += 1; - } - - (void)generator.Finalize(); - (void)ge::GELib::GetInstance()->Finalize(); - return ret; -} - -domi::Status GenerateOmModel() { - if (!CheckInputFormat()) { - GELOGE(ge::FAILED, "Check input_format failed"); - return domi::FAILED; - } - Status ret = GFlagUtils::CheckFlags(); - GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, - "Check flags failed! Please check whether some atc params that include semicolons[;] use double " - "quotation marks (\") to enclose each argument such as out_nodes, input_shape, dynamic_image_size"); -#if !defined(__ANDROID__) && !defined(ANDROID) - // Load custom operator Library - LoadCustomOpLib(true); - - SaveCustomCaffeProtoPath(); - - ret = ge::CheckCustomAiCpuOpLib(); - - GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "check custom aicpu run so failed!"); -#endif - - const int f_stream_num = 1; - std::map options; - options.insert(std::pair(string(ge::FRAMEWORK_TYPE), to_string(FLAGS_framework))); - options.insert(std::pair(string(ge::STREAM_NUM), to_string(f_stream_num))); - options.insert(std::pair(string(ge::CALIBRATION_CONF_FILE), FLAGS_cal_conf)); - options.insert(std::pair(string(ge::ENCRYPT_MODE), to_string(FLAGS_encrypt_mode))); - options.insert(std::pair(string(ge::EK_FILE), FLAGS_encrypt_key)); - options.insert(std::pair(string(ge::CERT_FILE), FLAGS_certificate)); - options.insert(std::pair(string(ge::HW_KEY_FILE), FLAGS_hardware_key)); - options.insert(std::pair(string(ge::PRIVATE_KEY_FILE), FLAGS_private_key)); - options.insert(std::pair(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes)); - options.insert(std::pair(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf)); - options.insert(std::pair(string(ge::PRECISION_MODE), FLAGS_precision_mode)); - - options.insert(std::pair(string(ge::RUN_FLAG), to_string(0))); - options.insert(std::pair(string(ge::TRAIN_FLAG), to_string(0))); - - if (!FLAGS_output_type.empty()) { - options.insert(std::pair(string(ge::OUTPUT_DATATYPE), FLAGS_output_type)); - } - - options.insert(std::pair(string(ge::OP_SELECT_IMPL_MODE), FLAGS_op_select_implmode)); - options.insert(std::pair(string(ge::OPTYPELIST_FOR_IMPLMODE), FLAGS_optypelist_for_implmode)); - - if (!FLAGS_input_fp16_nodes.empty()) { - GELOGI("FLAGS_input_fp16_nodes : %s .", FLAGS_input_fp16_nodes.c_str()); - options.insert(std::pair(ge::INPUT_FP16_NODES, FLAGS_input_fp16_nodes)); - } - - options.insert(std::pair(string(ge::AUTO_TUNE_MODE), FLAGS_auto_tune_mode)); - - options.insert( - std::pair(string(ge::OPTION_EXEC_DISABLE_REUSED_MEMORY), to_string(FLAGS_disable_reuse_memory))); - - options.insert(std::pair(string(ge::SOC_VERSION), FLAGS_soc_version)); - - options.insert(std::pair(string(ge::CORE_TYPE), FLAGS_core_type)); - - options.insert(std::pair(string(ge::AICORE_NUM), FLAGS_aicore_num)); - - options.insert(std::pair(string(ge::BUFFER_OPTIMIZE), FLAGS_buffer_optimize)); - - options.insert(std::pair(string(ge::ENABLE_SMALL_CHANNEL), FLAGS_enable_small_channel)); - - options.insert(std::pair(string(ge::FUSION_SWITCH_FILE), FLAGS_fusion_switch_file)); - - options.insert(std::pair(string(ge::ENABLE_COMPRESS_WEIGHT), - (FLAGS_enable_compress_weight == "true") ? - ge::kEnableCompressWeightTrue : ge::kEnableCompressWeightFalse)); - - options.insert(std::pair(string(ge::GRAPH_MEMORY_MAX_SIZE), kGraphMemoryManagerMallocMaxSize)); - - options.insert(std::pair(string(ge::ENABLE_SINGLE_STREAM), FLAGS_enable_single_stream)); - - SetDynamicInputSizeOptions(); - - if (!FLAGS_save_original_model.empty()) { - options.insert(std::pair(string(ge::SAVE_ORIGINAL_MODEL), FLAGS_save_original_model)); - options.insert(std::pair(string(ge::ORIGINAL_MODEL_FILE), FLAGS_output + "_original.om")); - } - - options.insert(std::pair(string(ge::OP_DEBUG_LEVEL), to_string(FLAGS_op_debug_level))); - // set enable scope fusion passes - SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes); - // print atc option map - ge::PrintOptionMap(options, "atc option"); - - // When the ATC module is transferred to a model, the suffix ".om" is automatically added to the model name - FLAGS_output = FLAGS_output + ".om"; - ret = GenerateModel(options, FLAGS_output); - if (ret != domi::SUCCESS) { - return domi::FAILED; - } - - return domi::SUCCESS; -} - -domi::Status ConvertModelToJson() { - Status ret = GFlagUtils::CheckConverJsonParamFlags(); - GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check convert json params flags failed!"); - - ret = ConvertModelToJson(FLAGS_framework, FLAGS_om, FLAGS_json); - - GE_IF_BOOL_EXEC(ret != domi::SUCCESS, return domi::FAILED); - return domi::SUCCESS; -} - -bool CheckRet(domi::Status ret) { - if (ret != domi::SUCCESS) { - if (FLAGS_mode == ONLY_PRE_CHECK) { - GELOGW("ATC precheck failed."); - } else if (FLAGS_mode == GEN_OM_MODEL) { - GELOGW("ATC generate offline model failed."); - } else if (FLAGS_mode == MODEL_TO_JSON) { - GELOGW("ATC convert model to json file failed."); - } else if (FLAGS_mode == PBTXT_TO_JSON) { - GELOGW("ATC convert pbtxt to json file failed."); - } else { - return false; - } - return false; - } - - if (FLAGS_mode == ONLY_PRE_CHECK) { - GELOGI("ATC precheck success."); - } else if (FLAGS_mode == GEN_OM_MODEL) { - GELOGI("ATC generate offline model success."); - } else if (FLAGS_mode == MODEL_TO_JSON) { - GELOGI("ATC convert model to json file success."); - } else if (FLAGS_mode == PBTXT_TO_JSON) { - GELOGI("ATC convert pbtxt to json file success."); - } - return true; -} - -domi::Status ConvertPbtxtToJson() { - Status ret = GFlagUtils::CheckConverJsonParamFlags(); - if (ret != domi::SUCCESS) { - GELOGE(ge::FAILED, "Check convert json params flags failed!"); - return domi::FAILED; - } - - ret = ge::ConvertPbtxtToJson(FLAGS_om.c_str(), FLAGS_json.c_str()); - if (ret != domi::SUCCESS) { - GELOGE(ge::FAILED, "ConvertPbtxtToJson fail."); - return domi::FAILED; - } - - return domi::SUCCESS; -} - -int init(int argc, char* argv[]) { - GFlagUtils::InitGFlag(argc, argv); - // set log level - int ret = -1; - const std::set log_level = {"null", "debug", "info", "warning", "error"}; - if (log_level.count(FLAGS_log) == 0) { - std::cout << "E10010: invalid value for --log:" << FLAGS_log - <<", only support debug, info, warning, error, null"<< std::endl; - return ret; - } - - ret = ge::CheckLogParamValidAndSetLogLevel(FLAGS_log); - if (ret != 0) { - return ret; - } - - std::string path_base = ge::GELib::GetPath(); - ret = ErrorManager::GetInstance().Init(path_base); - if (ret != 0) { - DOMI_LOGE("ErrorManager init fail !"); - return ret; - } - - return 0; -} - -long GetMemInfo(const std::string &key) { - std::string file_path = "/proc/meminfo"; - std::ifstream fs(file_path, std::ifstream::in); - if (!fs.is_open()) { - GELOGW("Can not open %s .", file_path.c_str()); - return 0; - } - std::string line; - while (getline(fs, line)) { // line not with \n - if (line.find(key) != std::string::npos) { - GELOGI("Find mem [%s] info line [%s]", key.c_str(), line.c_str()); - fs.close(); - size_t pos = line.find(":"); - if (pos == std::string::npos) { - return 0; - } - std::string current_mem_info_str = line.substr(pos + 1); - ge::StringUtils::Trim(current_mem_info_str); - GELOGI("Find mem [%s] info [%s].", key.c_str(), current_mem_info_str.c_str()); - return stol(current_mem_info_str); - } - } - fs.close(); // close the file - return 0; -} - -bool CheckMemInfo() { - if (FLAGS_auto_tune_mode.empty()) { - return true; - } - // only check current available mem when auto_tune_mode is set. - long current_mem_available = GetMemInfo("MemAvailable"); - GELOGI("Get mem available [%lu].", current_mem_available); - std::cout << "Current available mem is " << current_mem_available << "kB." << std::endl; - if ((current_mem_available > 0) && (current_mem_available < kMinAvailableMem)) { - GELOGE(ge::PARAM_INVALID, "Current available mem [%lu] can not be smaller than [%lu] .", - current_mem_available, kMinAvailableMem); - ErrorManager::GetInstance().ATCReportErrMessage("E10044", {"value", "min_value"}, - {to_string(current_mem_available), to_string(kMinAvailableMem)}); - return false; - } - return true; -} - -int main(int argc, char* argv[]) { - Status ret = domi::SUCCESS; - std::cout << "ATC start working now, please wait for a moment." << std::endl; - - // Initialize - if (init(argc, argv) != 0) { - std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl; - return -1; - } - do { - if (!CheckMemInfo()) { - GELOGE(ge::PARAM_INVALID, "Current available mem is too small"); - ret = domi::FAILED; - break; - } - if (!FLAGS_singleop.empty()) { - ret = GenerateSingleOp(FLAGS_singleop); - break; - } - - // default mode(mode:0), Open source model to model - if (GEN_OM_MODEL == FLAGS_mode || ONLY_PRE_CHECK == FLAGS_mode) { - GE_IF_BOOL_EXEC(GenerateOmModel() != domi::SUCCESS, ret = domi::FAILED; break); - } else if (MODEL_TO_JSON == FLAGS_mode) { // Mode 1, transfer model to JSON - GE_CHK_BOOL_EXEC(ConvertModelToJson() == domi::SUCCESS, ret = domi::FAILED; - break, "ATC ConvertJson execute failed!!"); - } else if (FLAGS_mode == ge::RunMode::PBTXT_TO_JSON) { - GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED; - break, "ATC convert pbtxt to json execute failed!!"); - } else { - ErrorManager::GetInstance().ATCReportErrMessage( - "E10001", {"parameter", "value", "reason"}, {"--mode", std::to_string(FLAGS_mode), kModeSupport}); - GELOGE(ge::PARAM_INVALID, "Invalid value for --mode[%d], %s.", FLAGS_mode, kModeSupport); - ret = domi::FAILED; - break; - } - } while (0); - - if (!CheckRet(ret)) { - std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl; - int result = ErrorManager::GetInstance().OutputErrMessage(STDOUT_FILENO); - if (result != 0) { - DOMI_LOGE("ErrorManager outputErrMessage fail !"); - } - GELOGI("Current mem available mem is [%lu]", GetMemInfo("MemAvailable")); - return ret; - } else { - std::cout << "ATC run success, welcome to the next use." << std::endl; - (void)ErrorManager::GetInstance().OutputMessage(STDOUT_FILENO); - return 0; - } -} /*lint +e530*/ diff --git a/ge/offline/module.mk b/ge/offline/module.mk deleted file mode 100755 index 42b217db..00000000 --- a/ge/offline/module.mk +++ /dev/null @@ -1,52 +0,0 @@ - -LOCAL_PATH := $(call my-dir) - -include $(CLEAR_VARS) - -LOCAL_MODULE := atc - -LOCAL_CFLAGS += -Werror -LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 - -LOCAL_SRC_FILES := \ - main.cc \ - single_op_parser.cc \ - ../session/omg.cc \ - ../ir_build/atc_ir_common.cc \ - -LOCAL_C_INCLUDES := \ - $(LOCAL_PATH)/../ ./ \ - $(TOPDIR)inc \ - $(TOPDIR)inc/external \ - $(TOPDIR)inc/external/graph \ - $(TOPDIR)inc/framework \ - $(TOPDIR)inc/framework/domi \ - $(TOPDIR)libc_sec/include \ - $(TOPDIR)inc/common/util \ - third_party/json/include \ - third_party/gflags/include \ - third_party/protobuf/include \ - proto/om.proto \ - proto/ge_ir.proto \ - proto/task.proto \ - proto/insert_op.proto \ - -LOCAL_SHARED_LIBRARIES := \ - libc_sec \ - libge_common \ - libprotobuf \ - libslog \ - libgraph \ - libregister \ - liberror_manager \ - libge_compiler \ - libruntime_compile \ - libparser_common \ - liberror_manager \ - -LOCAL_STATIC_LIBRARIES := libgflags - -LOCAL_LDFLAGS := -lrt -ldl - -include $(BUILD_HOST_EXECUTABLE) - diff --git a/ge/offline/proto/ge_ir.proto b/ge/offline/proto/ge_ir.proto deleted file mode 100644 index f60a0f89..00000000 --- a/ge/offline/proto/ge_ir.proto +++ /dev/null @@ -1 +0,0 @@ -../../../../inc/common/proto/ge_ir.proto \ No newline at end of file diff --git a/ge/offline/proto/insert_op.proto b/ge/offline/proto/insert_op.proto deleted file mode 100644 index 27b233e5..00000000 --- a/ge/offline/proto/insert_op.proto +++ /dev/null @@ -1 +0,0 @@ -../../../../inc/common/proto/insert_op.proto \ No newline at end of file diff --git a/ge/offline/proto/om.proto b/ge/offline/proto/om.proto deleted file mode 100644 index 91c581bb..00000000 --- a/ge/offline/proto/om.proto +++ /dev/null @@ -1 +0,0 @@ -../../../../inc/common/proto/om.proto \ No newline at end of file diff --git a/ge/offline/proto/task.proto b/ge/offline/proto/task.proto deleted file mode 100644 index 36ae4847..00000000 --- a/ge/offline/proto/task.proto +++ /dev/null @@ -1 +0,0 @@ -../../proto/task.proto \ No newline at end of file diff --git a/ge/offline/single_op_parser.cc b/ge/offline/single_op_parser.cc deleted file mode 100644 index 34ac7d5f..00000000 --- a/ge/offline/single_op_parser.cc +++ /dev/null @@ -1,503 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "single_op_parser.h" - -#include -#include -#include -#include - -#include - -#include "framework/common/debug/ge_log.h" -#include "common/util/error_manager/error_manager.h" -#include "common/ge_inner_error_codes.h" -#include "framework/common/util.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/operator_factory_impl.h" - -using Json = nlohmann::json; -using std::string; -using std::vector; -using std::map; - -namespace ge { -namespace { -constexpr char const *kKeyOp = "op"; -constexpr char const *kKeyInputDesc = "input_desc"; -constexpr char const *kKeyOutputDesc = "output_desc"; -constexpr char const *kKeyAttr = "attr"; -constexpr char const *kKeyName = "name"; -constexpr char const *kKeyType = "type"; -constexpr char const *kKeyShape = "shape"; -constexpr char const *kKeyShapeRange = "shape_range"; -constexpr char const *kKeyValue = "value"; -constexpr char const *kKeyFormat = "format"; -constexpr char const *kFileSuffix = ".om"; -constexpr int kDumpJsonIndent = 2; -constexpr int kShapeRangePairSize = 2; -constexpr int kShapeRangeLow = 0; -constexpr int kShapeRangeHigh = 1; - -map kAttrTypeDict = { - {"bool", GeAttrValue::VT_BOOL}, - {"int", GeAttrValue::VT_INT}, - {"float", GeAttrValue::VT_FLOAT}, - {"string", GeAttrValue::VT_STRING}, - {"list_bool", GeAttrValue::VT_LIST_BOOL}, - {"list_int", GeAttrValue::VT_LIST_INT}, - {"list_float", GeAttrValue::VT_LIST_FLOAT}, - {"list_string", GeAttrValue::VT_LIST_STRING}, - {"list_list_int", GeAttrValue::VT_LIST_LIST_INT}, - {"data_type", GeAttrValue::VT_DATA_TYPE}, -}; - -map kDataTypeDict = { - {"bool", DT_BOOL}, - {"int8", DT_INT8}, - {"uint8", DT_UINT8}, - {"int16", DT_INT16}, - {"uint16", DT_UINT16}, - {"int32", DT_INT32}, - {"uint32", DT_UINT32}, - {"int64", DT_INT64}, - {"uint64", DT_UINT64}, - {"float16", DT_FLOAT16}, - {"half", DT_FLOAT16}, - {"fp16", DT_FLOAT16}, - {"float", DT_FLOAT}, - {"float32", DT_FLOAT}, - {"double", DT_DOUBLE}, -}; - -map kFormatDict = { - {"nchw", FORMAT_NCHW}, - {"nhwc", FORMAT_NHWC}, - {"nd", FORMAT_ND}, - {"fractal_nz", FORMAT_FRACTAL_NZ}, - {"fractal_z", FORMAT_FRACTAL_Z}, - {"nc1hwc0", FORMAT_NC1HWC0}, -}; -} - -template -void SetAttrValue(const Json &j, SingleOpAttr &attr) { - attr.value.SetValue(j.at(kKeyValue).get()); -} - -template -T GetValue(const map &dict, string &key, T default_val) { - transform(key.begin(), key.end(), key.begin(), ::tolower); - auto it = dict.find(key); - if (it == dict.end()) { - return default_val; - } - - return it->second; -} - -void from_json(const Json &j, SingleOpTensorDesc &desc) { - desc.dims = j.at(kKeyShape).get>(); - auto it = j.find(kKeyShapeRange); - if (it != j.end()) { - desc.dim_ranges = j.at(kKeyShapeRange).get>>(); - } - string format_str = j.at(kKeyFormat).get(); - string type_str = j.at(kKeyType).get(); - desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); - desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED); - auto tensor_name = j.find(kKeyName); - if (tensor_name != j.end()) { - desc.name = tensor_name->get(); - } -} - -void from_json(const Json &j, SingleOpAttr &attr) { - attr.name = j.at(kKeyName).get(); - attr.type = j.at(kKeyType).get(); - auto it = kAttrTypeDict.find(attr.type); - if (it == kAttrTypeDict.end()) { - GELOGE(UNSUPPORTED, "Parse attr[%s] failed. Unsupported type: %s", attr.name.c_str(), attr.type.c_str()); - return; - } - - switch (it->second) { - case GeAttrValue::VT_BOOL: - SetAttrValue(j, attr); - break; - case GeAttrValue::VT_INT: - SetAttrValue(j, attr); - break; - case GeAttrValue::VT_FLOAT: - SetAttrValue(j, attr); - break; - case GeAttrValue::VT_STRING: - SetAttrValue(j, attr); - break; - case GeAttrValue::VT_LIST_BOOL: - SetAttrValue>(j, attr); - break; - case GeAttrValue::VT_LIST_INT: - SetAttrValue>(j, attr); - break; - case GeAttrValue::VT_LIST_FLOAT: - SetAttrValue>(j, attr); - break; - case GeAttrValue::VT_LIST_STRING: - SetAttrValue>(j, attr); - break; - case GeAttrValue::VT_LIST_LIST_INT: - SetAttrValue>>(j, attr); - break; - case GeAttrValue::VT_DATA_TYPE: - SetAttrValue(j, attr); - break; - default: - GELOGE(UNSUPPORTED, "Parse attr[%s] failed. Unsupported type: %s", attr.name.c_str(), attr.type.c_str()); - break; - } -} - -void from_json(const Json &j, SingleOpDesc &desc) { - desc.op = j.at(kKeyOp).get(); - - auto input_desc = j.find(kKeyInputDesc); - if (input_desc != j.end()) { - desc.input_desc = input_desc->get>(); - } - - auto output_desc = j.find(kKeyOutputDesc); - if (output_desc != j.end()) { - desc.output_desc = output_desc->get>(); - } - - auto attr_field = j.find(kKeyAttr); - if (attr_field != j.end()) { - desc.attrs = attr_field->get>(); - } -} - -Status SingleOpParser::ReadJsonFile(const std::string &file, Json &json_obj) { - std::string real_path = RealPath(file.c_str()); - if (real_path.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10023", {"value"}, {file}); - GELOGE(FAILED, "Input parameter[--singleop]'s value[%s] is not a valid path.", file.c_str()); - return INTERNAL_ERROR; - } - - std::ifstream ifs(real_path); - if (!ifs.is_open()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10024", {"value"}, {file}); - GELOGE(FAILED, "Open file[%s] provided in input parameter[--singleop] failed.", file.c_str()); - return FAILED; - } - try { - ifs >> json_obj; - } catch (const std::exception &e) { - ErrorManager::GetInstance().ATCReportErrMessage("E10025", {"realpath", "errmsg"}, {real_path, e.what()}); - GELOGE(PARAM_INVALID, "Parse file[%s] provided in input parameter[--singleop] failed, exception = %s.", - real_path.c_str(), e.what()); - return PARAM_INVALID; - } - - ifs.close(); - return SUCCESS; -} - -bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { - if (op_desc.op.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10026"); - GELOGE(PARAM_INVALID, "Op name is empty"); - return false; - } - - int index = 0; - for (auto &tensor_desc : op_desc.input_desc) { - if (tensor_desc.type == DT_UNDEFINED) { - ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(false, "Input's dataType is invalid when the index is %d", index); - return false; - } - - if (tensor_desc.format == FORMAT_RESERVED) { - ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Input's format is invalid when the index is %d", index); - return false; - } - ++index; - } - - index = 0; - for (auto &tensor_desc : op_desc.output_desc) { - if (tensor_desc.type == DT_UNDEFINED) { - ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"output", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index); - return false; - } - - if (tensor_desc.format == FORMAT_RESERVED) { - ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"output", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Output's format is invalid when the index is %d", index); - return false; - } - ++index; - } - - for (auto &attr : op_desc.attrs) { - if (attr.name.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10029"); - GELOGE(PARAM_INVALID, "attr name is empty"); - return false; - } - - if (attr.value.IsEmpty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10030", {"attrname"}, {attr.name}); - GELOGE(PARAM_INVALID, "Parse attr \"%s\" failed. ", attr.name.c_str()); - return false; - } - } - - return true; -} - -std::unique_ptr SingleOpParser::CreateOpDesc(const string &op_type) { - return std::unique_ptr(new(std::nothrow) OpDesc(op_type, op_type)); -} - -Status SingleOpParser::ConvertToBuildParam(int index, - const SingleOpDesc &single_op_desc, - SingleOpBuildParam &build_param) { - auto op_desc = CreateOpDesc(single_op_desc.op); - if (op_desc == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to create instance of opDesc"); - return MEMALLOC_FAILED; - } - - std::stringstream file_name; - file_name << index; - file_name << "_" << single_op_desc.op; - for (auto &desc : single_op_desc.input_desc) { - file_name << "_" << desc.type << "_" << desc.format; - for (auto dim : desc.dims) { - file_name << "_" << dim; - } - GeTensorDesc ge_tensor_desc(GeShape(desc.dims), - desc.format, - desc.type); - ge_tensor_desc.SetOriginFormat(desc.format); - GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc)); - TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size()); - TensorUtils::SetInputTensor(ge_tensor_desc, true); - TensorUtils::SetOutputTensor(ge_tensor_desc, false); - if (desc.name.empty()) { - op_desc->AddInputDesc(ge_tensor_desc); - } else { - op_desc->AddInputDesc(desc.name, ge_tensor_desc); - } - build_param.inputs.emplace_back(ge_tensor_desc); - } - - for (auto &desc : single_op_desc.output_desc) { - file_name << "_" << desc.type << "_" << desc.format; - for (auto dim : desc.dims) { - file_name << "_" << dim; - } - - GeTensorDesc ge_tensor_desc(GeShape(desc.dims), - desc.format, - desc.type); - ge_tensor_desc.SetOriginFormat(desc.format); - GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc)); - TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size()); - TensorUtils::SetInputTensor(ge_tensor_desc, false); - TensorUtils::SetOutputTensor(ge_tensor_desc, true); - if (desc.name.empty()) { - op_desc->AddOutputDesc(ge_tensor_desc); - } else { - op_desc->AddOutputDesc(desc.name, ge_tensor_desc); - } - build_param.outputs.emplace_back(ge_tensor_desc); - } - - for (const auto &attr : single_op_desc.attrs) { - op_desc->SetAttr(attr.name, attr.value); - } - - if (VerifyOpInputOutputSizeByIr(*op_desc) != SUCCESS) { - GELOGE(PARAM_INVALID, "Verify op [%s] input or output size failed.", op_desc->GetType().c_str()); - return PARAM_INVALID; - } - - file_name << kFileSuffix; - build_param.file_name = file_name.str(); - build_param.op_desc.reset(op_desc.release()); - return SUCCESS; -} - -Status SingleOpParser::VerifyOpInputOutputSizeByIr(const OpDesc ¤t_op_desc) { - ge::Operator operator_ir = ge::OperatorFactory::CreateOperator("tmp_operator", current_op_desc.GetType()); - if (!operator_ir.IsEmpty()) { - auto opdesc_ir = ge::OpDescUtils::GetOpDescFromOperator(operator_ir); - GE_CHECK_NOTNULL(opdesc_ir); - size_t current_opdesc_inputs_num = current_op_desc.GetInputsSize(); - size_t ir_opdesc_inputs_num = opdesc_ir->GetInputsSize(); - if (current_opdesc_inputs_num < ir_opdesc_inputs_num) { - string reason = "is smaller than the ir needed input size " + std::to_string(ir_opdesc_inputs_num); - ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, - {current_op_desc.GetName(), "input size " + std::to_string(current_opdesc_inputs_num), reason}); - GELOGE(PARAM_INVALID, "This op [%s] input size %zu is smaller than the ir needed input size %zu", - current_op_desc.GetName().c_str(), current_opdesc_inputs_num, ir_opdesc_inputs_num); - return PARAM_INVALID; - } - size_t current_opdesc_outputs_num = current_op_desc.GetOutputsSize(); - size_t ir_opdesc_outputs_num = opdesc_ir->GetOutputsSize(); - if (current_opdesc_outputs_num < ir_opdesc_outputs_num) { - string reason = "is smaller than the ir needed output size " + std::to_string(ir_opdesc_outputs_num); - ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, - {current_op_desc.GetName(), "output size " + std::to_string(current_opdesc_outputs_num), reason}); - GELOGE(PARAM_INVALID, "This op [%s] output size %zu is smaller than the ir needed output size %zu", - current_op_desc.GetName().c_str(), current_opdesc_outputs_num, ir_opdesc_outputs_num); - return PARAM_INVALID; - } - } - return SUCCESS; -} - -Status SingleOpParser::SetShapeRange(const std::string &op_name, - const SingleOpTensorDesc &tensor_desc, - GeTensorDesc &ge_tensor_desc) { - auto num_shape_ranges = tensor_desc.dim_ranges.size(); - GELOGD("Number of shape ranges = %zu", num_shape_ranges); - auto it = std::find(tensor_desc.dims.begin(), tensor_desc.dims.end(), ge::UNKNOWN_DIM_NUM); - if (it != tensor_desc.dims.end()) { - if (tensor_desc.dims != ge::UNKNOWN_RANK) { - ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, - {op_name, - "shape", - "has unknown rank but dim size is not one"}); - GELOGE(PARAM_INVALID, "Invalid tensor shape: [%s]", ge_tensor_desc.MutableShape().ToString().c_str()); - return PARAM_INVALID; - } - if (!tensor_desc.dim_ranges.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, - {op_name, - "shape range", - "is not needed while the rank the shape is unknown"}); - GELOGE(PARAM_INVALID, "Shape range is not needed while the rank the shape is unknown"); - return PARAM_INVALID; - } - - GELOGD("Shape is unknown rank, do not set shape range"); - return SUCCESS; - } - - std::vector> shape_range; - size_t range_index = 0; - for (auto dim : tensor_desc.dims) { - if (dim >= 0) { - shape_range.emplace_back(dim, dim); - GELOGD("Adding shape range: [%ld, %ld]", dim, dim); - } else { - GELOGD("To get shape range by index = %zu", range_index); - if (range_index >= num_shape_ranges) { - string reason = "is smaller than the unknown dim size " + std::to_string(++range_index); - ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, - {op_name, - "shape range size " + std::to_string(num_shape_ranges), - reason}); - GELOGE(PARAM_INVALID, "The number of shape_range mismatches that of unknown dims."); - return PARAM_INVALID; - } - - auto &range = tensor_desc.dim_ranges[range_index]; - if (range.size() != kShapeRangePairSize) { - string reason = "has " + std::to_string(range.size()) + " item(s)"; - ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, - {op_name, - "shape range " + std::to_string(range_index), - reason}); - GELOGE(PARAM_INVALID, "Invalid shape range entry. index = %zu, size = %zu", range_index, range.size()); - return PARAM_INVALID; - } - - shape_range.emplace_back(range[kShapeRangeLow], range[kShapeRangeHigh]); - GELOGD("Adding shape range: [%ld, %ld]", range[kShapeRangeLow], range[kShapeRangeHigh]); - ++range_index; - } - } - - if (num_shape_ranges != range_index) { - string reason = "is greater than the unknown dim size " + std::to_string(range_index); - ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, - {op_name, - "shape range size " + std::to_string(num_shape_ranges), - reason}); - GELOGE(PARAM_INVALID, - "The number of shape_range(%zu) mismatches that of unknown dims(%zu).", - num_shape_ranges, - range_index); - return PARAM_INVALID; - } - - if (range_index > 0) { - ge_tensor_desc.SetShapeRange(shape_range); - } - - return SUCCESS; -} - -Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector &op_list) { - Json single_op_list_json; - auto ret = ReadJsonFile(file, single_op_list_json); - if (ret != SUCCESS) { - return ret; - } - - int index = 0; - for (const Json &single_op_json : single_op_list_json) { - SingleOpDesc single_op_desc; - try { - GELOGI("Parsing op[%d], jsonStr = %s", index, single_op_json.dump(kDumpJsonIndent).c_str()); - single_op_desc = single_op_json; - } catch (const nlohmann::json::exception &e) { - ErrorManager::GetInstance().ATCReportErrMessage("E10032", {"index", "jsonfile", "exception"}, - {std::to_string(index), file, e.what()}); - GELOGE(PARAM_INVALID, "Parse the index[%d] of op failed when read json file[%s], exception %s", - index, file.c_str(), e.what()); - return PARAM_INVALID; - } - - if (!Validate(single_op_desc)) { - GELOGE(PARAM_INVALID, "Validate the index[%d] of op failed when read json file[%s].", index, file.c_str()); - return PARAM_INVALID; - } - - SingleOpBuildParam param; - ret = ConvertToBuildParam(index, single_op_desc, param); - if (ret != SUCCESS) { - return ret; - } - - op_list.emplace_back(param); - GELOGI("Parse the index[%d] of op success", index); - index += 1; - } - - return SUCCESS; -} -} // namespace ge - diff --git a/ge/offline/single_op_parser.h b/ge/offline/single_op_parser.h deleted file mode 100644 index 9a1bd962..00000000 --- a/ge/offline/single_op_parser.h +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef ACL_TOOLS_COMPILE_PARSER_H -#define ACL_TOOLS_COMPILE_PARSER_H - -#include -#include - -#include - -#include "ge/ge_api_error_codes.h" -#include "graph/types.h" -#include "graph/ge_attr_value.h" -#include "graph/op_desc.h" - -namespace ge { -struct SingleOpTensorDesc { - std::string name; - std::vector dims; - std::vector> dim_ranges; - ge::Format format = ge::FORMAT_RESERVED; - ge::DataType type = ge::DT_UNDEFINED; -}; - -struct SingleOpAttr { - std::string name; - std::string type; - ge::GeAttrValue value; -}; - -struct SingleOpDesc { - std::string op; - std::vector input_desc; - std::vector output_desc; - std::vector attrs; -}; - -struct SingleOpBuildParam { - ge::OpDescPtr op_desc; - std::vector inputs; - std::vector outputs; - std::string file_name; -}; - -void from_json(const nlohmann::json &json, SingleOpTensorDesc &desc); - -void from_json(const nlohmann::json &json, SingleOpAttr &desc); - -void from_json(const nlohmann::json &json, SingleOpDesc &desc); - -class SingleOpParser { - public: - static Status ParseSingleOpList(const std::string &file, std::vector &op_list); - - private: - static Status ReadJsonFile(const std::string &file, nlohmann::json &json_obj); - static bool Validate(const SingleOpDesc &op_desc); - static std::unique_ptr CreateOpDesc(const std::string &op_type); - static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param); - static Status VerifyOpInputOutputSizeByIr(const OpDesc ¤t_op_desc); - static Status SetShapeRange(const std::string &op_name, const SingleOpTensorDesc &tensor_desc, GeTensorDesc &ge_tensor_desc); -}; -} // namespace ge - -#endif // ACL_TOOLS_COMPILE_PARSER_H diff --git a/ge/plugin/engine/CMakeLists.txt b/ge/plugin/engine/CMakeLists.txt deleted file mode 100644 index 87a6d682..00000000 --- a/ge/plugin/engine/CMakeLists.txt +++ /dev/null @@ -1,49 +0,0 @@ -set(SRC_LIST - "dnnengines.cc" - "engine_manage.cc" -) - -############ libengine.so ############ -add_library(engine SHARED ${SRC_LIST}) - -target_compile_options(engine PRIVATE - -Werror -) - -target_compile_definitions(engine PRIVATE - REUSE_MEMORY=1 - PROTOBUF_INLINE_NOT_IN_HEADERS=0 -) - -target_include_directories(engine PRIVATE - ${GE_CODE_DIR}/ge - ${GE_CODE_DIR}/inc/ - ${GE_CODE_DIR}/inc/framework - ${GE_CODE_DIR}/inc/framework/common - ${GE_CODE_DIR}/inc/external - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/ge - #### yellow zone #### - ${GE_CODE_DIR}/../inc - #### blue zone #### - ${GE_CODE_DIR}/third_party/fwkacllib/inc -) - -target_link_libraries(engine PRIVATE - $ - -Wl,--no-as-needed - slog - -Wl,--as-needed - -lrt - -ldl -) - -############ install ############ -set(INSTALL_BASE_DIR "") -set(INSTALL_LIBRARY_DIR lib) - -install(TARGETS engine OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} -) diff --git a/ge/proto/caffe/caffe.proto b/ge/proto/caffe/caffe.proto deleted file mode 100644 index 3f45aae2..00000000 --- a/ge/proto/caffe/caffe.proto +++ /dev/null @@ -1,1821 +0,0 @@ -syntax = "proto2"; - -package domi.caffe; - -// Specifies the shape (dimensions) of a Blob. -message BlobShape { - repeated int64 dim = 1 [packed = true]; -} - -message BlobProto { - optional BlobShape shape = 7; - repeated float data = 5 [packed = true]; - repeated float diff = 6 [packed = true]; - repeated double double_data = 8 [packed = true]; - repeated double double_diff = 9 [packed = true]; - optional bytes int8_data = 10; - repeated int32 int32_data = 11 [packed = true]; - repeated uint64 uint64_data = 12 [packed = true]; - // 4D dimensions -- deprecated. Use "shape" instead. - optional int32 num = 1 [default = 0]; - optional int32 channels = 2 [default = 0]; - optional int32 height = 3 [default = 0]; - optional int32 width = 4 [default = 0]; -} - -// The BlobProtoVector is simply a way to pass multiple blobproto instances -// around. -message BlobProtoVector { - repeated BlobProto blobs = 1; -} - -message Datum { - optional int32 channels = 1; - optional int32 height = 2; - optional int32 width = 3; - // the actual image data, in bytes - optional bytes data = 4; - optional int32 label = 5; - // Optionally, the datum could also hold float data. - repeated float float_data = 6; - // If true data contains an encoded image that need to be decoded - optional bool encoded = 7 [default = false]; -} - -message FillerParameter { - // The filler type. - optional string type = 1 [default = 'constant']; - optional float value = 2 [default = 0]; // the value in constant filler - optional float min = 3 [default = 0]; // the min value in uniform filler - optional float max = 4 [default = 1]; // the max value in uniform filler - optional float mean = 5 [default = 0]; // the mean value in Gaussian filler - optional float std = 6 [default = 1]; // the std value in Gaussian filler - // The expected number of non-zero output weights for a given input in - // Gaussian filler -- the default -1 means don't perform sparsification. - optional int32 sparse = 7 [default = -1]; - // Normalize the filler variance by fan_in, fan_out, or their average. - // Applies to 'xavier' and 'msra' fillers. - enum VarianceNorm { - FAN_IN = 0; - FAN_OUT = 1; - AVERAGE = 2; - } - optional VarianceNorm variance_norm = 8 [default = FAN_IN]; -} - -message NetParameter { - optional string name = 1; // consider giving the network a name - // DEPRECATED. See InputParameter. The input blobs to the network. - repeated string input = 3; - // DEPRECATED. See InputParameter. The shape of the input blobs. - repeated BlobShape input_shape = 8; - - // 4D input dimensions -- deprecated. Use "input_shape" instead. - // If specified, for each input blob there should be four - // values specifying the num, channels, height and width of the input blob. - // Thus, there should be a total of (4 * #input) numbers. - repeated int32 input_dim = 4; - - // Whether the network will force every layer to carry out backward operation. - // If set False, then whether to carry out backward is determined - // automatically according to the net structure and learning rates. - optional bool force_backward = 5 [default = false]; - // The current "state" of the network, including the phase, level, and stage. - // Some layers may be included/excluded depending on this state and the states - // specified in the layers' include and exclude fields. - optional NetState state = 6; - - // Print debugging information about results while running Net::Forward, - // Net::Backward, and Net::Update. - optional bool debug_info = 7 [default = false]; - - // The layers that make up the net. Each of their configurations, including - // connectivity and behavior, is specified as a LayerParameter. - repeated LayerParameter layer = 100; // ID 100 so layers are printed last. - - // DEPRECATED: use 'layer' instead. - repeated V1LayerParameter layers = 2; -} - -// NOTE -// Update the next available ID when you add a new SolverParameter field. -// -// SolverParameter next available ID: 42 (last added: layer_wise_reduce) -message SolverParameter { - ////////////////////////////////////////////////////////////////////////////// - // Specifying the train and test networks - // - // Exactly one train net must be specified using one of the following fields: - // train_net_param, train_net, net_param, net - // One or more test nets may be specified using any of the following fields: - // test_net_param, test_net, net_param, net - // If more than one test net field is specified (e.g., both net and - // test_net are specified), they will be evaluated in the field order given - // above: (1) test_net_param, (2) test_net, (3) net_param/net. - // A test_iter must be specified for each test_net. - // A test_level and/or a test_stage may also be specified for each test_net. - ////////////////////////////////////////////////////////////////////////////// - - // Proto filename for the train net, possibly combined with one or more - // test nets. - optional string net = 24; - // Inline train net param, possibly combined with one or more test nets. - optional NetParameter net_param = 25; - - optional string train_net = 1; // Proto filename for the train net. - repeated string test_net = 2; // Proto filenames for the test nets. - optional NetParameter train_net_param = 21; // Inline train net params. - repeated NetParameter test_net_param = 22; // Inline test net params. - - // The states for the train/test nets. Must be unspecified or - // specified once per net. - // - // By default, all states will have solver = true; - // train_state will have phase = TRAIN, - // and all test_state's will have phase = TEST. - // Other defaults are set according to the NetState defaults. - optional NetState train_state = 26; - repeated NetState test_state = 27; - - // The number of iterations for each test net. - repeated int32 test_iter = 3; - - // The number of iterations between two testing phases. - optional int32 test_interval = 4 [default = 0]; - optional bool test_compute_loss = 19 [default = false]; - // If true, run an initial test pass before the first iteration, - // ensuring memory availability and printing the starting value of the loss. - optional bool test_initialization = 32 [default = true]; - optional float base_lr = 5; // The base learning rate - // the number of iterations between displaying info. If display = 0, no info - // will be displayed. - optional int32 display = 6; - // Display the loss averaged over the last average_loss iterations - optional int32 average_loss = 33 [default = 1]; - optional int32 max_iter = 7; // the maximum number of iterations - // accumulate gradients over `iter_size` x `batch_size` instances - optional int32 iter_size = 36 [default = 1]; - - // The learning rate decay policy. The currently implemented learning rate - // policies are as follows: - // - fixed: always return base_lr. - // - step: return base_lr * gamma ^ (floor(iter / step)) - // - exp: return base_lr * gamma ^ iter - // - inv: return base_lr * (1 + gamma * iter) ^ (- power) - // - multistep: similar to step but it allows non uniform steps defined by - // stepvalue - // - poly: the effective learning rate follows a polynomial decay, to be - // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) - // - sigmoid: the effective learning rate follows a sigmod decay - // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) - // - // where base_lr, max_iter, gamma, step, stepvalue and power are defined - // in the solver parameter protocol buffer, and iter is the current iteration. - optional string lr_policy = 8; - optional float gamma = 9; // The parameter to compute the learning rate. - optional float power = 10; // The parameter to compute the learning rate. - optional float momentum = 11; // The momentum value. - optional float weight_decay = 12; // The weight decay. - // regularization types supported: L1 and L2 - // controlled by weight_decay - optional string regularization_type = 29 [default = "L2"]; - // the stepsize for learning rate policy "step" - optional int32 stepsize = 13; - // the stepsize for learning rate policy "multistep" - repeated int32 stepvalue = 34; - - // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, - // whenever their actual L2 norm is larger. - optional float clip_gradients = 35 [default = -1]; - - optional int32 snapshot = 14 [default = 0]; // The snapshot interval - optional string snapshot_prefix = 15; // The prefix for the snapshot. - // whether to snapshot diff in the results or not. Snapshotting diff will help - // debugging but the final protocol buffer size will be much larger. - optional bool snapshot_diff = 16 [default = false]; - enum SnapshotFormat { - HDF5 = 0; - BINARYPROTO = 1; - } - optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO]; - // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. - enum SolverMode { - CPU = 0; - GPU = 1; - } - optional SolverMode solver_mode = 17 [default = GPU]; - // the device_id will that be used in GPU mode. Use device_id = 0 in default. - optional int32 device_id = 18 [default = 0]; - // If non-negative, the seed with which the Solver will initialize the Caffe - // random number generator -- useful for reproducible results. Otherwise, - // (and by default) initialize using a seed derived from the system clock. - optional int64 random_seed = 20 [default = -1]; - - // type of the solver - optional string type = 40 [default = "SGD"]; - - // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam - optional float delta = 31 [default = 1e-8]; - // parameters for the Adam solver - optional float momentum2 = 39 [default = 0.999]; - - // RMSProp decay value - // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) - optional float rms_decay = 38 [default = 0.99]; - - // If true, print information about the state of the net that may help with - // debugging learning problems. - optional bool debug_info = 23 [default = false]; - - // If false, don't save a snapshot after training finishes. - optional bool snapshot_after_train = 28 [default = true]; - - // DEPRECATED: old solver enum types, use string instead - enum SolverType { - SGD = 0; - NESTEROV = 1; - ADAGRAD = 2; - RMSPROP = 3; - ADADELTA = 4; - ADAM = 5; - } - // DEPRECATED: use type instead of solver_type - optional SolverType solver_type = 30 [default = SGD]; - - // Overlap compute and communication for data parallel training - optional bool layer_wise_reduce = 41 [default = true]; -} - -// A message that stores the solver snapshots -message SolverState { - optional int32 iter = 1; // The current iteration - optional string learned_net = 2; // The file that stores the learned net. - repeated BlobProto history = 3; // The history for sgd solvers - optional int32 current_step = 4 [default = 0]; // The current step for learning rate -} - -enum Phase { - TRAIN = 0; - TEST = 1; -} - -message NetState { - optional Phase phase = 1 [default = TEST]; - optional int32 level = 2 [default = 0]; - repeated string stage = 3; -} - -message NetStateRule { - // Set phase to require the NetState have a particular phase (TRAIN or TEST) - // to meet this rule. - optional Phase phase = 1; - - // Set the minimum and/or maximum levels in which the layer should be used. - // Leave undefined to meet the rule regardless of level. - optional int32 min_level = 2; - optional int32 max_level = 3; - - // Customizable sets of stages to include or exclude. - // The net must have ALL of the specified stages and NONE of the specified - // "not_stage"s to meet the rule. - // (Use multiple NetStateRules to specify conjunctions of stages.) - repeated string stage = 4; - repeated string not_stage = 5; -} - -// Specifies training parameters (multipliers on global learning constants, -// and the name and other settings used for weight sharing). -message ParamSpec { - // The names of the parameter blobs -- useful for sharing parameters among - // layers, but never required otherwise. To share a parameter between two - // layers, give it a (non-empty) name. - optional string name = 1; - - // Whether to require shared weights to have the same shape, or just the same - // count -- defaults to STRICT if unspecified. - optional DimCheckMode share_mode = 2; - enum DimCheckMode { - // STRICT (default) requires that num, channels, height, width each match. - STRICT = 0; - // PERMISSIVE requires only the count (num*channels*height*width) to match. - PERMISSIVE = 1; - } - - // The multiplier on the global learning rate for this parameter. - optional float lr_mult = 3 [default = 1.0]; - - // The multiplier on the global weight decay for this parameter. - optional float decay_mult = 4 [default = 1.0]; -} - -// NOTE -// Update the next available ID when you add a new LayerParameter field. -// -// LayerParameter next available layer-specific ID: 151 (last added: smooth_l1_loss_param) -message LayerParameter { - optional string name = 1; // the layer name - optional string type = 2; // the layer type - repeated string bottom = 3; // the name of each bottom blob - repeated string top = 4; // the name of each top blob - - // The train / test phase for computation. - optional Phase phase = 10; - - // The amount of weight to assign each top blob in the objective. - // Each layer assigns a default value, usually of either 0 or 1, - // to each top blob. - repeated float loss_weight = 5; - - // Specifies training parameters (multipliers on global learning constants, - // and the name and other settings used for weight sharing). - repeated ParamSpec param = 6; - - // The blobs containing the numeric parameters of the layer. - repeated BlobProto blobs = 7; - - // Specifies whether to backpropagate to each bottom. If unspecified, - // Caffe will automatically infer whether each input needs backpropagation - // to compute parameter gradients. If set to true for some inputs, - // backpropagation to those inputs is forced; if set false for some inputs, - // backpropagation to those inputs is skipped. - // - // The size must be either 0 or equal to the number of bottoms. - repeated bool propagate_down = 11; - - // Rules controlling whether and when a layer is included in the network, - // based on the current NetState. You may specify a non-zero number of rules - // to include OR exclude, but not both. If no include or exclude rules are - // specified, the layer is always included. If the current NetState meets - // ANY (i.e., one or more) of the specified rules, the layer is - // included/excluded. - repeated NetStateRule include = 8; - repeated NetStateRule exclude = 9; - - // Parameters for data pre-processing. - optional TransformationParameter transform_param = 100; - - // Parameters shared by loss layers. - optional LossParameter loss_param = 101; - - // Layer type-specific parameters. - // - // Note: certain layers may have more than one computational engine - // for their implementation. These layers include an Engine type and - // engine parameter for selecting the implementation. - // The default for the engine is set by the ENGINE switch at compile-time. - optional AccuracyParameter accuracy_param = 102; - optional ArgMaxParameter argmax_param = 103; - optional BatchNormParameter batch_norm_param = 139; - optional BiasParameter bias_param = 141; - optional ConcatParameter concat_param = 104; - optional ContrastiveLossParameter contrastive_loss_param = 105; - optional ConvolutionParameter convolution_param = 106; - optional CropParameter crop_param = 144; - optional DataParameter data_param = 107; - optional DetectionOutputParameter detection_output_param = 150; - optional DropoutParameter dropout_param = 108; - optional DummyDataParameter dummy_data_param = 109; - optional EltwiseParameter eltwise_param = 110; - optional ELUParameter elu_param = 140; - optional EmbedParameter embed_param = 137; - optional ExpParameter exp_param = 111; - optional FlattenParameter flatten_param = 135; - optional HDF5DataParameter hdf5_data_param = 112; - optional HDF5OutputParameter hdf5_output_param = 113; - optional HingeLossParameter hinge_loss_param = 114; - optional ImageDataParameter image_data_param = 115; - optional InfogainLossParameter infogain_loss_param = 116; - optional InnerProductParameter inner_product_param = 117; - optional InputParameter input_param = 143; - optional LogParameter log_param = 134; - optional LRNParameter lrn_param = 118; - optional MemoryDataParameter memory_data_param = 119; - optional MVNParameter mvn_param = 120; - optional ParameterParameter parameter_param = 145; - optional PoolingParameter pooling_param = 121; - optional PowerParameter power_param = 122; - optional PReLUParameter prelu_param = 131; - optional PythonParameter python_param = 130; - optional RecurrentParameter recurrent_param = 146; - optional ReductionParameter reduction_param = 136; - optional ReLUParameter relu_param = 123; - optional ReshapeParameter reshape_param = 133; - optional ScaleParameter scale_param = 142; - optional SigmoidParameter sigmoid_param = 124; - optional SmoothL1LossParameter smooth_l1_loss_param = 148; - optional SoftmaxParameter softmax_param = 125; - optional SPPParameter spp_param = 132; - optional SliceParameter slice_param = 126; - optional TanHParameter tanh_param = 127; - optional ThresholdParameter threshold_param = 128; - optional TileParameter tile_param = 138; - optional WindowDataParameter window_data_param = 129; - optional PermuteParameter permute_param = 202; - optional PriorBoxParameter prior_box_param = 203; - optional NormalizeParameter norm_param = 206; - optional PSROIPoolingParameter psroi_pooling_param = 207; - optional FreespaceExtractParameter freespace_extract_param = 151; - optional PostprocessParameter postprocess_param = 152; - optional SpatialTransformParameter spatial_transform_param = 153; - optional ROIAlignParameter roi_align_param = 154; - optional ReorgParameter reorg_param = 155; - optional RegionParameter region_param = 156; - optional ReverseParameter reverse_param = 157; - optional InterpParameter interp_param = 158; - optional ShuffleChannelParameter shuffle_channel_param = 159; - optional UpsampleParameter upsample_param = 160; - optional ROIPoolingParameter roi_pooling_param = 161; - optional YoloParameter yolo_param = 199; - optional YoloV3DetectionOutputParameter yolov3_detection_output_param = 200; - optional ProposalParameter proposal_param = 201; - optional FSRDetectionOutputParameter fsrdetectionoutput_param = 222; - optional SSDDetectionOutputParameter ssddetectionoutput_param = 232; - optional YoloV2DetectionOutputParameter yolov2_detection_output_param = 204; - optional QuantParameter quant_param = 208; - optional CondTakeParameter condtake_param = 233; - optional MatrixInverseParameter matrix_inverse_param = 210; - optional WarpPerspectiveParameter warp_perspective_param = 234; - optional BatchMatMulParameter batch_matmul_param = 235; - optional SpatialTransformerParameter st_param = 5000; - optional YoloV3DetectionOutputV2Parameter yolov3_detection_output_v2_param = 5001; -} - -// Message that stores parameters used to apply transformation -// to the data layer's data -message TransformationParameter { - // For data pre-processing, we can do simple scaling and subtracting the - // data mean, if provided. Note that the mean subtraction is always carried - // out before scaling. - optional float scale = 1 [default = 1]; - // Specify if we want to randomly mirror data. - optional bool mirror = 2 [default = false]; - // Specify if we would like to randomly crop an image. - optional uint32 crop_size = 3 [default = 0]; - // mean_file and mean_value cannot be specified at the same time - optional string mean_file = 4; - // if specified can be repeated once (would substract it from all the channels) - // or can be repeated the same number of times as channels - // (would subtract them from the corresponding channel) - repeated float mean_value = 5; - // Force the decoded image to have 3 color channels. - optional bool force_color = 6 [default = false]; - // Force the decoded image to have 1 color channels. - optional bool force_gray = 7 [default = false]; -} - -// Message that stores parameters shared by loss layers -message LossParameter { - // If specified, ignore instances with the given label. - optional int32 ignore_label = 1; - // How to normalize the loss for loss layers that aggregate across batches, - // spatial dimensions, or other dimensions. Currently only implemented in - // SoftmaxWithLoss and SigmoidCrossEntropyLoss layers. - enum NormalizationMode { - // Divide by the number of examples in the batch times spatial dimensions. - // Outputs that receive the ignore label will NOT be ignored in computing - // the normalization factor. - FULL = 0; - // Divide by the total number of output locations that do not take the - // ignore_label. If ignore_label is not set, this behaves like FULL. - VALID = 1; - // Divide by the batch size. - BATCH_SIZE = 2; - // Do not normalize the loss. - NONE = 3; - } - // For historical reasons, the default normalization for - // SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID. - optional NormalizationMode normalization = 3 [default = VALID]; - // Deprecated. Ignored if normalization is specified. If normalization - // is not specified, then setting this to false will be equivalent to - // normalization = BATCH_SIZE to be consistent with previous behavior. - optional bool normalize = 2; -} - -// Messages that store parameters used by individual layer types follow, in -// alphabetical order. - -message AccuracyParameter { - // When computing accuracy, count as correct by comparing the true label to - // the top k scoring classes. By default, only compare to the top scoring - // class (i.e. argmax). - optional uint32 top_k = 1 [default = 1]; - - // The "label" axis of the prediction blob, whose argmax corresponds to the - // predicted label -- may be negative to index from the end (e.g., -1 for the - // last axis). For example, if axis == 1 and the predictions are - // (N x C x H x W), the label blob is expected to contain N*H*W ground truth - // labels with integer values in {0, 1, ..., C-1}. - optional int32 axis = 2 [default = 1]; - - // If specified, ignore instances with the given label. - optional int32 ignore_label = 3; -} - -message ArgMaxParameter { - // If true produce pairs (argmax, maxval) - optional bool out_max_val = 1 [default = false]; - optional uint32 top_k = 2 [default = 1]; - // The axis along which to maximise -- may be negative to index from the - // end (e.g., -1 for the last axis). - // By default ArgMaxLayer maximizes over the flattened trailing dimensions - // for each index of the first / num dimension. - optional int32 axis = 3; -} - -message ConcatParameter { - // The axis along which to concatenate -- may be negative to index from the - // end (e.g., -1 for the last axis). Other axes must have the - // same dimension for all the bottom blobs. - // By default, ConcatLayer concatenates blobs along the "channels" axis (1). - optional int32 axis = 2 [default = 1]; - - // DEPRECATED: alias for "axis" -- does not support negative indexing. - optional uint32 concat_dim = 1 [default = 1]; -} - -message BatchNormParameter { - // If false, normalization is performed over the current mini-batch - // and global statistics are accumulated (but not yet used) by a moving - // average. - // If true, those accumulated mean and variance values are used for the - // normalization. - // By default, it is set to false when the network is in the training - // phase and true when the network is in the testing phase. - optional bool use_global_stats = 1; - // What fraction of the moving average remains each iteration? - // Smaller values make the moving average decay faster, giving more - // weight to the recent values. - // Each iteration updates the moving average @f$S_{t-1}@f$ with the - // current mean @f$ Y_t @f$ by - // @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$ - // is the moving_average_fraction parameter. - optional float moving_average_fraction = 2 [default = .999]; - // Small value to add to the variance estimate so that we don't divide by - // zero. - optional float eps = 3 [default = 1e-5]; -} - -message BiasParameter { - // The first axis of bottom[0] (the first input Blob) along which to apply - // bottom[1] (the second input Blob). May be negative to index from the end - // (e.g., -1 for the last axis). - // - // For example, if bottom[0] is 4D with shape 100x3x40x60, the output - // top[0] will have the same shape, and bottom[1] may have any of the - // following shapes (for the given value of axis): - // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 - // (axis == 1 == -3) 3; 3x40; 3x40x60 - // (axis == 2 == -2) 40; 40x60 - // (axis == 3 == -1) 60 - // Furthermore, bottom[1] may have the empty shape (regardless of the value of - // "axis") -- a scalar bias. - optional int32 axis = 1 [default = 1]; - - // (num_axes is ignored unless just one bottom is given and the bias is - // a learned parameter of the layer. Otherwise, num_axes is determined by the - // number of axes by the second bottom.) - // The number of axes of the input (bottom[0]) covered by the bias - // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. - // Set num_axes := 0, to add a zero-axis Blob: a scalar. - optional int32 num_axes = 2 [default = 1]; - - // (filler is ignored unless just one bottom is given and the bias is - // a learned parameter of the layer.) - // The initialization for the learned bias parameter. - // Default is the zero (0) initialization, resulting in the BiasLayer - // initially performing the identity operation. - optional FillerParameter filler = 3; - optional bool bias_from_blob = 4 [default = true]; -} - -message ContrastiveLossParameter { - // margin for dissimilar pair - optional float margin = 1 [default = 1.0]; - // The first implementation of this cost did not exactly match the cost of - // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2. - // legacy_version = false (the default) uses (margin - d)^2 as proposed in the - // Hadsell paper. New models should probably use this version. - // legacy_version = true uses (margin - d^2). This is kept to support / - // reproduce existing models and results - optional bool legacy_version = 2 [default = false]; -} - -message ConvolutionParameter { - optional uint32 num_output = 1; // The number of outputs for the layer - optional bool bias_term = 2 [default = true]; // whether to have bias terms - - // Pad, kernel size, and stride are all given as a single value for equal - // dimensions in all spatial dimensions, or once per spatial dimension. - repeated uint32 pad = 3; // The padding size; defaults to 0 - repeated uint32 kernel_size = 4; // The kernel size - repeated uint32 stride = 6; // The stride; defaults to 1 - // Factor used to dilate the kernel, (implicitly) zero-filling the resulting - // holes. (Kernel dilation is sometimes referred to by its use in the - // algorithme à trous from Holschneider et al. 1987.) - repeated uint32 dilation = 18; // The dilation; defaults to 1 - - // For 2D convolution only, the *_h and *_w versions may also be used to - // specify both spatial dimensions. - optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only) - optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only) - optional uint32 kernel_h = 11; // The kernel height (2D only) - optional uint32 kernel_w = 12; // The kernel width (2D only) - optional uint32 stride_h = 13; // The stride height (2D only) - optional uint32 stride_w = 14; // The stride width (2D only) - - optional uint32 group = 5 [default = 1]; // The group size for group conv - - optional FillerParameter weight_filler = 7; // The filler for the weight - optional FillerParameter bias_filler = 8; // The filler for the bias - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 15 [default = DEFAULT]; - - // The axis to interpret as "channels" when performing convolution. - // Preceding dimensions are treated as independent inputs; - // succeeding dimensions are treated as "spatial". - // With (N, C, H, W) inputs, and axis == 1 (the default), we perform - // N independent 2D convolutions, sliding C-channel (or (C/g)-channels, for - // groups g>1) filters across the spatial axes (H, W) of the input. - // With (N, C, D, H, W) inputs, and axis == 1, we perform - // N independent 3D convolutions, sliding (C/g)-channels - // filters across the spatial axes (D, H, W) of the input. - optional int32 axis = 16 [default = 1]; - - // Whether to force use of the general ND convolution, even if a specific - // implementation for blobs of the appropriate number of spatial dimensions - // is available. (Currently, there is only a 2D-specific convolution - // implementation; for input blobs with num_axes != 2, this option is - // ignored and the ND implementation will be used.) - optional bool force_nd_im2col = 17 [default = false]; -} - -message CropParameter { - // To crop, elements of the first bottom are selected to fit the dimensions - // of the second, reference bottom. The crop is configured by - // - the crop `axis` to pick the dimensions for cropping - // - the crop `offset` to set the shift for all/each dimension - // to align the cropped bottom with the reference bottom. - // All dimensions up to but excluding `axis` are preserved, while - // the dimensions including and trailing `axis` are cropped. - // If only one `offset` is set, then all dimensions are offset by this amount. - // Otherwise, the number of offsets must equal the number of cropped axes to - // shift the crop in each dimension accordingly. - // Note: standard dimensions are N,C,H,W so the default is a spatial crop, - // and `axis` may be negative to index from the end (e.g., -1 for the last - // axis). - optional int32 axis = 1 [default = 2]; - repeated uint32 offset = 2; -} - -message DataParameter { - enum DB { - LEVELDB = 0; - LMDB = 1; - } - // Specify the data source. - optional string source = 1; - // Specify the batch size. - optional uint32 batch_size = 4; - // The rand_skip variable is for the data layer to skip a few data points - // to avoid all asynchronous sgd clients to start at the same point. The skip - // point would be set as rand_skip * rand(0,1). Note that rand_skip should not - // be larger than the number of keys in the database. - // DEPRECATED. Each solver accesses a different subset of the database. - optional uint32 rand_skip = 7 [default = 0]; - optional DB backend = 8 [default = LEVELDB]; - // DEPRECATED. See TransformationParameter. For data pre-processing, we can do - // simple scaling and subtracting the data mean, if provided. Note that the - // mean subtraction is always carried out before scaling. - optional float scale = 2 [default = 1]; - optional string mean_file = 3; - // DEPRECATED. See TransformationParameter. Specify if we would like to randomly - // crop an image. - optional uint32 crop_size = 5 [default = 0]; - // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror - // data. - optional bool mirror = 6 [default = false]; - // Force the encoded image to have 3 color channels - optional bool force_encoded_color = 9 [default = false]; - // Prefetch queue (Increase if data feeding bandwidth varies, within the - // limit of device memory for GPU training) - optional uint32 prefetch = 10 [default = 4]; -} - -message DropoutParameter { - optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio - optional bool scale_train = 2 [default = true]; // scale train or test phase -} - -// DummyDataLayer fills any number of arbitrarily shaped blobs with random -// (or constant) data generated by "Fillers" (see "message FillerParameter"). -message DummyDataParameter { - // This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or N - // shape fields, and 0, 1 or N data_fillers. - // - // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used. - // If 1 data_filler is specified, it is applied to all top blobs. If N are - // specified, the ith is applied to the ith top blob. - repeated FillerParameter data_filler = 1; - repeated BlobShape shape = 6; - - // 4D dimensions -- deprecated. Use "shape" instead. - repeated uint32 num = 2; - repeated uint32 channels = 3; - repeated uint32 height = 4; - repeated uint32 width = 5; -} - -message EltwiseParameter { - enum EltwiseOp { - PROD = 0; - SUM = 1; - MAX = 2; - } - optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation - repeated float coeff = 2; // blob-wise coefficient for SUM operation - - // Whether to use an asymptotically slower (for >2 inputs) but stabler method - // of computing the gradient for the PROD operation. (No effect for SUM op.) - optional bool stable_prod_grad = 3 [default = true]; -} - -// Message that stores parameters used by ELULayer -message ELUParameter { - // Described in: - // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate - // Deep Network Learning by Exponential Linear Units (ELUs). arXiv - optional float alpha = 1 [default = 1]; -} - -// Message that stores parameters used by EmbedLayer -message EmbedParameter { - optional uint32 num_output = 1; // The number of outputs for the layer - // The input is given as integers to be interpreted as one-hot - // vector indices with dimension num_input. Hence num_input should be - // 1 greater than the maximum possible input value. - optional uint32 input_dim = 2; - - optional bool bias_term = 3 [default = true]; // Whether to use a bias term - optional FillerParameter weight_filler = 4; // The filler for the weight - optional FillerParameter bias_filler = 5; // The filler for the bias - -} - -// Message that stores parameters used by ExpLayer -message ExpParameter { - // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. - // Or if base is set to the default (-1), base is set to e, - // so y = exp(shift + scale * x). - optional float base = 1 [default = -1.0]; - optional float scale = 2 [default = 1.0]; - optional float shift = 3 [default = 0.0]; -} - -/// Message that stores parameters used by FlattenLayer -message FlattenParameter { - // The first axis to flatten: all preceding axes are retained in the output. - // May be negative to index from the end (e.g., -1 for the last axis). - optional int32 axis = 1 [default = 1]; - - // The last axis to flatten: all following axes are retained in the output. - // May be negative to index from the end (e.g., the default -1 for the last - // axis). - optional int32 end_axis = 2 [default = -1]; -} - -// Message that stores parameters used by HDF5DataLayer -message HDF5DataParameter { - // Specify the data source. - optional string source = 1; - // Specify the batch size. - optional uint32 batch_size = 2; - - // Specify whether to shuffle the data. - // If shuffle == true, the ordering of the HDF5 files is shuffled, - // and the ordering of data within any given HDF5 file is shuffled, - // but data between different files are not interleaved; all of a file's - // data are output (in a random order) before moving onto another file. - optional bool shuffle = 3 [default = false]; -} - -message HDF5OutputParameter { - optional string file_name = 1; -} - -message HingeLossParameter { - enum Norm { - L1 = 1; - L2 = 2; - } - // Specify the Norm to use L1 or L2 - optional Norm norm = 1 [default = L1]; -} - -message ImageDataParameter { - // Specify the data source. - optional string source = 1; - // Specify the batch size. - optional uint32 batch_size = 4 [default = 1]; - // The rand_skip variable is for the data layer to skip a few data points - // to avoid all asynchronous sgd clients to start at the same point. The skip - // point would be set as rand_skip * rand(0,1). Note that rand_skip should not - // be larger than the number of keys in the database. - optional uint32 rand_skip = 7 [default = 0]; - // Whether or not ImageLayer should shuffle the list of files at every epoch. - optional bool shuffle = 8 [default = false]; - // It will also resize images if new_height or new_width are not zero. - optional uint32 new_height = 9 [default = 0]; - optional uint32 new_width = 10 [default = 0]; - // Specify if the images are color or gray - optional bool is_color = 11 [default = true]; - // DEPRECATED. See TransformationParameter. For data pre-processing, we can do - // simple scaling and subtracting the data mean, if provided. Note that the - // mean subtraction is always carried out before scaling. - optional float scale = 2 [default = 1]; - optional string mean_file = 3; - // DEPRECATED. See TransformationParameter. Specify if we would like to randomly - // crop an image. - optional uint32 crop_size = 5 [default = 0]; - // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror - // data. - optional bool mirror = 6 [default = false]; - optional string root_folder = 12 [default = ""]; -} - -message InfogainLossParameter { - // Specify the infogain matrix source. - optional string source = 1; - optional int32 axis = 2 [default = 1]; // axis of prob -} - -message InnerProductParameter { - optional uint32 num_output = 1; // The number of outputs for the layer - optional bool bias_term = 2 [default = true]; // whether to have bias terms - optional FillerParameter weight_filler = 3; // The filler for the weight - optional FillerParameter bias_filler = 4; // The filler for the bias - - // The first axis to be lumped into a single inner product computation; - // all preceding axes are retained in the output. - // May be negative to index from the end (e.g., -1 for the last axis). - optional int32 axis = 5 [default = 1]; - // Specify whether to transpose the weight matrix or not. - // If transpose == true, any operations will be performed on the transpose - // of the weight matrix. The weight matrix itself is not going to be transposed - // but rather the transfer flag of operations will be toggled accordingly. - optional bool transpose = 6 [default = false]; -} - -message InputParameter { - // This layer produces N >= 1 top blob(s) to be assigned manually. - // Define N shapes to set a shape for each top. - // Define 1 shape to set the same shape for every top. - // Define no shape to defer to reshaping manually. - repeated BlobShape shape = 1; -} - -// Message that stores parameters used by LogLayer -message LogParameter { - // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0. - // Or if base is set to the default (-1), base is set to e, - // so y = ln(shift + scale * x) = log_e(shift + scale * x) - optional float base = 1 [default = -1.0]; - optional float scale = 2 [default = 1.0]; - optional float shift = 3 [default = 0.0]; -} - -// Message that stores parameters used by LRNLayer -message LRNParameter { - optional uint32 local_size = 1 [default = 5]; - optional float alpha = 2 [default = 1.]; - optional float beta = 3 [default = 0.75]; - enum NormRegion { - ACROSS_CHANNELS = 0; - WITHIN_CHANNEL = 1; - } - optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; - optional float k = 5 [default = 1.]; - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 6 [default = DEFAULT]; -} - -message MemoryDataParameter { - optional uint32 batch_size = 1; - optional uint32 channels = 2; - optional uint32 height = 3; - optional uint32 width = 4; -} - -message MVNParameter { - // This parameter can be set to false to normalize mean only - optional bool normalize_variance = 1 [default = true]; - - // This parameter can be set to true to perform DNN-like MVN - optional bool across_channels = 2 [default = false]; - - // Epsilon for not dividing by zero while normalizing variance - optional float eps = 3 [default = 1e-9]; -} - -message ParameterParameter { - optional BlobShape shape = 1; -} - -message PoolingParameter { - enum PoolMethod { - MAX = 0; - AVE = 1; - STOCHASTIC = 2; - } - optional PoolMethod pool = 1 [default = MAX]; // The pooling method - // Pad, kernel size, and stride are all given as a single value for equal - // dimensions in height and width or as Y, X pairs. - optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X) - optional uint32 pad_h = 9 [default = 0]; // The padding height - optional uint32 pad_w = 10 [default = 0]; // The padding width - optional uint32 kernel_size = 2; // The kernel size (square) - optional uint32 kernel_h = 5; // The kernel height - optional uint32 kernel_w = 6; // The kernel width - optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X) - optional uint32 stride_h = 7; // The stride height - optional uint32 stride_w = 8; // The stride width - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 11 [default = DEFAULT]; - // If global_pooling then it will pool over the size of the bottom by doing - // kernel_h = bottom->height and kernel_w = bottom->width - optional bool global_pooling = 12 [default = false]; - optional bool ceil_mode = 13 [default = true]; - // How to calculate the output size - using ceil (default) or floor rounding. - enum RoundMode { - CEIL = 0; - FLOOR = 1; - } - optional RoundMode round_mode = 14 [default = CEIL]; -} - -message PowerParameter { - // PowerLayer computes outputs y = (shift + scale * x) ^ power. - optional float power = 1 [default = 1.0]; - optional float scale = 2 [default = 1.0]; - optional float shift = 3 [default = 0.0]; -} - -message PythonParameter { - optional string module = 1; - optional string layer = 2; - // This value is set to the attribute `param_str` of the `PythonLayer` object - // in Python before calling the `setup()` method. This could be a number, - // string, dictionary in Python dict format, JSON, etc. You may parse this - // string in `setup` method and use it in `forward` and `backward`. - optional string param_str = 3 [default = '']; - // Whether this PythonLayer is shared among worker solvers during data parallelism. - // If true, each worker solver sequentially run forward from this layer. - // This value should be set true if you are using it as a data layer. - optional bool share_in_parallel = 4 [default = false]; -} - -// Message that stores parameters used by RecurrentLayer -message RecurrentParameter { - // The dimension of the output (and usually hidden state) representation -- - // must be explicitly set to non-zero. - optional uint32 num_output = 1 [default = 0]; - - optional FillerParameter weight_filler = 2; // The filler for the weight - optional FillerParameter bias_filler = 3; // The filler for the bias - - // Whether to enable displaying debug_info in the unrolled recurrent net. - optional bool debug_info = 4 [default = false]; - - // Whether to add as additional inputs (bottoms) the initial hidden state - // blobs, and add as additional outputs (tops) the final timestep hidden state - // blobs. The number of additional bottom/top blobs required depends on the - // recurrent architecture -- e.g., 1 for RNNs, 2 for LSTMs. - optional bool expose_hidden = 5 [default = false]; -} - -// Message that stores parameters used by ReductionLayer -message ReductionParameter { - enum ReductionOp { - SUM = 1; - ASUM = 2; - SUMSQ = 3; - MEAN = 4; - } - - optional ReductionOp operation = 1 [default = SUM]; // reduction operation - - // The first axis to reduce to a scalar -- may be negative to index from the - // end (e.g., -1 for the last axis). - // (Currently, only reduction along ALL "tail" axes is supported; reduction - // of axis M through N, where N < num_axes - 1, is unsupported.) - // Suppose we have an n-axis bottom Blob with shape: - // (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)). - // If axis == m, the output Blob will have shape - // (d0, d1, d2, ..., d(m-1)), - // and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1)) - // times, each including (dm * d(m+1) * ... * d(n-1)) individual data. - // If axis == 0 (the default), the output Blob always has the empty shape - // (count 1), performing reduction across the entire input -- - // often useful for creating new loss functions. - optional int32 axis = 2 [default = 0]; - - optional float coeff = 3 [default = 1.0]; // coefficient for output -} - -// Message that stores parameters used by ReLULayer -message ReLUParameter { - // Allow non-zero slope for negative inputs to speed up optimization - // Described in: - // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities - // improve neural network acoustic models. In ICML Workshop on Deep Learning - // for Audio, Speech, and Language Processing. - optional float negative_slope = 1 [default = 0]; - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 2 [default = DEFAULT]; -} - -message ReshapeParameter { - // Specify the output dimensions. If some of the dimensions are set to 0, - // the corresponding dimension from the bottom layer is used (unchanged). - // Exactly one dimension may be set to -1, in which case its value is - // inferred from the count of the bottom blob and the remaining dimensions. - // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8: - // - // layer { - // type: "Reshape" bottom: "input" top: "output" - // reshape_param { ... } - // } - // - // If "input" is 2D with shape 2 x 8, then the following reshape_param - // specifications are all equivalent, producing a 3D blob "output" with shape - // 2 x 2 x 4: - // - // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } - // reshape_param { shape { dim: 0 dim: 2 dim: 4 } } - // reshape_param { shape { dim: 0 dim: 2 dim: -1 } } - // reshape_param { shape { dim: 0 dim:-1 dim: 4 } } - // - optional BlobShape shape = 1; - - // axis and num_axes control the portion of the bottom blob's shape that are - // replaced by (included in) the reshape. By default (axis == 0 and - // num_axes == -1), the entire bottom blob shape is included in the reshape, - // and hence the shape field must specify the entire output shape. - // - // axis may be non-zero to retain some portion of the beginning of the input - // shape (and may be negative to index from the end; e.g., -1 to begin the - // reshape after the last axis, including nothing in the reshape, - // -2 to include only the last axis, etc.). - // - // For example, suppose "input" is a 2D blob with shape 2 x 8. - // Then the following ReshapeLayer specifications are all equivalent, - // producing a blob "output" with shape 2 x 2 x 4: - // - // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } - // reshape_param { shape { dim: 2 dim: 4 } axis: 1 } - // reshape_param { shape { dim: 2 dim: 4 } axis: -3 } - // - // num_axes specifies the extent of the reshape. - // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on - // input axes in the range [axis, axis+num_axes]. - // num_axes may also be -1, the default, to include all remaining axes - // (starting from axis). - // - // For example, suppose "input" is a 2D blob with shape 2 x 8. - // Then the following ReshapeLayer specifications are equivalent, - // producing a blob "output" with shape 1 x 2 x 8. - // - // reshape_param { shape { dim: 1 dim: 2 dim: 8 } } - // reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 } - // reshape_param { shape { dim: 1 } num_axes: 0 } - // - // On the other hand, these would produce output blob shape 2 x 1 x 8: - // - // reshape_param { shape { dim: 2 dim: 1 dim: 8 } } - // reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 } - // - optional int32 axis = 2 [default = 0]; - optional int32 num_axes = 3 [default = -1]; -} - - -message ScaleParameter { - // The first axis of bottom[0] (the first input Blob) along which to apply - // bottom[1] (the second input Blob). May be negative to index from the end - // (e.g., -1 for the last axis). - // - // For example, if bottom[0] is 4D with shape 100x3x40x60, the output - // top[0] will have the same shape, and bottom[1] may have any of the - // following shapes (for the given value of axis): - // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 - // (axis == 1 == -3) 3; 3x40; 3x40x60 - // (axis == 2 == -2) 40; 40x60 - // (axis == 3 == -1) 60 - // Furthermore, bottom[1] may have the empty shape (regardless of the value of - // "axis") -- a scalar multiplier. - optional int32 axis = 1 [default = 1]; - - // (num_axes is ignored unless just one bottom is given and the scale is - // a learned parameter of the layer. Otherwise, num_axes is determined by the - // number of axes by the second bottom.) - // The number of axes of the input (bottom[0]) covered by the scale - // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. - // Set num_axes := 0, to multiply with a zero-axis Blob: a scalar. - optional int32 num_axes = 2 [default = 1]; - - // (filler is ignored unless just one bottom is given and the scale is - // a learned parameter of the layer.) - // The initialization for the learned scale parameter. - // Default is the unit (1) initialization, resulting in the ScaleLayer - // initially performing the identity operation. - optional FillerParameter filler = 3; - - // Whether to also learn a bias (equivalent to a ScaleLayer+BiasLayer, but - // may be more efficient). Initialized with bias_filler (defaults to 0). - optional bool bias_term = 4 [default = false]; - optional FillerParameter bias_filler = 5; - optional bool scale_from_blob = 6 [default = true]; -} - -message SigmoidParameter { - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 1 [default = DEFAULT]; -} - -message SliceParameter { - // The axis along which to slice -- may be negative to index from the end - // (e.g., -1 for the last axis). - // By default, SliceLayer concatenates blobs along the "channels" axis (1). - optional int32 axis = 3 [default = 1]; - repeated uint32 slice_point = 2; - - // DEPRECATED: alias for "axis" -- does not support negative indexing. - optional uint32 slice_dim = 1 [default = 1]; -} - -message SmoothL1LossParameter { - // SmoothL1Loss(x) = - // 0.5 * (sigma * x) ** 2 -- if x < 1.0 / sigma / sigma - // |x| - 0.5 / sigma / sigma -- otherwise - optional float sigma = 1 [default = 1]; -} - -// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer -message SoftmaxParameter { - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 1 [default = DEFAULT]; - - // The axis along which to perform the softmax -- may be negative to index - // from the end (e.g., -1 for the last axis). - // Any other axes will be evaluated as independent softmaxes. - optional int32 axis = 2 [default = 1]; -} - -message TanHParameter { - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 1 [default = DEFAULT]; -} - -// Message that stores parameters used by TileLayer -message TileParameter { - // The index of the axis to tile. - optional int32 axis = 1 [default = 1]; - - // The number of copies (tiles) of the blob to output. - optional int32 tiles = 2; -} - -// Message that stores parameters used by ThresholdLayer -message ThresholdParameter { - optional float threshold = 1 [default = 0]; // Strictly positive values -} - -message WindowDataParameter { - // Specify the data source. - optional string source = 1; - // For data pre-processing, we can do simple scaling and subtracting the - // data mean, if provided. Note that the mean subtraction is always carried - // out before scaling. - optional float scale = 2 [default = 1]; - optional string mean_file = 3; - // Specify the batch size. - optional uint32 batch_size = 4; - // Specify if we would like to randomly crop an image. - optional uint32 crop_size = 5 [default = 0]; - // Specify if we want to randomly mirror data. - optional bool mirror = 6 [default = false]; - // Foreground (object) overlap threshold - optional float fg_threshold = 7 [default = 0.5]; - // Background (non-object) overlap threshold - optional float bg_threshold = 8 [default = 0.5]; - // Fraction of batch that should be foreground objects - optional float fg_fraction = 9 [default = 0.25]; - // Amount of contextual padding to add around a window - // (used only by the window_data_layer) - optional uint32 context_pad = 10 [default = 0]; - // Mode for cropping out a detection window - // warp: cropped window is warped to a fixed size and aspect ratio - // square: the tightest square around the window is cropped - optional string crop_mode = 11 [default = "warp"]; - // cache_images: will load all images in memory for faster access - optional bool cache_images = 12 [default = false]; - // append root_folder to locate images - optional string root_folder = 13 [default = ""]; -} - -message SPPParameter { - enum PoolMethod { - MAX = 0; - AVE = 1; - STOCHASTIC = 2; - } - optional uint32 pyramid_height = 1; - optional PoolMethod pool = 2 [default = MAX]; // The pooling method - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 6 [default = DEFAULT]; -} - -// DEPRECATED: use LayerParameter. -message V1LayerParameter { - repeated string bottom = 2; - repeated string top = 3; - optional string name = 4; - repeated NetStateRule include = 32; - repeated NetStateRule exclude = 33; - enum LayerType { - NONE = 0; - ABSVAL = 35; - ACCURACY = 1; - ARGMAX = 30; - BNLL = 2; - CONCAT = 3; - CONTRASTIVE_LOSS = 37; - CONVOLUTION = 4; - DATA = 5; - DECONVOLUTION = 39; - DROPOUT = 6; - DUMMY_DATA = 32; - EUCLIDEAN_LOSS = 7; - ELTWISE = 25; - EXP = 38; - FLATTEN = 8; - HDF5_DATA = 9; - HDF5_OUTPUT = 10; - HINGE_LOSS = 28; - IM2COL = 11; - IMAGE_DATA = 12; - INFOGAIN_LOSS = 13; - INNER_PRODUCT = 14; - LRN = 15; - MEMORY_DATA = 29; - MULTINOMIAL_LOGISTIC_LOSS = 16; - MVN = 34; - POOLING = 17; - POWER = 26; - RELU = 18; - SIGMOID = 19; - SIGMOID_CROSS_ENTROPY_LOSS = 27; - SILENCE = 36; - SOFTMAX = 20; - SOFTMAX_LOSS = 21; - SPLIT = 22; - SLICE = 33; - TANH = 23; - WINDOW_DATA = 24; - THRESHOLD = 31; - QUANT = 208; - DEQUANT = 209; - } - optional LayerType type = 5; - repeated BlobProto blobs = 6; - repeated string param = 1001; - repeated DimCheckMode blob_share_mode = 1002; - enum DimCheckMode { - STRICT = 0; - PERMISSIVE = 1; - } - repeated float blobs_lr = 7; - repeated float weight_decay = 8; - repeated float loss_weight = 35; - optional AccuracyParameter accuracy_param = 27; - optional ArgMaxParameter argmax_param = 23; - optional ConcatParameter concat_param = 9; - optional ContrastiveLossParameter contrastive_loss_param = 40; - optional ConvolutionParameter convolution_param = 10; - optional DataParameter data_param = 11; - optional DropoutParameter dropout_param = 12; - optional DummyDataParameter dummy_data_param = 26; - optional EltwiseParameter eltwise_param = 24; - optional ExpParameter exp_param = 41; - optional HDF5DataParameter hdf5_data_param = 13; - optional HDF5OutputParameter hdf5_output_param = 14; - optional HingeLossParameter hinge_loss_param = 29; - optional ImageDataParameter image_data_param = 15; - optional InfogainLossParameter infogain_loss_param = 16; - optional InnerProductParameter inner_product_param = 17; - optional LRNParameter lrn_param = 18; - optional MemoryDataParameter memory_data_param = 22; - optional MVNParameter mvn_param = 34; - optional PoolingParameter pooling_param = 19; - optional PowerParameter power_param = 21; - optional ReLUParameter relu_param = 30; - optional SigmoidParameter sigmoid_param = 38; - optional SoftmaxParameter softmax_param = 39; - optional SliceParameter slice_param = 31; - optional TanHParameter tanh_param = 37; - optional ThresholdParameter threshold_param = 25; - optional WindowDataParameter window_data_param = 20; - optional TransformationParameter transform_param = 36; - optional LossParameter loss_param = 42; - optional V0LayerParameter layer = 1; -} - -// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters -// in Caffe. We keep this message type around for legacy support. -message V0LayerParameter { - optional string name = 1; // the layer name - optional string type = 2; // the string to specify the layer type - - // Parameters to specify layers with inner products. - optional uint32 num_output = 3; // The number of outputs for the layer - optional bool biasterm = 4 [default = true]; // whether to have bias terms - optional FillerParameter weight_filler = 5; // The filler for the weight - optional FillerParameter bias_filler = 6; // The filler for the bias - - optional uint32 pad = 7 [default = 0]; // The padding size - optional uint32 kernelsize = 8; // The kernel size - optional uint32 group = 9 [default = 1]; // The group size for group conv - optional uint32 stride = 10 [default = 1]; // The stride - enum PoolMethod { - MAX = 0; - AVE = 1; - STOCHASTIC = 2; - } - optional PoolMethod pool = 11 [default = MAX]; // The pooling method - optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio - - optional uint32 local_size = 13 [default = 5]; // for local response norm - optional float alpha = 14 [default = 1.]; // for local response norm - optional float beta = 15 [default = 0.75]; // for local response norm - optional float k = 22 [default = 1.]; - - // For data layers, specify the data source - optional string source = 16; - // For data pre-processing, we can do simple scaling and subtracting the - // data mean, if provided. Note that the mean subtraction is always carried - // out before scaling. - optional float scale = 17 [default = 1]; - optional string meanfile = 18; - // For data layers, specify the batch size. - optional uint32 batchsize = 19; - // For data layers, specify if we would like to randomly crop an image. - optional uint32 cropsize = 20 [default = 0]; - // For data layers, specify if we want to randomly mirror data. - optional bool mirror = 21 [default = false]; - - // The blobs containing the numeric parameters of the layer - repeated BlobProto blobs = 50; - // The ratio that is multiplied on the global learning rate. If you want to - // set the learning ratio for one blob, you need to set it for all blobs. - repeated float blobs_lr = 51; - // The weight decay that is multiplied on the global weight decay. - repeated float weight_decay = 52; - - // The rand_skip variable is for the data layer to skip a few data points - // to avoid all asynchronous sgd clients to start at the same point. The skip - // point would be set as rand_skip * rand(0,1). Note that rand_skip should not - // be larger than the number of keys in the database. - optional uint32 rand_skip = 53 [default = 0]; - - // Fields related to detection (det_*) - // foreground (object) overlap threshold - optional float det_fg_threshold = 54 [default = 0.5]; - // background (non-object) overlap threshold - optional float det_bg_threshold = 55 [default = 0.5]; - // Fraction of batch that should be foreground objects - optional float det_fg_fraction = 56 [default = 0.25]; - - // optional bool OBSOLETE_can_clobber = 57 [default = true]; - - // Amount of contextual padding to add around a window - // (used only by the window_data_layer) - optional uint32 det_context_pad = 58 [default = 0]; - - // Mode for cropping out a detection window - // warp: cropped window is warped to a fixed size and aspect ratio - // square: the tightest square around the window is cropped - optional string det_crop_mode = 59 [default = "warp"]; - - // For ReshapeLayer, one needs to specify the new dimensions. - optional int32 new_num = 60 [default = 0]; - optional int32 new_channels = 61 [default = 0]; - optional int32 new_height = 62 [default = 0]; - optional int32 new_width = 63 [default = 0]; - - // Whether or not ImageLayer should shuffle the list of files at every epoch. - // It will also resize images if new_height or new_width are not zero. - optional bool shuffle_images = 64 [default = false]; - - // For ConcatLayer, one needs to specify the dimension for concatenation, and - // the other dimensions must be the same for all the bottom blobs. - // By default it will concatenate blobs along the channels dimension. - optional uint32 concat_dim = 65 [default = 1]; - - optional HDF5OutputParameter hdf5_output_param = 1001; -} - -message PReLUParameter { - // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers: - // Surpassing Human-Level Performance on ImageNet Classification, 2015. - - // Initial value of a_i. Default is a_i=0.25 for all i. - optional FillerParameter filler = 1; - // Whether or not slope parameters are shared across channels. - optional bool channel_shared = 2 [default = false]; -} - -// Message that stores parameters used by DetectionOutputLayer -//message DetectionOutputParameter { -// optional int32 num_classes = 1 [default = 21]; -// optional float nms_threshold = 2 [default = 0.3]; -// optional int32 top_k = 3; -// optional float confidence_threshold = 4 [default = 0.8]; -//} - -// Message that store parameters used by PriorBoxLayer -message PriorBoxParameter { - // Encode/decode type. - enum CodeType { - CORNER = 1; - CENTER_SIZE = 2; - CORNER_SIZE = 3; - } - // Minimum box size (in pixels). Required! - repeated float min_size = 1; - // Maximum box size (in pixels). Required! - repeated float max_size = 2; - // Various of aspect ratios. Duplicate ratios will be ignored. - // If none is provided, we use default ratio 1. - repeated float aspect_ratio = 3; - // If true, will flip each aspect ratio. - // For example, if there is aspect ratio "r", - // we will generate aspect ratio "1.0/r" as well. - optional bool flip = 4 [default = true]; - // If true, will clip the prior so that it is within [0, 1] - optional bool clip = 5 [default = false]; - // Variance for adjusting the prior bboxes. - repeated float variance = 6; - // By default, we calculate img_height, img_width, step_x, step_y based on - // bottom[0] (feat) and bottom[1] (img). Unless these values are explicitely - // provided. - // Explicitly provide the img_size. - optional uint32 img_size = 7; - // Either img_size or img_h/img_w should be specified; not both. - optional uint32 img_h = 8; - optional uint32 img_w = 9; - - // Explicitly provide the step size. - optional float step = 10; - // Either step or step_h/step_w should be specified; not both. - optional float step_h = 11; - optional float step_w = 12; - - // Offset to the top left corner of each cell. - optional float offset = 13 [default = 0.5]; -} - -// Message that stores parameters used by PermutetLayer -message PermuteParameter { - // The new orders of the axes of data. Notice it should be with - // in the same range as the input data, and it starts from 0. - // Do not provide repeated order. - repeated uint32 order = 1; -} - -message NormalizeParameter { - optional bool across_spatial = 1 [default = true]; - // Initial value of scale. Default is 1.0 for all - optional FillerParameter scale_filler = 2; - // Whether or not scale parameters are shared across channels. - optional bool channel_shared = 3 [default = true]; - // Epsilon for not dividing by zero while normalizing variance - optional float eps = 4 [default = 1e-10]; -} - -// needed by ssd -message SaveOutputParameter { - // Output directory. If not empty, we will save the results. - optional string output_directory = 1; - // Output name prefix. - optional string output_name_prefix = 2; - // Output format. - // VOC - PASCAL VOC output format. - // COCO - MS COCO output format. - optional string output_format = 3; - // If you want to output results, must also provide the following two files. - // Otherwise, we will ignore saving results. - // label map file. - optional string label_map_file = 4; - // A file which contains a list of names and sizes with same order - // of the input DB. The file is in the following format: - // name height width - // ... - optional string name_size_file = 5; - // Number of test images. It can be less than the lines specified in - // name_size_file. For example, when we only want to evaluate on part - // of the test images. - optional uint32 num_test_image = 6; - // The resize parameter used in saving the data. - // optional ResizeParameter resize_param = 7; -} - -message NonMaximumSuppressionParameter { - // Threshold to be used in nms. - optional float nms_threshold = 1 [default = 0.3]; - // Maximum number of results to be kept. - optional int32 top_k = 2; - // Parameter for adaptive nms. - optional float eta = 3 [default = 1.0]; -} - -message GeneralNmsParameter { - optional int32 post_top_k = 1 ; - optional float nms_threshold = 2 [default = 0]; - optional float iou_threshold_decay = 3 [default = 1.0]; - optional float coor_scale_factor = 4 [default = 1.0]; -} - -// Message that store parameters used by DetectionOutputLayer, ssd/fasterRcnn -message DetectionOutputParameter { - optional int32 num_classes = 1; - optional bool share_location = 2 [default = true]; - optional int32 background_label_id = 3 [default = 0]; - optional NonMaximumSuppressionParameter nms_param = 4; - optional SaveOutputParameter save_output_param = 5; - optional PriorBoxParameter.CodeType code_type = 6 [default = CENTER_SIZE]; - optional bool variance_encoded_in_target = 8 [default = true]; - optional int32 keep_top_k = 7; - optional float confidence_threshold = 9; - optional float nms_threshold = 13; - optional int32 top_k = 14; - optional int32 boxes = 15 [default = 1]; - optional bool relative = 17 [default = true]; - optional float objectness_threshold = 18 [default = 0.5]; - optional float class_threshold = 19 [default = 0.5]; - repeated float biases = 20; - optional GeneralNmsParameter general_nms_param = 21; - optional float objectness_score = 22; -} -message PSROIPoolingParameter { - required float spatial_scale = 1; - required int32 output_dim = 2; // output channel number - required int32 group_size = 3; // number of groups to encode position-sensitive score maps -} -// Message that stores parameters used by FreespaceExtractLayer -message FreespaceExtractParameter { - optional float org_height = 1; -} - -// Message that stores parameters used by DetectpostprocessLayer -message PostprocessParameter { - optional float nms_thresh = 1 [default = 0.3]; - optional float conf_thresh = 2 [default = 0.5]; - optional uint32 post_nms_topn = 3 [default = 100]; - optional uint32 cls_num = 4 [default = 12]; - repeated float bbox_reg_weights = 5; -} - -// Message that stores parameters used by SpatialTransformLayer -message SpatialTransformParameter { - optional uint32 output_h = 1 [default = 0]; - optional uint32 output_w = 2 [default = 0]; - optional float border_value = 3 [default = 0]; - repeated float affine_transform = 4; - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 15 [default = DEFAULT]; -} -message ROIAlignParameter { - // Pad, kernel size, and stride are all given as a single value for equal - // dimensions in height and width or as Y, X pairs. - optional uint32 pooled_h = 1 [default = 0]; // The pooled output height - optional uint32 pooled_w = 2 [default = 0]; // The pooled output width - // Multiplicative spatial scale factor to translate ROI coords from their - // input scale to the scale used when pooling - optional float spatial_scale = 3 [default = 1]; - optional int32 sampling_ratio = 4 [default = -1]; - optional int32 roi_end_mode = 5 [default = 0]; -} - -message RegionParameter { - optional uint32 classes = 1 [default = 20]; // Category of classification - optional uint32 coords = 2 [default = 4]; // Coordinates of box - optional uint32 boxes = 3 [default = 1]; // Number of boxes predicted per grid - optional uint32 softmax = 4 [default = 0]; - optional string softmax_tree = 5 [default = ""]; - optional uint32 background = 6 [default = 0]; -} -message ReorgParameter{ - optional uint32 stride = 2 [default = 2]; - optional bool reverse = 1 [default = false]; -} -message ReverseParameter{ - repeated int32 axis = 1; -} -message InterpParameter{ - optional int32 height = 1 [default = 0];//Height of output - optional int32 width = 2 [default = 0];//Width of output - optional int32 zoom_factor = 3 [default = 1];//zoom factor - optional int32 shrink_factor = 4 [default = 1];//shrink factor - optional int32 pad_beg = 5 [default = 0];//padding at begin of input - optional int32 pad_end = 6 [default = 0];//padding at end of input -} -message ShuffleChannelParameter{ - optional uint32 group = 1[default = 1]; // The number of group -} -message UpsampleParameter{ - optional float scale = 1[default = 1]; - optional int32 stride = 2[default = 2]; - optional int32 stride_h = 3[default = 2]; - optional int32 stride_w = 4[default=2]; -} -message ROIPoolingParameter { - required int32 pooled_h = 1; - required int32 pooled_w = 2; - optional float spatial_scale = 3 [default=0.0625]; - optional float spatial_scale_h = 4; - optional float spatial_scale_w = 5; -} - -message YoloParameter { - optional int32 boxes = 1 [default = 3]; - optional int32 coords = 2 [default = 4]; - optional int32 classes = 3 [default = 80]; - optional string yolo_version = 4 [default = "V3"]; - optional bool softmax = 5 [default = false]; - optional bool background = 6 [default = false]; - optional bool softmaxtree = 7 [default = false]; -} - -message YoloV3DetectionOutputParameter { - optional int32 boxes = 1 [default = 3]; - optional int32 classes = 2 [default = 80]; - optional bool relative = 3 [default = true]; - optional float obj_threshold = 4 [default = 0.5]; - optional float score_threshold = 5 [default = 0.5]; - optional float iou_threshold = 6 [default = 0.45]; - optional int32 pre_nms_topn = 7 [default = 512]; - optional int32 post_nms_topn = 8 [default = 1024]; - repeated float biases_high = 9; - repeated float biases_mid = 10; - repeated float biases_low = 11; - optional int32 coords = 12 [default = 4]; - repeated float biases = 13; - optional bool resize_origin_img_to_net = 14 [default = false]; -} - -message YoloV3DetectionOutputV2Parameter { - optional int32 boxes = 1 [default = 3]; - optional int32 classes = 2 [default = 80]; - optional bool relative = 3 [default = true]; - optional float obj_threshold = 4 [default = 0.5]; - optional float score_threshold = 5 [default = 0.5]; - optional float iou_threshold = 6 [default = 0.45]; - optional int32 pre_nms_topn = 7 [default = 512]; - optional int32 post_nms_topn = 8 [default = 1024]; - repeated float biases_high = 9; - repeated float biases_mid = 10; - repeated float biases_low = 11; - optional int32 coords = 12 [default = 4]; - repeated float biases = 13; - optional bool resize_origin_img_to_net = 14 [default = false]; - optional int32 out_box_dim = 15 [default = 3]; -} - -message ProposalParameter { - optional float feat_stride = 1 [default = 16]; - optional float base_size = 2 [default = 16]; - optional float min_size = 3 [default = 16]; - repeated float ratio = 4; - repeated float scale = 5; - optional int32 pre_nms_topn = 6 [default = 3000]; - optional int32 post_nms_topn = 7 [default = 304]; - optional float iou_threshold = 8 [default = 0.7]; - optional bool output_actual_rois_num = 9 [default = false]; -} - -message FSRDetectionOutputParameter { - required int32 num_classes = 1; - required float score_threshold = 2; - required float iou_threshold = 3; - optional int32 batch_rois = 4 [default = 1]; -} - -message SSDDetectionOutputParameter { - required int32 num_classes= 1 [default = 2]; - optional bool share_location = 2 [default = true]; - optional int32 background_label_id = 3 [default = 0]; - optional float iou_threshold = 4 [default = 0.3]; - optional int32 top_k = 5 [default = 200]; - optional float eta = 6 [default = 1.0]; - optional bool variance_encoded_in_target = 7 [default = false]; - optional int32 code_type = 8 [default = 1]; - optional int32 keep_top_k = 9 [default = -1]; - optional float confidence_threshold = 10 [default = 0.0]; -} -message YoloV2DetectionOutputParameter { - optional int32 boxes = 1 [default = 5]; - optional int32 classes = 2 [default = 80]; - optional bool relative = 3 [default = true]; - optional float obj_threshold = 4 [default = 0.5]; - optional float score_threshold = 5 [default = 0.5]; - optional float iou_threshold = 6 [default = 0.45]; - optional int32 pre_nms_topn = 7 [default = 512]; - optional int32 post_nms_topn = 8 [default = 1024]; - repeated float biases = 9; - optional int32 coords = 10 [default = 4]; - optional bool resize_origin_img_to_net = 11 [default = false]; -} - -message QuantParameter { - optional float scale = 2; - optional bytes offset = 3; -} - -message BatchMatMulParameter{ - optional bool adj_x1 = 1 [default = false]; - optional bool adj_x2 = 2 [default = false]; -} - -message CondTakeParameter { - required string mode = 1; - required float val = 2; - optional float eps = 3 [default = 1e-06]; -} - -message MatrixInverseParameter { - optional bool adjoint = 1 [default = false]; -} - -message WarpPerspectiveParameter { - required int32 out_height = 1; - required int32 out_width = 2; - optional float constant = 3; - optional string border_type = 4 [default = 'BORDER_CONSTANT']; -} - -message SpatialTransformerParameter { - // How to use the parameter passed by localisation network - optional string transform_type = 1 [default = "affine"]; - // What is the sampling technique - optional string sampler_type = 2 [default = "bilinear"]; - - // If not set,stay same with the input dimension H and W - optional int32 output_H = 3; - optional int32 output_W = 4; - // If false, only compute dTheta, DO NOT compute dU - optional bool to_compute_dU = 5 [default = true]; - - // The default value for some parameters - optional double theta_1_1 = 6; - optional double theta_1_2 = 7; - optional double theta_1_3 = 8; - optional double theta_2_1 = 9; - optional double theta_2_2 = 10; - optional double theta_2_3 = 11; -} diff --git a/ge/proto/dump_task.proto b/ge/proto/dump_task.proto deleted file mode 100644 index ecdf4792..00000000 --- a/ge/proto/dump_task.proto +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; -package toolkit.dumpdata; - -enum OutputDataType { - DT_UNDEFINED = 0; - DT_FLOAT = 1; - DT_FLOAT16 = 2; - DT_INT8 = 3; - DT_UINT8 = 4; - DT_INT16 = 5; - DT_UINT16 = 6; - DT_INT32 = 7; - DT_INT64 = 8; - DT_UINT32 = 9; - DT_UINT64 = 10; - DT_BOOL = 11; - DT_DOUBLE = 12; - DT_STRING = 13; - DT_DUAL_SUB_INT8 = 14; - DT_DUAL_SUB_UINT8 = 15; - DT_COMPLEX64 = 16; - DT_COMPLEX128 = 17; - DT_QINT8 = 18; - DT_QINT16 = 19; - DT_QINT32 = 20; - DT_QUINT8 = 21; - DT_QUINT16 = 22; - DT_RESOURCE = 23; - DT_STRING_REF = 24; - DT_DUAL = 25; -} - -enum OutputFormat { - FORMAT_NCHW = 0; - FORMAT_NHWC = 1; - FORMAT_ND = 2; - FORMAT_NC1HWC0 = 3; - FORMAT_FRACTAL_Z = 4; - FORMAT_NC1C0HWPAD = 5; - FORMAT_NHWC1C0 = 6; - FORMAT_FSR_NCHW = 7; - FORMAT_FRACTAL_DECONV = 8; - FORMAT_C1HWNC0 = 9; - FORMAT_FRACTAL_DECONV_TRANSPOSE = 10; - FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11; - FORMAT_NC1HWC0_C04 = 12; - FORMAT_FRACTAL_Z_C04 = 13; - FORMAT_CHWN = 14; - FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15; - FORMAT_HWCN = 16; - FORMAT_NC1KHKWHWC0 = 17; - FORMAT_BN_WEIGHT = 18; - FORMAT_FILTER_HWCK = 19; - FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20; - FORMAT_HASHTABLE_LOOKUP_KEYS = 21; - FORMAT_HASHTABLE_LOOKUP_VALUE = 22; - FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23; - FORMAT_HASHTABLE_LOOKUP_HITS=24; - FORMAT_C1HWNCoC0 = 25; - FORMAT_MD = 26; - FORMAT_NDHWC = 27; - FORMAT_FRACTAL_ZZ = 28; - FORMAT_FRACTAL_NZ = 29; - FORMAT_RESERVED = 30; -} - -message OriginalOp { - string name = 1; - uint32 output_index = 2; - OutputDataType data_type = 3; - OutputFormat format = 4; -} - -message Shape { - repeated uint64 dim = 1; -} - -message OpOutput { - OutputDataType data_type = 1; - OutputFormat format = 2; - Shape shape = 3; - OriginalOp original_op = 4; // the original op corresponding to the output - bytes data = 5; - uint64 size = 6; -} - -message OpInput { - OutputDataType data_type = 1; - OutputFormat format = 2; - Shape shape = 3; - bytes data = 4; - uint64 size = 5; -} - -enum BufferType { - L1 = 0; -} - -message OpBuffer { - BufferType buffer_type = 1; - bytes data = 2; - uint64 size = 3; -} - -message DumpData{ - string version = 1; - uint64 dump_time = 2; - repeated OpOutput output = 3; - repeated OpInput input = 4; - repeated OpBuffer buffer = 5; -} diff --git a/ge/proto/ge_api.proto b/ge/proto/ge_api.proto deleted file mode 100755 index ac5b3b3a..00000000 --- a/ge/proto/ge_api.proto +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; -package ge.api_pb; - -import "ge_ir.proto"; - -// GE initialize -message GEInitialize { - map options = 1; -}; - -// initialize response -message GEInitializeResponse { - uint32 status = 1; - uint32 clientId = 2; -}; - -// GE finalize -message GEFinalize { - bool final = 1; - uint32 clientId = 2; -}; - -message GEFinalizeResponse { - uint32 status = 1; -}; - -// GE Session -message CreateSession{ - map options = 1; -}; - -message CreateSessionResponse { - uint32 status = 1; - uint64 sessionId = 2; -}; - -//GE AddGraph -//model serialize :: serializegraph -message SessionAddGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; - ge.proto.GraphDef graph = 3; -}; - -message SessionAddGraphResponse { - uint32 status = 1; -}; - -//GE SessionRemoveGraph -message SessionRemoveGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; -}; - -message SessionRemoveGraphResponse { - uint32 status = 1; -}; - -message SessionRunGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; - repeated ge.proto.TensorDef tensor = 3; -}; - -message SessionBuildGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; - repeated ge.proto.TensorDef tensor = 3; - string savePath = 4; -}; - -message SessionRunGraphResponse { - uint32 status = 1; - repeated ge.proto.TensorDef tensor = 2; -}; - -message SessionBuildGraphResponse { - uint32 status = 1; -}; - -message DestroySession{ - bool final = 1; - uint64 sessionId = 2; -}; - -message DestroySessionResponse { - uint32 status = 1; -}; diff --git a/ge/proto/ge_ir.proto b/ge/proto/ge_ir.proto deleted file mode 100644 index 87886c84..00000000 --- a/ge/proto/ge_ir.proto +++ /dev/null @@ -1,206 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package ge.proto; - -enum DataType -{ - DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. - DT_FLOAT = 1; // float type - DT_FLOAT16 = 2; // fp16 type - DT_INT8 = 3; // int8 type - DT_UINT8 = 4; // uint8 type - DT_INT16 = 5; // int16 type - DT_UINT16 = 6; // uint16 type - DT_INT32 = 7; // - DT_INT64 = 8; // int64 type - DT_UINT32 = 9; // unsigned int32 - DT_UINT64 = 10; // unsigned int64 - DT_BOOL = 11; // bool type - DT_DOUBLE = 12; // double type - DT_STRING = 13; // string type - DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ - DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ - DT_COMPLEX64 = 16; // complex64 type - DT_COMPLEX128 = 17; // complex128 type - DT_QINT8 = 18; // qint8 type - DT_QINT16 = 19; // qint16 type - DT_QINT32 = 20; // qint32 type - DT_QUINT8 = 21; // quint8 type - DT_QUINT16 = 22; // quint16 type - DT_RESOURCE = 23; // resource type - DT_STRING_REF = 24; // string_ref type - DT_DUAL = 25; /**< dual output type */ -} - -message AttrDef -{ - message ListValue - { - enum ListValueType{ - VT_LIST_NONE = 0; - VT_LIST_STRING = 1; - VT_LIST_INT = 2; - VT_LIST_FLOAT = 3; - VT_LIST_BOOL = 4; - VT_LIST_BYTES = 5; - VT_LIST_TENSOR_DESC = 6; - VT_LIST_TENSOR = 7; - VT_LIST_GRAPH = 8; - VT_LIST_NAMED_ATTRS = 9; - VT_LIST_DATA_TYPE = 10; - } - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3; // "list(int)" - repeated float f = 4; // "list(float)" - repeated bool b = 5; // "list(bool)" - repeated bytes bt = 7; - repeated TensorDescriptor td = 8; - repeated TensorDef t = 9; - repeated GraphDef g = 10; - repeated NamedAttrs na = 11; - repeated int64 dt = 12; // list ge::DataType - - ListValueType val_type = 20; - } - - message ListListInt{ - message ListInt{ - repeated int64 list_i = 1; // list int - } - repeated ListInt list_list_i = 1; // list list int - } - - oneof value - { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; // Used to support attr nesting - TensorDescriptor td = 11; // GeTensorDesc type - TensorDef t = 12; // GeTensor type - GraphDef g = 13; // Graph type - ListListInt list_list_int = 14; // List List Int type - int64 dt = 15; // ge::DataType - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs -{ - string name = 1; - map attr = 2; -} - -// Shape / dimension description, using row-major order -message ShapeDef -{ - repeated int64 dim = 1; // Size of each dimension -} - -// Multidimensional data description -message TensorDescriptor -{ - string name = 1; // Optional parameter, tensor name - - DataType dtype = 2; // tensor datatype - ShapeDef shape = 3; // Shape / dimension - string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" - - bool has_out_attr = 9; - int64 size = 10; - int64 weight_size = 11; - bool reuse_input = 12; - bool output_tensor = 13; - string device_type = 14; - bool input_tensor =15; - int64 real_dim_cnt = 16; - int64 reuse_input_index = 17; - int64 data_offset = 18; - int64 cmps_size = 19; - string cmps_tab = 20; - int64 cmps_tab_offset = 21; - - map attr = 5; // Set of extra parameter fields -} - -// GeTensor definition -message TensorDef -{ - TensorDescriptor desc = 1; // Tensor description - bytes data = 2; // Tensor data -} - - -// Operator description -message OpDef -{ - string name = 1; // name - string type = 2; // type - - repeated string input = 5; // input original op name + outgoing index. op_name:index - - map attr = 10; // Set of operator parameter fields - - bool has_out_attr = 20; - int64 id = 21; - int64 stream_id =22; - repeated string input_name = 23; - repeated string src_name = 24; - repeated int64 src_index = 25; - repeated string dst_name = 26; - repeated int64 dst_index = 27; - repeated int64 input_i = 28; - repeated int64 output_i = 29; - repeated int64 workspace = 30; - repeated int64 workspace_bytes = 31; - repeated bool is_input_const = 32; - repeated TensorDescriptor input_desc = 33; - repeated TensorDescriptor output_desc = 34; - repeated string subgraph_name = 35; -} - -// Graph definition -message GraphDef -{ - string name = 1; // name - - repeated string input = 4; // Graph input - repeated string output = 5; // Graph output - - repeated OpDef op = 6; // List of operators - - map attr = 11; // Extended field -} - -// model definition -message ModelDef -{ - string name = 1; // name - uint32 version = 2; // IR Proto verion - string custom_version = 3; // User model version number, passed in by user - - repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef - - map attr = 11; // Extended field -} - diff --git a/ge/proto/insert_op.proto b/ge/proto/insert_op.proto deleted file mode 100644 index a059e122..00000000 --- a/ge/proto/insert_op.proto +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package domi; - -message InsertNewOps { - repeated AippOpParams aipp_op = 1; - repeated MultiShapeOpParams multi_shape_op = 2; -} - -message AippOpParams { - enum InputFormat { - UNDEFINED = 0; - YUV420SP_U8 = 1; - XRGB8888_U8 = 2; - RGB888_U8 = 3; - YUV400_U8 = 4; - NC1HWC0DI_FP16 = 5; - NC1HWC0DI_S8 = 6; - ARGB8888_U8 = 7; - YUYV_U8 = 8; - YUV422SP_U8 = 9; - AYUV444_U8 = 10; - RAW10 = 11; - RAW12 = 12; - RAW16 = 13; - RAW24 = 14; - RGB16 = 15; - RGB20 = 16; - RGB24 = 17; - RGB8_IR = 18; - RGB16_IR = 19; - RGB24_IR = 20; - } - - enum AippMode { - undefined = 0; - static = 1; - dynamic = 2; - } - - // AIPPģʽ£¬Çø·Ö¾²Ì¬AIPPºÍ¶¯Ì¬AIPP - AippMode aipp_mode = 1; - - // related_input_rank²ÎÊýΪ±ØÌÀàÐÍΪÕûÐÍ£¬ÅäÖ÷¶Î§>=0, <=ÊäÈëDataËã×ӵĸöÊý£¬Ä¬ÈÏֵΪ0¡£ - // ±êʶ¶ÔÄ£Ð͵ĵڼ¸¸öÊäÈë×öAIPP´¦Àí£¬ÀýÈçÄ£ÐÍÓÐÁ½¸öÊäÈ룬ÐèÒª¶ÔµÚ2¸öÊäÈë×öAIPP£¬ÔòÅäÖÃrelated_input_rankΪ1¡£ - uint32 related_input_rank = 2; - - // input_edge_idx²ÎÊýΪ¿ÉÑ¡£¬ÀàÐÍΪÕûÐÍ£¬ÅäÖ÷¶Î§Îª>=0¡£ - // ÅäÖøòÎÊýµÄ×÷Óã¬ÔÚÓÚ¶ÔDataËã×Ó²»Í¬µÄÊä³ö×ö²»Í¬µÄAIPP´¦Àí£¬Èç¹û¸Ã²ÎÊýûÓÐÅäÖã¬Ä¬È϶Ôrelated_input_rankÖ¸¶¨µÄÄ£ÐÍÊäÈëµÄËùÓÐÊä³ö±ß×öAIPP¡£ - // ÅäÖÃÖµ <= DataËã×ÓÊä³ö±ßµÄ¸öÊý¡£ - repeated uint32 input_edge_idx = 3; - - // [Begin] ¶¯Ì¬AIPP²ÎÊý£¬ÅäÖþ²Ì¬AIPPʱÎÞЧ - uint32 max_src_image_size = 4; - - // ÊÇ·ñÖ§³ÖÐýת¡£Ä¬Èϲ»Ö§³Ö£¬¿ªÆôÖ§³ÖÐýתʱ£¬»áÓжîÍâµÄ¿Õ¼äºÍÐÔÄÜËðʧ - bool support_rotation = 5; - - // [End] ¶¯Ì¬AIPP²ÎÊý - - - // [Begin] ¾²Ì¬AIPP²ÎÊý£¬ÅäÖö¯Ì¬AIPPʱÎÞЧ - InputFormat input_format = 51; - bool csc_switch = 52; - float cpadding_value = 53; - bool rbuv_swap_switch = 54; - bool ax_swap_switch = 55; - bool single_line_mode = 56; - - int32 src_image_size_w = 57; - int32 src_image_size_h = 58; - - bool crop = 59; - int32 load_start_pos_w = 60; - int32 load_start_pos_h = 61; - int32 crop_size_w = 62; - int32 crop_size_h = 63; - - bool resize = 64; - int32 resize_output_w = 65; - int32 resize_output_h = 66; - - bool padding = 67; - int32 left_padding_size = 68; - int32 right_padding_size = 69; - int32 top_padding_size = 70; - int32 bottom_padding_size = 71; - - int32 mean_chn_0 = 10; - int32 mean_chn_1 = 11; - int32 mean_chn_2 = 12; - int32 mean_chn_3 = 19; - float min_chn_0 = 13; - float min_chn_1 = 14; - float min_chn_2 = 15; - float min_chn_3 = 20; - repeated float var_reci_chn_0 = 16; - repeated float var_reci_chn_1 = 17; - repeated float var_reci_chn_2 = 18; - repeated float var_reci_chn_3 = 21; - - repeated int32 matrix_r0c0 = 30; - repeated int32 matrix_r0c1 = 31; - repeated int32 matrix_r0c2 = 32; - repeated int32 matrix_r1c0 = 33; - repeated int32 matrix_r1c1 = 34; - repeated int32 matrix_r1c2 = 35; - repeated int32 matrix_r2c0 = 36; - repeated int32 matrix_r2c1 = 37; - repeated int32 matrix_r2c2 = 38; - repeated int32 output_bias_0 = 39; - repeated int32 output_bias_1 = 40; - repeated int32 output_bias_2 = 41; - repeated int32 input_bias_0 = 42; - repeated int32 input_bias_1 = 43; - repeated int32 input_bias_2 = 44; - - // [End] ¾²Ì¬AIPP²ÎÊý - - // The n number that is used for raw/rgbir data into f16 transformation. - // The transformation equation is x/(2^n). If set to 0, no transform is performed. - uint32 raw_rgbir_to_f16_n = 45; -} - -message MultiShapeOpParams { - enum MultiShapeMode { - batch = 0; //¶¯Ì¬batch - resolution = 1; //¶¯Ì¬·Ö±æÂÊ£¬À©Õ¹Óà - } - - MultiShapeMode mode = 1; //Ëã×Óģʽ - uint32 related_input_rank = 2; //ÐÂÔöËã×Ó²åÈëµ½ÄĸöÊäÈë - - - repeated uint32 batch_list = 11; //batch_listÖµ£¬batch_listµÄ¸öÊýÊÇ2µ½8Ö®¼ä -} diff --git a/ge/proto/om.proto b/ge/proto/om.proto deleted file mode 100644 index dd992191..00000000 --- a/ge/proto/om.proto +++ /dev/null @@ -1,401 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package domi; - -enum TargetType -{ - MINI = 0; - TINY = 1; - LITE = 2; -} - -// offline model -message ModelDef { - string name = 1; - uint32 version = 2; - - uint64 memory_size = 10; - uint32 stream_num = 11; - uint32 event_num = 12; - uint64 weight_size = 13; - uint32 label_num = 15; - repeated OpDef op = 20; - TargetType target_type = 23; - - map attr = 30; -}; - -// operator define -message OpDef { - string name = 1; - string type = 2; - - uint32 id = 3; - uint32 stream_id = 4; - - repeated string input_name = 5; - - repeated string src_name = 8; - repeated int32 src_index = 9; - repeated int64 input = 10; - repeated int64 output = 11; - repeated TensorDescriptor input_desc = 12; - repeated TensorDescriptor output_desc = 13; - repeated WeightDef weights = 14; - repeated string dst_name = 15; - repeated int32 dst_index = 16; - - repeated int64 workspace = 20; - repeated uint32 workspace_bytes = 21; - - repeated string weight_name = 22; - repeated bool is_input_const = 23; - - map attr = 30; - - QuantizeFactorParams quantize_factor = 31; - - oneof op_params { - // start at 100 here - SendOpParams sender_param = 100; - RecvOpParams receiver_param = 200; - ConvolutionOpParams convolution_param = 300; - PoolingOpParams pooling_param = 400; - EltwiseOpParams eltwise_param = 500; - BatchNormOpParams batchnorm_param = 600; - ScaleOpParams scale_param = 700; - FullConnectionOpParams full_connection_param = 800; - SoftmaxOpParams softmax_param = 900; - ActivationOpParams activation_param = 1000; - ReshapeOpParams reshape_param = 1100; - } -}; - -message SendOpParams { - uint32 event_id = 1; -}; - -message RecvOpParams { - uint32 event_id = 1; -}; - -enum QuantizeScaleType -{ - VECTOR_SCALE = 0; - SCALAR_SCALE = 1; -} - -enum QuantizeScaleMode -{ - NORMAL_MODE = 0; - SQRT_MODE = 1; -} - -enum QuantizeAlgorithm -{ - NON_OFFSET_ALGO = 0; - HALF_OFFSET_ALGO = 1; - ALL_OFFSET_ALGO = 2; -} -message QuantizeFactor -{ - QuantizeScaleMode scale_mode = 1; - bytes scale_value = 2; - int64 scale_offset = 3; - bytes offset_data_value = 4; - int64 offset_data_offset = 5; - bytes offset_weight_value = 6; - int64 offset_weight_offset = 7; - bytes offset_pad_value = 8; - int64 offset_pad_offset = 9; -}; - -message QuantizeCalcFactor -{ - bytes offsetw = 1; - int64 offsetw_offset = 2; - bytes offsetd = 3; - int64 offsetd_offset = 4; - bytes scalereq = 5; - int64 scaledreq_offset = 6; - bytes offsetdnext = 7; - int64 offsetdnext_offset = 8; -} - -message QuantizeFactorParams -{ - QuantizeAlgorithm quantize_algo = 1; - QuantizeScaleType scale_type = 2; - QuantizeFactor quantize_param = 3; - QuantizeFactor dequantize_param = 4; - QuantizeFactor requantize_param = 5; - QuantizeCalcFactor quantizecalc_param = 6; -}; - -message ConvolutionOpParams { - int32 mode = 1; - int32 algo = 2; - int32 pad_mode = 3; - uint32 group = 4; - uint32 num_output = 5; - - repeated uint32 pad = 10; - repeated uint32 stride = 11; - repeated uint32 dilation = 12; - repeated uint32 kernel = 13; - - float alpha = 20; - float beta = 21; - - WeightDef filter = 40; - WeightDef bias = 41; - - bool relu_flag = 62; - repeated uint32 adj = 70; - repeated uint32 target_shape = 71; - repeated uint32 before_pad = 72; -}; - -message PoolingOpParams { - int32 mode = 1; - int32 nan_opt = 2; - int32 pad_mode = 3; - bool global_pooling = 4; - - repeated uint32 window = 10; - repeated uint32 pad = 11; - repeated uint32 stride = 12; - bool ceil_mode = 13; - int32 data_mode = 14; - - float alpha = 20; - float beta = 21; - repeated uint32 before_pad = 22; -}; - -message EltwiseOpParams { - int32 mode = 1; - repeated float coeff = 2; - float alpha = 3; - float beta = 4; - repeated WeightDef weight = 5; - bool relu_flag = 6; -}; - -message ActivationOpParams { - int32 mode = 1; - float coef = 2; - float alpha = 3; - float beta = 4; -}; - -message BatchNormOpParams { - int32 mode = 1; - - float alpha = 2; - float beta = 3; - double epsilon = 4;//optinal,[default = 1e-5] - bool use_global_stats = 5; //optinal,by default true,testing mode - float moving_average_fraction = 6; //optinal,[default = .999]; - - WeightDef estimated_mean = 7; - WeightDef estimated_variance = 8; - - WeightDef scale = 9; - WeightDef bias = 10; -}; - -message ScaleOpParams { - WeightDef scale = 1; - WeightDef bias = 2; -}; - -message ReshapeOpParams { - float alpha = 1; - float beta = 2; - ShapeDef shape = 3; - int32 axis = 4; - int32 num_axes = 5; - int32 format = 6; -}; - -message SoftmaxOpParams { - int32 algo = 1; - int32 mode = 2; - float alpha = 3; - float beta = 4; -}; - -message FullConnectionOpParams { - WeightDef filter = 1; - WeightDef bias = 2; - uint32 num_output = 3; - bool relu_flag = 12; -}; - -message FlattenOpParams { - float alpha = 1; - float beta = 2; - int32 start_axis = 3; - int32 end_axis = 4; -} - -message AddLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message MulLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message AddOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message MulOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message SubOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message BiasAddOpParams { - float alpha = 1; - float beta = 2; - - WeightDef bias = 10; -}; - -message MatMulOpParams { - float alpha = 1; - float beta = 2; - bool transposeX = 3; - bool transposeW = 4; - - WeightDef filter = 10; - WeightDef bias = 12; -}; - -message RsqrtOpParams { - float alpha = 1; - float beta = 2; -}; - - -message WeightDef { - int32 format = 1; - int32 data_type = 2; - ShapeDef shape = 3; - bytes data = 4; - int64 data_offset = 5; - uint32 cmps_size = 6; - bytes cmps_tab = 7; - int64 cmps_tab_offset = 10; - CompressInfo cmps_info = 8; - AllOffsetQuantizeInfo alloffset_quantize_info = 11; -} - -message ShapeDef { - repeated int64 dim = 1; -} - -enum DeviceType { - NPU = 0; // In default, we will use NPU. - CPU = 1; // CPU -} - -message AllOffsetQuantizeInfo { - float scale = 1; - int32 offset = 2; -} - -message TensorDescriptor { - int32 format = 1; - int32 data_type = 2; - repeated int64 dim = 3; - uint32 size = 4; - bool reuse_input = 5; - bool output_tensor = 7; - DeviceType device_type = 8; - bool input_tensor = 9; - uint32 real_dim_cnt = 10; - uint32 reuse_input_index = 11; - AllOffsetQuantizeInfo alloffset_quantize_info = 12; -} - -message CompressInfo { - int32 blockRow = 1; // block row - int32 blockCol = 2; // block col - int32 fractalK = 3; // fractal K - int32 fractalN = 4; // fractal N - int32 lastFractalK = 5; // K of last fractal - int32 lastFractalN = 6; // N of last fractal - int32 cubeSize = 7; // cube's length - int32 loadDir = 8; // data load directtiono 0:col load 1:row load -} - -message AttrDef { - message ListValue { - repeated string s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated uint32 u = 6 [packed = true]; // "list(uint)" - repeated bytes bt = 7; - } - - oneof value { - string s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - uint32 u = 6; // "uint32" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs { - string name = 1; - map attr = 2; -} - diff --git a/ge/proto/op_mapping_info.proto b/ge/proto/op_mapping_info.proto deleted file mode 100644 index 7b84a115..00000000 --- a/ge/proto/op_mapping_info.proto +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; -package aicpu.dump; - -message Shape { - repeated uint64 dim = 1; -} - -message Output { - int32 data_type = 1; - int32 format = 2; - Shape shape = 3; - uint64 address = 4; - string original_name = 5; - int32 original_output_index = 6; - int32 original_output_data_type = 7; - int32 original_output_format = 8; - uint64 size = 9; -} - -message Input { - int32 data_type =1; - int32 format = 2; - Shape shape = 3; - uint64 address = 4; - uint64 size = 5; -} - -enum BufferType { - L1 = 0; -} - -message OpBuffer { - BufferType buffer_type = 1; - uint64 address = 2; - uint64 size = 3; -} - -message Op { - string op_name = 1; - string op_type = 2; -} - -message Task { - uint32 task_id = 1; - uint32 stream_id = 2; - Op op = 3; - repeated Output output = 4; - bool end_graph = 5; - repeated Input input = 6; - repeated OpBuffer buffer = 7; -} - -message OpMappingInfo { - string dump_path = 1; - oneof model_name_param { - string model_name = 2; - } - oneof model_id_param { - uint32 model_id = 3; - } - oneof step_id { - uint64 step_id_addr = 4; - } - oneof iterations_per_loop { - uint64 iterations_per_loop_addr = 5; - } - oneof loop_cond { - uint64 loop_cond_addr = 6; - } - uint32 flag = 7; // 0x01 load, 0x00 unload - repeated Task task = 8; - string dump_step = 9; -} \ No newline at end of file diff --git a/ge/proto/task.proto b/ge/proto/task.proto deleted file mode 100644 index 50ea061b..00000000 --- a/ge/proto/task.proto +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package domi; - -message ModelTaskDef { - string version = 1; - - map attr = 9; // Extended field - repeated TaskDef task = 10; - - uint64 memory_size = 11; - uint32 stream_num = 12; - uint32 event_num = 13; - uint64 weight_size = 14; - - repeated bytes op = 15; // input/output opdef in bytes - - uint64 base_addr = 16; // base addr - uint64 weight_addr = 17; // weight addr - uint32 batch_num = 18; -} - - -message TaskDef { - uint32 id = 1; - uint32 type = 2; - - uint32 stream_id = 10; - uint32 event_id = 11; - - KernelDef kernel = 20; - KernelExDef kernel_ex = 21; - KernelHcclDef kernel_hccl = 25; - EventExDef event_ex = 26; - LogTimeStampDef log_timestamp = 28; - - uint32 label_id = 30; - - MemcpyAsyncDef memcpy_async = 31; - StreamSwitchDef stream_switch = 32; - StreamActiveDef stream_active = 33; - bytes private_def = 34; - uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future - StreamSwitchNDef stream_switch_n = 36; - - LabelSetDef label_set = 37; - LabelGotoExDef label_goto_ex = 38; - LabelSwitchByIndexDef label_switch_by_index = 39; -} - -message KernelDef { - KernelContext context = 1; - - string stub_func = 10; - uint32 block_dim = 11; - uint32 args_size = 12; - bytes args = 13; - bytes sm_desc = 14; - bytes flowtable = 15; - string so_name = 16; - string kernel_name = 17; - bytes kernel_ext_info = 18; - uint32 kernel_ext_info_size = 19; -} - -message KernelContext { - uint32 kernel_type = 1; - uint32 op_id = 2; // OP type in CCE - uint32 kernel_func_id = 3; - uint32 op_index = 4; // TE/Custom operator - bool is_flowtable = 5; // Identify whether args is a flowtable structure - bytes args_offset = 6; // args offset information - uint32 args_count = 7; // args count - repeated uint32 origin_op_index = 8; -} - - -message KernelExDef { - uint32 flags = 1; - - uint32 op_index = 4; - uint32 args_size = 12; - bytes args = 13; - bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput - uint32 task_info_size = 15; - bytes kernel_ext_info = 16; - uint32 kernel_ext_info_size = 17; -} - - -message KernelHcclDef { - uint32 op_index = 8; - string hccl_type = 9; -} - - -message EventExDef { - uint32 op_index = 1; - uint32 event_type = 2; -} - -message LogTimeStampDef { - uint64 logid = 1; - bool notify = 2; - uint32 flat = 3; -} - -message MemcpyAsyncDef { - uint64 dst = 1; - uint64 dst_max = 2; - uint64 src = 3; - uint64 count = 4; - uint32 kind = 5; - uint32 op_index = 6; -} - -message StreamSwitchDef { - uint32 op_index = 1; - uint32 true_stream_id = 2; - int64 value = 3; - uint64 value_ptr = 4; - uint32 data_type = 5; -} - -message StreamActiveDef { - uint32 op_index = 1; - uint32 active_stream_id = 2; -} - -message StreamSwitchNDef { - uint32 op_index = 1; - uint32 size = 2; - repeated int64 target_value = 3; - repeated uint32 true_stream_id = 4; - uint32 element_size = 5; - uint32 data_type = 6; -} - -message LabelSetDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelGotoExDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelSwitchByIndexDef { - uint32 op_index = 1; - uint32 label_max = 2; -} diff --git a/ge/proto/tensorflow/attr_value.proto b/ge/proto/tensorflow/attr_value.proto deleted file mode 100644 index 1cc67d62..00000000 --- a/ge/proto/tensorflow/attr_value.proto +++ /dev/null @@ -1,62 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "AttrValueProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensor.proto"; -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing the value for an attr used to configure an Op. -// Comment indicates the corresponding attr type. Only the field matching the -// attr type may be filled. -message AttrValue { - // LINT.IfChange - message ListValue { - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated DataType type = 6 [packed = true]; // "list(type)" - repeated TensorShapeProto shape = 7; // "list(shape)" - repeated TensorProto tensor = 8; // "list(tensor)" - repeated NameAttrList func = 9; // "list(attr)" - } - // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) - - oneof value { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - DataType type = 6; // "type" - TensorShapeProto shape = 7; // "shape" - TensorProto tensor = 8; // "tensor" - ListValue list = 1; // any "list(...)" - - // "func" represents a function. func.name is a function's name or - // a primitive op's name. func.attr.first is the name of an attr - // defined for that function. func.attr.second is the value for - // that attr in the instantiation. - NameAttrList func = 10; - - // This is a placeholder only used in nodes defined inside a - // function. It indicates the attr value will be supplied when - // the function is instantiated. For example, let us suppose a - // node "N" in function "FN". "N" has an attr "A" with value - // placeholder = "foo". When FN is instantiated with attr "foo" - // set to "bar", the instantiated node N's attr A will have been - // given the value "bar". - string placeholder = 9; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NameAttrList { - string name = 1; - map attr = 2; -} diff --git a/ge/proto/tensorflow/function.proto b/ge/proto/tensorflow/function.proto deleted file mode 100644 index 075897c6..00000000 --- a/ge/proto/tensorflow/function.proto +++ /dev/null @@ -1,100 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "FunctionProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; -import "node_def.proto"; -import "op_def.proto"; - -// A library is a set of named functions. -message FunctionDefLibrary { - repeated FunctionDef function = 1; - repeated GradientDef gradient = 2; -} - -// A function can be instantiated when the runtime can bind every attr -// with a value. When a GraphDef has a call to a function, it must -// have binding for every attr defined in the signature. -// * device spec, etc. -message FunctionDef { - // The definition of the function's name, arguments, return values, - // attrs etc. - OpDef signature = 1; - - // Attributes specific to this function definition. - map attr = 5; - - // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. - reserved 2; - - // In both of the following fields, there is the need to specify an - // output that is used as either the input to another node (in - // `node_def`) or as a return value of the function (in `ret`). - // Unlike the NodeDefs in GraphDef, we need to be able to specify a - // list in some cases (instead of just single outputs). Also, we - // need to be able to deal with lists of unknown length (so the - // output index may not be known at function definition time). So - // we use the following format instead: - // * "fun_in" where "fun_in" is the name of a function input arg in - // the `signature` field above. This represents that input, whether - // it is a single tensor or a list. - // * "fun_in:0" gives the first element of a function input arg (a - // non-list input is considered a list of length 1 for these - // purposes). - // * "node:out" where "node" is the name of a node in `node_def` and - // "out" is the name one of its op's output arguments (the name - // comes from the OpDef of the node's op). This represents that - // node's output, whether it is a single tensor or a list. - // Note: We enforce that an op's output arguments are never - // renamed in the backwards-compatibility test. - // * "node:out:0" gives the first element of a node output arg (a - // non-list output is considered a list of length 1 for these - // purposes). - // - // NOT CURRENTLY SUPPORTED (but may be in the future): - // * "node:out:-1" gives last element in a node output list - // * "node:out:1:" gives a list with all but the first element in a - // node output list - // * "node:out::-1" gives a list with all but the last element in a - // node output list - - // The body of the function. Unlike the NodeDefs in a GraphDef, attrs - // may have values of type `placeholder` and the `input` field uses - // the "output" format above. - - // By convention, "op" in node_def is resolved by consulting with a - // user-defined library first. If not resolved, "func" is assumed to - // be a builtin op. - repeated NodeDef node_def = 3; - - // A mapping from the output arg names from `signature` to the - // outputs from `node_def` that should be returned by the function. - map ret = 4; -} - -// GradientDef defines the gradient function of a function defined in -// a function library. -// -// A gradient function g (specified by gradient_func) for a function f -// (specified by function_name) must follow the following: -// -// The function 'f' must be a numerical function which takes N inputs -// and produces M outputs. Its gradient function 'g', which is a -// function taking N + M inputs and produces N outputs. -// -// I.e. if we have -// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), -// then, g is -// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, -// dL/dy1, dL/dy2, ..., dL/dy_M), -// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the -// loss function). dL/dx_i is the partial derivative of L with respect -// to x_i. -message GradientDef { - string function_name = 1; // The function name. - string gradient_func = 2; // The gradient function's name. -} diff --git a/ge/proto/tensorflow/graph.proto b/ge/proto/tensorflow/graph.proto deleted file mode 100644 index d639a7d6..00000000 --- a/ge/proto/tensorflow/graph.proto +++ /dev/null @@ -1,56 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "GraphProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "node_def.proto"; -import "function.proto"; -import "versions.proto"; - -// Represents the graph of operations -message GraphDef { - repeated NodeDef node = 1; - - // Compatibility versions of the graph. See core/public/version.h for version - // history. The GraphDef version is distinct from the TensorFlow version, and - // each release of TensorFlow will support a range of GraphDef versions. - VersionDef versions = 4; - - // Deprecated single version field; use versions above instead. Since all - // GraphDef changes before "versions" was introduced were forward - // compatible, this field is entirely ignored. - int32 version = 3 [deprecated = true]; - - // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. - // - // "library" provides user-defined functions. - // - // Naming: - // * library.function.name are in a flat namespace. - // NOTE: We may need to change it to be hierarchical to support - // different orgs. E.g., - // { "/google/nn", { ... }}, - // { "/google/vision", { ... }} - // { "/org_foo/module_bar", { ... }} - // map named_lib; - // * If node[i].op is the name of one function in "library", - // node[i] is deemed as a function call. Otherwise, node[i].op - // must be a primitive operation supported by the runtime. - // - // - // Function call semantics: - // - // * The callee may start execution as soon as some of its inputs - // are ready. The caller may want to use Tuple() mechanism to - // ensure all inputs are ready in the same time. - // - // * The consumer of return values may start executing as soon as - // the return values the consumer depends on are ready. The - // consumer may want to use Tuple() mechanism to ensure the - // consumer does not start until all return values of the callee - // function are ready. - FunctionDefLibrary library = 2; -}; diff --git a/ge/proto/tensorflow/graph_library.proto b/ge/proto/tensorflow/graph_library.proto deleted file mode 100644 index e393d38d..00000000 --- a/ge/proto/tensorflow/graph_library.proto +++ /dev/null @@ -1,14 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; - -import "graph.proto"; - -message GeGraphDef { - string name = 1; - GraphDef graph = 2; -} - -message GraphDefLibrary { - repeated GeGraphDef graph_def = 1; -}; \ No newline at end of file diff --git a/ge/proto/tensorflow/node_def.proto b/ge/proto/tensorflow/node_def.proto deleted file mode 100644 index b9bc97ee..00000000 --- a/ge/proto/tensorflow/node_def.proto +++ /dev/null @@ -1,63 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "NodeProto"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; - -message NodeDef { - // The name given to this operator. Used for naming inputs, - // logging, visualization, etc. Unique within a single GraphDef. - // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". - string name = 1; - - // The operation name. There may be custom parameters in attrs. - // Op names starting with an underscore are reserved for internal use. - string op = 2; - - // Each input is "node:src_output" with "node" being a string name and - // "src_output" indicating which output tensor to use from "node". If - // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs - // may optionally be followed by control inputs that have the format - // "^node". - repeated string input = 3; - - // A (possibly partial) specification for the device on which this - // node should be placed. - // The expected syntax for this string is as follows: - // - // DEVICE_SPEC ::= PARTIAL_SPEC - // - // PARTIAL_SPEC ::= ("/" CONSTRAINT) * - // CONSTRAINT ::= ("job:" JOB_NAME) - // | ("replica:" [1-9][0-9]*) - // | ("task:" [1-9][0-9]*) - // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) - // - // Valid values for this string include: - // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) - // * "/job:worker/device:GPU:3" (partial specification) - // * "" (no specification) - // - // If the constraints do not resolve to a single device (or if this - // field is empty or not present), the runtime will attempt to - // choose a device automatically. - string device = 4; - - // Operation-specific graph-construction-time configuration. - // Note that this should include all attrs defined in the - // corresponding OpDef, including those with a value matching - // the default -- this allows the default to change and makes - // NodeDefs easier to interpret on their own. However, if - // an attr with a default is not specified in this list, the - // default will be used. - // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and - // one of the names from the corresponding OpDef's attr field). - // The values must have a type matching the corresponding OpDef - // attr's type field. - // Add some examples here showing best practices. - map attr = 5; -}; diff --git a/ge/proto/tensorflow/op_def.proto b/ge/proto/tensorflow/op_def.proto deleted file mode 100644 index 3485d045..00000000 --- a/ge/proto/tensorflow/op_def.proto +++ /dev/null @@ -1,164 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "OpDefProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; -import "types.proto"; - -// Defines an operation. A NodeDef in a GraphDef specifies an Op by -// using the "op" field which should match the name of a OpDef. -// LINT.IfChange -message OpDef { - // Op names starting with an underscore are reserved for internal use. - // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". - string name = 1; - - // For describing inputs and outputs. - message ArgDef { - // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". - string name = 1; - - // Human readable description. - string description = 2; - - // Describes the type of one or more tensors that are accepted/produced - // by this input/output arg. The only legal combinations are: - // * For a single tensor: either the "type" field is set or the - // "type_attr" field is set to the name of an attr with type "type". - // * For a sequence of tensors with the same type: the "number_attr" - // field will be set to the name of an attr with type "int", and - // either the "type" or "type_attr" field will be set as for - // single tensors. - // * For a sequence of tensors, the "type_list_attr" field will be set - // to the name of an attr with type "list(type)". - DataType type = 3; - string type_attr = 4; // if specified, attr must have type "type" - string number_attr = 5; // if specified, attr must have type "int" - // If specified, attr must have type "list(type)", and none of - // type, type_attr, and number_attr may be specified. - string type_list_attr = 6; - - // For inputs: if true, the inputs are required to be refs. - // By default, inputs can be either refs or non-refs. - // For outputs: if true, outputs are refs, otherwise they are not. - bool is_ref = 16; - }; - - // Description of the input(s). - repeated ArgDef input_arg = 2; - - // Description of the output(s). - repeated ArgDef output_arg = 3; - - // Description of the graph-construction-time configuration of this - // Op. That is to say, this describes the attr fields that will - // be specified in the NodeDef. - message AttrDef { - // A descriptive name for the argument. May be used, e.g. by the - // Python client, as a keyword argument name, and so should match - // the regexp "[a-z][a-z0-9_]+". - string name = 1; - - // One of the type names from attr_value.proto ("string", "list(string)", - // "int", etc.). - string type = 2; - - // A reasonable default for this attribute if the user does not supply - // a value. If not specified, the user must supply a value. - AttrValue default_value = 3; - - // Human-readable description. - string description = 4; - - - // --- Constraints --- - // These constraints are only in effect if specified. Default is no - // constraints. - - // For type == "int", this is a minimum value. For "list(___)" - // types, this is the minimum length. - bool has_minimum = 5; - int64 minimum = 6; - - // The set of allowed values. Has type that is the "list" version - // of the "type" field above (uses the "list" field of AttrValue). - // If type == "type" or "list(type)" above, then the "type" field - // of "allowed_values.list" has the set of allowed DataTypes. - // If type == "string" or "list(string)", then the "s" field of - // "allowed_values.list" has the set of allowed strings. - AttrValue allowed_values = 7; - } - repeated AttrDef attr = 4; - - // Optional deprecation based on GraphDef versions. - OpDeprecation deprecation = 8; - - // One-line human-readable description of what the Op does. - string summary = 5; - - // Additional, longer human-readable description of what the Op does. - string description = 6; - - // ------------------------------------------------------------------------- - // Which optimizations this operation can participate in. - - // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) - bool is_commutative = 18; - - // If is_aggregate is true, then this operation accepts N >= 2 - // inputs and produces 1 output all of the same type. Should be - // associative and commutative, and produce output with the same - // shape as the input. The optimizer may replace an aggregate op - // taking input from multiple devices with a tree of aggregate ops - // that aggregate locally within each device (and possibly within - // groups of nearby devices) before communicating. - bool is_aggregate = 16; // for things like add - - // Other optimizations go here, like - // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. - - // ------------------------------------------------------------------------- - // Optimization constraints. - - // Ops are marked as stateful if their behavior depends on some state beyond - // their input tensors (e.g. variable reading op) or if they have - // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops - // must always produce the same output for the same input and have - // no side-effects. - // - // By default Ops may be moved between devices. Stateful ops should - // either not be moved, or should only be moved if that state can also - // be moved (e.g. via some sort of save / restore). - // Stateful ops are guaranteed to never be optimized away by Common - // Subexpression Elimination (CSE). - bool is_stateful = 17; // for things like variables, queue - - // ------------------------------------------------------------------------- - // Non-standard options. - - // By default, all inputs to an Op must be initialized Tensors. Ops - // that may initialize tensors for the first time should set this - // field to true, to allow the Op to take an uninitialized Tensor as - // input. - bool allows_uninitialized_input = 19; // for Assign, etc. -}; -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) - -// Information about version-dependent deprecation of an op -message OpDeprecation { - // First GraphDef version at which the op is disallowed. - int32 version = 1; - - // Explanation of why it was deprecated and what to use instead. - string explanation = 2; -}; - -// A collection of OpDefs -message OpList { - repeated OpDef op = 1; -}; diff --git a/ge/proto/tensorflow/resource_handle.proto b/ge/proto/tensorflow/resource_handle.proto deleted file mode 100644 index a3452351..00000000 --- a/ge/proto/tensorflow/resource_handle.proto +++ /dev/null @@ -1,29 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "ResourceHandle"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// Protocol buffer representing a handle to a tensorflow resource. Handles are -// not valid across executions, but can be serialized back and forth from within -// a single run. -message ResourceHandleProto { - // Unique name for the device containing the resource. - string device = 1; - - // Container in which this resource is placed. - string container = 2; - - // Unique name of this resource. - string name = 3; - - // Hash code for the type of the resource. Is only valid in the same device - // and in the same execution. - uint64 hash_code = 4; - - // For debug-only, the name of the type pointed to by this handle, if - // available. - string maybe_type_name = 5; -}; diff --git a/ge/proto/tensorflow/tensor.proto b/ge/proto/tensorflow/tensor.proto deleted file mode 100644 index d0a4d024..00000000 --- a/ge/proto/tensorflow/tensor.proto +++ /dev/null @@ -1,94 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TensorProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "resource_handle.proto"; -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing a tensor. -message TensorProto { - DataType dtype = 1; - - // Shape of the tensor. - TensorShapeProto tensor_shape = 2; - - // Only one of the representations below is set, one of "tensor_contents" and - // the "xxx_val" attributes. We are not using oneof because as oneofs cannot - // contain repeated fields it would require another extra set of messages. - - // Version number. - // - // In version 0, if the "repeated xxx" representations contain only one - // element, that element is repeated to fill the shape. This makes it easy - // to represent a constant Tensor with a single value. - int32 version_number = 3; - - // Serialized raw tensor content from either Tensor::AsProtoTensorContent or - // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation - // can be used for all tensor types. The purpose of this representation is to - // reduce serialization overhead during RPC call by avoiding serialization of - // many repeated small items. - bytes tensor_content = 4; - - // Type specific representations that make it easy to create tensor protos in - // all languages. Only the representation corresponding to "dtype" can - // be set. The values hold the flattened representation of the tensor in - // row major order. - - // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll - // have some pointless zero padding for each value here. - repeated int32 half_val = 13 [packed = true]; - - // DT_FLOAT. - repeated float float_val = 5 [packed = true]; - - // DT_DOUBLE. - repeated double double_val = 6 [packed = true]; - - // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. - repeated int32 int_val = 7 [packed = true]; - - // DT_STRING - repeated bytes string_val = 8; - - // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real - // and imaginary parts of i-th single precision complex. - repeated float scomplex_val = 9 [packed = true]; - - // DT_INT64 - repeated int64 int64_val = 10 [packed = true]; - - // DT_BOOL - repeated bool bool_val = 11 [packed = true]; - - // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real - // and imaginary parts of i-th double precision complex. - repeated double dcomplex_val = 12 [packed = true]; - - // DT_RESOURCE - repeated ResourceHandleProto resource_handle_val = 14; - - // DT_VARIANT - repeated VariantTensorDataProto variant_val = 15; - - // DT_UINT32 - repeated uint32 uint32_val = 16 [packed = true]; - - // DT_UINT64 - repeated uint64 uint64_val = 17 [packed = true]; -}; - -// Protocol buffer representing the serialization format of DT_VARIANT tensors. -message VariantTensorDataProto { - // Name of the type of objects being serialized. - string type_name = 1; - // Portions of the object that are not Tensors. - bytes metadata = 2; - // Tensors contained within objects being serialized. - repeated TensorProto tensors = 3; -} diff --git a/ge/proto/tensorflow/tensor_shape.proto b/ge/proto/tensorflow/tensor_shape.proto deleted file mode 100644 index 4225a2e3..00000000 --- a/ge/proto/tensorflow/tensor_shape.proto +++ /dev/null @@ -1,45 +0,0 @@ -// Protocol buffer representing the shape of tensors. - -syntax = "proto3"; -option cc_enable_arenas = true; -option java_outer_classname = "TensorShapeProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -package domi.tensorflow; - -// Dimensions of a tensor. -message TensorShapeProto { - // One dimension of the tensor. - message Dim { - // Size of the tensor in that dimension. - // This value must be >= -1, but values of -1 are reserved for "unknown" - // shapes (values of -1 mean "unknown" dimension). Certain wrappers - // that work with TensorShapeProto may fail at runtime when deserializing - // a TensorShapeProto containing a dim value of -1. - int64 size = 1; - - // Optional name of the tensor dimension. - string name = 2; - }; - - // Dimensions of the tensor, such as {"input", 30}, {"output", 40} - // for a 30 x 40 2D tensor. If an entry has size -1, this - // corresponds to a dimension of unknown size. The names are - // optional. - // - // The order of entries in "dim" matters: It indicates the layout of the - // values in the tensor in-memory representation. - // - // The first entry in "dim" is the outermost dimension used to layout the - // values, the last entry is the innermost dimension. This matches the - // in-memory layout of RowMajor Eigen tensors. - // - // If "dim.size()" > 0, "unknown_rank" must be false. - repeated Dim dim = 2; - - // If true, the number of dimensions in the shape is unknown. - // - // If true, "dim.size()" must be 0. - bool unknown_rank = 3; -}; diff --git a/ge/proto/tensorflow/types.proto b/ge/proto/tensorflow/types.proto deleted file mode 100644 index ba7a72b3..00000000 --- a/ge/proto/tensorflow/types.proto +++ /dev/null @@ -1,74 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TypesProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// LINT.IfChange -enum DataType { - // Not a legal value for DataType. Used to indicate a DataType field - // has not been set. - DT_INVALID = 0; - - // Data types that all computation devices are expected to be - // capable to support. - DT_FLOAT = 1; - DT_DOUBLE = 2; - DT_INT32 = 3; - DT_UINT8 = 4; - DT_INT16 = 5; - DT_INT8 = 6; - DT_STRING = 7; - DT_COMPLEX64 = 8; // Single-precision complex - DT_INT64 = 9; - DT_BOOL = 10; - DT_QINT8 = 11; // Quantized int8 - DT_QUINT8 = 12; // Quantized uint8 - DT_QINT32 = 13; // Quantized int32 - DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. - DT_QINT16 = 15; // Quantized int16 - DT_QUINT16 = 16; // Quantized uint16 - DT_UINT16 = 17; - DT_COMPLEX128 = 18; // Double-precision complex - DT_HALF = 19; - DT_RESOURCE = 20; - DT_VARIANT = 21; // Arbitrary C++ data types - DT_UINT32 = 22; - DT_UINT64 = 23; - - // Do not use! These are only for parameters. Every enum above - // should have a corresponding value below (verified by types_test). - DT_FLOAT_REF = 101; - DT_DOUBLE_REF = 102; - DT_INT32_REF = 103; - DT_UINT8_REF = 104; - DT_INT16_REF = 105; - DT_INT8_REF = 106; - DT_STRING_REF = 107; - DT_COMPLEX64_REF = 108; - DT_INT64_REF = 109; - DT_BOOL_REF = 110; - DT_QINT8_REF = 111; - DT_QUINT8_REF = 112; - DT_QINT32_REF = 113; - DT_BFLOAT16_REF = 114; - DT_QINT16_REF = 115; - DT_QUINT16_REF = 116; - DT_UINT16_REF = 117; - DT_COMPLEX128_REF = 118; - DT_HALF_REF = 119; - DT_RESOURCE_REF = 120; - DT_VARIANT_REF = 121; - DT_UINT32_REF = 122; - DT_UINT64_REF = 123; -} -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/c/c_api.h, -// https://www.tensorflow.org/code/tensorflow/go/tensor.go, -// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, -// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, -// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, -// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, -// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/ge/proto/tensorflow/versions.proto b/ge/proto/tensorflow/versions.proto deleted file mode 100644 index 48061218..00000000 --- a/ge/proto/tensorflow/versions.proto +++ /dev/null @@ -1,31 +0,0 @@ -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "VersionsProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// Version information for a piece of serialized data -// -// There are different types of versions for each type of data -// (GraphDef, etc.), but they all have the same common shape -// described here. -// -// Each consumer has "consumer" and "min_producer" versions (specified -// elsewhere). A consumer is allowed to consume this data if -// -// producer >= min_producer -// consumer >= min_consumer -// consumer not in bad_consumers -// -message VersionDef { - // The version of the code that produced this data. - int32 producer = 1; - - // Any consumer below this version is not allowed to consume this data. - int32 min_consumer = 2; - - // Specific consumer versions which are disallowed (e.g. due to bugs). - repeated int32 bad_consumers = 3; -}; diff --git a/ge/session/readme.txt b/ge/session/readme.txt deleted file mode 100644 index d8d0f393..00000000 --- a/ge/session/readme.txt +++ /dev/null @@ -1,3 +0,0 @@ -GE -SessionManager -InnerSession diff --git a/inc/common/blocking_queue.h b/inc/common/blocking_queue.h new file mode 100644 index 00000000..12b02773 --- /dev/null +++ b/inc/common/blocking_queue.h @@ -0,0 +1,141 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_BLOCKING_QUEUE_H_ +#define INC_COMMON_BLOCKING_QUEUE_H_ + +#include +#include +#include +#include + +static const int kDefaultMaxQueueSize = 2048; + +template +class BlockingQueue { + public: + explicit BlockingQueue(uint32_t max_size = kDefaultMaxQueueSize) : max_size_(max_size), is_stoped_(false) {} + + ~BlockingQueue() {} + + bool Pop(T &item) { + std::unique_lock lock(mutex_); + + while (queue_.empty() && !is_stoped_) { + empty_cond_.wait(lock); + } + + if (is_stoped_) { + return false; + } + + item = std::move(queue_.front()); + queue_.pop_front(); + + full_cond_.notify_one(); + + return true; + } + + bool Push(const T &item, bool is_wait = true) { + std::unique_lock lock(mutex_); + + while (queue_.size() >= max_size_ && !is_stoped_) { + if (!is_wait) { + return false; + } + full_cond_.wait(lock); + } + + if (is_stoped_) { + return false; + } + + queue_.push_back(item); + + empty_cond_.notify_one(); + + return true; + } + + bool Push(T &&item, bool is_wait = true) { + std::unique_lock lock(mutex_); + + while (queue_.size() >= max_size_ && !is_stoped_) { + if (!is_wait) { + return false; + } + full_cond_.wait(lock); + } + + if (is_stoped_) { + return false; + } + + queue_.emplace_back(std::move(item)); + + empty_cond_.notify_one(); + + return true; + } + + void Stop() { + { + std::unique_lock lock(mutex_); + is_stoped_ = true; + } + + full_cond_.notify_all(); + empty_cond_.notify_all(); + } + + void Restart() { + std::unique_lock lock(mutex_); + is_stoped_ = false; + } + + // if the queue is stoped ,need call this function to release the unprocessed items + std::list GetRemainItems() { + std::unique_lock lock(mutex_); + + if (!is_stoped_) { + return std::list(); + } + + return queue_; + } + + bool IsFull() { + std::unique_lock lock(mutex_); + return queue_.size() >= max_size_; + } + + void Clear() { + std::unique_lock lock(mutex_); + queue_.clear(); + } + + private: + std::list queue_; + std::mutex mutex_; + std::condition_variable empty_cond_; + std::condition_variable full_cond_; + uint32_t max_size_; + + bool is_stoped_; +}; + +#endif // INC_COMMON_BLOCKING_QUEUE_H_ diff --git a/inc/common/dynamic_aipp.h b/inc/common/dynamic_aipp.h new file mode 100644 index 00000000..a687853f --- /dev/null +++ b/inc/common/dynamic_aipp.h @@ -0,0 +1,104 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_DYNAMIC_AIPP_H_ +#define INC_COMMON_DYNAMIC_AIPP_H_ + +#include + +/** + * @ingroup dnn + * @brief struct define of dynamic aipp batch parameter. + */ +typedef struct tagAippDynamicBatchPara { + int8_t cropSwitch; // crop switch + int8_t scfSwitch; // resize switch + int8_t paddingSwitch; // 0: unable padding + // 1: padding config value,sfr_filling_hblank_ch0 ~ sfr_filling_hblank_ch2 + // 2: padding source picture data, single row/collumn copy + // 3: padding source picture data, block copy + // 4: padding source picture data, mirror copy + int8_t rotateSwitch; // rotate switch,0: non-ratate, + // 1: ratate 90° clockwise,2: ratate 180° clockwise,3: ratate 270° clockwise + int8_t reserve[4]; + int32_t cropStartPosW; // the start horizontal position of cropping + int32_t cropStartPosH; // the start vertical position of cropping + int32_t cropSizeW; // crop width + int32_t cropSizeH; // crop height + + int32_t scfInputSizeW; // input width of scf + int32_t scfInputSizeH; // input height of scf + int32_t scfOutputSizeW; // output width of scf + int32_t scfOutputSizeH; // output height of scf + + int32_t paddingSizeTop; // top padding size + int32_t paddingSizeBottom; // bottom padding size + int32_t paddingSizeLeft; // left padding size + int32_t paddingSizeRight; // right padding size + + int16_t dtcPixelMeanChn0; // mean value of channel 0 + int16_t dtcPixelMeanChn1; // mean value of channel 1 + int16_t dtcPixelMeanChn2; // mean value of channel 2 + int16_t dtcPixelMeanChn3; // mean value of channel 3 + + uint16_t dtcPixelMinChn0; // min value of channel 0 + uint16_t dtcPixelMinChn1; // min value of channel 1 + uint16_t dtcPixelMinChn2; // min value of channel 2 + uint16_t dtcPixelMinChn3; // min value of channel 3 + uint16_t dtcPixelVarReciChn0; // sfr_dtc_pixel_variance_reci_ch0 + uint16_t dtcPixelVarReciChn1; // sfr_dtc_pixel_variance_reci_ch1 + uint16_t dtcPixelVarReciChn2; // sfr_dtc_pixel_variance_reci_ch2 + uint16_t dtcPixelVarReciChn3; // sfr_dtc_pixel_variance_reci_ch3 + + int8_t reserve1[16]; // 32B assign, for ub copy +} kAippDynamicBatchPara; + +/** + * @ingroup dnn + * @brief struct define of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte + */ +typedef struct tagAippDynamicPara { + uint8_t inputFormat; // input format:YUV420SP_U8/XRGB8888_U8/RGB888_U8 + int8_t cscSwitch; // csc switch + int8_t rbuvSwapSwitch; // rb/ub swap switch + int8_t axSwapSwitch; // RGBA->ARGB, YUVA->AYUV swap switch + int8_t batchNum; // batch parameter number + int8_t reserve1[3]; + int32_t srcImageSizeW; // source image width + int32_t srcImageSizeH; // source image height + int16_t cscMatrixR0C0; // csc_matrix_r0_c0 + int16_t cscMatrixR0C1; // csc_matrix_r0_c1 + int16_t cscMatrixR0C2; // csc_matrix_r0_c2 + int16_t cscMatrixR1C0; // csc_matrix_r1_c0 + int16_t cscMatrixR1C1; // csc_matrix_r1_c1 + int16_t cscMatrixR1C2; // csc_matrix_r1_c2 + int16_t cscMatrixR2C0; // csc_matrix_r2_c0 + int16_t cscMatrixR2C1; // csc_matrix_r2_c1 + int16_t cscMatrixR2C2; // csc_matrix_r2_c2 + int16_t reserve2[3]; + uint8_t cscOutputBiasR0; // output Bias for RGB to YUV, element of row 0, unsigned number + uint8_t cscOutputBiasR1; // output Bias for RGB to YUV, element of row 1, unsigned number + uint8_t cscOutputBiasR2; // output Bias for RGB to YUV, element of row 2, unsigned number + uint8_t cscInputBiasR0; // input Bias for YUV to RGB, element of row 0, unsigned number + uint8_t cscInputBiasR1; // input Bias for YUV to RGB, element of row 1, unsigned number + uint8_t cscInputBiasR2; // input Bias for YUV to RGB, element of row 2, unsigned number + uint8_t reserve3[2]; + int8_t reserve4[16]; // 32B assign, for ub copy + + kAippDynamicBatchPara aippBatchPara; // allow transfer several batch para. +} kAippDynamicPara; + +#endif // INC_COMMON_DYNAMIC_AIPP_H_ diff --git a/inc/common/npu_error_define.h b/inc/common/npu_error_define.h new file mode 100644 index 00000000..a4515cf6 --- /dev/null +++ b/inc/common/npu_error_define.h @@ -0,0 +1,94 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_NPU_ERROR_DEFINE_H_ +#define INC_COMMON_NPU_ERROR_DEFINE_H_ + +typedef enum tagHiAiNpuLocal { + HIAI_HOST = 1, + HIAI_DEVICE = 2, +} HiAiNpuLocal; + +typedef enum tagHiAiNpuCodeType { + ERROR_CODE = 1, + EXCEPTION_CODE = 2, +} HiAiNpuCodeType; + +typedef enum tagHiAiNpuErrLevel { + NONE_LEVEL = 0, + SUGGESTION_LEVEL = 1, + NORMAL_LEVEL = 2, + SERIOUS_LEVEL = 3, + CRITICAL_ERROR = 4, +} HiAiNpuErrLevel; + +typedef enum tagHiAiNpuModuleId { + HIAI_DRIVER = 1, + HIAI_CTRLCPU = 2, + HIAI_TS = 3, + HIAI_RUNTIME = 4, + HIAI_AICPU = 5, + HIAI_CCE = 6, + HIAI_TVM = 7, + HIAI_FRAMEWORK = 8, + HiAI_ENGINE = 9, + HIAI_DVPP = 10, + HIAI_AIPP = 11, + HIAI_LOWPOWER = 12, + HIAI_MDC = 13, + HIAI_COMPILE = 14, + HIAI_TOOLCHIAN = 15, + HIAI_ALG = 16, + HIAI_PROFILING = 17, + HIAI_HCCL = 18, + HIAI_SIMULATION = 19, + HIAI_BIOS = 20, + HIAI_SEC = 21, + HIAI_TINY = 22, + HIAI_DP = 23, +} HiAiNpuModuleId; + +/* bit 31-bit30 to be hiai local */ +#define HIAI_NPULOCAL_MASK 0xC0000000 +#define SHIFT_LOCAL_MASK 30 +#define HIAI_NPULOCAL_VAL_MASK 0x3 +/* bit 29 -bit28 to be hiai aicpu code type */ +#define HIAI_CODE_TYPE_MASK 0x30000000 +#define SHIFT_CODE_MASK 28 +#define HIAI_CODE_TYPE_VAL_MASK 0x3 +/* bit 27 -bit25 to be hiai error level */ +#define HIAI_ERROR_LEVEL_MASK 0x0E000000 +#define SHIFT_ERROR_LVL_MASK 25 +#define HIAI_ERROR_LEVEL_VAL_MASK 0x7 +/* bit 24 -bit17 to be hiai mod */ +#define HIAI_MODE_ID_MASK 0x01FE0000 +#define SHIFT_MODE_MASK 17 +#define HIAI_MODE_ID_VAL_MASK 0xFF + +#define HIAI_NPU_LOC_BIT(a) \ + (HIAI_NPULOCAL_MASK & ((unsigned int)((HiAiNpuLocal)(a)) & HIAI_NPULOCAL_VAL_MASK) << SHIFT_LOCAL_MASK) +#define HIAI_NPU_CODE_TYPE_BIT(a) \ + (HIAI_CODE_TYPE_MASK & ((unsigned int)((HiAiNpuCodeType)(a)) & HIAI_CODE_TYPE_VAL_MASK) << SHIFT_CODE_MASK) +#define HIAI_NPU_ERR_LEV_BIT(a) \ + (HIAI_ERROR_LEVEL_MASK & ((unsigned int)((HiAiNpuErrLevel)(a)) & HIAI_ERROR_LEVEL_VAL_MASK) << SHIFT_ERROR_LVL_MASK) +#define HIAI_NPU_MOD_ID_BIT(a) \ + (HIAI_MODE_ID_MASK & ((unsigned int)((HiAiNpuModuleId)(a)) & HIAI_MODE_ID_VAL_MASK) << SHIFT_MODE_MASK) + +#define HIAI_NPU_ERR_CODE_HEAD(npuLocal, codeType, errLevel, moduleId) \ + (HIAI_NPU_LOC_BIT(npuLocal) + HIAI_NPU_CODE_TYPE_BIT(codeType) + HIAI_NPU_ERR_LEV_BIT(errLevel) + \ + HIAI_NPU_MOD_ID_BIT(moduleId)) + +#endif // INC_COMMON_NPU_ERROR_DEFINE_H_ diff --git a/inc/common/opskernel/ge_task_info.h b/inc/common/opskernel/ge_task_info.h new file mode 100644 index 00000000..9f3c409d --- /dev/null +++ b/inc/common/opskernel/ge_task_info.h @@ -0,0 +1,74 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ +#define INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ + +#include +#include +#include +#include + +using std::string; +namespace ge { +// when need to eliminate GETaskKernelHcclInfo, so not need DAVINCI_TRAIN/DAVINCI_CLOUD +struct GETaskKernelHcclInfo { + string input_name; + string hccl_type; + void *inputDataAddr; + void *outputDataAddr; + void *workSpaceAddr; + int32_t count; + int32_t dataType; + int32_t opType; + int64_t rootId; + uint64_t workSpaceMemSize; + std::vector dims; + std::vector hcclStreamList; +}; + +struct GETaskInfo { + uint32_t id; + uint16_t type; + uint32_t streamID; + void *stream; // rtKernelLaunch input argument + void *event; + void *privateDef; + uint32_t privateDefLen; + void *opsKernelStorePtr; + + std::vector kernelHcclInfo; +}; + +struct HcomOpertion { + std::string hcclType; + void *inputPtr; + void *outputPtr; + uint64_t count; + int32_t dataType; + int32_t opType; + int32_t root; +}; + +struct HcomRemoteAccessAddrInfo { + uint32_t remotetRankID; + uint64_t remoteAddr; // host embedding table address + uint64_t localAddr; // device HBM address + uint64_t length; // memory Length in Bytes +}; + +} // namespace ge +#endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ diff --git a/inc/common/opskernel/ops_kernel_info_store.h b/inc/common/opskernel/ops_kernel_info_store.h new file mode 100644 index 00000000..ce1464d4 --- /dev/null +++ b/inc/common/opskernel/ops_kernel_info_store.h @@ -0,0 +1,88 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ +#define INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ + +#include +#include +#include +#include +#include "./ge_task_info.h" +#include "./ops_kernel_info_types.h" +#include "cce/aicpu_engine_struct.h" +#include "cce/fwk_adpt_struct.h" +#include "common/ge_inner_error_codes.h" +#include "graph/node.h" +#include "proto/task.pb.h" +using std::map; +using std::string; +using std::to_string; +using std::vector; + +namespace ge { +class OpDesc; + +class OpsKernelInfoStore { + public: + OpsKernelInfoStore() {} + + virtual ~OpsKernelInfoStore() {} + + // initialize opsKernelInfoStore + virtual Status Initialize(const map &options) = 0; /*lint -e148*/ + + // close opsKernelInfoStore + virtual Status Finalize() = 0; /*lint -e148*/ + + virtual Status CreateSession(const std::map &session_options) { return SUCCESS; } + + virtual Status DestroySession(const std::map &session_options) { return SUCCESS; } + + // get all opsKernelInfo + virtual void GetAllOpsKernelInfo(map &infos) const = 0; + + // whether the opsKernelInfoStore is supported based on the operator attribute + virtual bool CheckSupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason) const = 0; + + virtual bool CheckAccuracySupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason, + bool realQuery = false) const { + return CheckSupported(opDescPtr, un_supported_reason); + } + // opsFlag opsFlag[0] indicates constant folding is supported or not + virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag){}; + + // memory allocation requirement + virtual Status CalcOpRunningParam(Node &node) = 0; /*lint -e148*/ + + // generate task for op。 + virtual Status GenerateTask(const Node &node, RunContext &context, + std::vector &tasks) = 0; /*lint -e148*/ + + // only call fe engine interface to compile single op + virtual Status CompileOp(vector &node_vec) { return SUCCESS; } + virtual Status CompileOpRun(vector &node_vec) { return SUCCESS; } + // load task for op + virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; } + + // only call aicpu interface to generate task struct + virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } + + // only call aicpu interface to generate task struct + virtual Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } +}; +} // namespace ge +#endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ diff --git a/inc/common/opskernel/ops_kernel_info_types.h b/inc/common/opskernel/ops_kernel_info_types.h new file mode 100644 index 00000000..684c1abc --- /dev/null +++ b/inc/common/opskernel/ops_kernel_info_types.h @@ -0,0 +1,66 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ +#define INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ + +#include +#include +#include +#include "graph/buffer.h" +#include "runtime/rt_model.h" + +using std::string; + +namespace ge { +/*lint -e148*/ +struct RunContext { + rtModel_t model; + rtStream_t stream; + uint64_t sessionId; + uint64_t dataMemSize; + uint8_t *dataMemBase; + uint64_t weightMemSize; + uint8_t *weightMemBase; + ge::Buffer weightsBuffer; + std::vector graphStreamList; // all streams of graph, order by ge stream id(0,1,...) + std::vector graphEventList; // all events of graph, order by ge event id(0,1,...) + std::vector graphLabelList; // all labels of graph, order by ge label id(0,1,...) +}; + +/*lint +e148*/ + +struct Task { + uint32_t id; + uint16_t type; + void *stream; + void *event; +}; + +struct OpInfo { + string engine; // which engin + /*lint -e148*/ + string opKernelLib; // which opsKernelStore + int computeCost; // compute cost + bool flagPartial; // whether to support is related to shape + bool flagAsync; // Whether to support asynchronous + bool isAtomic; // whether to support atomic addr clean + string opFileName; // op file name + string opFuncName; // op function name +}; +} // namespace ge + +#endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ diff --git a/inc/common/optimizer/graph_optimizer.h b/inc/common/optimizer/graph_optimizer.h new file mode 100644 index 00000000..253aaae1 --- /dev/null +++ b/inc/common/optimizer/graph_optimizer.h @@ -0,0 +1,71 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ +#define INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ + +#include +#include +#include "./graph_optimizer_types.h" +#include "common/ge_inner_error_codes.h" +#include "common/opskernel/ops_kernel_info_types.h" +#include "graph/compute_graph.h" + +using std::map; +using std::string; + +/*lint -e148*/ +namespace ge { +class GraphOptimizer { + public: + virtual ~GraphOptimizer() {} + + // initialize graphOptimizer + virtual Status Initialize(const map &options) = 0; + + // close graphOptimizer + virtual Status Finalize() = 0; + + // optimize original graph for FE quant optimize + virtual Status OptimizeGraphPrepare(ComputeGraph &graph) { return SUCCESS; } + + // optimize graph before build for RTS + virtual Status OptimizeGraphBeforeBuild(ComputeGraph &graph) { return SUCCESS; } + + // optimize original graph, using in graph preparation stage + virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; + + // optimize original graph, using for conversion operator insert in graph preparation stage + virtual Status OptimizeOriginalGraphJudgeInsert(ComputeGraph &graph) { return SUCCESS; } + + // optimize fused graph + virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; + + // optimize whole graph, using after graph merged stage + virtual Status OptimizeWholeGraph(ComputeGraph &graph) = 0; + + // get attribute of graph optimizer + virtual Status GetAttributes(GraphOptimizerAttribute &attrs) const = 0; + + // optimize streamed Graph + virtual Status OptimizeStreamGraph(ComputeGraph &graph, const RunContext &context) { return SUCCESS; } + + // op compile + virtual Status OptimizeFusedGraphAfterGraphSlice(ComputeGraph &graph) { return SUCCESS; } +}; +} // namespace ge +/*lint +e148*/ +#endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ diff --git a/inc/common/optimizer/graph_optimizer_types.h b/inc/common/optimizer/graph_optimizer_types.h new file mode 100644 index 00000000..9e1ec96b --- /dev/null +++ b/inc/common/optimizer/graph_optimizer_types.h @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ +#define INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ + +#include +#include +namespace ge { +enum OPTIMIZER_SCOPE { + UNIT = 0, + ENGINE, +}; + +struct GraphOptimizerAttribute { + std::string engineName; + OPTIMIZER_SCOPE scope; +}; +} // namespace ge + +#endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ diff --git a/inc/common/util/ai_core/common/aicore_util_attr_define.h b/inc/common/util/ai_core/common/aicore_util_attr_define.h new file mode 100644 index 00000000..ba28d7b3 --- /dev/null +++ b/inc/common/util/ai_core/common/aicore_util_attr_define.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_ATTR_DEFINE_H_ +#define INC_COMMON_UTILS_AI_CORE_COMMON_ATTR_DEFINE_H_ + +#include + +namespace fe { +static const std::string SCOPE_ID_ATTR = "fusion_scope"; + +static const std::string FE_IMPLY_TYPE = "_fe_imply_type"; + +static const std::string PARENT_OP_TYPE = "parentOpType"; + +static const std::string ATTR_NAME_TASK_L2_FUSION_INFO_EXTEND_PTR = "task_l2_fusion_info_extend_content"; + +static const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; + +static const std::string ATTR_NAME_L2_FUSION_EXTEND_PTR = "l2_fusion_extend_content"; + +static const std::string L1_OPTIMIZED = "l1_optimized"; + +static const std::string L2_OPTIMIZED = "l2_optimized"; + +static const std::string OP_SLICE_INFO = "_op_slice_info"; +} // namespace fe +#endif diff --git a/inc/common/util/ai_core/common/aicore_util_types.h b/inc/common/util/ai_core/common/aicore_util_types.h new file mode 100644 index 00000000..b2615dc9 --- /dev/null +++ b/inc/common/util/ai_core/common/aicore_util_types.h @@ -0,0 +1,118 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_TYPES_H_ +#define INC_COMMON_UTILS_AI_CORE_COMMON_TYPES_H_ + +#include "graph/anchor.h" +#include "graph/types.h" +#include "runtime/kernel.h" +#include +#include +#include + +namespace fe { +struct FusionOpSrc { + uint32_t src_op_id; + ge::AnchorPtr src_anchor; + int32_t fusion_src_index; + int32_t fusion_dst_index; +}; + +struct FusionOpDst { + uint32_t dst_op_id; + ge::AnchorPtr dst_anchor; +}; + +struct FusionDataFlow { + std::pair edge; + std::pair node_dataindex_pair; +}; + +typedef struct tagL2FusionData { + uint32_t l2Index; + uint64_t l2Addr; + uint64_t l2PageNum; +} L2FusionData_t; +typedef std::map L2FusionDataMap_t; + +typedef struct tagFeSmDesc { + rtL2Ctrl_t l2ctrl; + std::string nodeName[8]; + uint8_t outputIndex[8]; +} feSmDesc_t; + +typedef struct TagTaskL2FusionInfo { + std::string nodeName; + feSmDesc_t l2Info; + L2FusionDataMap_t input; + L2FusionDataMap_t output; + uint32_t isUsed; +} TaskL2FusionInfo_t; + +using L2FusionInfoPtr = std::shared_ptr; + +typedef struct ToOpStruct { + int64_t opL1Space = 0; + std::vector opL1FusionType; + int64_t opL1WorkspaceFlag = 0; // for workspace flag + int64_t opL1WorkspaceSize = 0; + std::vector> validInputShape; + std::vector> validOutputShape; + std::vector> sliceInputOffset; // conv & pooling & ReadSelect + std::vector> sliceOutputOffset; // WriteSelect + std::vector totalShape; + uint32_t splitIndex = 0; + ToOpStruct() { + // set invalid value for essential variable + opL1Space = -1; + opL1WorkspaceSize = -1; + } +} ToOpStruct_t; + +enum OpImplType { + EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op + EN_IMPL_CUSTOM_TIK, // custom tik op + EN_IMPL_CUSTOM_TBE, // custom tbe op + EN_IMPL_HW_CONSTANT_CCE, // Huawei built-in constant op + EN_IMPL_HW_GENERAL_CCE, // Huawei built-in cce op + EN_IMPL_HW_TIK, // Huawei built-in tik op + EN_IMPL_HW_TBE, // Huawei built-in tbe op + EN_IMPL_RL, // RL op + EN_IMPL_PLUGIN_TBE, // Huawei built-in tbe plugin op + EN_IMPL_VECTOR_CORE_HW_TBE, // Huawei built-in tbe op + EN_IMPL_VECTOR_CORE_CUSTOM_TBE, // custom tbe op + EN_IMPL_NON_PERSISTENT_CUSTOM_TBE, // custom tbe op + EN_RESERVED // reserved value +}; + +static const std::map DATATYPE_SIZE_MAP{{ge::DT_FLOAT, sizeof(float)}, + {ge::DT_FLOAT16, sizeof(int16_t)}, + {ge::DT_INT8, sizeof(int8_t)}, + {ge::DT_INT32, sizeof(int32_t)}, + {ge::DT_UINT8, sizeof(uint8_t)}, + {ge::DT_UINT32, sizeof(uint32_t)}, + {ge::DT_INT16, sizeof(int16_t)}, + {ge::DT_UINT16, sizeof(uint16_t)}, + {ge::DT_INT64, sizeof(int64_t)}, + {ge::DT_UINT64, sizeof(uint64_t)}, + {ge::DT_DOUBLE, sizeof(double)}, + {ge::DT_BOOL, sizeof(bool)}, + {ge::DT_DUAL, sizeof(float) + sizeof(int8_t)}, + {ge::DT_DUAL_SUB_UINT8, sizeof(int8_t)}, + {ge::DT_DUAL_SUB_INT8, sizeof(int8_t)}}; +} // namespace fe +#endif diff --git a/inc/common/util/ai_core/common/graph_comm.h b/inc/common/util/ai_core/common/graph_comm.h new file mode 100644 index 00000000..d672e056 --- /dev/null +++ b/inc/common/util/ai_core/common/graph_comm.h @@ -0,0 +1,107 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_GRAPH_COMMON_H_ +#define INC_COMMON_UTILS_AI_CORE_COMMON_GRAPH_COMMON_H_ + +#include "graph/compute_graph.h" +#include "common/aicore_util_types.h" +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" + +#include +#include +#include +#include + +namespace fe { + +using kScopeNodeMap_t = std::map>; +using kScopeNodePair_t = std::pair>; + +class GraphCommImpl; +using GraphCommImplPtr = std::unique_ptr; + +class GraphComm { + public: + GraphComm(const string &engineName); + virtual ~GraphComm(); + GraphComm(const GraphComm &in) = delete; + GraphComm &operator=(const GraphComm &in) = delete; + + Status GetscopeNodeMap(ge::ComputeGraph &graph, kScopeNodeMap_t &fusionMap); + + Status CopyFusionOpNodes(vector &fusInputEdgeList, vector &fusOutputEdgeList, + vector &fusNodelist, ge::OpDescPtr fusionOpDesc, + ge::ComputeGraphPtr fusionGraph); + + Status CopyFusionOpEdges(ge::OpDescPtr fusionOpDesc, ge::ComputeGraph &origGraph, ge::ComputeGraphPtr fusionGraph); + + Status GetNodeDataFlowMap(const ge::NodePtr &fusNode, + std::map> &fusionOpAnchorsMap, + ge::kFusionDataFlowVec_t &fusDataflowList, const int &mapType); + + Status GetFusionNodeEdgeList(std::vector &fusNodelist, std::vector &fusInputEdgeList, + std::vector &fusOutputEdgeList); + void ClearFusionSrc(); + + void ClearFusionDst(); + + void AddFusionOutputSrc(const uint32_t &src_op_id, const ge::AnchorPtr &src_anchor, const int32_t &fusion_src_index, + std::pair &node_dataindex_pair); + + void AddFusionInputSrc(const uint32_t &src_op_id, const ge::AnchorPtr &src_anchor, const int32_t &fusion_dst_index, + std::pair &node_dataindex_pair); + + void SaveFusionDst(const uint32_t &dst_op_id, ge::AnchorPtr dst_anchor); + + bool IsFusionDstExist(const uint32_t &dst_op_id, const ge::AnchorPtr &dst_anchor); + + bool GetFusionSrc(const uint32_t &src_op_id, const ge::AnchorPtr &src_anchor, int32_t &fusion_src_index, + int32_t &fusion_dst_index); + + Status GetFusionNodeCtrlEdgeList(vector &fusNodelist, vector &fusInputCtrlEdgeList, + vector &fusOutputCtrlEdgeList); + + Status MergeFusionNodeEdgeList(ge::NodePtr &fusNode, vector &fusNodelist, + vector &fusInputEdgeList, vector &fusOutputEdgeList); + + Status MergeFusionNodeCtrlEdgeList(ge::NodePtr &fusNode, vector &fusNodelist, + vector &fusInputEdgeList, + vector &fusOutputEdgeList); + + string GetEngineName(); + + private: + Status MergeFusionNodeInputEdgeList(ge::NodePtr fusNode, std::vector &fusNodelist, + std::vector &fusInputEdgeList); + Status MergeFusionNodeOutputEdgeList(ge::NodePtr fusNode, std::vector &fusNodelist, + std::vector &fusOutputEdgeList); + + string engineName_; + + std::vector exist_fusion_src_list_; + std::vector exist_fusion_dst_list_; + + // std::vector> + ge::kFusionDataFlowVec_t fusion_input_dataflow_list_; + + // std::vector> + ge::kFusionDataFlowVec_t fusion_output_dataflow_list_; + + GraphCommImplPtr graphCommImplPtr_; +}; +} // namespace fe +#endif diff --git a/inc/common/util/ai_core/common/scope_allocator.h b/inc/common/util/ai_core/common/scope_allocator.h new file mode 100644 index 00000000..6cebb286 --- /dev/null +++ b/inc/common/util/ai_core/common/scope_allocator.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_SCOPE_ALLOCATOR_H_ +#define INC_COMMON_UTILS_AI_CORE_COMMON_SCOPE_ALLOCATOR_H_ + +#include "graph/op_desc.h" + +namespace fe { +class ScopeAllocator { + public: + ScopeAllocator(); + virtual ~ScopeAllocator(); + ScopeAllocator(const ScopeAllocator& in) = delete; + ScopeAllocator& operator=(const ScopeAllocator& in) = delete; + + public: + void Init(); + int64_t GetCurrentScopeId(); + int64_t AllocateScopeId(void); + bool HasScopeAttr(ge::ConstOpDescPtr opdef); + bool GetScopeAttr(ge::ConstOpDescPtr opdef, int64_t& scopeId); + bool SetScopeAttr(ge::OpDescPtr opdef, int64_t scopeId); + bool ResetScopeId(int64_t scopeId); + + private: + int64_t scopeId; +}; +} // namespace fe +#endif diff --git a/inc/common/util/ai_core/param_calculate/aicore_param_calculator.h b/inc/common/util/ai_core/param_calculate/aicore_param_calculator.h new file mode 100644 index 00000000..c0c378fd --- /dev/null +++ b/inc/common/util/ai_core/param_calculate/aicore_param_calculator.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICORE_PARAM_CALCULATOR +#define AICORE_PARAM_CALCULATOR + +#include "graph/node.h" +#include "graph_optimizer/graph_optimize_register_error_codes.h" + +namespace fe { +class AICoreParamCalculator { + public: + AICoreParamCalculator(); + + ~AICoreParamCalculator(); + + Status CalcOpRunningParam(ge::Node &node); +}; +} // namespace fe +#endif // AICORE_PARAM_CALCULATOR diff --git a/inc/common/util/ai_core/param_calculate/tensorsize_calculator.h b/inc/common/util/ai_core/param_calculate/tensorsize_calculator.h new file mode 100644 index 00000000..c82cca4b --- /dev/null +++ b/inc/common/util/ai_core/param_calculate/tensorsize_calculator.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TENSORSIZE_CALCULATOR_H +#define TENSORSIZE_CALCULATOR_H + +#include "graph_optimizer/graph_optimize_register_error_codes.h" + +#include +#include +#include "graph/compute_graph.h" +#include "graph/op_desc.h" + +namespace fe { +class TensorSizeCalculator { + public: + /** + * Calculate the tensor size of input and output of each opdesc + * @param opDesc opdesc object + * @param opImplType op impl type + * @return status SUCCESS or FAILED + */ + static Status CalculateOpTensorSize(ge::OpDesc &opDesc); + + private: + static Status CalcInputOpTensorSize(ge::OpDesc &opDesc, int32_t &outputRealCalcFlag); + + static Status CalcOutputOpTensorSize(ge::OpDesc &opDesc, int32_t &outputRealCalcFlag); +}; +} // namespace fe + +#endif // TENSORSIZE_CALCULATOR_H diff --git a/inc/common/util/compress/compress.h b/inc/common/util/compress/compress.h new file mode 100644 index 00000000..e350f9e5 --- /dev/null +++ b/inc/common/util/compress/compress.h @@ -0,0 +1,37 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPRESS_H +#define COMPRESS_H + +#include + +enum CmpStatus { RET_SUCCESS = 0, RET_ERROR = -1 }; + +struct CompressConfig { + size_t inputSize; // length of data to compress + size_t engineNum; // how many decompress engines + size_t maxRatio; // how much size of a basic compression block, only 64 supported now (8x: 64 4x: 32) + size_t channel; // channels of L2 or DDR. For load balance + size_t fractalSize; // size of compressing block + bool isTight; // whether compose compressed data tightly + size_t init_offset; +}; + +CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, + size_t& compressedLength); + +#endif // COMPRESS_H diff --git a/inc/common/util/compress/compress_weight.h b/inc/common/util/compress/compress_weight.h new file mode 100644 index 00000000..34ea47d1 --- /dev/null +++ b/inc/common/util/compress/compress_weight.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPRESS_WEIGHT_H +#define COMPRESS_WEIGHT_H + +#include "compress.h" + +const int SHAPE_SIZE_WEIGHT = 4; + +struct CompressOpConfig { + int64_t wShape[SHAPE_SIZE_WEIGHT]; + size_t compressTilingK; + size_t compressTilingN; + struct CompressConfig compressConfig; +}; + +extern "C" CmpStatus CompressWeightsConv2D(const char *const input, char *const zipBuffer, char *const infoBuffer, + CompressOpConfig *const param); +#endif // COMPRESS_WEIGHT_H diff --git a/inc/common/util/error_manager/error_manager.h b/inc/common/util/error_manager/error_manager.h new file mode 100644 index 00000000..438e68a7 --- /dev/null +++ b/inc/common/util/error_manager/error_manager.h @@ -0,0 +1,94 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ERROR_MANAGER_H_ +#define ERROR_MANAGER_H_ + +#include +#include +#include + +class ErrorManager { + public: + /// + /// @brief Obtain ErrorManager instance + /// @return ErrorManager instance + /// + static ErrorManager &GetInstance(); + + /// + /// @brief init + /// @param [in] path: current so path + /// @return int 0(success) -1(fail) + /// + int Init(std::string path); + + /// + /// @brief Report error message + /// @param [in] error_code: error code + /// @param [in] args_map: parameter map + /// @return int 0(success) -1(fail) + /// + int ReportErrMessage(std::string error_code, const std::map &args_map); + + /// + /// @brief output error message + /// @param [in] handle: print handle + /// @return int 0(success) -1(fail) + /// + int OutputErrMessage(int handle); + + /// + /// @brief output message + /// @param [in] handle: print handle + /// @return int 0(success) -1(fail) + /// + int OutputMessage(int handle); + + /// + /// @brief Report error message + /// @param [in] key: vector parameter key + /// @param [in] value: vector parameter value + /// + void ATCReportErrMessage(std::string error_code, const std::vector &key = {}, + const std::vector &value = {}); + + private: + struct ErrorInfo { + std::string error_id; + std::string error_message; + std::vector arg_list; + }; + + ErrorManager() {} + ~ErrorManager() {} + + ErrorManager(const ErrorManager &) = delete; + ErrorManager(ErrorManager &&) = delete; + ErrorManager &operator=(const ErrorManager &) = delete; + ErrorManager &operator=(ErrorManager &&) = delete; + + int ParseJsonFile(std::string path); + + int ReadJsonFile(const std::string &file_path, void *handle); + + bool is_init_ = false; + std::map error_map_; + std::vector error_messages_; + std::vector warning_messages_; +}; + +#endif // ERROR_MANAGER_H_ diff --git a/inc/common/util/platform_info.h b/inc/common/util/platform_info.h new file mode 100644 index 00000000..8d2a0579 --- /dev/null +++ b/inc/common/util/platform_info.h @@ -0,0 +1,101 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PLATFORM_INFO_H +#define PLATFORM_INFO_H + +#include +#include +#include +#include "platform_info_def.h" + +using std::map; +using std::string; +using std::vector; + +namespace fe { +class PlatformInfoManager { + public: + PlatformInfoManager(const PlatformInfoManager &) = delete; + PlatformInfoManager &operator=(const PlatformInfoManager &) = delete; + + static PlatformInfoManager &Instance(); + uint32_t InitializePlatformInfo(); + uint32_t Finalize(); + + uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); + + uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); + + void SetOptionalCompilationInfo(OptionalInfo &optiCompilationInfo); + + private: + PlatformInfoManager(); + ~PlatformInfoManager(); + + uint32_t LoadIniFile(string iniFileRealPath); + + void Trim(string &str); + + uint32_t LoadConfigFile(string realPath); + + string RealPath(const std::string &path); + + string GetSoFilePath(); + + void ParseVersion(map &versionMap, string &socVersion, PlatformInfo &platformInfoTemp); + + void ParseSocInfo(map &socInfoMap, PlatformInfo &platformInfoTemp); + + void ParseCubeOfAICoreSpec(map &aiCoreSpecMap, PlatformInfo &platformInfoTemp); + + void ParseBufferOfAICoreSpec(map &aiCoreSpecMap, PlatformInfo &platformInfoTemp); + + void ParseUBOfAICoreSpec(map &aiCoreSpecMap, PlatformInfo &platformInfoTemp); + + void ParseUnzipOfAICoreSpec(map &aiCoreSpecMap, PlatformInfo &platformInfoTemp); + + void ParseAICoreSpec(map &aiCoreSpecMap, PlatformInfo &platformInfoTemp); + + void ParseBufferOfAICoreMemoryRates(map &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); + + void ParseAICoreMemoryRates(map &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); + + void ParseUBOfAICoreMemoryRates(map &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); + + void ParseAICoreintrinsicDtypeMap(map &aiCoreintrinsicDtypeMap, PlatformInfo &platformInfoTemp); + + void ParseVectorCoreSpec(map &vectorCoreSpecMap, PlatformInfo &platformInfoTemp); + + void ParseVectorCoreMemoryRates(map &vectorCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); + + void ParseCPUCache(map &CPUCacheMap, PlatformInfo &platformInfoTemp); + + void ParseVectorCoreintrinsicDtypeMap(map &vectorCoreintrinsicDtypeMap, + PlatformInfo &platformInfoTemp); + + uint32_t ParsePlatformInfoFromStrToStruct(map> &contentInfoMap, string &socVersion, + PlatformInfo &platformInfoTemp); + + uint32_t AssemblePlatformInfoVector(map> &contentInfoMap); + + private: + bool initFlag_; + map platformInfoMap_; + OptionalInfo optiCompilationInfo_; +}; +} // namespace fe +#endif diff --git a/inc/common/util/platform_info_def.h b/inc/common/util/platform_info_def.h new file mode 100644 index 00000000..c660e8f1 --- /dev/null +++ b/inc/common/util/platform_info_def.h @@ -0,0 +1,140 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PLATFORM_INFO_DEF_H +#define PLATFORM_INFO_DEF_H + +#include +#include +#include + +using std::map; +using std::string; +using std::vector; + +namespace fe { +enum MemoryType { DDR = 0, HBM }; + +enum L2Type { Cache = 0, Buff }; + +typedef struct tagStrInfo { + string aicVersion; + string ccecAICVersion; + string ccecAIVVersion; + string isSupportAIcpuCompiler; +} StrInfo; + +typedef struct tagSoCInfo { + uint32_t aiCoreCnt; + uint32_t vectorCoreCnt; + uint32_t aiCpuCnt; + MemoryType memoryType; + uint64_t memorySize; + L2Type l2Type; + uint64_t l2Size; + uint32_t l2PageNum; +} SoCInfo; + +typedef struct tagAiCoreSpec { + double cubeFreq; + uint64_t cubeMSize; + uint64_t cubeNSize; + uint64_t cubeKSize; + uint64_t vecCalcSize; + uint64_t l0ASize; + uint64_t l0BSize; + uint64_t l0CSize; + uint64_t l1Size; + uint64_t smaskBuffer; + uint64_t ubSize; + uint64_t ubblockSize; + uint64_t ubbankSize; + uint64_t ubbankNum; + uint64_t ubburstInOneBlock; + uint64_t ubbankGroupNum; + uint32_t unzipEngines; + uint32_t unzipMaxRatios; + uint32_t unzipChannels; + uint8_t unzipIsTight; +} AiCoreSpec; + +typedef struct tagAiCoreMemoryRates { + double ddrRate; + double ddrReadRate; + double ddrWriteRate; + double l2Rate; + double l2ReadRate; + double l2WriteRate; + double l1ToL0ARate; + double l1ToL0BRate; + double l1ToUBRate; + double l0CToUBRate; + double ubToL2Rate; + double ubToDdrRate; + double ubToL1Rate; +} AiCoreMemoryRates; + +typedef struct tagVectorCoreSpec { + double vecFreq; + uint64_t vecCalcSize; + uint64_t smaskBuffer; + uint64_t ubSize; + uint64_t ubblockSize; + uint64_t ubbankSize; + uint64_t ubbankNum; + uint64_t ubburstInOneBlock; + uint64_t ubbankGroupNum; + uint64_t vectorRegSize; + uint64_t predicateRegSize; + uint64_t addressRegSize; +} VectorCoreSpec; + +typedef struct tagVectorCoreMemoryRates { + double ddrRate; + double ddrReadRate; + double ddrWriteRate; + double l2Rate; + double l2ReadRate; + double l2WriteRate; + double ubToL2Rate; + double ubToDdrRate; +} VectorCoreMemoryRates; + +typedef struct tagCPUCache { + uint32_t AICPUSyncBySW; + uint32_t TSCPUSyncBySW; +} CPUCache; + +typedef struct tagPlatformInfo { + StrInfo strInfo; + SoCInfo socInfo; + AiCoreSpec aiCoreSpec; + AiCoreMemoryRates aiCoreMemoryRates; + map> aiCoreIntrinsicDtypeMap; + VectorCoreSpec vectorCoreSpec; + VectorCoreMemoryRates vectorCoreMemoryRates; + CPUCache cpucache; + map> vectorCoreIntrinsicDtypeMap; +} PlatformInfo; + +typedef struct tagOptionalInfo { + string socVersion; + string coreType; + uint32_t aiCoreNum; + string l1FusionFlag; +} OptionalInfo; +} // namespace fe +#endif diff --git a/inc/external/graph/attr_value.h b/inc/external/graph/attr_value.h new file mode 100644 index 00000000..af430f9b --- /dev/null +++ b/inc/external/graph/attr_value.h @@ -0,0 +1,75 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ +#define INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ + +#include +#include +#include +#include + +#include "./ge_error_codes.h" + +using std::make_shared; +using std::map; +using std::pair; +using std::string; +using std::to_string; +using std::unique_ptr; +using std::vector; + +namespace ge { +class AttrValueImpl; +/*lint -e148*/ +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue { + public: + using INT = int64_t; + using FLOAT = float; + using STR = std::string; + + AttrValue(); + ~AttrValue() = default; + + // GetValue, not list type + template + graphStatus GetValue(DT &val) const { + T valGet; + auto status = GetValue(valGet); + if (status != GRAPH_SUCCESS) { + return status; + } + val = DT(valGet); + return GRAPH_SUCCESS; + } + + template + static T CreateFrom(DT &&val) { + return val; + } + + std::shared_ptr impl; + + private: +#define VALUE_SET_GET_DEC(DT) graphStatus GetValue(DT &val) const; + VALUE_SET_GET_DEC(AttrValue::STR) + VALUE_SET_GET_DEC(AttrValue::INT) + VALUE_SET_GET_DEC(AttrValue::FLOAT) +#undef VALUE_SET_GET_DEC +}; +/*lint +e148*/ +} // namespace ge +#endif // INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ diff --git a/inc/external/graph/ge_error_codes.h b/inc/external/graph/ge_error_codes.h new file mode 100644 index 00000000..d815a22d --- /dev/null +++ b/inc/external/graph/ge_error_codes.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ +#define INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ + +namespace ge { +#ifdef HOST_VISIBILITY +#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_HOST_VISIBILITY +#endif +#ifdef DEV_VISIBILITY +#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_DEV_VISIBILITY +#endif + +using graphStatus = uint32_t; +const graphStatus GRAPH_FAILED = 0xFFFFFFFF; +const graphStatus GRAPH_SUCCESS = 0; +const graphStatus GRAPH_PARAM_INVALID = 50331649; +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ diff --git a/inc/external/graph/graph.h b/inc/external/graph/graph.h new file mode 100644 index 00000000..30886733 --- /dev/null +++ b/inc/external/graph/graph.h @@ -0,0 +1,81 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_GRAPH_H_ +#define INC_EXTERNAL_GRAPH_GRAPH_H_ + +#include +#include +#include +#include + +#include "./operator.h" + +namespace ge { +class GraphImpl; + +using GraphImplPtr = std::shared_ptr; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { + friend class GraphUtils; + + public: + explicit Graph(const std::string &name); + + Graph() = default; + + ~Graph() = default; + + Graph &SetInputs(const std::vector &inputs); + + Graph &SetOutputs(const std::vector &outputs); + + Graph &SetOutputs(const std::vector>> &output_indexs); + + Graph &SetOutputs(const std::vector> &outputs); + + Graph &SetTargets(const std::vector &targets); + + bool IsValid() const; + + graphStatus AddOp(const ge::Operator &op); + + graphStatus FindOpByName(const string &name, ge::Operator &op) const; + + graphStatus FindOpByType(const string &type, std::vector &ops) const; + + graphStatus GetAllOpName(std::vector &op_name) const; + + graphStatus SaveToFile(const string &file_name) const; + + graphStatus LoadFromFile(const string &file_name); + + const std::string &GetName() const; + + /// + /// Set is need train iteration. + /// If set true, it means this graph need to be run iteration some + /// times(according variant "npu_runconfig/iterations_per_loop"). + /// @param need_iteration need_iteration:whether to set iteration or not + /// + void SetNeedIteration(bool need_iteration); + + private: + GraphImplPtr impl_{nullptr}; +}; +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_GRAPH_H_ diff --git a/inc/external/graph/inference_context.h b/inc/external/graph/inference_context.h new file mode 100644 index 00000000..69079142 --- /dev/null +++ b/inc/external/graph/inference_context.h @@ -0,0 +1,76 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ +#define INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ + +#include +#include +#include + +#include "./tensor.h" +#include "./types.h" + +namespace ge { +class InferenceContext; +using InferenceContextPtr = std::shared_ptr; + +class ShapeAndTypeImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ShapeAndType { + public: + ShapeAndType(); + ~ShapeAndType() = default; + + ShapeAndType(const Shape &shape, DataType dataType); + + void SetShape(const Shape &shape); + + void SetType(DataType dataType); + + Shape GetShape() const; + + DataType GetDataType() const; + + private: + std::shared_ptr shape_and_type_impl_; +}; + +class InferenceContextImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { + public: + ~InferenceContext() = default; + InferenceContext(const InferenceContext &context) = delete; + InferenceContext(const InferenceContext &&context) = delete; + InferenceContext &operator=(const InferenceContext &context) = delete; + InferenceContext &operator=(const InferenceContext &&context) = delete; + + void SetInputHandleShapesAndTypes(std::vector> &&shapes_and_types); + const std::vector> &GetInputHandleShapesAndTypes() const; + const std::vector> &GetOutputHandleShapesAndTypes() const; + void SetOutputHandleShapesAndTypes(const std::vector> &shapes_and_types); + void SetOutputHandleShapesAndTypes(std::vector> &&shapes_and_types); + + void SetMarks(const std::vector &marks); + const std::vector &GetMarks() const; + + static std::unique_ptr Create(); + + private: + explicit InferenceContext(std::unique_ptr &impl); + std::shared_ptr inference_context_impl_; +}; +} // namespace ge +#endif // INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ diff --git a/inc/external/graph/operator.h b/inc/external/graph/operator.h new file mode 100644 index 00000000..81d726eb --- /dev/null +++ b/inc/external/graph/operator.h @@ -0,0 +1,289 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_OPERATOR_H_ +#define INC_EXTERNAL_GRAPH_OPERATOR_H_ + +#include +#include +#include +#include +#include + +#include "./ge_error_codes.h" +#include "./inference_context.h" +#include "./tensor.h" + +#ifndef USER_GE_LOGI +#define USER_GE_LOGI(...) +#endif // USER_GE_LOGI + +#ifndef USER_GE_LOGW +#define USER_GE_LOGW(...) +#endif // USER_GE_LOGW + +#ifndef USER_GE_LOGE +#define USER_GE_LOGE(...) +#endif // USER_GE_LOGE + +#define DYNAMIC_OUTPUT_TD_NUM(name) ("__dynamic_output_" + name + "_cnt") +#define DYNAMIC_INPUT_TD_NUM(name) ("__dynamic_input_" + name + "_cnt") + +namespace ge { +class Operator; +class OperatorImpl; +class NodeUtils; +class NamedAttrs; +class Graph; +class AttrValue; +class Node; + +using SubgraphBuilder = std::function; +using OperatorImplPtr = std::shared_ptr; +using OperatorPtr = std::shared_ptr; + +class OpIO; +using OutHandler = std::shared_ptr; +using InHandler = std::shared_ptr; + +using std::function; +using std::shared_ptr; +using std::string; + +/*lint -e148*/ +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { + public: + friend class OperatorImpl; + friend class GraphBuilderImpl; + friend class NodeUtils; + + using OpInt = int64_t; + using OpFloat = float; + using OpString = string; + using OpBool = bool; + using OpTensor = Tensor; + using OpType = ge::DataType; + using OpNamedAttrs = ge::NamedAttrs; + using OpListInt = std::vector; + using OpListFloat = std::vector; + using OpListString = std::vector; + using OpListBool = std::vector; + using OpListTensor = std::vector; + using OpBytes = std::vector; + using OpListListInt = std::vector>; + using OpListType = std::vector; + using OpListNamedAttrs = std::vector; + + Operator() {} + + explicit Operator(const string &type); + + Operator(const string &name, const string &type); // lint !e148 + + virtual ~Operator() = default; + + bool IsEmpty() const; + + string GetName() const; + + string GetOpType() const; + + // Only has one output index = 0 + Operator &SetInput(const string &dst_name, const Operator &src_oprt); + + Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); // lint !e148 + + Operator &SetInput(const string &dst_name, const Operator &src_oprt, uint32_t index); + + Operator &AddControlInput(const Operator &src_oprt); + + graphStatus GetInputConstData(const string &dst_name, Tensor &data) const; + + TensorDesc GetInputDesc(const string &name) const; + + TensorDesc GetInputDesc(uint32_t index) const; + + int GetDynamicOutputNum(const string &name) const; + + int GetDynamicInputNum(const string &name) const; + + graphStatus TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const; + + graphStatus UpdateInputDesc(const string &name, const TensorDesc &tensor_desc); + + TensorDesc GetOutputDesc(const string &name) const; + + TensorDesc GetOutputDesc(uint32_t index) const; + + graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); // lint !e148 + + TensorDesc GetDynamicInputDesc(const string &name, uint32_t index) const; + + graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 + + TensorDesc GetDynamicOutputDesc(const string &name, uint32_t index) const; + + graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 + + graphStatus InferShapeAndType(); // lint !e148 + + void SetInferenceContext(const InferenceContextPtr &inference_context); + InferenceContextPtr GetInferenceContext() const; + + graphStatus VerifyAllAttr(bool disable_common_verifier = false); // lint !e148 + + size_t GetInputsSize() const; + + size_t GetOutputsSize() const; + + const std::map GetAllAttrNamesAndTypes() const; + + Operator &SetAttr(const string &name, int64_t attr_value); + Operator &SetAttr(const string &name, int32_t attr_value); + Operator &SetAttr(const string &name, uint32_t attr_value); + graphStatus GetAttr(const string &name, int64_t &attr_value) const; + graphStatus GetAttr(const string &name, int32_t &attr_value) const; + graphStatus GetAttr(const string &name, uint32_t &attr_value) const; + Operator &SetAttr(const string &name, const std::vector &attr_value); + Operator &SetAttr(const string &name, const std::vector &attr_value); + Operator &SetAttr(const string &name, const std::vector &attr_value); + Operator &SetAttr(const string &name, std::initializer_list &&attr_value); + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + + Operator &SetAttr(const string &name, float attr_value); + graphStatus GetAttr(const string &name, float &attr_value) const; + Operator &SetAttr(const string &name, const std::vector &attr_value); + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + Operator &SetAttr(const string &name, AttrValue &&attr_value); + graphStatus GetAttr(const string &name, AttrValue &attr_value) const; + + Operator &SetAttr(const string &name, const string &attr_value); + graphStatus GetAttr(const string &name, string &attr_value) const; + Operator &SetAttr(const string &name, const std::vector &attr_value); + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + + Operator &SetAttr(const string &name, bool attr_value); + graphStatus GetAttr(const string &name, bool &attr_value) const; + Operator &SetAttr(const string &name, const std::vector &attr_value); + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + + Operator &SetAttr(const string &name, const Tensor &attr_value); + graphStatus GetAttr(const string &name, Tensor &attr_value) const; + Operator &SetAttr(const string &name, const std::vector &attr_value); + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + + // Bytes type + Operator &SetAttr(const string &name, const OpBytes &attr_value); + // Bytes type + graphStatus GetAttr(const string &name, OpBytes &attr_value) const; + + Operator &SetAttr(const string &name, const std::vector> &attr_value); + graphStatus GetAttr(const string &name, std::vector> &attr_value) const; + + Operator &SetAttr(const string &name, const std::vector &attr_value); + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + + Operator &SetAttr(const string &name, const ge::DataType &attr_value); + graphStatus GetAttr(const string &name, ge::DataType &attr_value) const; + + // func type + Operator &SetAttr(const string &name, const ge::NamedAttrs &attr_value); + graphStatus GetAttr(const string &name, ge::NamedAttrs &attr_value) const; + Operator &SetAttr(const string &name, const std::vector &attr_value); + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + + void BreakConnect() const; + + size_t GetSubgraphNamesCount() const; + std::vector GetSubgraphNames() const; + SubgraphBuilder GetSubgraphBuilder(const string &name) const; + Graph GetSubgraph(const string &name) const; + SubgraphBuilder GetDynamicSubgraphBuilder(const string &name, uint32_t index) const; + Graph GetDynamicSubgraph(const string &name, uint32_t index) const; + + protected: + void AttrRegister(const string &name, float attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + void AttrRegister(const string &name, int64_t attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + void AttrRegister(const string &name, const string &attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + void AttrRegister(const string &name, bool attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + void AttrRegister(const string &name, const Tensor &attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + void AttrRegister(const string &name, const OpBytes &attr_value); + void AttrRegister(const string &name, const std::vector> &attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + void AttrRegister(const string &name, const ge::DataType &attr_value); + void AttrRegister(const string &name, const ge::NamedAttrs &attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + + explicit Operator(OperatorImplPtr &&op_impl); + + void InputRegister(const string &name); + + void OptionalInputRegister(const string &name); + + void InferFuncRegister(const std::function &func); + + void VerifierFuncRegister(const std::function &func); + + void InferFormatFuncRegister(const std::function &func); + + void OutputRegister(const string &name); + + void DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back = true); + + void DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index); + + void DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back = true); + + void RequiredAttrRegister(const string &name); + + graphStatus VerifyAll(); // lint !e148 + + // Only has one output index = 0 + Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt); + + Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, + const string &name); // lint !e148 + + void SubgraphRegister(const string &ir_name, bool dynamic); + void SubgraphCountRegister(const string &ir_name, uint32_t count); + void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder); + + private: + Operator &SetInput(const string &dst_name, const OutHandler &out_handler); // lint !e148 + + OutHandler GetOutput(const string &name) const; + + OutHandler GetOutput(uint32_t index) const; + + OperatorImplPtr GetOperatorImplPtr() const; + + OperatorImplPtr operator_impl_{nullptr}; + + graphStatus GetInputConstDataOut(const string &dst_name, Tensor &data) const; + + std::shared_ptr GetNode() const; +}; +/*lint +e148*/ +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_OPERATOR_H_ diff --git a/inc/external/graph/operator_factory.h b/inc/external/graph/operator_factory.h new file mode 100644 index 00000000..f9ec7669 --- /dev/null +++ b/inc/external/graph/operator_factory.h @@ -0,0 +1,68 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ +#define INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ + +#include +#include +#include +#include + +#include "./operator.h" +#include "./ge_error_codes.h" + +namespace ge { +using OpCreator = std::function; +using InferShapeFunc = std::function; +using InferFormatFunc = std::function; +using VerifyFunc = std::function; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactory { + public: + static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type); + + static graphStatus GetOpsTypeList(std::vector &all_ops); + + static bool IsExistOp(const string &operator_type); +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorCreatorRegister { + public: + OperatorCreatorRegister(const string &operator_type, OpCreator const &op_creator); + ~OperatorCreatorRegister() = default; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferShapeFuncRegister { + public: + InferShapeFuncRegister(const std::string &operator_type, const InferShapeFunc &infer_shape_func); + ~InferShapeFuncRegister() = default; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferFormatFuncRegister { + public: + InferFormatFuncRegister(const std::string &operator_type, const InferFormatFunc &infer_format_func); + ~InferFormatFuncRegister() = default; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY VerifyFuncRegister { + public: + VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func); + ~VerifyFuncRegister() = default; +}; +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ diff --git a/inc/external/graph/operator_reg.h b/inc/external/graph/operator_reg.h new file mode 100644 index 00000000..759c70f2 --- /dev/null +++ b/inc/external/graph/operator_reg.h @@ -0,0 +1,376 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ +#define INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ + +#include +#include +#include +#include + +#include "graph/operator.h" +#include "graph/operator_factory.h" +#include "graph/tensor.h" +#include "graph/types.h" +#include "graph/graph.h" + +namespace ge { +using std::function; +using std::string; +using std::vector; + +class OpReg { + public: + OpReg &N() { return *this; } + + OpReg &ATTR() { return *this; } + + OpReg &REQUIRED_ATTR() { return *this; } + + OpReg &INPUT() { return *this; } + + OpReg &OPTIONAL_INPUT() { return *this; } + + OpReg &OUTPUT() { return *this; } + + OpReg &GRAPH() { return *this; } + + OpReg &DYNAMIC_GRAPH() { return *this; } + + OpReg &INFER_SHAPE_AND_TYPE() { return *this; } +}; + +#define REG_OP(x) \ + namespace op { \ + class x : public Operator { \ + typedef x _THIS_TYPE; \ + \ + public: \ + explicit x(const string &name) : Operator(name, #x) { __##x(); } \ + x() : Operator(#x) { __##x(); } \ + \ + private: \ + void __##x() { \ + OpReg() + +#define ATTR(x, Type, ...) \ + N(); \ + __attr_##x(); \ + } \ + \ + public: \ + static const string name_attr_##x() { return #x; } \ + Op##Type get_attr_##x() const { \ + Op##Type ret = __VA_ARGS__; \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + return ret; \ + } \ + return ret; \ + } \ + _THIS_TYPE &set_attr_##x(const Op##Type &v) { \ + Operator::SetAttr(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_attr_##x(const function &v) { return *this; } \ + \ + private: \ + void __attr_##x() { \ + Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \ + string attr_name(#x); \ + (void)OpReg() + +#define REQUIRED_ATTR(x, Type) \ + N(); \ + __required_attr_##x(); \ + } \ + \ + public: \ + static const string name_attr_##x() { return #x; } \ + Op##Type get_attr_##x() const { \ + Op##Type ret; \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + return ret; \ + } \ + return ret; \ + } \ + _THIS_TYPE &set_attr_##x(const Op##Type &v) { \ + Operator::SetAttr(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_attr_##x(const function &v) { return *this; } \ + \ + private: \ + void __required_attr_##x() { \ + Operator::RequiredAttrRegister(#x); \ + string attr_name(#x); \ + (void)OpReg() + +#define INPUT(x, t) \ + N(); \ + __input_##x(); \ + } \ + \ + public: \ + static const string name_in_##x() { return #x; } \ + _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \ + Operator::SetInput(#x, v, srcName); \ + return *this; \ + } \ + _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ + Operator::SetInput(#x, v, index); \ + return *this; \ + } \ + _THIS_TYPE &set_input_##x(Operator &v) { \ + Operator::SetInput(#x, v); \ + return *this; \ + } \ + TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \ + graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ + return Operator::UpdateInputDesc(#x, tensorDesc); \ + } \ + \ + private: \ + void __input_##x() { \ + Operator::InputRegister(#x); \ + (void)OpReg() + +#define OPTIONAL_INPUT(x, t) \ + N(); \ + __optional_input_##x(); \ + } \ + \ + public: \ + static const string name_in_##x() { return #x; } \ + _THIS_TYPE &set_input_##x(Operator &v) { \ + Operator::SetInput(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \ + Operator::SetInput(#x, v, srcName); \ + return *this; \ + } \ + _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ + Operator::SetInput(#x, v, index); \ + return *this; \ + } \ + TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \ + graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ + return Operator::UpdateInputDesc(#x, tensorDesc); \ + } \ + \ + private: \ + void __optional_input_##x() { \ + Operator::OptionalInputRegister(#x); \ + (void)OpReg() + +#define OUTPUT(x, t) \ + N(); \ + __out_##x(); \ + } \ + \ + public: \ + static const string name_out_##x() { return #x; } \ + TensorDesc get_output_desc_##x() const { return Operator::GetOutputDesc(#x); } \ + graphStatus update_output_desc_##x(const TensorDesc &tensorDesc) { \ + return Operator::UpdateOutputDesc(#x, tensorDesc); \ + } \ + \ + private: \ + void __out_##x() { \ + Operator::OutputRegister(#x); \ + (void)OpReg() + +#define DYNAMIC_INPUT(x, t) \ + N(); \ + __dy_input_##x(); \ + } \ + \ + public: \ + _THIS_TYPE &create_dynamic_input_##x(uint32_t num, bool isPushBack = true) { \ + Operator::DynamicInputRegister(#x, num, isPushBack); \ + return *this; \ + } \ + _THIS_TYPE &create_dynamic_input_byindex_##x(uint32_t num, size_t index) { \ + Operator::DynamicInputRegisterByIndex(#x, num, index); \ + return *this; \ + } \ + TensorDesc get_dynamic_input_desc_##x(uint32_t index) const { return Operator::GetDynamicInputDesc(#x, index); } \ + graphStatus update_dynamic_input_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ + return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ + } \ + _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v) { \ + Operator::SetInput(#x, dstIndex, v); \ + return *this; \ + } \ + _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const string &srcName) { \ + Operator::SetInput(#x, dstIndex, v, srcName); \ + return *this; \ + } \ + \ + private: \ + void __dy_input_##x() { \ + Operator::DynamicInputRegister(#x, 0, true); \ + (void)OpReg() + +#define DYNAMIC_OUTPUT(x, t) \ + N(); \ + __dy_output_##x(); \ + } \ + \ + public: \ + _THIS_TYPE &create_dynamic_output_##x(uint32_t num, bool isPushBack = true) { \ + Operator::DynamicOutputRegister(#x, num, isPushBack); \ + return *this; \ + } \ + TensorDesc get_dynamic_output_desc_##x(uint32_t index) const { return Operator::GetDynamicOutputDesc(#x, index); } \ + graphStatus update_dynamic_output_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ + return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \ + } \ + \ + private: \ + void __dy_output_##x() { \ + Operator::DynamicOutputRegister(#x, 0, true); \ + (void)OpReg() + +#define GRAPH(x) \ + N(); \ + __graph_##x(); \ + } \ + \ + public: \ + static const string name_graph_##x() { return #x; } \ + SubgraphBuilder get_subgraph_builder_##x() const { return Operator::GetSubgraphBuilder(#x); } \ + _THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \ + Operator::SetSubgraphBuilder(#x, 0, v); \ + return *this; \ + } \ + Graph get_subgraph_##x() const { return Operator::GetSubgraph(#x); } \ + \ + private: \ + void __graph_##x() { \ + Operator::SubgraphRegister(#x, false); \ + Operator::SubgraphCountRegister(#x, 1); \ + (void)OpReg() + +#define DYNAMIC_GRAPH(x) \ + N(); \ + __graph_##x(); \ + } \ + \ + public: \ + static const string name_graph_##x() { return #x; } \ + _THIS_TYPE &create_dynamic_subgraph_##x(uint32_t num) { \ + Operator::SubgraphCountRegister(#x, num); \ + return *this; \ + } \ + SubgraphBuilder get_dynamic_subgraph_builder_##x(uint32_t index) const { \ + return Operator::GetDynamicSubgraphBuilder(#x, index); \ + } \ + Graph get_dynamic_subgraph_##x(uint32_t index) const { return Operator::GetDynamicSubgraph(#x, index); } \ + _THIS_TYPE &set_dynamic_subgraph_builder_##x(uint32_t index, const SubgraphBuilder &v) { \ + Operator::SetSubgraphBuilder(#x, index, v); \ + return *this; \ + } \ + \ + private: \ + void __graph_##x() { \ + Operator::SubgraphRegister(#x, true); \ + (void)OpReg() + +#define PASTE(g_register, y) g_register##y +#define __OP_END_IMPL__(x, y) \ + N(); \ + } \ + static_assert( \ + std::is_same::value, \ + "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ + } \ + ; \ + static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \ + } +#define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__) + +// Specialized shape inferencer macro + +#define IMPLEMT_INFERFUNC(op_name, func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) + +#define IMPLEMT_COMMON_INFERFUNC(func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(Operator &op) + +#define IMPLEMT_INFERFORMAT_FUNC(op_name, func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) + +// Specialized verifier macro + +#define IMPLEMT_VERIFIER(op_name, func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name op) + +#define INFER_VERIFY_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); } + +#define COMMON_INFER_VERIFY_FUNC(x) [&](Operator &v) { return x(v); } + +#define INFER_FORMAT_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); } + +#define __INFER_FUNC_REG_IMPL__(op_name, x, n) static const InferShapeFuncRegister PASTE(if_register, n)(#op_name, x) + +#define __VERIFY_FUNC_REG_IMPL__(op_name, x, n) static const VerifyFuncRegister PASTE(vf_register, n)(#op_name, x) +// Infer format func register +#define __INFER_FORMAT_FUNC_REG_IMPL__(op_name, x, n) \ + static const InferFormatFuncRegister PASTE(ff_register, n)(#op_name, x) + +// Shape inferencer & verifier register macro + +#define INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__) + +#define COMMON_INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, COMMON_INFER_VERIFY_FUNC(x), __COUNTER__) + +#define VERIFY_FUNC_REG(op_name, x) __VERIFY_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__) + +// Infer format func reg +#define INFER_FORMAT_FUNC_REG(op_name, x) \ + __INFER_FORMAT_FUNC_REG_IMPL__(op_name, INFER_FORMAT_FUNC(op_name, x), __COUNTER__) + +// Common shape inferencer + +#define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \ + [](Operator op) -> graphStatus { \ + auto x_shape = op.GetInputDesc(in_name).GetShape().GetDims(); \ + auto x_type = op.GetInputDesc(in_name).GetDataType(); \ + TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ + op_output_desc.SetShape(ge::Shape(x_shape)); \ + op_output_desc.SetOriginShape(ge::Shape(x_shape)); \ + op_output_desc.SetDataType(x_type); \ + return op.UpdateOutputDesc(out_name, op_output_desc); \ + } + +graphStatus BroadCastInfer(const function()> &get_in1_shape, + const function()> &get_in2_shape, + const function &y_shape)> &set_out_shape); + +#define BROADCAST_INFER(in1_name, in2_name, out_name) \ + [](Operator op) -> graphStatus { \ + return BroadCastInfer([&]() { return op.GetInputDesc(in1_name).GetShape().GetDims(); }, \ + [&]() { return op.GetInputDesc(in2_name).GetShape().GetDims(); }, \ + [&](const vector &y_shape) { \ + TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ + op_output_desc.SetShape(ge::Shape(y_shape)); \ + (void)op.UpdateOutputDesc(out_name, op_output_desc); \ + }); \ + } +} // namespace ge +#endif // INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ diff --git a/inc/external/graph/tensor.h b/inc/external/graph/tensor.h new file mode 100644 index 00000000..800e1037 --- /dev/null +++ b/inc/external/graph/tensor.h @@ -0,0 +1,131 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_TENSOR_H_ +#define INC_EXTERNAL_GRAPH_TENSOR_H_ + +#include +#include +#include +#include +#include + +#include "./ge_error_codes.h" +#include "./types.h" + +namespace ge { +class ShapeImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Shape { + public: + Shape(); + ~Shape() = default; + explicit Shape(const std::vector &dims); + + size_t GetDimNum() const; + // If the idx is invalid, return 0 + int64_t GetDim(size_t idx) const; + graphStatus SetDim(size_t idx, int64_t value); + std::vector GetDims() const; + int64_t GetShapeSize() const; + + private: + std::shared_ptr impl_; +}; + +class TensorDescImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc { + public: + TensorDesc(); + ~TensorDesc() = default; + explicit TensorDesc(Shape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); + // Copy + TensorDesc(const TensorDesc &desc); + // Move + TensorDesc(TensorDesc &&desc); + // Copy + TensorDesc &operator=(const TensorDesc &desc); + // Move + TensorDesc &operator=(TensorDesc &&desc); + + void Update(const Shape &shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); + Shape GetShape() const; + void SetShape(const Shape &shape); + // set shape with -2, it stand for unknown shape + graphStatus SetUnknownDimNumShape(); + // for unknown shape + graphStatus SetShapeRange(const std::vector> &range); + graphStatus GetShapeRange(std::vector> &range) const; + + Format GetFormat() const; + void SetFormat(Format format); + + Shape GetOriginShape() const; + void SetOriginShape(const Shape &originShape); + + Format GetOriginFormat() const; + void SetOriginFormat(Format originFormat); + + DataType GetDataType() const; + void SetDataType(DataType dt); + + std::string GetName() const; + void SetName(const std::string &name); + + // Attr acess + void SetSize(int64_t size); + int64_t GetSize() const; + + int64_t GetRealDimCnt() const; + void SetRealDimCnt(const int64_t realDimCnt); + + private: + std::shared_ptr impl; +}; + +class TensorImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Tensor { + public: + Tensor(); + ~Tensor() = default; + explicit Tensor(const TensorDesc &tensorDesc); + Tensor(const TensorDesc &tensorDesc, const std::vector &data); + Tensor(const TensorDesc &tensorDesc, const uint8_t *data, size_t size); + Tensor(TensorDesc &&tensorDesc, std::vector &&data); + + TensorDesc GetTensorDesc() const; + graphStatus SetTensorDesc(const TensorDesc &tensorDesc); + + const uint8_t *GetData() const; + uint8_t *GetData(); + size_t GetSize() const; + + graphStatus SetData(std::vector &&data); + graphStatus SetData(const std::vector &data); + graphStatus SetData(const uint8_t *data, size_t size); + graphStatus SetData(const std::string &data); + graphStatus SetData(const std::vector &data); + graphStatus IsValid(); + + Tensor Clone() const; + + private: + std::shared_ptr impl; + friend class TensorAdapter; +}; +} // namespace ge +/*lint +e148*/ + +#endif // INC_EXTERNAL_GRAPH_TENSOR_H_ diff --git a/inc/external/graph/types.h b/inc/external/graph/types.h new file mode 100644 index 00000000..a1245c9d --- /dev/null +++ b/inc/external/graph/types.h @@ -0,0 +1,240 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_TYPES_H_ +#define INC_EXTERNAL_GRAPH_TYPES_H_ + +#include +#include +#include + +namespace ge { +static const int64_t UNKNOWN_DIM = -1; +static const int64_t UNKNOWN_DIM_NUM = -2; +static const std::vector UNKNOWN_SHAPE = {-1}; +static const std::vector UNKNOWN_RANK = {-2}; + +#ifdef HOST_VISIBILITY +#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_HOST_VISIBILITY +#endif +#ifdef DEV_VISIBILITY +#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_DEV_VISIBILITY +#endif + +enum DataType { + DT_FLOAT = 0, // float type + DT_FLOAT16 = 1, // fp16 type + DT_INT8 = 2, // int8 type + DT_INT16 = 6, // int16 type + DT_UINT16 = 7, // uint16 type + DT_UINT8 = 4, // uint8 type + DT_INT32 = 3, // + DT_INT64 = 9, // int64 type + DT_UINT32 = 8, // unsigned int32 + DT_UINT64 = 10, // unsigned int64 + DT_BOOL = 12, // bool type + DT_DOUBLE = 11, // double type + DT_STRING = 13, // string type + DT_DUAL_SUB_INT8 = 14, // dual output int8 type + DT_DUAL_SUB_UINT8 = 15, // dual output uint8 type + DT_COMPLEX64 = 16, // complex64 type + DT_COMPLEX128 = 17, // complex128 type + DT_QINT8 = 18, // qint8 type + DT_QINT16 = 19, // qint16 type + DT_QINT32 = 20, // qint32 type + DT_QUINT8 = 21, // quint8 type + DT_QUINT16 = 22, // quint16 type + DT_RESOURCE = 23, // resource type + DT_STRING_REF = 24, // string ref type + DT_DUAL = 25, // dual output type + DT_UNDEFINED // Used to indicate a DataType field has not been set. +}; + +inline int GetSizeByDataType(DataType data_type) { + static int data_type_size[DT_UNDEFINED] = { + 4, // DT_FLOAT = 0, float type + 2, // DT_FLOAT16 = 1, fp16 type + 1, // DT_INT8 = 2, int8 type + 4, // DT_INT32 = 3, + 1, // DT_UINT8 = 4, uint8 type + -1, + 2, // DT_INT16 = 6, int16 type + 2, // DT_UINT16 = 7, uint16 type + 4, // DT_UINT32 = 8, unsigned int32 + 8, // DT_INT64 = 9, int64 type + 8, // DT_UINT64 = 10, unsigned int64 + 8, // DT_DOUBLE = 11, double type + 1, // DT_BOOL = 12, bool type + -1, // DT_STRING = 13, string type + 1, // DT_DUAL_SUB_INT8 = 14, dual output int8 type + 1, // DT_DUAL_SUB_UINT8 = 15, dual output uint8 type + 8, // DT_COMPLEX64 = 16, complex64 type + 16, // DT_COMPLEX128 = 17, complex128 type + 1, // DT_QINT8 = 18, qint8 type + 2, // DT_QINT16 = 19, qint16 type + 4, // DT_QINT32 = 20, qint32 type + 1, // DT_QUINT8 = 21, quint8 type + 2, // DT_QUINT16 = 22, quint16 type + -1, // DT_RESOURCE = 23, resource type + -1, // DT_STRING_REF = 24, string ref type + 5, // DT_DUAL = 25, dual output type (float + int8) + // DT_UNDEFINED Used to indicate a DataType field has not been set. + }; + if (data_type >= DT_UNDEFINED) { + return -1; + } + return data_type_size[data_type]; +} + +enum Format { + FORMAT_NCHW = 0, // NCHW + FORMAT_NHWC, // NHWC + FORMAT_ND, // Nd Tensor + FORMAT_NC1HWC0, // NC1HWC0 + FORMAT_FRACTAL_Z, // FRACTAL_Z + FORMAT_NC1C0HWPAD, + FORMAT_NHWC1C0, + FORMAT_FSR_NCHW, + FORMAT_FRACTAL_DECONV, + FORMAT_C1HWNC0, + FORMAT_FRACTAL_DECONV_TRANSPOSE, + FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, + FORMAT_NC1HWC0_C04, // NC1HWC0, C0 =4 + FORMAT_FRACTAL_Z_C04, // FRACZ, C0 =4 + FORMAT_CHWN, + FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, + FORMAT_HWCN, + FORMAT_NC1KHKWHWC0, // KH,KW kernel h& kernel w maxpooling max output format + FORMAT_BN_WEIGHT, + FORMAT_FILTER_HWCK, // filter input tensor format + FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20, + FORMAT_HASHTABLE_LOOKUP_KEYS, + FORMAT_HASHTABLE_LOOKUP_VALUE, + FORMAT_HASHTABLE_LOOKUP_OUTPUT, + FORMAT_HASHTABLE_LOOKUP_HITS = 24, + FORMAT_C1HWNCoC0, + FORMAT_MD, + FORMAT_NDHWC, + FORMAT_FRACTAL_ZZ, + FORMAT_FRACTAL_NZ, + FORMAT_NCDHW, + FORMAT_DHWCN, // 3D filter input tensor format + FORMAT_NDC1HWC0, + FORMAT_FRACTAL_Z_3D, + FORMAT_CN, + FORMAT_NC, + FORMAT_DHWNC, + FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format + FORMAT_FRACTAL_ZN_LSTM, + FORMAT_FRACTAL_Z_G, + FORMAT_RESERVED, + FORMAT_ALL, + FORMAT_NULL +}; + +// for unknown shape op type +enum UnknowShapeOpType { + DEPEND_IN_SHAPE = 1, // op out shape get by input shape + DEPEND_CONST_VALUE = 2, // op out shape get by const op value + DEPEND_SHAPE_RANGE = 3, // op out shape get by range + DEPEND_COMPUTE = 4 // op out shape get by totally computing +}; + +struct TensorDescInfo { + Format format_ = FORMAT_RESERVED; // tbe op register support format + DataType dataType_ = DT_UNDEFINED; // tbe op register support datatype +}; + +enum DeviceType { + NPU = 0, + CPU = 1, +}; + +class TensorTypeImpl; +struct TensorType { + explicit TensorType(DataType dt); + + TensorType(const std::initializer_list &types); + + static TensorType ALL() { + return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, + DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, + DT_QUINT8, DT_RESOURCE, DT_STRING, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; + } + + static TensorType QuantifiedType() { return TensorType{DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, DT_QUINT8}; } + + static TensorType OrdinaryType() { + return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, + DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; + } + + static TensorType BasicType() { + return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, + DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, + DT_QUINT16, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; + } + + static TensorType NumberType() { + return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, + DT_INT8, DT_QINT32, DT_QINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; + } + + static TensorType RealNumberType() { + return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, + DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; + } + + static TensorType ComplexDataType() { return TensorType{DT_COMPLEX128, DT_COMPLEX64}; } + + static TensorType IntegerDataType() { + return TensorType{DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; + } + + static TensorType SignedDataType() { return TensorType{DT_INT16, DT_INT32, DT_INT64, DT_INT8}; } + + static TensorType UnsignedDataType() { return TensorType{DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; } + + static TensorType FloatingDataType() { return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16}; } + + static TensorType IndexNumberType() { return TensorType{DT_INT32, DT_INT64}; } + + static TensorType UnaryDataType() { return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16}; } + + static TensorType FLOAT() { return TensorType{DT_FLOAT, DT_FLOAT16}; } + + std::shared_ptr tensor_type_impl_; +}; +} // namespace ge + +namespace domi { +enum class ImplyType : unsigned int { + BUILDIN = 0, // Built in operator, normally executed by OME + TVM, // Compile to TVM bin file for execution + CUSTOM, // User defined calculation logic, executed by CPU + AI_CPU, // AICPU + CCE, // Cce + GELOCAL, // GE local, do node need execute by device + HCCL, // Hccl + INVALID = 0xFFFFFFFF, +}; +} // namespace domi + +#endif // INC_EXTERNAL_GRAPH_TYPES_H_ diff --git a/inc/external/register/register.h b/inc/external/register/register.h new file mode 100644 index 00000000..f3091fae --- /dev/null +++ b/inc/external/register/register.h @@ -0,0 +1,163 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_REGISTER_REGISTER_H_ +#define INC_EXTERNAL_REGISTER_REGISTER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "graph/operator.h" +#include "register/register_error_codes.h" +#include "register/register_fmk_types.h" +#include "register/register_types.h" + +using std::make_shared; +using std::map; +using std::pair; +using std::string; +using std::to_string; +using std::unique_ptr; +using std::vector; + +/*lint -e148*/ +namespace ge { +class Operator; +class TensorDesc; +class Tensor; +class TBEPluginManager; +} // namespace ge + +namespace google { +namespace protobuf { +class Message; +} +} // namespace google + +namespace domi { +const int64_t kMaxNameLength = 1048576; // 1M + +enum DynamicType { kInvalid = 0, kInput = 1, kOutput = 2 }; +struct DynamicInputOutputInfo { + DynamicType type; // input/output + const char *port_name; + int64_t port_name_len; + const char *attr_name; + int64_t attr_name_len; + DynamicInputOutputInfo() + : type(kInvalid), port_name(nullptr), port_name_len(0), attr_name(nullptr), attr_name_len(0) {} + DynamicInputOutputInfo(DynamicType type, const char *port_name, int64_t port_name_len, const char *attr_name, + int64_t attr_name_len) + : type(type), + port_name(port_name), + port_name_len(port_name_len), + attr_name(attr_name), + attr_name_len(attr_name_len) {} +}; +Status AutoMappingByOpFn(const ge::Operator &op_src, ge::Operator &op); +Status AutoMappingByOpFnDynamic(const ge::Operator &op_src, ge::Operator &op, + const vector &dynamic_name_attr_value); +Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); +Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, + std::map> dynamic_name_attr_value, + int in_pos = -1, int out_pos = -1); +Status AutoMappingSubgraphIndex(const ge::Graph &graph, const std::function &input, + const std::function &output); +Status AutoMappingSubgraphIndex(const ge::Graph &graph, + const std::function &input, + const std::function &output); +using google::protobuf::Message; +class OpRegistrationDataImpl; + +using ParseParamFunc = std::function; +using ParseParamByOpFunc = std::function; +using FusionParseParamFunc = + std::function, ge::Operator &)>; +using FusionParseParamByOpFunc = std::function &, ge::Operator &)>; +using ParseSubgraphFunc = std::function; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { + public: + OpRegistrationData(const std::string &om_optype); + + ~OpRegistrationData(); + + OpRegistrationData &FrameworkType(const domi::FrameworkType &fmk_type); + + OpRegistrationData &OriginOpType(const std::initializer_list &ori_optype_list); + + OpRegistrationData &OriginOpType(const std::string &ori_optype); + + OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn); + + OpRegistrationData &ParseParamsByOperatorFn(const ParseParamByOpFunc &parse_param_by_op_fn); + + OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); + + OpRegistrationData &FusionParseParamsFn(const FusionParseParamByOpFunc &fusion_parse_param_fn); + + OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn); + + OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); + + OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); + + OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type); + + OpRegistrationData &InputReorderVector(const vector &input_order); + + domi::ImplyType GetImplyType() const; + std::string GetOmOptype() const; + std::set GetOriginOpTypeSet() const; + domi::FrameworkType GetFrameworkType() const; + ParseParamFunc GetParseParamFn() const; + ParseParamByOpFunc GetParseParamByOperatorFn() const; + FusionParseParamFunc GetFusionParseParamFn() const; + FusionParseParamByOpFunc GetFusionParseParamByOpFn() const; + ParseSubgraphFunc GetParseSubgraphPostFn() const; + + private: + std::shared_ptr impl_; + friend class OpRegistry; + friend class OpRegistrationTbe; + friend class ge::TBEPluginManager; +}; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { + public: + OpReceiver(OpRegistrationData ®_data); + ~OpReceiver() {} +}; + +#define REGISTER_CUSTOM_OP(name) REGISTER_CUSTOM_OP_UNIQ_HELPER(__COUNTER__, name) +#define REGISTER_CUSTOM_OP_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_OP_UNIQ(ctr, name) +#define REGISTER_CUSTOM_OP_UNIQ(ctr, name) \ + static OpReceiver register_op##ctr __attribute__((unused)) = OpRegistrationData(name) +} // namespace domi + +namespace ge { +using OpRegistrationData = domi::OpRegistrationData; +using OpReceiver = domi::OpReceiver; +} // namespace ge +/*lint +e148*/ +#endif // INC_EXTERNAL_REGISTER_REGISTER_H_ diff --git a/inc/external/register/register_error_codes.h b/inc/external/register/register_error_codes.h new file mode 100644 index 00000000..5e0ed79f --- /dev/null +++ b/inc/external/register/register_error_codes.h @@ -0,0 +1,39 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ +#define INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ + +#define SYSID_FWK 3 // Subsystem ID +#define MODID_COMMON 0 // Common module ID + +#define DECLARE_ERRORNO(sysid, modid, name, value) \ + const domi::Status name = \ + ((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value)); + +#define DECLARE_ERRORNO_COMMON(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_COMMON, name, value) + +namespace domi { +using Status = uint32_t; + +// General error code +DECLARE_ERRORNO(0, 0, SUCCESS, 0); +DECLARE_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFFFFFF); +DECLARE_ERRORNO_COMMON(PARAM_INVALID, 1); // 50331649 +DECLARE_ERRORNO(SYSID_FWK, 1, SCOPE_NOT_CHANGED, 201); +} // namespace domi + +#endif // INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ diff --git a/inc/external/register/register_fmk_types.h b/inc/external/register/register_fmk_types.h new file mode 100644 index 00000000..97616060 --- /dev/null +++ b/inc/external/register/register_fmk_types.h @@ -0,0 +1,37 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ +#define INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ + +#include + +namespace domi { +/// +/// @ingroup domi_omg +/// @brief AI framework types +/// +enum FrameworkType { + CAFFE = 0, + MINDSPORE = 1, + TENSORFLOW = 3, + ANDROID_NN, + ONNX, + FRAMEWORK_RESERVED, +}; +} // namespace domi + +#endif // INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ diff --git a/inc/external/register/register_types.h b/inc/external/register/register_types.h new file mode 100644 index 00000000..08d72713 --- /dev/null +++ b/inc/external/register/register_types.h @@ -0,0 +1,59 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ +#define INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ + +namespace domi { +#ifdef HOST_VISIBILITY +#define FMK_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) +#else +#define FMK_FUNC_HOST_VISIBILITY +#endif +#ifdef DEV_VISIBILITY +#define FMK_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) +#else +#define FMK_FUNC_DEV_VISIBILITY +#endif + +/// CCE defined constant + +/// +/// @ingroup domi +/// @brief original tensor type +/// +typedef enum tagDomiTensorFormat { + DOMI_TENSOR_NCHW = 0, // < NCHW + DOMI_TENSOR_NHWC, // < NHWC + DOMI_TENSOR_ND, // < Nd Tensor + DOMI_TENSOR_NC1HWC0, // < NC1HWC0 + DOMI_TENSOR_FRACTAL_Z, // < FRACTAL_Z + DOMI_TENSOR_NC1C0HWPAD, + DOMI_TENSOR_NHWC1C0, + DOMI_TENSOR_FSR_NCHW, + DOMI_TENSOR_FRACTAL_DECONV, + DOMI_TENSOR_BN_WEIGHT, + DOMI_TENSOR_CHWN, // Android NN Depth CONV + DOMI_TENSOR_FILTER_HWCK, // filter input tensor format + DOMI_TENSOR_NDHWC, + DOMI_TENSOR_NCDHW, + DOMI_TENSOR_DHWCN, // 3D filter input tensor format + DOMI_TENSOR_DHWNC, + DOMI_TENSOR_RESERVED +} domiTensorFormat_t; +} // namespace domi + +#endif // INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ diff --git a/inc/external/register/scope/scope_fusion_pass_register.h b/inc/external/register/scope/scope_fusion_pass_register.h new file mode 100644 index 00000000..8e5605a7 --- /dev/null +++ b/inc/external/register/scope/scope_fusion_pass_register.h @@ -0,0 +1,334 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ +#define EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ + +#include +#include +#include +#include +#include +#include "ge/ge_api_error_codes.h" +#include "register/register_error_codes.h" +#include "register/register_types.h" +#include "graph/operator.h" + +#define CHECK_INNER_NODE_CONDITION(cond, fusion_rlt) \ + do { \ + if (!(cond)) { \ + if ((fusion_rlt) != nullptr) { \ + (fusion_rlt)->SetType(ge::kScopeInvalidType); \ + } \ + return; \ + } \ + } while (0) + +namespace domi { +class TensorFlowModelParser; +} // namespace domi +namespace ge { +const int32_t kFusionDisableIndex = 99999; +const char *const kScopeToMultiNodes = "ScopeToMultiNodes"; +const char *const kScopeInvalidType = "ScopeInvalidType"; +const char *const kInputFromFusionScope = "InputFromFusionScope"; +const char *const kOutputToFusionScope = "OutputToFusionScope"; +class ScopePattern; +using ScopeFusionPatterns = std::vector>; + +class ScopePassManager; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope { + public: + Scope(); + Status Init(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr); + ~Scope(); + + const std::string &Name() const; + const std::string &SubType() const; + const std::unordered_map &AllNodesMap() const; + Scope *GetSubScope(const std::string &scope_name) const; + const std::string LastName() const; + const std::vector &GetAllSubScopes() const; + const Scope *GetFatherScope() const; + + private: + class ScopeImpl; + std::unique_ptr impl_; + friend class ScopeBasePass; + friend class ScopeTree; + friend class NodeOpTypeFeature; + friend class NodeAttrFeature; + friend class ScopeFeature; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY FusionScopesResult { + public: + FusionScopesResult(); + Status Init(); + ~FusionScopesResult(); + void SetName(const std::string &name); + void SetType(const std::string &type); + void SetDescription(const std::string &description); + const std::string &Name() const; + const std::vector &Nodes() const; + void InsertInputs(const std::string &inner_op_name, const std::vector &index_map); + void InsertOutputs(const std::string &inner_op_name, const std::vector &index_map); + + class InnerNodeInfo { + public: + explicit InnerNodeInfo(const std::string &fusion_node_name); + InnerNodeInfo(const std::string &fusion_node_name, const std::string &name, const std::string &type); + InnerNodeInfo(InnerNodeInfo &&other) noexcept; + InnerNodeInfo &operator=(InnerNodeInfo &&other) noexcept; + InnerNodeInfo(const InnerNodeInfo &) = delete; + InnerNodeInfo &operator=(const InnerNodeInfo &) = delete; + ~InnerNodeInfo(); + InnerNodeInfo &SetName(const std::string &name); + InnerNodeInfo &SetType(const std::string &type); + InnerNodeInfo &InsertInput(const std::string &input_node, int32_t peer_out_idx); + InnerNodeInfo &InsertOutput(const std::string &output_node, int32_t peer_in_idx); + ge::graphStatus BuildInnerNode(); + ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format); + ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format); + ge::graphStatus SetDynamicInputFormat(const std::string &input_name, uint32_t index, const std::string &format); + ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, uint32_t index, const std::string &format); + ge::Operator *MutableOperator(); + + std::string GetName() const; + std::string GetType() const; + std::vector> GetInputs() const; + std::vector> GetOutputs() const; + + private: + class InnerNodeInfoImpl; + std::unique_ptr impl_; + }; + + InnerNodeInfo *AddInnerNode(const std::string &name, const std::string &type); + InnerNodeInfo *MutableRecentInnerNode(); + InnerNodeInfo *MutableInnerNode(uint32_t index); + ge::graphStatus CheckInnerNodesInfo(); + + private: + class FusionScopesResultImpl; + std::unique_ptr impl_; + friend class ScopeGraph; + friend class ScopeBasePass; + friend class TensorFlowModelParser; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeTree { + public: + ScopeTree(); + Status Init(); + ScopeTree(const ScopeTree &scopetree) = delete; + ScopeTree &operator=(const ScopeTree &scopetree) = delete; + ~ScopeTree(); + + const std::vector &GetAllScopes() const; + + private: + class ScopeTreeImpl; + std::unique_ptr impl_; + friend class ScopeGraph; + friend class ScopeBasePass; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeGraph { + public: + ScopeGraph(); + Status Init(); + ScopeGraph(const ScopeGraph &scope_graph) = delete; + ScopeGraph &operator=(const ScopeGraph &scope_graph) = delete; + ~ScopeGraph(); + + const ScopeTree *GetScopeTree() const; + const std::unordered_map &GetNodesMap() const; + + private: + class ScopeGraphImpl; + std::unique_ptr impl_; + friend class ScopePassManager; + friend class ScopeBasePass; + friend class TensorFlowModelParser; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeAttrValue { + public: + ScopeAttrValue(); + ScopeAttrValue(ScopeAttrValue const &attr_value); + ScopeAttrValue &operator=(ScopeAttrValue const &attr_value); + ~ScopeAttrValue(); + + void SetIntValue(int64_t value); + void SetFloatValue(float value); + void SetStringValue(std::string value); + void SetBoolValue(bool value); + + private: + class ScopeAttrValueImpl; + std::unique_ptr impl_; + friend class NodeAttrFeature; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBaseFeature { + public: + virtual bool Match(const Scope *scope) = 0; + virtual ~ScopeBaseFeature(){}; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeOpTypeFeature : ScopeBaseFeature { + public: + NodeOpTypeFeature(std::string nodeType, int num, int step = 0); + NodeOpTypeFeature(NodeOpTypeFeature const &feature); + NodeOpTypeFeature &operator=(NodeOpTypeFeature const &feature); + ~NodeOpTypeFeature(); + bool Match(const Scope *scope) override; + + private: + class NodeOpTypeFeatureImpl; + std::unique_ptr impl_; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeAttrFeature : ScopeBaseFeature { + public: + NodeAttrFeature(std::string nodeType, std::string attr_name, ge::DataType datatype, ScopeAttrValue &attr_value); + NodeAttrFeature(NodeAttrFeature const &feature); + NodeAttrFeature &operator=(NodeAttrFeature const &feature); + ~NodeAttrFeature(); + bool Match(const Scope *scope) override; + + private: + class NodeAttrFeatureImpl; + std::unique_ptr impl_; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFeature : ScopeBaseFeature { + public: + ScopeFeature(std::string sub_type, int32_t num, std::string suffix = "", std::string sub_scope_mask = "", + int step = 0); + ScopeFeature(ScopeFeature const &feature); + ScopeFeature &operator=(ScopeFeature const &feature); + ~ScopeFeature(); + bool Match(const Scope *scope) override; + + private: + class ScopeFeatureImpl; + std::unique_ptr impl_; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopePattern { + public: + ScopePattern(); + ~ScopePattern(); + + ScopePattern &SetSubType(const std::string &sub_type); + ScopePattern &AddNodeOpTypeFeature(NodeOpTypeFeature feature); + ScopePattern &AddNodeAttrFeature(NodeAttrFeature feature); + ScopePattern &AddScopeFeature(ScopeFeature feature); + + private: + class ScopePatternImpl; + std::unique_ptr impl_; + friend class ScopeBasePass; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopesResult { + public: + ScopesResult(); + ScopesResult(ScopesResult const &result); + ScopesResult &operator=(ScopesResult const &result); + ~ScopesResult(); + + void SetScopes(std::vector &scopes); + void SetNodes(std::vector &nodes); + + private: + class ScopesResultImpl; + std::unique_ptr impl_; + friend class ScopeBasePass; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBasePass { + public: + ScopeBasePass(); + virtual ~ScopeBasePass(); + + protected: + // Subclasses implement respective fusion strategies and build the Patterns + virtual std::vector DefinePatterns() = 0; + // Define the name of the scope pass + virtual std::string PassName() = 0; + // Subclasses implement respective multi-scope or operator fusion methods across scopes + virtual Status LastMatchScopesAndOPs(std::shared_ptr &scope_graph, + std::vector &results) = 0; + // Subclasses implement their own results and set the input and output of the final fusion operator + virtual void GenerateFusionResult(const std::vector &scopes, FusionScopesResult *fusion_rlt) = 0; + + private: + class ScopeBasePassImpl; + std::unique_ptr impl_; + friend class ge::ScopePassManager; + friend class ScopeBasePassImpl; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistry { + public: + using CreateFn = ScopeBasePass *(*)(); + ~ScopeFusionPassRegistry(); + + static ScopeFusionPassRegistry &GetInstance() { + static ScopeFusionPassRegistry instance; + return instance; + } + + void RegisterScopeFusionPass(const std::string &pass_name, CreateFn create_fn, bool is_general); + + private: + ScopeFusionPassRegistry(); + class ScopeFusionPassRegistryImpl; + /*lint -e148*/ + std::unique_ptr impl_; + friend class TensorFlowModelParser; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeUtil { + public: + static std::string StringReplaceAll(std::string str, const std::string &old_value, const std::string &new_value); + static void FreeScopePatterns(ScopeFusionPatterns &patterns); + static void FreeOneBatchPattern(std::vector &one_batch_pattern); +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistrar { + public: + ScopeFusionPassRegistrar(const char *pass_name, ScopeBasePass *(*create_fn)(), bool is_general); + ~ScopeFusionPassRegistrar() {} +}; + +#define REGISTER_SCOPE_FUSION_PASS(pass_name, scope_pass, is_general) \ + REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(__COUNTER__, pass_name, scope_pass, is_general) + +#define REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, scope_pass, is_general) \ + REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) + +#define REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) \ + static ::ge::ScopeFusionPassRegistrar register_scope_fusion_pass##ctr __attribute__((unused)) = \ + ::ge::ScopeFusionPassRegistrar( \ + pass_name, []() -> ::ge::ScopeBasePass * { return new (std::nothrow) scope_pass(); }, is_general) +} // namespace ge + +#endif // EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ diff --git a/inc/framework/omg/parser/model_parser.h b/inc/framework/omg/parser/model_parser.h deleted file mode 100644 index 3a8aa6ce..00000000 --- a/inc/framework/omg/parser/model_parser.h +++ /dev/null @@ -1,111 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_FRAMEWORK_OMG_PARSER_MODEL_PARSER_H_ -#define INC_FRAMEWORK_OMG_PARSER_MODEL_PARSER_H_ - -#include -#include "framework/common/types.h" -#include "framework/omg/omg_inner_types.h" -#include "graph/attr_value.h" -#include "graph/compute_graph.h" -#include "graph/ge_tensor.h" -#include "graph/graph.h" -#include "graph/op_desc.h" -#include "graph/operator.h" -#include "graph/range_vistor.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" - -using Status = domi::Status; - -namespace domi { -using GetGraphCallback = std::function( - const google::protobuf::Message *root_proto, const std::string &graph)>; -class ModelParser { - public: - ModelParser() {} - - virtual ~ModelParser() {} - - /** - * @ingroup domi_omg - * @brief Analyze network model data - * @param [in] file Network model file path - * @param [in|out] graph Save the network information after analysis - * @return SUCCESS - * @return Others failed - */ - virtual Status Parse(const char *file, ge::Graph &graph) = 0; - - /** - * @ingroup domi_omg - * @brief Parse relevant data from memory and save it to graph - * @param [in] input Model file memory data - * @param [in|out] graph A graph for saving the model information after analysis - * @return SUCCESS - * @return FAILED - * @author - */ - virtual Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) = 0; - - /** - * @ingroup domi_omg - * @brief Analyze network model data - * @param [in] proto network model - * @param [in|out] graph Save the network information after analysis - * @return SUCCESS - * @return Others failed - */ - virtual Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) = 0; - - /** - * @ingroup domi_omg - * @brief Analyze callback model data in subgraph - * @param [in] proto network model - * @param [in] callback callback of subgraph - * @param [in|out] graph Save the network information after analysis - * @return SUCCESS - * @return Others failed - */ - virtual Status ParseProtoWithSubgraph(const google::protobuf::Message *proto, - GetGraphCallback callback, - ge::ComputeGraphPtr &graph) = 0; - /** - * @ingroup domi_omg - * @brief Convert model files to JSON format - * @param [in] model_file Model file path to be converted - * @param [out] json_file Converted JSON file path - * @return SUCCESS - * @return Others failed - */ - virtual Status ToJson(const char *model_file, const char *json_file) { return domi::SUCCESS; } - - /* - * @ingroup domi_omg - * @brief Convert network data type - * @param [in] type Data type to be converted - * @return ge::DataType - */ - virtual ge::DataType ConvertToGeDataType(const uint32_t type) = 0; - - virtual Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) = 0; -}; -} // namespace domi - -#endif // INC_FRAMEWORK_OMG_PARSER_MODEL_PARSER_H_ diff --git a/inc/framework/omg/parser/op_parser.h b/inc/framework/omg/parser/op_parser.h deleted file mode 100644 index 251c0447..00000000 --- a/inc/framework/omg/parser/op_parser.h +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_FRAMEWORK_OMG_PARSER_OP_PARSER_H_ -#define INC_FRAMEWORK_OMG_PARSER_OP_PARSER_H_ - -#include -#include "common/types.h" -#include "omg/omg_inner_types.h" -#include "proto/om.pb.h" -#include "graph/ge_tensor.h" -#include "graph/op_desc.h" -#include "graph/utils/op_desc_utils.h" - -using google::protobuf::Message; -using Status = domi::Status; - -namespace ge { -/** - * @ingroup domi_omg - * @brief Used to analyze operator information - * - */ -class OpParser { - public: - /** - * @ingroup domi_omg - * @brief Deconstructor - */ - virtual ~OpParser() {} - - /** - * @ingroup domi_omg - * @brief Analytic operator parameters - * @param [in] op_src Parameter data to be resolved - * @param [out] graph Parsed parameter data - * @return SUCCESS - * @return FAILED - */ - virtual Status ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) = 0; - - /** - * @ingroup domi_omg - * @brief Analytic operator parameters - * @param [in] op_src Parameter data to be resolved - * @param [out] Operator parameter data - * @return SUCCESS - * @return FAILED - */ - virtual Status ParseParams(const Message *op_src, ge::Operator &op_dest) = 0; - - /** - * @ingroup domi_omg - * @brief Analytic operator weight information - * @param [in] op_src Weight data to be resolved - * @param [out] op_dest Weight data after analysis - * @return SUCCESS - * @return FAILED - */ - virtual Status ParseWeights(const Message *op_src, ge::NodePtr &node) = 0; - - /** - * @ingroup domi_omg - * @brief Get the format information according to the parameters in the operator - * @param [in] op_src Parameter data to be resolved - * @param [out] format Output the parsed format - * @return SUCCESS - * @return FAILED - */ - virtual Status GetFormat(const Message *op_src, domi::domiTensorFormat_t &format) { - (void)op_src; - // Indicates that the op does not provide a value for format - format = domi::DOMI_TENSOR_RESERVED; - return domi::SUCCESS; - } -}; -} // namespace ge - -#endif // INC_FRAMEWORK_OMG_PARSER_OP_PARSER_H_ diff --git a/inc/framework/omg/parser/parser_api.h b/inc/framework/omg/parser/parser_api.h deleted file mode 100644 index 382bdfde..00000000 --- a/inc/framework/omg/parser/parser_api.h +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_FRAMEWORK_OMG_PARSER_PARSER_API_H_ -#define INC_FRAMEWORK_OMG_PARSER_PARSER_API_H_ - -#include -#include -#include -#include "ge/ge_api_error_codes.h" - -namespace ge { -// Initialize parser -Status ParserInitialize(const std::map& options); -// Finalize parser, release all resources -Status ParserFinalize(); -} // namespace ge -#endif // INC_FRAMEWORK_OMG_PARSER_PARSER_API_H_ diff --git a/inc/framework/omg/parser/parser_factory.h b/inc/framework/omg/parser/parser_factory.h deleted file mode 100644 index 90d441d7..00000000 --- a/inc/framework/omg/parser/parser_factory.h +++ /dev/null @@ -1,138 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_FRAMEWORK_OMG_PARSER_PARSER_FACTORY_H_ -#define INC_FRAMEWORK_OMG_PARSER_PARSER_FACTORY_H_ - -#include -#include -#include -#include -#include "framework/common/types.h" -#include "framework/omg/omg_inner_types.h" - -using Status = domi::Status; - -namespace domi { -class WeightsParser; -class ModelParser; - -typedef std::shared_ptr (*MODEL_PARSER_CREATOR_FUN)(void); - -// Create modelparser for different frameworks -class ModelParserFactory { - public: - static ModelParserFactory *Instance(); - - /** - * @ingroup domi_omg - * @brief Create a modelparser based on the type entered - * @param [in] type Framework type - * @return Created modelparser - */ - std::shared_ptr CreateModelParser(const domi::FrameworkType type); - - /** - * @ingroup domi_omg - * @brief Register create function - * @param [in] type Framework type - * @param [in] fun ModelParser's create function - */ - void RegisterCreator(const domi::FrameworkType type, MODEL_PARSER_CREATOR_FUN fun); - - protected: - ModelParserFactory() {} - ~ModelParserFactory(); - - private: - std::map creator_map_; -}; // end class ModelParserFactory - -class ModelParserRegisterar { - public: - ModelParserRegisterar(const domi::FrameworkType type, MODEL_PARSER_CREATOR_FUN fun) { - ModelParserFactory::Instance()->RegisterCreator(type, fun); - } - ~ModelParserRegisterar() {} -}; - -// Registration macros for model parsers -#define REGISTER_MODEL_PARSER_CREATOR(type, clazz) \ - std::shared_ptr Creator_##type##_Model_Parser() { \ - std::shared_ptr ptr = nullptr; \ - try { \ - ptr = make_shared(); \ - } catch (...) { \ - ptr = nullptr; \ - } \ - return std::shared_ptr(ptr); \ - } \ - ModelParserRegisterar g_##type##_Model_Parser_Creator(type, Creator_##type##_Model_Parser) - -typedef std::shared_ptr (*WEIGHTS_PARSER_CREATOR_FUN)(void); - -// Create weightsparser for different frameworks -class WeightsParserFactory { - public: - static WeightsParserFactory *Instance(); - - /** - * @ingroup domi_omg - * @brief Create weightsparser based on the type entered - * @param [in] type Framework type - * @return Created weightsparser - */ - std::shared_ptr CreateWeightsParser(const domi::FrameworkType type); - - /** - * @ingroup domi_omg - * @brief Register create function - * @param [in] type Framework type - * @param [in] fun WeightsParser's create function - */ - void RegisterCreator(const domi::FrameworkType type, WEIGHTS_PARSER_CREATOR_FUN fun); - - protected: - WeightsParserFactory() {} - ~WeightsParserFactory(); - - private: - std::map creator_map_; -}; // end class WeightsParserFactory - -class WeightsParserRegisterar { - public: - WeightsParserRegisterar(const domi::FrameworkType type, WEIGHTS_PARSER_CREATOR_FUN fun) { - WeightsParserFactory::Instance()->RegisterCreator(type, fun); - } - ~WeightsParserRegisterar() {} -}; - -// Register macro of weight resolver -#define REGISTER_WEIGHTS_PARSER_CREATOR(type, clazz) \ - std::shared_ptr Creator_##type##_Weights_Parser() { \ - std::shared_ptr ptr = nullptr; \ - try { \ - ptr = make_shared(); \ - } catch (...) { \ - ptr = nullptr; \ - } \ - return std::shared_ptr(ptr); \ - } \ - WeightsParserRegisterar g_##type##_Weights_Parser_Creator(type, Creator_##type##_Weights_Parser) -}; // namespace domi - -#endif // INC_FRAMEWORK_OMG_PARSER_PARSER_FACTORY_H_ diff --git a/inc/framework/omg/parser/parser_inner_ctx.h b/inc/framework/omg/parser/parser_inner_ctx.h deleted file mode 100644 index 53f79895..00000000 --- a/inc/framework/omg/parser/parser_inner_ctx.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_FRAMEWORK_OMG_PARSER_PARSER_INNER_CONTEXT_H_ -#define INC_FRAMEWORK_OMG_PARSER_PARSER_INNER_CONTEXT_H_ - -#include -#include -#include -#include -#include -#include -#include "external/register/register_fmk_types.h" -#include "external/register/register_types.h" -#include "framework/omg/omg_inner_types.h" - -namespace ge { -struct ParserContext { - std::unordered_map> input_dims; - domi::domiTensorFormat_t format = domi::DOMI_TENSOR_ND; - RunMode run_mode = ONLY_PRE_CHECK; - std::string custom_proto_path; // save caffe custom proto path, used by caffe parse - std::string caffe_proto_path; // save caffe proto path, used by caffe parse - std::string enable_scope_fusion_passes; // name of the pass that needs to take effect -}; - -ParserContext &GetParserContext(); -} // namespace ge - -#endif // INC_FRAMEWORK_OMG_PARSER_PARSER_INNER_CONTEXT_H_ diff --git a/inc/framework/omg/parser/weights_parser.h b/inc/framework/omg/parser/weights_parser.h deleted file mode 100644 index 1b5216b3..00000000 --- a/inc/framework/omg/parser/weights_parser.h +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_FRAMEWORK_OMG_PARSER_WEIGHTS_PARSER_H_ -#define INC_FRAMEWORK_OMG_PARSER_WEIGHTS_PARSER_H_ - -#include "graph/graph.h" -#include "graph/attr_value.h" -#include "graph/compute_graph.h" -#include "graph/ge_tensor.h" -#include "graph/op_desc.h" -#include "graph/operator.h" -#include "graph/range_vistor.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" - -namespace domi { -/** - * @ingroup domi_omg - * @brief Weight information resolver - * - */ -class WeightsParser { - public: - /** - * @ingroup domi_omg - * @brief Constructor - */ - WeightsParser() {} - - /** - * @ingroup domi_omg - * @brief Deconstructor - */ - virtual ~WeightsParser() {} - - /** - * @ingroup domi_omg - * @brief Analyze weight data - * @param [in] file Path of weight file after training - * @param [in|out] graph Graph for saving weight information after analysis - * @return SUCCESS - * @return Others failed - */ - virtual Status Parse(const char *file, ge::Graph &graph) = 0; - - /** - * @ingroup domi_omg - * @brief Parse relevant data from memory and save it to graph - * @param [in] input Model file memory data - * @param [in|out] graph A graph for saving the model information after analysis - * @return SUCCESS - * @return FAILED - * @author - */ - virtual Status ParseFromMemory(const char *input, uint32_t lengt, ge::ComputeGraphPtr &graph) = 0; -}; -} // namespace domi - -#endif // INC_FRAMEWORK_OMG_PARSER_WEIGHTS_PARSER_H_ diff --git a/inc/graph/anchor.h b/inc/graph/anchor.h new file mode 100644 index 00000000..565f0843 --- /dev/null +++ b/inc/graph/anchor.h @@ -0,0 +1,284 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_ANCHOR_H_ +#define INC_GRAPH_ANCHOR_H_ + +#include +#include +#include +#include "graph/ge_error_codes.h" +#include "graph/range_vistor.h" +#include "graph/types.h" + +namespace ge { +enum AnchorStatus { + ANCHOR_SUSPEND = 0, // dat null + ANCHOR_CONST = 1, + ANCHOR_DATA = 2, // Effective + ANCHOR_RESERVED = 3 +}; +using std::string; +using std::vector; + +class Node; + +using NodePtr = std::shared_ptr; + +class Edge; + +using EdgePtr = std::shared_ptr; + +class Anchor; + +using AnchorPtr = std::shared_ptr; + +class DataAnchor; + +using DataAnchorPtr = std::shared_ptr; + +class InDataAnchor; + +using InDataAnchorPtr = std::shared_ptr; + +class OutDataAnchor; + +using OutDataAnchorPtr = std::shared_ptr; + +class ControlAnchor; + +using ControlAnchorPtr = std::shared_ptr; + +class InControlAnchor; + +using InControlAnchorPtr = std::shared_ptr; + +class OutControlAnchor; + +using OutControlAnchorPtr = std::shared_ptr; + +using ConstAnchor = const Anchor; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable_shared_from_this { + friend class AnchorUtils; + + public: + using TYPE = const char *; + template + using Vistor = RangeVistor>; + + Anchor(const NodePtr &ownerNode, int idx); + + virtual ~Anchor() = default; + + protected: + // Whether the two anchor is equal + virtual bool Equal(AnchorPtr anchor) const = 0; + virtual bool IsTypeOf(TYPE type) const; + + public: + // Get all peer anchors connected to current anchor + Vistor GetPeerAnchors() const; + // Get peer anchor size + size_t GetPeerAnchorsSize() const; + // Get first peer anchor + AnchorPtr GetFirstPeerAnchor() const; + + // Get the anchor belong to which node + NodePtr GetOwnerNode() const; + + // Remove all links with the anchor + void UnlinkAll() noexcept; + + // Remove link with the given anchor + graphStatus Unlink(const AnchorPtr &peer); + + // Replace peer with new peers + graphStatus ReplacePeer(const AnchorPtr &oldPeer, const AnchorPtr &firstPeer, const AnchorPtr &secondPeer); + + // Judge if the anchor is linked with the given anchor + bool IsLinkedWith(const AnchorPtr &peer); + + // Get anchor index of the node + int GetIdx() const; + + // set anchor index of the node + void SetIdx(int index); + + protected: + // All peer anchors connected to current anchor + vector> peer_anchors_; + // The owner node of anchor + std::weak_ptr owner_node_; + // The index of current anchor + int idx_; + template + static Anchor::TYPE TypeOf() { + static_assert(std::is_base_of::value, "T must be a Anchor!"); + return __PRETTY_FUNCTION__; + } + + public: + template + static std::shared_ptr DynamicAnchorCast(AnchorPtr anchorPtr) { + static_assert(std::is_base_of::value, "T must be a Anchor!"); + if (anchorPtr == nullptr || !anchorPtr->IsTypeOf()) { + return nullptr; + } + return std::static_pointer_cast(anchorPtr); + } + + template + bool IsTypeOf() { + return IsTypeOf(TypeOf()); + } +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY DataAnchor : public Anchor { + friend class AnchorUtils; + + public: + explicit DataAnchor(const NodePtr &ownerNode, int idx); + + virtual ~DataAnchor() = default; + + protected: + bool IsTypeOf(TYPE type) const override; + + private: + Format format_{FORMAT_ND}; + AnchorStatus status_{ANCHOR_SUSPEND}; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchor : public DataAnchor { + friend class OutDataAnchor; + + friend class OutControlAnchor; + + public: + explicit InDataAnchor(const NodePtr &ownerNode, int idx); + + virtual ~InDataAnchor() = default; + + // Get source out data anchor + OutDataAnchorPtr GetPeerOutAnchor() const; + + // Build connection from OutDataAnchor to InDataAnchor + graphStatus LinkFrom(const OutDataAnchorPtr &src); + + protected: + bool Equal(AnchorPtr anchor) const override; + bool IsTypeOf(TYPE type) const override; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchor : public DataAnchor { + friend class InDataAnchor; + + friend class AnchorUtils; + + public: + template + using Vistor = RangeVistor>; + + explicit OutDataAnchor(const NodePtr &ownerNode, int idx); + + virtual ~OutDataAnchor() = default; + // Get dst in data anchor(one or more) + Vistor GetPeerInDataAnchors() const; + uint32_t GetPeerInDataNodesSize() const; + + // Get dst in control anchor(one or more) + Vistor GetPeerInControlAnchors() const; + + // Build connection from OutDataAnchor to InDataAnchor + graphStatus LinkTo(const InDataAnchorPtr &dest); + + // Build connection from OutDataAnchor to InControlAnchor + graphStatus LinkTo(const InControlAnchorPtr &dest); + + protected: + bool Equal(AnchorPtr anchor) const override; + bool IsTypeOf(TYPE type) const override; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ControlAnchor : public Anchor { + public: + explicit ControlAnchor(const NodePtr &ownerNode); + + explicit ControlAnchor(const NodePtr &ownerNode, int idx); + + virtual ~ControlAnchor() = default; + + protected: + bool IsTypeOf(TYPE type) const override; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InControlAnchor : public ControlAnchor { + friend class OutControlAnchor; + + friend class OutDataAnchor; + + public: + explicit InControlAnchor(const NodePtr &ownerNode); + + explicit InControlAnchor(const NodePtr &ownerNode, int idx); + + virtual ~InControlAnchor() = default; + + // Get source out control anchors + Vistor GetPeerOutControlAnchors() const; + bool IsPeerOutAnchorsEmpty() const { return peer_anchors_.empty(); } + + // Get source out data anchors + Vistor GetPeerOutDataAnchors() const; + + // Build connection from OutControlAnchor to InControlAnchor + graphStatus LinkFrom(const OutControlAnchorPtr &src); + + protected: + bool Equal(AnchorPtr anchor) const override; + bool IsTypeOf(TYPE type) const override; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutControlAnchor : public ControlAnchor { + friend class InControlAnchor; + + public: + template + using Vistor = RangeVistor>; + + explicit OutControlAnchor(const NodePtr &ownerNode); + + explicit OutControlAnchor(const NodePtr &ownerNode, int idx); + + virtual ~OutControlAnchor() = default; + + // Get dst in control anchor(one or more) + Vistor GetPeerInControlAnchors() const; + // Get dst data anchor in control anchor(one or more) + Vistor GetPeerInDataAnchors() const; + + // Build connection from OutControlAnchor to InControlAnchor + graphStatus LinkTo(const InControlAnchorPtr &dest); + // Build connection from OutDataAnchor to InDataAnchor + graphStatus LinkTo(const InDataAnchorPtr &dest); + + protected: + bool Equal(AnchorPtr anchor) const override; + bool IsTypeOf(TYPE type) const override; +}; +} // namespace ge +#endif // INC_GRAPH_ANCHOR_H_ diff --git a/inc/graph/attr_value_serializable.h b/inc/graph/attr_value_serializable.h new file mode 100644 index 00000000..a69beb96 --- /dev/null +++ b/inc/graph/attr_value_serializable.h @@ -0,0 +1,191 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ +#define INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ + +#include +#include +#include "graph/ge_attr_value.h" + +namespace ge { + +class GeAttrValue; +class _GeSerializable { + public: + template + struct ge_serializable_int64_t_support_type { + using DT = typename std::remove_cv::type; + static const bool value = std::is_same::value // by cast + || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value; + }; + + template + static GeAttrValue SaveItemAsAttrValue(const T &t) { + return GeAttrValue::CreateFrom(t); + } + + template + static GeAttrValue SaveItemAsAttrValue(const vector &t) { + return GeAttrValue::CreateFrom(t); + } + + template = 0, typename DT = typename std::remove_cv::type> + static GeAttrValue SaveItemAsAttrValue(const T &t) { + return GeAttrValue::CreateFrom
(t); + } + // int64_t support type + template ::value, int>::type = 0> + static GeAttrValue SaveItemAsAttrValue(const T &t) { + return GeAttrValue::CreateFrom(t); + } + // vector int64_t support type + template ::value, int>::type = 0> + static GeAttrValue SaveItemAsAttrValue(const vector &t) { + return GeAttrValue::CreateFrom(t); + } + + template + static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { + return attrVal.GetValue(t); + } + + template + static graphStatus LoadItemFromAttrValue(vector &t, GeAttrValue &attrVal) { + return attrVal.GetValue(t); + } + + template = 0, typename DT = typename std::remove_cv::type> + static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { + return attrVal.GetValue
(t); + } + + template ::value, int>::type = 0> + static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { + return attrVal.GetValue(t); + } + + template ::value, int>::type = 0> + static graphStatus LoadItemFromAttrValue(vector &t, GeAttrValue &attrVal) { + return attrVal.GetValue(t); + } + + template + static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { + GeAttrValue itemVal = SaveItemAsAttrValue(item); + (void)namedAttrs.SetAttr(itemName, itemVal); + SaveItem(namedAttrs, args...); + } + + static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) {} + + template + static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { + auto itemVal = namedAttrs.GetItem(itemName); + auto status = LoadItemFromAttrValue(item, itemVal); + if (status != GRAPH_SUCCESS) { + return status; + } + return LoadItem(namedAttrs, args...); + } + + static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) { + return GRAPH_SUCCESS; + } +}; + +#define _GE_FI(a) #a, a +#define _GE_MAP_FIELDS1(a1) _GE_FI(a1) +#define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2) +#define _GE_MAP_FIELDS3(a1, a2, a3) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3) +#define _GE_MAP_FIELDS4(a1, a2, a3, a4) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4) +#define _GE_MAP_FIELDS5(a1, a2, a3, a4, a5) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5) +#define _GE_MAP_FIELDS6(a1, a2, a3, a4, a5, a6) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6) +#define _GE_MAP_FIELDS7(a1, a2, a3, a4, a5, a6, a7) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7) +#define _GE_MAP_FIELDS8(a1, a2, a3, a4, a5, a6, a7, a8) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8) +#define _GE_MAP_FIELDS9(a1, a2, a3, a4, a5, a6, a7, a8, a9) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9) +#define _GE_MAP_FIELDS10(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10) +#define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ + _GE_FI(a11) +#define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ + _GE_FI(a11), _GE_FI(a12) +#define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ + _GE_FI(a11), _GE_FI(a12), _GE_FI(a13) +#define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ + _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14) +#define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ + _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15) + +#define _GE_PRIVATE_ARGS_GLUE(x, y) x y + +#define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, \ + ...) \ + N +#define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL(args) _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT args +#define _GE_COUNT_MACRO_VAR_ARGS(...) \ + _GE_PRIVATE_MACRO_VAR_ARGS_IMPL((__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)) + +#define _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count) M##count +#define _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count) +#define _GE_PRIVATE_MACRO_CHOOSE_HELPER(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count) + +#define _GE_INVOKE_VAR_MACRO(...) \ + _GE_PRIVATE_ARGS_GLUE(_GE_PRIVATE_MACRO_CHOOSE_HELPER(_GE_MAP_FIELDS, _GE_COUNT_MACRO_VAR_ARGS(__VA_ARGS__)), \ + (__VA_ARGS__)) + +#define GE_SERIALIZABLE(...) \ + public: \ + friend class ge::GeAttrValue; \ + using __ge_serializable = int; \ + \ + private: \ + ge::graphStatus Save(GeAttrValue &ar) const { \ + GeAttrValue::NAMED_ATTRS named_attrs; \ + _GeSerializable::SaveItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ + return ar.SetValue(named_attrs); \ + } \ + ge::graphStatus Load(const GeAttrValue &ar) { \ + GeAttrValue::NAMED_ATTRS named_attrs; \ + ge::graphStatus status = ar.GetValue(named_attrs); \ + if (status != GRAPH_SUCCESS) { \ + return status; \ + } \ + return _GeSerializable::LoadItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ + } + +// end NamedAttrs Helper: GE_SERIALIZABLE +} // namespace ge +#endif // INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ diff --git a/inc/graph/buffer.h b/inc/graph/buffer.h new file mode 100644 index 00000000..ca4355a7 --- /dev/null +++ b/inc/graph/buffer.h @@ -0,0 +1,82 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_BUFFER_H_ +#define INC_GRAPH_BUFFER_H_ + +#include +#include +#include +#include +#include "detail/attributes_holder.h" + +namespace ge { +#ifdef HOST_VISIBILITY +#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_HOST_VISIBILITY +#endif +#ifdef DEV_VISIBILITY +#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_DEV_VISIBILITY +#endif + +using std::shared_ptr; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer { + public: + Buffer(); + Buffer(const Buffer &other); + + explicit Buffer(std::size_t bufferSize, std::uint8_t defualtVal = 0); + + ~Buffer() = default; + + Buffer &operator=(const Buffer &other); + + static Buffer CopyFrom(const std::uint8_t *data, std::size_t bufferSize); + + const std::uint8_t *GetData() const; + std::uint8_t *GetData(); + std::size_t GetSize() const; + void ClearBuffer(); + + // For compatibility + inline const std::uint8_t *data() const { return GetData(); } + inline std::uint8_t *data() { return GetData(); } // lint !e659 + inline std::size_t size() const { return GetSize(); } + inline void clear() { return ClearBuffer(); } + uint8_t operator[](size_t index) const { // lint !e1022 !e1042 + if (buffer_ != nullptr && index < buffer_->size()) { // lint !e574 + return (uint8_t)(*buffer_)[index]; + } + return 0xff; + } + + private: + GeIrProtoHelper data_; + std::string *buffer_ = nullptr; + + // Create from protobuf obj + Buffer(const ProtoMsgOwner &protoOnwer, proto::AttrDef *buffer); + Buffer(const ProtoMsgOwner &protoOnwer, std::string *buffer); + + friend class GeAttrValueImp; + friend class GeTensor; +}; +} // namespace ge +#endif // INC_GRAPH_BUFFER_H_ diff --git a/inc/graph/compute_graph.h b/inc/graph/compute_graph.h new file mode 100644 index 00000000..2ec6b663 --- /dev/null +++ b/inc/graph/compute_graph.h @@ -0,0 +1,308 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_COMPUTE_GRAPH_H_ +#define INC_GRAPH_COMPUTE_GRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include "detail/attributes_holder.h" +#include "graph/anchor.h" +#include "graph/node.h" +#include "graph/op_desc.h" +#include "graph/range_vistor.h" + +namespace ge { +class Node; +using NodePtr = std::shared_ptr; +class Edge; +using EdgePtr = std::shared_ptr; + +class InDataAnchor; +using InDataAnchorPtr = std::shared_ptr; + +class OutDataAnchor; +using OutDataAnchorPtr = std::shared_ptr; + +class ControlAnchor; +using ControlAnchorPtr = std::shared_ptr; +class InControlAnchor; +using InControlAnchorPtr = std::shared_ptr; +class OutControlAnchor; +using OutControlAnchorPtr = std::shared_ptr; +class GeAttrValue; +using AttrValuePtr = std::shared_ptr; +using ConstComputeGraph = const ComputeGraph; + +class OperatorImpl; +using OperatorImplPtr = std::shared_ptr; + +class ComputeGraph : public std::enable_shared_from_this, public AttrHolder { + friend class GraphUtils; + + public: + template + using Vistor = RangeVistor>; + + explicit ComputeGraph(const std::string &name); + ~ComputeGraph() override; + + std::string GetName() const; + void SetName(const std::string &name); + + using AttrHolder::DelAttr; + using AttrHolder::GetAttr; + using AttrHolder::HasAttr; + using AttrHolder::SetAttr; + + size_t GetAllNodesSize() const; + Vistor GetAllNodes() const; + // is_unknown_shape: false, same with GetAllNodes func + // is_unknown_shape: true, same with GetDirectNodes func + Vistor GetNodes(bool is_unknown_shape) const; + size_t GetDirectNodesSize() const; + Vistor GetDirectNode() const; + Vistor GetInputNodes() const; + Vistor GetOutputNodes() const; + + NodePtr FindNode(const std::string &name) const; + NodePtr FindFirstNodeMatchType(const std::string &name) const; + /*lint -e504*/ + // AddNode with NodePtr + NodePtr AddNode(NodePtr node); + NodePtr AddNode(OpDescPtr op); + NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize + NodePtr AddNodeFront(NodePtr node); + NodePtr AddNodeFront(const OpDescPtr &op); + NodePtr AddInputNode(NodePtr node); + NodePtr AddOutputNode(NodePtr node); + NodePtr AddOutputNodeByIndex(NodePtr node, int32_t index); + // insert node with specific pre_node + NodePtr AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node); + NodePtr AddNodeAfter(NodePtr node, const NodePtr &pre_node); + + graphStatus RemoveNode(const NodePtr &node); + graphStatus RemoveInputNode(const NodePtr &node); + graphStatus RemoveOutputNode(const NodePtr &node); + graphStatus RemoveConstInput(const NodePtr &node); + + /// Add a subgraph to this graph. The subgraph must has a parent graph and parent node, + /// which means the member functions `SetParentGraph` and `SetParentNode` of the subgraph + /// must be called before add it to the root graph. and subgraph->GetParentNode()->GetOwnerGraph() + /// must equal to subgraph->GetOwnerGraph(). + /// The subgraphs can only be added to a *root graph*. A root graph is a graph without any parent graph. + /// The subgraph's name SHOULD(not must) be the same as the parameter `name` + graphStatus AddSubgraph(const std::string &name, const std::shared_ptr &subgraph); + graphStatus AddSubgraph(const std::shared_ptr &subgraph); + + void RemoveSubgraph(const std::string &name); + void RemoveSubgraph(const std::shared_ptr &subgraph); + + std::shared_ptr GetSubgraph(const std::string &name) const; + std::vector> GetAllSubgraphs() const; + + // obsolete + std::shared_ptr AddSubGraph(std::shared_ptr sub_graph); + // obsolete + graphStatus RemoveSubGraph(const std::shared_ptr &sub_graph); + + /// + /// @brief Update input-mapping + /// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input + /// @return graphStatus + /// + graphStatus UpdateInputMapping(const std::map &input_mapping); + + /// + /// @brief Update output-mapping + /// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output + /// @return graphStatus + /// + graphStatus UpdateOutputMapping(const std::map &output_mapping); + + graphStatus TopologicalSorting(); + bool IsValid() const; + void InValid() { is_valid_flag_ = false; } + void Dump() const; + + void Swap(ComputeGraph &graph); + + graphStatus IsolateNode(const NodePtr &node); + graphStatus Verify(); + graphStatus InferShape(); + graphStatus InferOriginFormat(); + graphStatus InferShapeInNeed(); + graphStatus InsertEventNodes(); + bool operator==(const ComputeGraph &r_compute_graph) const; + + /*lint +e504*/ + const std::map, std::vector> &GetShareParamLayer() const { + return params_share_map_; + } + + void SetShareParamLayer(const std::map, std::vector> params_share_map) { + params_share_map_ = params_share_map; + } + + void SetInputsOrder(const std::vector &inputs_order) { inputs_order_ = inputs_order; } + + void SetGraphOutNodes(std::map> out_nodes_map) { out_nodes_map_ = out_nodes_map; } + + void AppendGraphOutNodes(std::map> out_nodes_map) { + for (auto &item : out_nodes_map) { + (void)out_nodes_map_.emplace(item.first, item.second); + } + } + + shared_ptr GetParentGraph(); + void SetParentGraph(const shared_ptr &parent); + shared_ptr GetParentNode(); + void SetParentNode(const shared_ptr &parent); + + const std::map> &GetGraphOutNodes() const { return out_nodes_map_; } + + void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; } + + ComputeGraphPtr GetOrigGraph(void) { return origGraph_; } + void SetOutputSize(uint32_t size) { output_size_ = size; } + uint32_t GetOutputSize() const { return output_size_; } + void SetInputSize(uint32_t size) { input_size_ = size; } + uint32_t GetInputSize() const { return input_size_; } + + // false: known shape true: unknow shape + bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; } + void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; } + + /// + /// Set is need train iteration. + /// If set true, it means this graph need to be run iteration some + /// times(according variant "npu_runconfig/iterations_per_loop"). + /// @param need_iteration is need iteration + /// + void SetNeedIteration(bool need_iteration) { need_iteration_ = need_iteration; } + + void SetUserDefOutput(const std::string &output_name); + + const std::string GetOutput(); + + /// + /// Get is need train iteration. + /// @return is need iteration + /// + bool GetNeedIteration() const { return need_iteration_; } + + void SetGraphOpName(const std::map &op_name_map) { op_name_map_ = op_name_map; } + const std::map &GetGraphOpName() const { return op_name_map_; } + + const std::map &GetAllNodesInfo() const; + + void SetAllNodesInfo(const std::map &nodes) { all_nodes_infos_ = nodes; } + + void SetGraphOutNodesInfo(std::vector> &out_nodes_info) { + output_nodes_info_ = out_nodes_info; + } + + void AppendGraphOutNodesInfo(std::vector> &out_nodes_info) { + output_nodes_info_.insert(output_nodes_info_.end(), out_nodes_info.begin(), out_nodes_info.end()); + } + + const std::vector> &GetGraphOutNodesInfo() const { return output_nodes_info_; } + + void SetGraphTargetNodesInfo(const std::vector &target_nodes_info) { + target_nodes_info_ = target_nodes_info; + } + const std::vector &GetGraphTargetNodesInfo() const { return target_nodes_info_; } + + void SetSessionID(uint64_t session_id) { session_id_ = session_id; } + uint64_t GetSessionID() const { return session_id_; } + + void SetGraphID(uint32_t graph_id) { graph_id_ = graph_id; } + uint32_t GetGraphID() const { return graph_id_; } + + void SaveDataFormat(ge::Format data_format) { data_format_ = data_format; } + ge::Format GetDataFormat() const { return data_format_; } + bool IsSummaryGraph() const { return is_summary_graph_; } + void SetSummaryFlag(bool is_summary_graph) { is_summary_graph_ = is_summary_graph; } + // Graph Before BFE + ComputeGraphPtr origGraph_; + + protected: + ProtoAttrMapHelper MutableAttrMap() override; + ConstProtoAttrMapHelper GetAttrMap() const override; + + private: + graphStatus DFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, + std::vector &stack); + graphStatus BFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, + std::deque &stack); + graphStatus CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, + std::map &breadth_node_map); + graphStatus TopologicalSortingGraph(); + graphStatus SortNodes(std::vector &stack, std::map &mapInEdgeNum); + Vistor AllGraphNodes(std::vector> &subgraphs) const; + size_t GetInEdgeSize(const NodePtr &node); + size_t GetOutEdgeSize(const NodePtr &node); + graphStatus RemoveExtraOutEdge(const NodePtr &node); + bool GraphMembersAreEqual(const ComputeGraph &r_graph) const; + bool GraphAttrsAreEqual(const ComputeGraph &r_graph) const; + bool VectorInputNodePtrIsEqual(const std::vector &r_node_ptr_vector, + const std::vector &l_node_ptr_vector) const; + + void SetNodesOwner(); + + friend class ModelSerializeImp; + friend class GraphDebugImp; + friend class OnnxUtils; + friend class TuningUtils; + + std::string name_; + uint32_t graph_id_ = 0; + ProtoAttrMapHelper attrs_; + std::vector nodes_; + std::map all_nodes_infos_; + std::vector target_nodes_info_; + + std::vector input_nodes_; + std::vector inputs_order_; + uint32_t input_size_ = 1; + std::map> out_nodes_map_; + uint32_t output_size_ = 1; + std::vector> output_nodes_info_; + + std::vector> sub_graph_; + std::map> names_to_subgraph_; + std::weak_ptr parent_graph_; + std::weak_ptr parent_node_; + + // the members followed should not in the ComputeGraph class + bool is_valid_flag_; + bool is_summary_graph_ = false; + // Indicates whether it is need iteration + bool need_iteration_ = false; + std::map, std::vector> params_share_map_; + // TaskIdx -> op_name Map + std::map op_name_map_; + uint64_t session_id_ = 0; + ge::Format data_format_ = ge::FORMAT_ND; + // unknown graph indicator, default is false, mean known shape + bool is_unknown_shape_graph_ = false; +}; +} // namespace ge +#endif // INC_GRAPH_COMPUTE_GRAPH_H_ diff --git a/inc/graph/debug/ge_attr_define.h b/inc/graph/debug/ge_attr_define.h new file mode 100644 index 00000000..47b11ba8 --- /dev/null +++ b/inc/graph/debug/ge_attr_define.h @@ -0,0 +1,1130 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*lint -e618*/ +#ifndef INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ +#define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ + +#include +#include "graph/types.h" + +namespace ge { +#ifdef HOST_VISIBILITY +#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_HOST_VISIBILITY +#endif +#ifdef DEV_VISIBILITY +#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_DEV_VISIBILITY +#endif +// Public attribute +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_UNKNOWN_SHAPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WORKSPACE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHT_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_QUANTIZE_FACTOR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ALPHA; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BETA; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADMODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADMODES; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILTER; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SCALE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WINDOWS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GLOBAL_POOLING; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELU_FLAG; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ALGO; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORMAT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_FORMAT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_SHAPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILTER_FORMAT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_K; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_NORM_REGION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_LOCAL_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_ALPHA; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_BETA; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROADCAST; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TIDX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TPADDINGS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_IMG_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_IMG_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NET_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NET_W; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TMULTIPLES; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTIPLES; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_T; + +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string ATTR_NAME_N; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TSHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_INPUTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_OUTPUTS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_AIPP_INPUT_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_RELATED_AIPP_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_AIPP_DATA_NAME_MAP; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRAPH_HAS_BEEN_ADDED; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_GRAPH_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_NODE_DEF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_OP_DEF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_TENSOR_DESC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_TENSOR_DESC; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INFERRED_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_ORIGIN_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_INPUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_OUTPUT; + +// to be deleted +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_LOC_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_CONF_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_OCR_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; + +// _Arg +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INDEX; +// _RetVal +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETVAL_ATTR_NAME_INDEX; +// Data +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DATA_ATTR_NAME_DATA_TYPE; + +// Send +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SEND_ATTR_EVENT_ID; + +// Recv +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RECV_ATTR_EVENT_ID; + +// Convolution +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COEF; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDES; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATIONS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_ALGO; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_GROUP; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_PAD_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_PAD; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_STRIDE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_DILATION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_NUM_OUTPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_KERNEL; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_FILTER; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_BIAS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_RELU_FLAG; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_ADJ; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_TARGET_SHAPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_BEFORE_PAD; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_HAS_BIAS; + +// Pooling +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_NAN_OPT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_PAD_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_GLOBAL_POOLING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_WINDOW; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_PAD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_STRIDE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_CEIL_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_DATA_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_BEFORE_PAD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_NAME_ALGO; + +// Eltwise +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_COEFF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_WEIGHT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_RELU_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_ALPHA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_BETA; + +// BatchNorm +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_EPSILON; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_MEAN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; + +// Huberloss +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; + +// SSDRealDivTileMul +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; + +// SSDSumMulRealDivMean +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; +/// ConcatFive2Four +/// ConcatFour2Five +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; +// Scale +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; + +// FullConnection +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_FILTER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_BIAS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_RELU_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_ATTR_NAME_ALGO; + +// SoftmaxOpParams +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_ALGO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_MODE; + +// SparseSoftmaxCrossEntropy +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE; +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD; +// Attr labelSmoothing +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING; + +// ApplyMomentum +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string APPLYMENTUM_ATTR_IS_GRAPH_FUSION; + +// Activation +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ACTIVATION_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ACTIVATION_ATTR_COEF; + +// Concat +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_ATTR_NAME_AXIS; + +// Const +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_DATA_TRANSTYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; + +// Roipooling +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLED_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLED_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO; + +// DetectionOutput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IMG_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IMG_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE; +// Ssd DetectionOutput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_ETA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K; + +// Refinedet DetectionOutput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE; + +// Yolo DetectionOutput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_ClASSES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BIASES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_RELATIVE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION; + +// DetectionPostprocess +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_CLS_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_POST_NMS_TOPN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT; + +// Spatialtransfrom +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_OUTPUT_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_OUTPUT_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM; + +// Proposal +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_BASE_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_MIN_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_RATIO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_NMS_THRESH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_TOP_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_IMG_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_IMG_W; +// Softmax +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_AXIS; + +// Permute +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; + +// SSD Normalize +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_EPS; + +// Flatten +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_ATTR_END_AXIS; + +// SsdPRIORBOX +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_FLIP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_CLIP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_IMG_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_IMG_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_STEP_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_STEP_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_VARIANCE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM; + +// PRelu +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PRELU_ATTR_CHANNEL_SHARED; + +// Psroi pooling +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_OUTPUT_DIM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_GROUP_SIZE; + +// Power +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_POWER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; + +// Log +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; +// Pack +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; + +// Dynamic stitch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; +// Unpack +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; +// Gathernd +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND_ATTR_NAME_TINDICES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND_ATTR_NAME_TPARAMS; + +// Argmax +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; + +// Upsample +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; +// Relu +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; + +// FreeSpaceExtract +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT; + +// Split +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_SLICE_POINT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_SIZE_SPLIT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_NUM_SPLIT; + +// Tvm +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_MAGIC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_BLOCKDIM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_METADATA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_WORKSPACE_TYPE; + +// Squeeze +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_OP_NAME; + +// Stride slice +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_BEGIN_MASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_END_MASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK; + +// Slice +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SLICE_ATTR_NAME_BEGINS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SLICE_ATTR_NAME_SIZES; + +// Roialign +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SPATIAL_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W; + +// Generate_rpn_proposal +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string + GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string + GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH; +// Decode_bbox +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DECODE_BBOX_ATTR_DECODECLIP; + +// Cast +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_DSTT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_SRCT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_DST_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_TRUNCATE; + +// Fastrcnnn predications +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES; + +// REORG +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_ATTR_STRIDE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_ATTR_REVERSE; + +// MERGE +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; +static const std::string NOT_NET_OUTPUT = "not_net_output"; + +// ENTER +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_CONSTANT_FLAG; + +// Concatv2 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_V2_ATTR_TIDX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_V2_ATTR_N; +// SUM +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_TIDX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_KEEP_DIMS; + +// ResizeBilinear +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_HEIGHT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_WIDTH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_PAD_END; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; + +// RetinaNet +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; +// MatMul +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_HAS_BIAS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_ATTR_IS_TRAINING; + +// Flatten +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_START_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_END_AXIS; + +// Reshape +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NUM_AXES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_ALPHA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_BETA; + +// Frameoworkop +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string T_IN_DATATYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string T_OUT_DATATYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_N; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_C; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_PAD_DEPTH_CONV; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_PAD_CONV; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BEFORE_PAD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ANN_MEAN_KEEPDIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_ATTR_PADDINGDS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_ATTR_CONSTANT_VALUE; + +// ConvGradFilter +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE; +// ConvGradInput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE; + +// Rnn +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_MODE_STATIC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MUTI_RNN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CELL_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CNN_RNN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; + +// Upsample +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; + +// PadV2 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; + +// MirrorPad +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; +// Filler +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; + +// Shufflechannel +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHUFFLE_CHANNEL_GROUP; + +// TopKV2 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TOPKV2_ATTR_K; + +// Calibaration +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_H_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_W_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_TOP_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_BOTTOM_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_RIGHT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATION_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_EPSILON; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_POOLING_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CLASS_NUM; +// Model +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TARGET_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_STREAM_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_HUGE_STREAM_LIST; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OUT_NODES_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; + +// Public attribute +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BYTE_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_INFERENCE_ID; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_OPDEF; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IO_OP; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_SCOPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OPATTR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUFLAG; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SEQLEN_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_X_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONT_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_XSTATIC_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_MINI; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_TINY; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_LITE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_OUTPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; + +// Used for operators that do not generate task +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOTASK; + +// Used for operators that output reuse input +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_REUSE_INPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATC_VERSION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OPP_VERSION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; + +// L2_normalize +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; +// HCOM +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; + +// Log time stamp +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; +// SpaceToDepth/DepthToSpace +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; + +// SparseSoftmaxCrossEntropyWithLogits +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; + +// MaxPoolGradWithArgmax +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; + +// AvgPoolGrad +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; + +// Varible +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; + +// Assign +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VAR_NAME; + +// ShapeN +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; + +// Space2bacth batch2space +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; +// Depth_to_space space_to_depth +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; +// FakeQuantWithMinMaxVars +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; +// Mobilenet_ssd_conv_fusion +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; + +// Lsh project +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; + +// Control flow +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; + +// GatherV2 attr def +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; + +// Reshape attr def +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; + +// Axis attr def +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; +// The node link with SparseSoftmaxCrossEntropyWithLogits +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; +// For constant folding +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; + +// Used for mark the active label list to find stream of activated node +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE; + +// Multi batch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_LABEL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_BATCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_USER_DESIGNEATE_SHAPE_ORDER; + +// Control flow +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_DATA_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ORIG_NODE_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEXT_ITERATION; + +// Function Op +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; + +// Used for mark the active node is for loop, type:bool +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_INPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_OUTPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_RANGE; + +// Atomic addr clean attrs +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_INPUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_OUTPUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_IS_FUSION_NODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_IS_ATOMIC_NODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string EXT_ATTR_ATOMIC_WORKSPACE_INFO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string EXT_ATTR_ATOMIC_WORKSPACE_OFFSET; +// Used for find variable session_id +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MODEL_ATTR_SESSION_ID; + +// Source/dst format for Op FormatTransfer +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FORMAT_TRANSFER_SRC_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FORMAT_TRANSFER_DST_FORMAT; + +// For compile op by ge call +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NEED_COMPILE; + +// For mutil-batch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERT_BY_MBATCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_TYPE; + +// For inserted op +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; + +// For compress weight +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMPRESS_WEIGHT; + +// For data dump +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_SUB_SPLITER_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_GROUP_OP_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; + +// used for lX fusion +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_DUMP_REF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_ADDR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ENGINE_NAME_FOR_LX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEED_LX_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OPTIMIZE_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_COMPILE_STRATEGY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_BUFFER; + +// for unregistered op +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_OPPATH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_ATTRLIST; + +// op overflow dump +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_MODE; + +// op dynamic input +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_INPUT_START; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_INPUT_END; + +// functional ops attr +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_ELSE_BRANCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; + +// used for label switch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_END_NODE; + +// Variable +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; + +// HCOM +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DATATYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_DATATYPE; +// used for LX tiling +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_L1_SPACE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_TYPE_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST; + +// Dynamic stitch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; + +// Used for support Horovod +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INTER_EVENT_IDENTIFY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE; +// for gradient group +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG; + +// dynamic shape attrs +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX; + +// atc user def dtype&format +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_DATATYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_FORMAT; + +// for fusion op plugin +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; + +// graph partition for aicpu +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME; + +// input and output memory type +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_VARIABLE_PLACEMENT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INPUT_MEMORY_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OUTPUT_MEMORY_TYPE; + +// input_output_offset +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_BASIC_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET; +} // namespace ge + +#endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ +/*lint +e618*/ diff --git a/inc/graph/def_types.h b/inc/graph/def_types.h new file mode 100644 index 00000000..6d70fb18 --- /dev/null +++ b/inc/graph/def_types.h @@ -0,0 +1,195 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_DEF_TYPES_H_ +#define INC_GRAPH_DEF_TYPES_H_ + +#include +#include +#include +#include "graph/attr_value_serializable.h" +#include "graph/buffer.h" +namespace ge { +#define DEF_TYPE_DEC(type, name) \ + inline void set_##name(const type &value) { name = value; } \ + type *mutable_##name() { return &name; } + +#define DEF_TYPE_HAS_DEC(type, name) \ + inline void set_##name(const type &value) { name = value; } \ + \ + private: \ + bool has_mutable_##name{false}; \ + \ + public: \ + bool has_##name() const { return (has_mutable_##name) || QuantizeFactorHasData(name); } \ + type *mutable_##name() { \ + has_mutable_##name = true; \ + return &name; \ + } + +#define DEF_TYPE_VEC_DEC(type, name) \ + inline int name##_size() const { return name.size(); } \ + inline void clear_##name() { name.clear(); } \ + inline void set_##name(int index, type value) { name[index] = value; } \ + inline void add_##name(type value) { name.push_back(value); } \ + inline std::vector *mutable_##name() { return &name; } + +#define DEF_TYPE_BYTES_DEC(name) \ + inline void clear_##name() { name.ClearBuffer(); } \ + inline void set_##name(const void *value, size_t size) { name = Buffer::CopyFrom((const uint8_t *)(value), size); } \ + inline Buffer *mutable_##name() { return &name; } + +struct CompressInfo { + public: + CompressInfo() {} + CompressInfo(int32_t blockRow, int32_t blockCol, int32_t fractalK, int32_t fractalN, int32_t lastFractalK, + int32_t lastFractalN, int32_t cubeSize, int32_t loadDir) { + blockrow = blockRow; + blockcol = blockCol; + fractalk = fractalK; + fractaln = fractalN; + lastfractalk = lastFractalK; + lastfractaln = lastFractalN; + cubesize = cubeSize; + loaddir = loadDir; + } + + int32_t blockrow{0}; // Block row + int32_t blockcol{0}; // Block col + int32_t fractalk{0}; // Fractal K + int32_t fractaln{0}; // Fractal N + int32_t lastfractalk{0}; // K of last fractal + int32_t lastfractaln{0}; // N of last fractal + int32_t cubesize{0}; // Cube's length + int32_t loaddir{0}; // Data load directtiono 0:col load 1:row load + DEF_TYPE_DEC(int32_t, blockrow); + DEF_TYPE_DEC(int32_t, blockcol); + DEF_TYPE_DEC(int32_t, fractalk); + DEF_TYPE_DEC(int32_t, fractaln); + DEF_TYPE_DEC(int32_t, lastfractalk); + DEF_TYPE_DEC(int32_t, lastfractaln); + DEF_TYPE_DEC(int32_t, cubesize); + DEF_TYPE_DEC(int32_t, loaddir); + + GE_SERIALIZABLE(blockrow, blockcol, fractalk, fractaln, lastfractalk, lastfractaln, cubesize, loaddir); +}; + +enum QuantizeScaleType { VECTOR_SCALE = 0, SCALAR_SCALE = 1 }; +enum QuantizeScaleMode { NORMAL_MODE = 0, SQRT_MODE = 1 }; +enum QuantizeAlgorithm { + NON_OFFSET_ALGO = 0, + HALF_OFFSET_ALGO = 1, + ALL_OFFSET_ALGO = 2, +}; +struct QuantizeFactor { + public: + // QuantizeScaleMode scale_mode; + uint32_t scale_mode{0}; + Buffer scale_value; + int64_t scale_offset{0}; + Buffer offset_data_value; + int64_t offset_data_offset{0}; + Buffer offset_weight_value; + int64_t offset_weight_offset{0}; + Buffer offset_pad_value; + int64_t offset_pad_offset{0}; + + DEF_TYPE_DEC(uint32_t, scale_mode); + DEF_TYPE_BYTES_DEC(scale_value); + + DEF_TYPE_DEC(int64_t, scale_offset); + DEF_TYPE_BYTES_DEC(offset_data_value); + DEF_TYPE_DEC(int64_t, offset_data_offset); + + DEF_TYPE_BYTES_DEC(offset_weight_value); + DEF_TYPE_DEC(int64_t, offset_weight_offset); + DEF_TYPE_BYTES_DEC(offset_pad_value); + DEF_TYPE_DEC(int64_t, offset_pad_offset); + + GE_SERIALIZABLE(scale_mode, scale_value, scale_offset, offset_data_value, offset_data_offset, offset_weight_value, + offset_weight_offset, offset_pad_value, offset_pad_offset) +}; + +static inline bool QuantizeFactorHasData(const QuantizeFactor &factor) { + return factor.scale_value.GetSize() > 0 || factor.offset_data_value.GetSize() > 0 || + factor.offset_weight_value.GetSize() > 0 || factor.offset_pad_value.GetSize() > 0; +} + +struct AllOffsetQuantizeInfo { + public: + AllOffsetQuantizeInfo() {} + AllOffsetQuantizeInfo(float s, int32_t o) : scale(s), offset(o) {} + float scale{0}; + int32_t offset{0}; + + DEF_TYPE_DEC(float, scale); + DEF_TYPE_DEC(int32_t, offset); + + GE_SERIALIZABLE(scale, offset) +}; + +struct QuantizeCalcFactor { + public: + Buffer offsetw; + int64_t offsetw_offset{0}; + Buffer offsetd; + int64_t offsetd_offset{0}; + Buffer scalereq; + int64_t scaledreq_offset{0}; + Buffer offsetdnext; + int64_t offsetdnext_offset{0}; + + DEF_TYPE_BYTES_DEC(offsetw); + DEF_TYPE_DEC(int64_t, offsetw_offset); + DEF_TYPE_BYTES_DEC(offsetd); + DEF_TYPE_DEC(int64_t, offsetd_offset); + DEF_TYPE_BYTES_DEC(scalereq); + DEF_TYPE_DEC(int64_t, scaledreq_offset); + DEF_TYPE_BYTES_DEC(offsetdnext); + DEF_TYPE_DEC(int64_t, offsetdnext_offset); + + GE_SERIALIZABLE(offsetw, offsetw_offset, offsetd, offsetd_offset, scalereq, scaledreq_offset, offsetdnext, + offsetdnext_offset); +}; + +static inline bool QuantizeFactorHasData(const QuantizeCalcFactor &factor) { + return factor.offsetw.GetSize() > 0 || factor.offsetd.GetSize() > 0 || factor.scalereq.GetSize() > 0 || + factor.offsetdnext.GetSize() > 0; +} + +struct QuantizeFactorParams { + uint32_t quantize_algo{0}; + uint32_t scale_type{0}; + QuantizeFactor quantize_param; + QuantizeFactor dequantize_param; + QuantizeFactor requantize_param; + QuantizeCalcFactor quantizecalc_param; + DEF_TYPE_DEC(uint32_t, quantize_algo); + DEF_TYPE_DEC(uint32_t, scale_type); + DEF_TYPE_HAS_DEC(QuantizeFactor, quantize_param); + DEF_TYPE_HAS_DEC(QuantizeFactor, dequantize_param); + DEF_TYPE_HAS_DEC(QuantizeFactor, requantize_param); + DEF_TYPE_HAS_DEC(QuantizeCalcFactor, quantizecalc_param); + + GE_SERIALIZABLE(quantize_algo, scale_type, quantize_param, dequantize_param, requantize_param, quantizecalc_param, + has_mutable_quantize_param, has_mutable_dequantize_param, has_mutable_requantize_param, + has_mutable_quantizecalc_param); +}; + +#undef DEF_TYPE_DEC +} // namespace ge + +#endif // INC_GRAPH_DEF_TYPES_H_ diff --git a/inc/graph/detail/any_map.h b/inc/graph/detail/any_map.h new file mode 100644 index 00000000..70533ea1 --- /dev/null +++ b/inc/graph/detail/any_map.h @@ -0,0 +1,120 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_DETAIL_ANY_MAP_H_ +#define INC_GRAPH_DETAIL_ANY_MAP_H_ + +#include +#include +#include +#include + +namespace ge { +using std::shared_ptr; +using std::string; + +class TypeID { + public: + template + static TypeID Of() { + return TypeID(__PRETTY_FUNCTION__); + } + + ~TypeID() = default; + + bool operator==(const TypeID &__arg) const { return type_ == __arg.type_; } + + private: + explicit TypeID(string type) : type_(std::move(type)) {} // lint !e30 !e32 + + string type_; +}; + +class AnyMap { + public: + template + bool Set(const string &name, const DT &val); + + template + bool Get(const string &name, T &retValue) const; + + bool Has(const string &name) const { return anyValues_.find(name) != anyValues_.end(); } + + void Swap(AnyMap &other) { anyValues_.swap(other.anyValues_); } + + private: + class Placeholder { + public: + virtual ~Placeholder() = default; + + virtual const TypeID &GetTypeInfo() const = 0; + }; + + template + class Holder : public Placeholder { + public: + explicit Holder(const VT &value) : value_(value) {} + + ~Holder() override = default; + + const TypeID &GetTypeInfo() const override { + static const TypeID typeId = TypeID::Of(); + return typeId; + } + + const VT value_; + }; + + std::map> anyValues_; +}; + +template +bool AnyMap::Set(const string &name, const DT &val) { + auto it = anyValues_.find(name); + + std::shared_ptr> tmp; + try { + tmp = std::make_shared>(val); + } catch (std::bad_alloc &e) { + tmp = nullptr; + } catch (...) { + tmp = nullptr; + } + + if (it == anyValues_.end()) { + (void)anyValues_.emplace(name, tmp); + } else { + if (it->second && it->second->GetTypeInfo() == TypeID::Of
()) { + it->second = tmp; + } else { + return false; + } + } + return true; +} + +template +bool AnyMap::Get(const string &name, T &retValue) const { + auto it = anyValues_.find(name); + if (it != anyValues_.end() && it->second && it->second->GetTypeInfo() == TypeID::Of()) { + auto retPtr = std::static_pointer_cast>(it->second); + retValue = retPtr->value_; + return true; + } + return false; +} +} // namespace ge +#endif // INC_GRAPH_DETAIL_ANY_MAP_H_ diff --git a/inc/graph/detail/attributes_holder.h b/inc/graph/detail/attributes_holder.h new file mode 100644 index 00000000..49741143 --- /dev/null +++ b/inc/graph/detail/attributes_holder.h @@ -0,0 +1,165 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ +#define INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ + +#include +#include +#include +#include +#include +#include +#include "graph/detail/any_map.h" +#include "graph/ge_error_codes.h" +#include "graph/types.h" + +namespace google { +namespace protobuf { +class Message; +template +class Map; +} // namespace protobuf +} // namespace google + +namespace ge { +using std::string; +class GeAttrValue; + +namespace proto { +class AttrDef; +class TensorDef; +class TensorDescriptor; +class ShapeDef; +class NamedAttrs; +class ModelDef; +class OpDef; +class GraphDef; +} // namespace proto + +using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>; // lint !e1073 +using ProtoMsgOwner = std::shared_ptr<::google::protobuf::Message>; + +template +class GeIrProtoHelper { + public: + GeIrProtoHelper(const ProtoMsgOwner &protoOwner, ProtoType *protoMsg) + : protoOwner_(protoOwner), protoMsg_(protoMsg) {} + + GeIrProtoHelper() { + protoOwner_ = std::shared_ptr<::google::protobuf::Message>(nullptr); + protoMsg_ = nullptr; + } + virtual ~GeIrProtoHelper() = default; + + template + GeIrProtoHelper(const GeIrProtoHelper &other) { + protoOwner_ = other.protoOwner_; + protoMsg_ = other.protoMsg_; + } + template + GeIrProtoHelper &operator=(const GeIrProtoHelper &other) { + protoOwner_ = other.protoOnwer_; + protoMsg_ = other.protoMsg_; + return *this; + } + void InitDefault(); + template + bool operator==(const GeIrProtoHelper &other) const { + return protoOwner_ == other.protoOwner_ && protoMsg_ == other.protoMsg_; + } + + inline const ProtoMsgOwner &GetProtoOwner() const { return protoOwner_; } + inline ProtoType *GetProtoMsg() const { return protoMsg_; } + void CopyValueFrom(const GeIrProtoHelper &other) { + if (other.protoMsg_ != nullptr && protoMsg_ != nullptr) { + *protoMsg_ = *other.protoMsg_; + } + } + void MoveValueFrom(GeIrProtoHelper &&other) { + if (other.protoMsg_ != nullptr && protoMsg_ != nullptr) { + *protoMsg_ = std::move(*other.protoMsg_); + } + } + + void Swap(GeIrProtoHelper &other) { + protoOwner_.swap(other.protoOwner_); + + ProtoType *temp = protoMsg_; + protoMsg_ = other.protoMsg_; + other.protoMsg_ = temp; + } + + // protoMsg_ is part of protoOwner_, they have the same runtime + ProtoMsgOwner protoOwner_ = nullptr; + ProtoType *protoMsg_ = nullptr; + friend class GeIrProtoHelper::value, typename std::remove_const::type, const ProtoType>::type>; +}; + +using ProtoAttrMapHelper = GeIrProtoHelper; +using ConstProtoAttrMapHelper = GeIrProtoHelper; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { + public: + AttrHolder() = default; + virtual ~AttrHolder() = default; + + graphStatus SetAttr(const string &name, const GeAttrValue &value); + + graphStatus GetAttr(const string &name, GeAttrValue &value) const; + + bool HasAttr(const string &name) const; + + graphStatus DelAttr(const string &name); + + void CopyAttrsFrom(const AttrHolder &holder); + + void Swap(AttrHolder &holder) { + requiredAttrs_.swap(holder.requiredAttrs_); + extAttrs_.Swap(holder.extAttrs_); + } + + template + bool SetExtAttr(const string &name, const T &value) { + return extAttrs_.Set(name, value); + } + template + T TryGetExtAttr(const string &name, T defaultValue) const { + T ret(defaultValue); + (void)extAttrs_.Get(name, ret); + return ret; + } + + protected: + graphStatus AddRequiredAttr(const std::string &name); + const std::unordered_set GetAllAttrNames() const; + const std::map GetAllAttrs() const; // lint !e1073 + + virtual ProtoAttrMapHelper MutableAttrMap() = 0; + virtual ConstProtoAttrMapHelper GetAttrMap() const = 0; + + friend class ModelSerializeImp; + friend class AttrUtils; + friend class AttrUtilsHelper; + + std::vector requiredAttrs_; + + private: + AnyMap extAttrs_; +}; +} // namespace ge +#endif // INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ diff --git a/inc/graph/detail/model_serialize_imp.h b/inc/graph/detail/model_serialize_imp.h new file mode 100644 index 00000000..ff27335a --- /dev/null +++ b/inc/graph/detail/model_serialize_imp.h @@ -0,0 +1,93 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ +#define INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ + +#include +#include +#include +#include +#include "graph/anchor.h" +#include "graph/detail/attributes_holder.h" +#include "graph/ge_tensor.h" +#include "graph/graph.h" +#include "graph/node.h" + +namespace ge { +using ComputeGraphPtr = std::shared_ptr; + +struct NodeNameGraphReq { + string node_name; + int32_t index; + ComputeGraphPtr graph; +}; + +struct NodeNameNodeReq { + string src_node_name; + int32_t src_out_index; + NodePtr dst_node; + int32_t dst_in_index; + string dst_node_name; +}; + +class ModelSerializeImp { + public: + bool SerializeModel(const Model &model, proto::ModelDef *modeProto, bool is_dump = false); + + bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto, bool is_dump = false); + + bool SerializeEdge(const NodePtr &node, proto::OpDef *opDefProto); + + bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto, bool is_dump = false); + + bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto, bool is_dump = false); + + bool SerializeTensor(const ConstGeTensorPtr &tensor, proto::TensorDef *tensorProto); + + bool UnserializeModel(Model &model, proto::ModelDef &modeProto); + + bool UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graphProto); + + bool UnserializeGraph(ComputeGraphPtr &graph, proto::GraphDef &graphProto); + + bool HandleNodeNameRef(); + + bool UnserializeOpDesc(OpDescPtr &opDesc, proto::OpDef &opDefProto); + void AttrDefToOpDesc(OpDescPtr &op_desc, std::vector &key_in, std::vector &key_out, + std::vector &value_in, std::vector &value_out, std::vector &opt); + void OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto); + + bool UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &opDefProto); + + bool UnserializeTensor(GeTensorPtr &tensor, proto::TensorDef &tensorProto); + + bool ParseNodeIndex(const string &node_index, string &nodeName, int32_t &index); + + void SetProtobufOwner(const ProtoMsgOwner &bufferProtobufOnwer) { protobuf_owner_ = bufferProtobufOnwer; } + + private: + bool RebuildOwnership(ComputeGraphPtr &compute_graph, std::map &subgraphs); + + std::vector graph_input_node_names_; + std::vector graph_output_node_names_; + std::vector node_input_node_names_; + std::map node_map_; + ProtoMsgOwner protobuf_owner_; +}; +} // namespace ge + +#endif // INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ diff --git a/inc/graph/ge_attr_value.h b/inc/graph/ge_attr_value.h new file mode 100644 index 00000000..0c265c20 --- /dev/null +++ b/inc/graph/ge_attr_value.h @@ -0,0 +1,343 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_GE_ATTR_VALUE_H_ +#define INC_GRAPH_GE_ATTR_VALUE_H_ + +#include +#include +#include +#include +#include +#include +#include "graph/buffer.h" +#include "detail/attributes_holder.h" +#include "graph/ge_error_codes.h" +#include "graph/ge_tensor.h" + +using std::map; +using std::string; +using std::vector; + +namespace ge { +class GeTensor; + +using GeTensorPtr = std::shared_ptr; +using ConstGeTensorPtr = std::shared_ptr; + +class ComputeGraph; +using ComputeGraphPtr = std::shared_ptr; +using ConstComputeGraphPtr = std::shared_ptr; + +class GeTensorDesc; +class GeAttrValue; +class GeAttrValueImp; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NamedAttrs : public AttrHolder { + public: + NamedAttrs(); + virtual ~NamedAttrs() = default; + void SetName(const std::string &name); + string GetName() const; + GeAttrValue GetItem(const string &key) const; + + protected: + ProtoAttrMapHelper MutableAttrMap() override; + ConstProtoAttrMapHelper GetAttrMap() const override; + + private: + // Create namedAttrs from protobuf obj + NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg); + GeIrProtoHelper named_attrs_; + friend class GeAttrValueImp; + friend class GeAttrValue; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { + public: + using INT = int64_t; + using FLOAT = float; + using BOOL = bool; + using STR = std::string; + using TENSOR = GeTensorPtr; + using TENSOR_DESC = GeTensorDesc; + using GRAPH = ComputeGraphPtr; + using BYTES = Buffer; + using NAMED_ATTRS = ge::NamedAttrs; + using DATA_TYPE = ge::DataType; + + using LIST_INT = vector; + using LIST_FLOAT = vector; + using LIST_BOOL = vector; + using LIST_STR = vector; + using LIST_TENSOR = vector; + using LIST_TENSOR_DESC = vector; + using LIST_GRAPH = vector; + using LIST_BYTES = vector; + using LIST_NAMED_ATTRS = vector; + using LIST_LIST_INT = vector>; + using LIST_DATA_TYPE = vector; + + using NamedAttrs = ge::NamedAttrs; // for cce use (ge::GeAttrValue::NamedAttrs). + + enum ValueType { + VT_NONE = 0, + VT_STRING, + VT_FLOAT, + VT_BOOL, + VT_INT, + VT_TENSOR_DESC, + VT_TENSOR, + VT_BYTES, + VT_GRAPH, + VT_NAMED_ATTRS, + VT_LIST_LIST_INT, + VT_DATA_TYPE, + + VT_LIST_BASE = 1000, + VT_LIST_STRING = VT_LIST_BASE + VT_STRING, + VT_LIST_FLOAT = VT_LIST_BASE + VT_FLOAT, + VT_LIST_BOOL = VT_LIST_BASE + VT_BOOL, + VT_LIST_INT = VT_LIST_BASE + VT_INT, + VT_LIST_TENSOR_DESC = VT_LIST_BASE + VT_TENSOR_DESC, + VT_LIST_TENSOR = VT_LIST_BASE + VT_TENSOR, + VT_LIST_BYTES = VT_LIST_BASE + VT_BYTES, + VT_LIST_GRAPH = VT_LIST_BASE + VT_GRAPH, + VT_LIST_NAMED_ATTRS = VT_LIST_BASE + VT_NAMED_ATTRS, + VT_LIST_DATA_TYPE = VT_LIST_BASE + VT_DATA_TYPE, + }; + + template + struct IsAttrTypeEnable { + using DT = typename std::remove_cv::type; + + static bool const VALUE = std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value; + + // Not has list type of NamedAttrs + static bool const LIST_VALUE = std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value; + }; + + template + // To cols + using enable_if_vector_type_valid_t = typename std::enable_if::LIST_VALUE, int>::type; + + template + using enable_if_one_type_valid_t = typename std::enable_if::VALUE, int>::type; + + template + using enable_if_type_valid_t = + typename std::enable_if::VALUE || IsAttrTypeEnable::LIST_VALUE, int>::type; + + template + using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable; + + GeAttrValue(); + ~GeAttrValue() = default; + // SetValue, Set initializer_list + template = 0> + graphStatus SetValue(std::initializer_list
&&val) { + T vectorVal; + for (auto &item : val) { + vectorVal.push_back(item); + } + return SetValue(vectorVal); + } + + // SetValue, Set vector + template = 0> + graphStatus SetValue(const std::vector
&val) { + T vectorVal; + for (auto item : val) { + vectorVal.push_back(item); + } + return SetValue(vectorVal); + } + + // SetValue, not list type + template = 0> + graphStatus SetValue(DT &&val) { + return SetValue(T(std::forward
(val))); + } + + // GE_SERIALIZABLE + template = 0> + graphStatus SetValue(const T &t) { + return t.Save(*this); + } + + template = 0> + graphStatus SetValue(const vector &t) { + vector attrs; + for (auto &item : t) { + GeAttrValue val; + item.Save(val); + NamedAttrs attrsItem; + (void)val.GetValue(attrsItem); + attrs.push_back(attrsItem); + } + return SetValue(attrs); + } + + // GetValue, list value + template = 0, + typename std::enable_if::value, int>::type = 0> + graphStatus GetValue(std::vector
&val) const { + T valGet; + val.clear(); + auto status = GetValue(valGet); + if (status != GRAPH_SUCCESS) { + return status; + } + for (auto item : valGet) { + val.push_back(item); + } + return GRAPH_SUCCESS; + } + + // GetValue, not list type + template = 0, + typename std::enable_if::value, int>::type = 0> + graphStatus GetValue(DT &val) const { + T valGet; + auto status = GetValue(valGet); + if (status != GRAPH_SUCCESS) { + return status; + } + val = DT(valGet); + return GRAPH_SUCCESS; + } + + // GE_SERIALIZABLE + template = 0> + graphStatus GetValue(T &t) { + return t.Load(*this); + } + + template = 0> + graphStatus GetValue(vector &t) { + graphStatus status; + t.clear(); + vector attrs; + status = this->GetValue(attrs); + if (status != GRAPH_SUCCESS) { + return status; + } + for (auto &attr : attrs) { + T item; + GeAttrValue val; + (void)val.SetValue(attr); + status = item.Load(val); + if (status != GRAPH_SUCCESS) { + return status; + } + t.push_back(item); + } + return GRAPH_SUCCESS; + } + + template = 0> + static GeAttrValue CreateFrom(DT &&val) { + GeAttrValue valRet; + (void)valRet.SetValue(std::forward
(val)); + return valRet; + } + + template = 0> + static GeAttrValue CreateFrom(std::initializer_list
&&val) { + GeAttrValue valRet; + (void)valRet.SetValue(std::move(val)); + return valRet; + } + + template = 0> + static GeAttrValue CreateFrom(const T &val) { + GeAttrValue valRet; + (void)valRet.SetValue(val); + return valRet; + } + + template = 0> + static GeAttrValue CreateFrom(const vector &val) { + GeAttrValue valRet; + (void)valRet.SetValue(val); + return valRet; + } + + ValueType GetValueType() const; + + bool IsEmpty() const; + + GeAttrValue Copy() const; + + // For map key + bool operator==(const GeAttrValue &other) const { return value_ == other.value_; } + + graphStatus MutableTensor(GeTensorPtr &tensor); + graphStatus MutableListTensor(vector &list_tensor); + + private: +#define VALUE_SET_GET_DEC(DT) \ + graphStatus SetValue(const DT &val); \ + graphStatus GetValue(DT &val) const; + VALUE_SET_GET_DEC(GeAttrValue::STR) + VALUE_SET_GET_DEC(GeAttrValue::INT) + VALUE_SET_GET_DEC(GeAttrValue::FLOAT) + VALUE_SET_GET_DEC(GeAttrValue::BOOL) + VALUE_SET_GET_DEC(GeTensorDesc) + VALUE_SET_GET_DEC(GeAttrValue::TENSOR) + VALUE_SET_GET_DEC(GeAttrValue::GRAPH) + VALUE_SET_GET_DEC(BYTES) + VALUE_SET_GET_DEC(NamedAttrs) + VALUE_SET_GET_DEC(ge::DataType) // lint !e665 + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector>) // lint !e665 + VALUE_SET_GET_DEC(vector) // lint !e665 +#undef VALUE_SET_GET_DEC + + GeIrProtoHelper value_; + GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val); + + friend class AttrHolder; + friend class ModelSerializeImp; + friend class OnnxUtils; +}; + +class AttrValueImpl { + public: + AttrValueImpl() = default; + ~AttrValueImpl() = default; + + GeAttrValue geAttrValue_; +}; +} // namespace ge +#endif // INC_GRAPH_GE_ATTR_VALUE_H_ diff --git a/inc/graph/ge_context.h b/inc/graph/ge_context.h new file mode 100644 index 00000000..53985e9c --- /dev/null +++ b/inc/graph/ge_context.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_GE_CONTEXT_H_ +#define INC_GRAPH_GE_CONTEXT_H_ + +#include +#include "graph/ge_error_codes.h" + +namespace ge { +class GEContext { + public: + graphStatus GetOption(const std::string &key, std::string &option); + bool GetHostExecFlag(); + uint64_t SessionId(); + uint32_t DeviceId(); + uint64_t TraceId(); + void Init(); + void SetSessionId(uint64_t session_id); + void SetCtxDeviceId(uint32_t device_id); + + private: + uint64_t session_id_ = 0; + uint32_t device_id_ = 0; + uint64_t trace_id_ = 0; +}; // class GEContext + +/// Get context +/// @return +GEContext &GetContext(); +} // namespace ge + +#endif // INC_GRAPH_GE_CONTEXT_H_ diff --git a/inc/graph/ge_global_options.h b/inc/graph/ge_global_options.h new file mode 100644 index 00000000..b55192e2 --- /dev/null +++ b/inc/graph/ge_global_options.h @@ -0,0 +1,26 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_GE_GLOBAL_OPTIONS_H_ +#define INC_GRAPH_GE_GLOBAL_OPTIONS_H_ + +#include +#include + +namespace ge { +std::map &GetMutableGlobalOptions(); +} +#endif // INC_GRAPH_GE_GLOBAL_OPTIONS_H_ diff --git a/inc/graph/ge_local_context.h b/inc/graph/ge_local_context.h new file mode 100644 index 00000000..b47098fb --- /dev/null +++ b/inc/graph/ge_local_context.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_GE_LOCAL_CONTEXT_H_ +#define INC_GRAPH_GE_LOCAL_CONTEXT_H_ + +#include +#include +#include +#include "graph/ge_error_codes.h" + +using std::map; +using std::string; + +namespace ge { +class GEThreadLocalContext { + public: + graphStatus GetOption(const string &key, string &option); + void SetGraphOption(map options_map); + void SetSessionOption(map options_map); + void SetGlobalOption(map options_map); + + private: + map graph_options_; + map session_options_; + map global_options_; +}; // class GEThreadLocalContext + +GEThreadLocalContext &GetThreadLocalContext(); +} // namespace ge +#endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_ diff --git a/inc/graph/ge_tensor.h b/inc/graph/ge_tensor.h new file mode 100644 index 00000000..834dca0b --- /dev/null +++ b/inc/graph/ge_tensor.h @@ -0,0 +1,193 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_GE_TENSOR_H_ +#define INC_GRAPH_GE_TENSOR_H_ + +#include +#include +#include +#include +#include "detail/attributes_holder.h" +#include "graph/buffer.h" +#include "graph/ge_error_codes.h" +#include "graph/types.h" + +namespace ge { +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { + public: + GeShape(); + ~GeShape() = default; + explicit GeShape(std::vector s); + + size_t GetDimNum() const; + // If the idx is invalid, return 0 + int64_t GetDim(size_t idx) const; + graphStatus SetDim(size_t idx, int64_t value); + std::vector GetDims() const; + + int64_t GetShapeSize() const; + std::string ToString() const; + + /// + /// @brief Check is unknown shape + /// @return bool + /// + bool IsUnknownShape() const; + + /// + /// @brief Check is a scalar + /// @return bool + /// + bool IsScalar() const; + + GeShape(const GeShape &other); + GeShape(GeShape &&other); + GeShape &operator=(const GeShape &other); + GeShape &operator=(GeShape &&other); + + private: + GeIrProtoHelper shape_def_; + friend class GeTensorDesc; + // Create from proto obj + GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg); + + void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; } +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrHolder { + friend class TensorUtils; + friend class GeAttrValue; + friend class ModelSerialize; + + public: + GeTensorDesc(); + explicit GeTensorDesc(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); + GeTensorDesc(const GeTensorDesc &desc); + GeTensorDesc(GeTensorDesc &&desc); + + ~GeTensorDesc() = default; + bool operator==(const GeTensorDesc &r_ge_tensor_desc) const; + + void Update(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); + + GeShape GetShape() const; + GeShape &MutableShape(); + void SetShape(GeShape shape); + + // set shape with -2, it stand for unknown shape + void SetUnknownDimNumShape(); + // for unknown shape + graphStatus SetShapeRange(const std::vector> &range); + graphStatus GetShapeRange(std::vector> &range) const; + + GeShape GetOriginShape() const; + void SetOriginShape(const GeShape &originShape); + + Format GetFormat() const; + void SetFormat(Format format); + + Format GetOriginFormat() const; + void SetOriginFormat(Format originFormat); + + void SetName(const std::string &name); + const std::string GetName() const; + + DataType GetDataType() const; + void SetDataType(DataType dt); + + DataType GetOriginDataType() const; + void SetOriginDataType(DataType originDataType); + + std::vector GetRefPortIndex() const; + void SetRefPortByIndex(const std::vector &index); + + GeTensorDesc Clone() const; + GeTensorDesc &operator=(const GeTensorDesc &desc); + GeTensorDesc &operator=(GeTensorDesc &&desc); + + graphStatus IsValid() const; + + protected: + ProtoAttrMapHelper MutableAttrMap() override; + ConstProtoAttrMapHelper GetAttrMap() const override; + + private: + bool GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const; + using AttrHolder::DelAttr; + using AttrHolder::GetAllAttrs; + using AttrHolder::GetAttr; + using AttrHolder::HasAttr; + using AttrHolder::SetAttr; + + void Init(); + + // Create from proto obj + GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg); + friend class GeTensor; + friend class GeAttrValueImp; + friend class ModelSerializeImp; + friend class OnnxUtils; + + GeIrProtoHelper tensor_descriptor_; + // Reference from tensorDescriptor_, do not direct use + mutable GeShape __shape_; + + void RefTo(const GeTensorDesc &tensorDesc) { tensor_descriptor_ = tensorDesc.tensor_descriptor_; } + GeShape &ShapeReference() const; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { + public: + GeTensor(); + explicit GeTensor(const GeTensorDesc &tensorDesc); + explicit GeTensor(const GeTensorDesc &tensorDesc, const std::vector &data); + explicit GeTensor(const GeTensorDesc &tensorDesc, const Buffer &data); + explicit GeTensor(const GeTensorDesc &tensorDesc, const uint8_t *data, size_t size); + explicit GeTensor(GeTensorDesc &&tensorDesc, std::vector &&data); + ~GeTensor() = default; + + GeTensorDesc GetTensorDesc() const; + GeTensorDesc &MutableTensorDesc(); + void SetTensorDesc(const GeTensorDesc &tensorDesc); + + const Buffer GetData() const; + Buffer MutableData(); + graphStatus SetData(std::vector &&data); + graphStatus SetData(const std::vector &data); + graphStatus SetData(const Buffer &data); + graphStatus SetData(const uint8_t *data, size_t size); + + GeTensor Clone() const; + + // Share value + GeTensor(const GeTensor &other); + // Share value + GeTensor &operator=(const GeTensor &other); + + private: + friend class GeAttrValueImp; + friend class ModelSerializeImp; + friend class OnnxUtils; + // Create from proto obj + GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg); + GeIrProtoHelper tensor_def_; + // Reference from tensorDef_, do not direct use + mutable GeTensorDesc __desc_; + GeTensorDesc &DescReference() const; +}; +} // namespace ge +#endif // INC_GRAPH_GE_TENSOR_H_ diff --git a/inc/graph/graph_util.h b/inc/graph/graph_util.h new file mode 100644 index 00000000..c39ecbc1 --- /dev/null +++ b/inc/graph/graph_util.h @@ -0,0 +1,134 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_GRAPH_UTIL_H_ +#define INC_GRAPH_GRAPH_UTIL_H_ + +#include + +#include "proto/om.pb.h" + +namespace ge { +using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; +bool HasOpAttr(const OpDef *opdef, std::string attr_name); +bool GetOpAttr(const std::string &key, int32_t *value, const OpDef *opdef); + +static const char OP_TYPE_DATA[] = "Data"; +static const char OP_TYPE_INPUT[] = "Input"; +static const char ATTR_KEY_INPUT_FORMAT[] = "input_format"; +static const char ATTR_KEY_OUTPUT_FORMAT[] = "output_format"; +static const char OP_TYPE_ANN_DATA[] = "AnnData"; +} // namespace ge + +#if !defined(__ANDROID__) && !defined(ANDROID) +#include "toolchain/slog.h" +const char levelStr[4][8] = {"ERROR", "WARN", "INFO", "DEBUG"}; +#else +#include +#include +const char levelStr[8][8] = {"EMERG", "ALERT", "CRIT", "ERROR", "WARNING", "NOTICE", "INFO", "DEBUG"}; +#endif + +#ifdef _MSC_VER +#define FUNC_NAME __FUNCTION__ +#else +#define FUNC_NAME __PRETTY_FUNCTION__ +#endif + +#if !defined(__ANDROID__) && !defined(ANDROID) +#define D_GRAPH_LOGI(MOD_NAME, fmt, ...) \ + dlog_info(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) +#define D_GRAPH_LOGW(MOD_NAME, fmt, ...) \ + dlog_warn(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) +#define D_GRAPH_LOGE(MOD_NAME, fmt, ...) \ + dlog_error(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) +#else +#define D_GRAPH_LOG(level, format, ...) \ + do { \ + { \ + fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \ + __FILE__, __LINE__, ##__VA_ARGS__); \ + syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \ + ##__VA_ARGS__); \ + } \ + } while (0) +#define D_GRAPH_LOGI(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#define D_GRAPH_LOGW(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#define D_GRAPH_LOGE(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#endif + +#if !defined(__ANDROID__) && !defined(ANDROID) +#define GRAPH_LOGI(...) D_GRAPH_LOGI(GRAPH_MOD_NAME, __VA_ARGS__) +#define GRAPH_LOGW(...) D_GRAPH_LOGW(GRAPH_MOD_NAME, __VA_ARGS__) +#define GRAPH_LOGE(...) D_GRAPH_LOGE(GRAPH_MOD_NAME, __VA_ARGS__) +#else + +#define GRAPH_LOG(level, format, ...) \ + do { \ + { \ + fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \ + __FILE__, __LINE__, ##__VA_ARGS__); \ + syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \ + ##__VA_ARGS__); \ + } \ + } while (0) +#define GRAPH_LOGI(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#define GRAPH_LOGW(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#define GRAPH_LOGE(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#endif + +#define GRAPH_CHK_STATUS_RET_NOLOG(expr) \ + do { \ + const domi::graphStatus _status = (expr); \ + if (_status != domi::GRAPH_SUCCESS) { \ + return _status; \ + } \ + } while (0) + +#define GRAPH_CHK_BOOL_RET_STATUS(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + GRAPH_LOGE(__VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +#define GRAPH_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ + { \ + bool b = (expr); \ + if (!b) { \ + exec_expr; \ + } \ + }; + +#define GRAPH_IF_BOOL_EXEC(expr, exec_expr) \ + { \ + if (expr) { \ + exec_expr; \ + } \ + } + +#define GRAPH_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ + do { \ + const ::domi::graphStatus _status = (expr); \ + if (_status) { \ + GRAPH_LOGE(__VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +#endif // INC_GRAPH_GRAPH_UTIL_H_ diff --git a/inc/graph/model.h b/inc/graph/model.h new file mode 100644 index 00000000..38ea501b --- /dev/null +++ b/inc/graph/model.h @@ -0,0 +1,94 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_MODEL_H_ +#define INC_GRAPH_MODEL_H_ + +#include +#include +#include +#include +#include "detail/attributes_holder.h" +#include "graph/ge_attr_value.h" +#include "graph/graph.h" + +namespace ge { +using std::map; +using std::string; +using std::vector; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { + public: + Model(); + + ~Model() = default; + + Model(const string &name, const string &custom_version); + + string GetName() const; + void SetName(const string &name); + + uint32_t GetVersion() const; + + void SetVersion(uint32_t version) { version_ = version; } + + std::string GetPlatformVersion() const; + + void SetPlatformVersion(string version) { platform_version_ = version; } + + Graph GetGraph() const; + + void SetGraph(const Graph &graph); + + void SetAttr(const ProtoAttrMapHelper &attrs); + + using AttrHolder::GetAllAttrNames; + using AttrHolder::GetAllAttrs; + using AttrHolder::GetAttr; + using AttrHolder::HasAttr; + using AttrHolder::SetAttr; + + graphStatus Save(Buffer &buffer, bool is_dump = false) const; + + graphStatus SaveToFile(const string &file_name) const; + // Model will be rewrite + static graphStatus Load(const uint8_t *data, size_t len, Model &model); + graphStatus Load(ge::proto::ModelDef &model_def); + graphStatus LoadFromFile(const string &file_name); + + bool IsValid() const; + + protected: + ConstProtoAttrMapHelper GetAttrMap() const override; + ProtoAttrMapHelper MutableAttrMap() override; + + private: + void Init(); + ProtoAttrMapHelper attrs_; + friend class ModelSerializeImp; + friend class GraphDebugImp; + friend class OnnxUtils; + friend class ModelHelper; + friend class ModelBuilder; + string name_; + uint32_t version_; + std::string platform_version_{""}; + Graph graph_; +}; +} // namespace ge +using ModelPtr = std::shared_ptr; + +#endif // INC_GRAPH_MODEL_H_ diff --git a/inc/graph/model_serialize.h b/inc/graph/model_serialize.h new file mode 100644 index 00000000..16529512 --- /dev/null +++ b/inc/graph/model_serialize.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_MODEL_SERIALIZE_H_ +#define INC_GRAPH_MODEL_SERIALIZE_H_ + +#include +#include +#include "graph/buffer.h" +#include "graph/compute_graph.h" +#include "graph/model.h" + +namespace ge { +class ModelSerialize { + public: + Buffer SerializeModel(const Model &model, bool is_dump = false); + + Model UnserializeModel(const uint8_t *data, size_t len); + Model UnserializeModel(ge::proto::ModelDef &model_def); + + Buffer SerializeGraph(const ComputeGraphPtr &graph); + + ComputeGraphPtr UnserializeGraph(const uint8_t *data, size_t len); + + Buffer SerializeOpDesc(const ConstOpDescPtr &opDesc); + OpDescPtr UnserializeOpDesc(const uint8_t *data, size_t len); + + size_t GetSerializeModelSize(const Model &model); + + private: + static std::map &MutableTensorDescAttrMap(GeTensorDesc &tensorDesc); + + static const std::map &GetTensorDescAttrMap(const GeTensorDesc &tensorDesc); + + friend class ModelSerializeImp; + friend class GraphDebugImp; +}; +} // namespace ge +#endif // INC_GRAPH_MODEL_SERIALIZE_H_ diff --git a/inc/graph/node.h b/inc/graph/node.h new file mode 100644 index 00000000..f4a1c6a8 --- /dev/null +++ b/inc/graph/node.h @@ -0,0 +1,213 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_NODE_H_ +#define INC_GRAPH_NODE_H_ + +#include +#include +#include +#include +#include +#include +#include "graph/ge_attr_value.h" +#include "utils/attr_utils.h" + +#include "graph/op_desc.h" +#include "graph/range_vistor.h" + +namespace ge { +class ComputeGraph; + +using ComputeGraphPtr = std::shared_ptr; + +class Node; + +using NodePtr = std::shared_ptr; +using ConstNodePtr = std::shared_ptr; +using NodeRef = std::weak_ptr; + +class Anchor; + +using AnchorPtr = std::shared_ptr; + +class InDataAnchor; + +using InDataAnchorPtr = std::shared_ptr; + +class OutDataAnchor; + +using OutDataAnchorPtr = std::shared_ptr; + +class ControlAnchor; + +using ControlAnchorPtr = std::shared_ptr; + +class InControlAnchor; + +using InControlAnchorPtr = std::shared_ptr; + +class OutControlAnchor; + +using OutControlAnchorPtr = std::shared_ptr; + +using OpDescPtr = std::shared_ptr; + +using ConstNode = const Node; + +typedef std::vector> kFusionDataFlowVec_t; + +// Node is a component of ComputeGraph +class Node : public std::enable_shared_from_this { + friend class ComputeGraph; + friend class ModelSerializeImp; + + public: + template + using Vistor = RangeVistor>; + ~Node(); + Node(const Node &) = delete; + Node &operator=(const Node &) = delete; + bool operator==(const Node &r_node) const; + + protected: + Node() = default; + Node(const OpDescPtr &op, const ComputeGraphPtr &ownerGraph); + + public: + graphStatus Init(); + + std::string GetName() const; + std::string GetType() const; + + ComputeGraphPtr GetOwnerComputeGraph() const; + graphStatus SetOwnerComputeGraph(const ComputeGraphPtr &graph); + + Vistor GetAllInDataAnchors() const; + Vistor GetAllOutDataAnchors() const; + uint32_t GetAllInDataAnchorsSize() const; + uint32_t GetAllOutDataAnchorsSize() const; + Vistor GetAllOutAnchors() const; + Vistor GetAllInAnchors() const; + InDataAnchorPtr GetInDataAnchor(int idx) const; + OutDataAnchorPtr GetOutDataAnchor(int idx) const; + InControlAnchorPtr GetInControlAnchor() const; + OutControlAnchorPtr GetOutControlAnchor() const; + Vistor GetInNodes() const; + Vistor GetOutNodes() const; + AnchorPtr GetInAnchor(int idx) const; + AnchorPtr GetOutAnchor(int idx) const; + + bool IsAllInNodesSeen(std::unordered_set &nodes_seen) const; + + // All in Data nodes + Vistor GetInDataNodes() const; + // All in Control nodes + Vistor GetInControlNodes() const; + // GetInAllNodes = InDataNodes + InControlNodes + Vistor GetInAllNodes() const; + + // All out Data nodes + Vistor GetOutDataNodes() const; + uint32_t GetOutDataNodesSize() const; + // All out Control nodes + Vistor GetOutControlNodes() const; + // GetOutAllNodes = OutDataNodes + InControlNodes + Vistor GetOutAllNodes() const; + + // Get all in data nodes and its out-anchor + Vistor> GetInDataNodesAndAnchors() const; + + // Get all out data nodes and its in-anchor + Vistor> GetOutDataNodesAndAnchors() const; + + graphStatus InferShapeAndType() const; + graphStatus Verify() const; + + graphStatus InferOriginFormat() const; + + OpDescPtr GetOpDesc() const; + + graphStatus UpdateOpDesc(const OpDescPtr &op); + + graphStatus AddLinkFrom(const NodePtr &input_node); + + graphStatus AddLinkFrom(const uint32_t &index, NodePtr input_node); + + graphStatus AddLinkFrom(const string &name, NodePtr input_node); + + graphStatus AddLinkFromForParse(const NodePtr &input_node); + + void AddSendEventId(uint32_t event_id) { send_event_id_list_.push_back(event_id); } + + void AddRecvEventId(uint32_t event_id) { recv_event_id_list_.push_back(event_id); } + + const std::vector &GetSendEventIdList() const { return send_event_id_list_; } + + const std::vector &GetRecvEventIdList() const { return recv_event_id_list_; } + void GetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) { + fusion_input_list = fusion_input_dataflow_list_; + } + + void GetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) { + fusion_output_list = fusion_output_dataflow_list_; + } + + void SetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) { + fusion_input_dataflow_list_ = fusion_input_list; + } + + void SetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) { + fusion_output_dataflow_list_ = fusion_output_list; + } + + bool GetHostNode() const { return host_node_; } + void SetHostNode(bool is_host) { host_node_ = is_host; } + + void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } + + NodePtr GetOrigNode() { return orig_node_; } + + private: + bool NodeMembersAreEqual(const Node &r_node) const; + bool NodeAttrsAreEqual(const Node &r_node) const; + bool NodeInConnectsAreEqual(const Node &r_node) const; + bool NodeOutConnectsAreEqual(const Node &r_node) const; + bool NodeAnchorIsEqual(const AnchorPtr &l_anchor, const AnchorPtr &r_anchor, size_t i) const; + OpDescPtr op_; + std::weak_ptr owner_graph_; + vector in_data_anchors_; + vector out_data_anchors_; + InControlAnchorPtr in_control_anchor_; + OutControlAnchorPtr out_control_anchor_; + map attrs_; // lint !e1073 + bool has_init_{false}; + bool host_node_{false}; + bool anchor_status_updated_{false}; + std::vector send_event_id_list_; + std::vector recv_event_id_list_; + + kFusionDataFlowVec_t fusion_input_dataflow_list_; + kFusionDataFlowVec_t fusion_output_dataflow_list_; + + NodePtr orig_node_; + friend class NodeUtils; + friend class OnnxUtils; + friend class TuningUtils; +}; +} // namespace ge + +#endif // INC_GRAPH_NODE_H_ diff --git a/inc/graph/op_desc.h b/inc/graph/op_desc.h new file mode 100644 index 00000000..4d724c42 --- /dev/null +++ b/inc/graph/op_desc.h @@ -0,0 +1,329 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_OP_DESC_H_ +#define INC_GRAPH_OP_DESC_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "detail/attributes_holder.h" +#include "graph/range_vistor.h" + +#define DYNAMIN_INPUT_NAME(name, index) (((name)) + std::to_string((index))) +#define DYNAMIN_OUTPUT_NAME(name, index) (((name)) + std::to_string((index))) +namespace ge { +using std::map; +using std::pair; +using std::shared_ptr; +using std::string; +using std::vector; + +class Operator; +class GeTensorDesc; + +using GeTensorDescPtr = shared_ptr; +using ConstGeTensorDescPtr = shared_ptr; + +class OpDesc; + +using OpDescPtr = shared_ptr; +using ConstOpDescPtr = shared_ptr; + +class GeAttrValue; + +using ConstOpDesc = const OpDesc; + +enum SubgraphType { kStatic, kDynamic, kSubgraphTypeEnd }; + +class OpDesc : public std::enable_shared_from_this, public AttrHolder { + public: + template + using Vistor = RangeVistor>; + + friend class GraphBuilderImpl; + + friend class OperatorImpl; + + OpDesc(const string &name, const string &type); + + OpDesc(); + + ~OpDesc(); + + bool operator==(const OpDesc &r_op_desc) const; + + string GetName() const; + + void SetName(const string &name); + + string GetType() const; + + void SetType(const string &type); + + graphStatus AddInputDesc(const GeTensorDesc &input_desc); + + graphStatus AddInputDesc(const string &name, const GeTensorDesc &input_desc); + + graphStatus AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_desc); + + graphStatus AddInputDescForward(const string &name, const unsigned int num); + + graphStatus AddInputDescMiddle(const string &name, const unsigned int num, size_t index); + + graphStatus AddOutputDescMiddle(const string &name, const unsigned int num, size_t index); + + graphStatus AddOutputDescForward(const string &name, const unsigned int num); + + graphStatus AddOptionalInputDesc(const string &name, const GeTensorDesc &input_desc); + + graphStatus UpdateInputDesc(uint32_t index, const GeTensorDesc &tensor_desc); + + graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc); + + bool InputIsSet(const string &name) const; + + GeTensorDesc GetInputDesc(uint32_t index) const; + + GeTensorDesc GetInputDesc(const string &name) const; + + Vistor GetAllInputNames() const; + + GeTensorDescPtr MutableInputDesc(uint32_t index) const; + + GeTensorDescPtr MutableInputDesc(const string &name) const; + + Vistor GetAllInputsDesc() const; + + Vistor GetAllInputsDescPtr() const; + + size_t GetInputsSize() const; + + size_t GetAllInputsSize() const; + + graphStatus AddOutputDesc(const GeTensorDesc &output_desc); + + graphStatus AddOutputDesc(const string &name, const GeTensorDesc &output_desc); + + graphStatus UpdateOutputDesc(uint32_t index, const GeTensorDesc &tensor_desc); + + graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc); + + GeTensorDesc GetOutputDesc(uint32_t index) const; + + GeTensorDesc GetOutputDesc(const string &name) const; + + GeTensorDescPtr MutableOutputDesc(uint32_t index) const; + + GeTensorDescPtr MutableOutputDesc(const string &name) const; + + uint32_t GetAllOutputsDescSize() const; + + Vistor GetAllOutputsDesc() const; + + Vistor GetAllOutputsDescPtr() const; + + size_t GetOutputsSize() const; + + ConstGeTensorDescPtr GetOutputDescPtr(uint32_t index) const; + + ConstGeTensorDescPtr GetInputDescPtr(uint32_t index) const; + + ConstGeTensorDescPtr GetInputDescPtrDfault(uint32_t index) const; + + ConstGeTensorDescPtr GetInputDescPtr(const string &name) const; + + graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); + + graphStatus AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index); + + graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); + + bool IsOptionalInput(const string &name) const; + + bool IsOptionalInput(uint32_t index) const; + + std::map GetAllInputName() const; + + std::map GetAllOutputName(); + + bool UpdateInputName(std::map inputNameIdx); + + bool UpdateOutputName(std::map outputNameIdx); + + void AddInferFunc(const std::function &func); + + std::function GetInferFunc() const; + + graphStatus InferShapeAndType(); + + void AddInferFormatFunc(const std::function &func); + + std::function GetInferFormatFunc() const; + + graphStatus DefaultInferFormat(); + + std::function GetVerifyFunc() const; + + void AddVerifierFunc(const std::function &func); + + graphStatus CallInferFormatFunc(Operator &op); + + graphStatus OpVerify(); + + graphStatus CommonVerify() const; + + graphStatus AddRegisterInputName(const string &name); + + graphStatus AddRegisterOutputName(const string &name); + + vector GetRegisterInputName() const; + + vector GetRegisterOutputName() const; + + using AttrHolder::AddRequiredAttr; + using AttrHolder::DelAttr; + using AttrHolder::GetAllAttrNames; + using AttrHolder::GetAllAttrs; + using AttrHolder::GetAttr; + using AttrHolder::HasAttr; + using AttrHolder::SetAttr; + + void SetId(int64_t id); + int64_t GetId() const; + void SetStreamId(int64_t stream_id); + int64_t GetStreamId() const; + void SetInputName(const vector &input_name); + vector GetInputName() const; + void SetSrcName(const vector &src_name); + vector GetSrcName() const; + void SetSrcIndex(const vector &src_index); + vector GetSrcIndex() const; + void SetInputOffset(const vector &input); + vector GetInputOffset() const; + void SetOutputOffset(const vector &input); + vector GetOutputOffset() const; + void SetDstName(const vector &dst_name); + vector GetDstName() const; + void SetDstIndex(const vector &dst_index); + vector GetDstIndex() const; + void SetWorkspace(const vector &workspace); + vector GetWorkspace() const; + void SetWorkspaceBytes(const vector &workspace_bytes); + vector GetWorkspaceBytes() const; + void SetIsInputConst(const vector &is_input_const); + vector GetIsInputConst() const; + + void SetOpInferDepends(const vector &depend_names); + vector GetOpInferDepends() const; + + string GetInputNameByIndex(uint32_t index) const; + string GetValidInputNameByIndex(uint32_t index) const; + int GetValidInputIndexByName(const string &name) const; + int GetInputIndexByName(const string &name) const; + + string GetOutputNameByIndex(uint32_t index) const; + + int GetOutputIndexByName(const string &name) const; + + graphStatus RestoreInputNameIdx(const string &name, const int &index); + + graphStatus RestoreOutputNameIdx(const string &name, const int &index); + + graphStatus CallInferFunc(Operator &op); + + void SetOpKernelLibName(const std::string &name); + + std::string GetOpKernelLibName() const; + + void SetOpEngineName(const std::string &name); + + std::string GetOpEngineName() const; + + void RegisterSubgraphIrName(const std::string &name, SubgraphType type); + const std::map &GetSubgraphIrNames() const; + SubgraphType GetSubgraphTypeByIrName(const std::string &name) const; + + graphStatus AddSubgraphName(const std::string &name); + const std::map &GetSubgraphNameIndexes() const; + + std::string GetSubgraphInstanceName(uint32_t index) const; + const std::vector &GetSubgraphInstanceNames() const; + /// Does not provide functions `AddSubgraphInstance` or `AppendSubgraphInstance`, + /// because this kind of functions will only append a new subgraph instance name + /// at the tail of `subgraph_instance_names_` and ignore the synchronous change of `subgraph_names_to_index_`. + /// If we want to append a new subgraph instance name, the function `AddSubgraphName` should be called first. + /// \param index + /// \param name + /// \return + graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name); + void RemoveSubgraphInstanceName(const std::string &name); + + graphStatus GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const; + + protected: + ProtoAttrMapHelper MutableAttrMap() override; + ConstProtoAttrMapHelper GetAttrMap() const override; + + private: + OpDesc(const ProtoMsgOwner &proto_msg_owner, ge::proto::OpDef *op_def); + bool OpDescMembersAreEqual(const OpDesc &r_op_desc) const; + bool OpDescAttrsAreEqual(const OpDesc &r_op_desc) const; + bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const; + + GeIrProtoHelper op_def_; + std::vector subgraph_instance_names_; + + // subgraph names to index, for a `if` operator: + // then_branch: 0 + // else_branch: 1 + // or for a `case` node: + // branches0: 0 + // branches1: 1 + // branches2: 2 + std::map subgraph_names_to_index_; + + // subgraph ir names to type, for a `if` operator: + // then_branch: static + // else_branch: static + // or for a `case` op: + // branches: dynamic + std::map subgraph_ir_names_to_type_; + + vector inputs_desc_{}; + map input_name_idx_{}; + vector register_input_name_{}; + std::unordered_set optional_input_names_{}; + vector outputs_desc_{}; + map output_name_idx_{}; + vector register_output_name_{}; + std::function infer_func_ = nullptr; + std::function infer_format_func_ = nullptr; + std::function verifier_func_ = nullptr; + string op_kernel_lib_name_; + string engine_name_; + friend class OpDescUtils; + friend class ModelSerializeImp; + friend class AttrUtils; + friend class GeAttrValueImp; + friend class OnnxUtils; +}; +} // namespace ge +#endif // INC_GRAPH_OP_DESC_H_ diff --git a/inc/graph/op_kernel_bin.h b/inc/graph/op_kernel_bin.h new file mode 100644 index 00000000..3970460a --- /dev/null +++ b/inc/graph/op_kernel_bin.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_OP_KERNEL_BIN_H_ +#define INC_GRAPH_OP_KERNEL_BIN_H_ + +#include +#include +#include +#include + +namespace ge { +class OpKernelBin { + public: + OpKernelBin(std::string name, std::vector &&data) : name_(std::move(name)), data_(std::move(data)) {} + + ~OpKernelBin() = default; + + const std::string &GetName() const { return name_; } + const uint8_t *GetBinData() const { return (const uint8_t *)data_.data(); } + size_t GetBinDataSize() const { return data_.size(); } + OpKernelBin(const OpKernelBin &) = delete; + const OpKernelBin &operator=(const OpKernelBin &) = delete; + + private: + std::string name_; + std::vector data_; +}; + +using OpKernelBinPtr = std::shared_ptr; +const char *const OP_EXTATTR_NAME_TBE_KERNEL = "tbeKernel"; +const char *const OP_EXTATTR_CUSTAICPU_KERNEL = "cust_aicpu_kernel"; +} // namespace ge + +#endif // INC_GRAPH_OP_KERNEL_BIN_H_ diff --git a/inc/graph/operator_factory_impl.h b/inc/graph/operator_factory_impl.h new file mode 100644 index 00000000..ea343ebc --- /dev/null +++ b/inc/graph/operator_factory_impl.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ +#define INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ + +#include +#include +#include +#include +#include "graph/operator_factory.h" + +namespace ge { +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactoryImpl { + public: + static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type); + + static graphStatus GetOpsTypeList(std::vector &all_ops); + + static bool IsExistOp(const string &operator_type); + + static InferShapeFunc GetInferShapeFunc(const std::string &operator_type); + + static InferFormatFunc GetInferFormatFunc(const std::string &operator_type); + + static VerifyFunc GetVerifyFunc(const std::string &operator_type); + + static graphStatus RegisterOperatorCreator(const std::string &operator_type, OpCreator const &op_creator); + + static graphStatus RegisterInferShapeFunc(const std::string &operator_type, InferShapeFunc const infer_shape_func); + + static graphStatus RegisterInferFormatFunc(const std::string &operator_type, InferFormatFunc const infer_format_func); + + static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func); + + static shared_ptr> operator_creators_; + static shared_ptr> operator_infershape_funcs_; + static shared_ptr> operator_inferformat_funcs_; + static shared_ptr> operator_verify_funcs_; +}; +} // namespace ge + +#endif // INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ diff --git a/inc/graph/opsproto_manager.h b/inc/graph/opsproto_manager.h new file mode 100644 index 00000000..06846573 --- /dev/null +++ b/inc/graph/opsproto_manager.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_OPSPROTO_MANAGER_H_ +#define INC_GRAPH_OPSPROTO_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace ge { +class OpsProtoManager { + public: + static OpsProtoManager *Instance(); + + bool Initialize(const std::map &options); + void Finalize(); + + private: + void LoadOpsProtoPluginSo(std::string &path); + + std::string pluginPath_; + std::vector handles_; + bool is_init_ = false; + std::mutex mutex_; +}; +} // namespace ge + +#endif // INC_GRAPH_OPSPROTO_MANAGER_H_ diff --git a/inc/graph/range_vistor.h b/inc/graph/range_vistor.h new file mode 100644 index 00000000..8635d413 --- /dev/null +++ b/inc/graph/range_vistor.h @@ -0,0 +1,57 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_RANGE_VISTOR_H_ +#define INC_GRAPH_RANGE_VISTOR_H_ + +#include + +template +class RangeVistor { + public: + /*lint -e151*/ + using Iterator = typename std::vector::iterator; + using ConstIterator = typename std::vector::const_iterator; + /*lint +e151*/ + + RangeVistor(O owner, const std::vector &vs) : owner_(owner), elements_(vs) {} + + ~RangeVistor() {} + + Iterator begin() { return elements_.begin(); } + + Iterator end() { return elements_.end(); } + + ConstIterator begin() const { return elements_.begin(); } + + ConstIterator end() const { return elements_.end(); } + + std::size_t size() const { return elements_.size(); } + + bool empty() const { return elements_.empty(); } + + /*lint -e659*/ + E &at(std::size_t index) { return elements_.at(index); } + /*lint +e659*/ + + const E &at(std::size_t index) const { return elements_.at(index); } + + private: + O owner_; + std::vector elements_; +}; + +#endif // INC_GRAPH_RANGE_VISTOR_H_ diff --git a/inc/graph/ref_relation.h b/inc/graph/ref_relation.h new file mode 100644 index 00000000..71457916 --- /dev/null +++ b/inc/graph/ref_relation.h @@ -0,0 +1,79 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_REF_RELATION_H_ +#define COMMON_GRAPH_REF_RELATION_H_ + +#include +#include +#include +#include + +#include "graph/compute_graph.h" +#include "graph/types.h" +#include "graph/ge_error_codes.h" +#include "node.h" + +namespace ge { +enum InOutFlag { + NODE_IN = 0, // input flag + NODE_OUT = 1, // output flag +}; + +struct RefCell { + std::string node_name; + ge::NodePtr node = nullptr; + InOutFlag in_out = NODE_IN; + int in_out_idx = 0; + + bool operator==(const RefCell &c) const { + return node_name == c.node_name && node == c.node && in_out == c.in_out && in_out_idx == c.in_out_idx; + } + + RefCell() = default; + RefCell(std::string name, ge::NodePtr node_ptr, InOutFlag in_out_flag, int idx) { + node_name = name; + node = node_ptr; + in_out = in_out_flag; + in_out_idx = idx; + }; + ~RefCell() = default; +}; + +struct RefCellHash { + size_t operator()(const RefCell &c) const { + unsigned long number = reinterpret_cast(reinterpret_cast(c.node.get())); + string tmp = c.node_name + std::to_string(c.in_out) + std::to_string(c.in_out_idx) + std::to_string(number); + return std::hash()(tmp); + } +}; + +class RefRelations { + public: + graphStatus LookUpRefRelations(const RefCell &key, std::unordered_set &result); + graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); + graphStatus Clear(); + + RefRelations(); + ~RefRelations() = default; + + public: + class Impl; + std::shared_ptr impl_ = nullptr; +}; + +} // namespace ge +#endif // COMMON_GRAPH_REF_RELATION_H_ diff --git a/inc/graph/runtime_inference_context.h b/inc/graph/runtime_inference_context.h new file mode 100644 index 00000000..f0b38546 --- /dev/null +++ b/inc/graph/runtime_inference_context.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ +#define INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ + +#include +#include +#include +#include +#include "external/graph/ge_error_codes.h" +#include "external/graph/tensor.h" +#include "ge_attr_value.h" + +namespace ge { +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY RuntimeInferenceContext { + public: + static graphStatus GetContext(const std::string &context_id, RuntimeInferenceContext **ctx); + static graphStatus CreateContext(const std::string &context_id); + static void DestroyContext(const std::string &context_id); + + graphStatus SetTensor(int64_t node_id, int output_id, Tensor &&tensor); + graphStatus GetTensor(int64_t node_id, int output_id, GeTensorPtr &tensor); + graphStatus GetTensor(int64_t node_id, int output_id, Tensor &tensor); + + private: + std::map> tensors_; + std::map> ge_tensors_; + std::mutex mu_; + + static std::map> contexts_; + static std::mutex ctx_mu_; +}; +} // namespace ge + +#endif // INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ diff --git a/inc/graph/shape_refiner.h b/inc/graph/shape_refiner.h new file mode 100644 index 00000000..4f8783a3 --- /dev/null +++ b/inc/graph/shape_refiner.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_SHAPE_REFINER_H_ +#define INC_GRAPH_SHAPE_REFINER_H_ + +#include +#include "external/graph/inference_context.h" + +#include "external/graph/ge_error_codes.h" +#include "graph/node.h" + +namespace ge { +// ShapeRefiner performs shape inference for compute graphs +class ShapeRefiner { + public: + static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph); + static graphStatus InferShapeAndType(const NodePtr &node, bool before_subgraph); + static graphStatus InferShapeAndType(const NodePtr &node); + static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); + static void ClearContextMap(); + + private: + static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); +}; +} // namespace ge +#endif // INC_GRAPH_SHAPE_REFINER_H_ diff --git a/inc/graph/tuning_utils.h b/inc/graph/tuning_utils.h new file mode 100644 index 00000000..98262a23 --- /dev/null +++ b/inc/graph/tuning_utils.h @@ -0,0 +1,130 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MAIN_TUNING_UTILS_H +#define MAIN_TUNING_UTILS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "framework/common/debug/ge_log.h" +#include "utils/attr_utils.h" +#include "utils/node_utils.h" +#include "external/ge/ge_api_types.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +namespace ge { +// Configure build mode, default value is "normal" +const char *const BUILD_MODE = "ge.buildMode"; +const char *const BUILD_STEP = "ge.buildStep"; +// Configure tuning path +const char *const TUNING_PATH = "ge.tuningPath"; +// for interface: aclgrphBuildModel +const std::set ir_builder_supported_options_for_lx_fusion = {BUILD_MODE, BUILD_STEP, TUNING_PATH}; + +// Build model +const char *const BUILD_MODE_NORMAL = "normal"; +const char *const BUILD_MODE_TUNING = "tuning"; +const char *const BUILD_MODE_BASELINE = "baseline"; +const std::set build_mode_options = {BUILD_MODE_NORMAL, BUILD_MODE_TUNING, BUILD_MODE_BASELINE}; + +// Build step +const char *const BUILD_STEP_BEFORE_UB_MATCH = "before_ub_match"; +const char *const BUILD_STEP_AFTER_UB_MATCH = "after_ub_match"; +const char *const BUILD_STEP_AFTER_BUILDER = "after_builder"; +const char *const BUILD_STEP_AFTER_BUILDER_SUB = "after_builder_sub"; +const char *const BUILD_STEP_AFTER_MERGE = "after_merge"; +const std::set build_step_options = {BUILD_STEP_BEFORE_UB_MATCH, BUILD_STEP_AFTER_UB_MATCH, + BUILD_STEP_AFTER_BUILDER, BUILD_STEP_AFTER_BUILDER_SUB, + BUILD_STEP_AFTER_MERGE}; + +using SubgraphCreateOutNode = std::unordered_map; +using NodetoNodeMap = std::unordered_map; +using NodeSet = std::set; +using NodeNametoNodeNameMap = std::unordered_map; +using NodetoNodeNameMap = std::unordered_map; +class TuningUtils { + public: + TuningUtils() = default; + ~TuningUtils() = default; + // Dump all the subgraphs and modify + // the subgraphs in them to be executable subgraphs if exe_flag is true + // `tuning_path` means path to save the graphs + static graphStatus ConvertGraphToFile(std::vector tuning_subgraphs, + std::vector non_tuning_subgraphs = {}, bool exe_flag = false, + const std::string &path = "", const std::string &user_path = ""); + // Recovery `graph` from graph dump files configured in options + static graphStatus ConvertFileToGraph(const map &options, ge::Graph &graph); + + private: + // part 1 + struct HelpInfo { + int64_t index; + bool exe_flag; + bool is_tuning_graph; + const std::string &path; + const std::string &user_path; + }; + static graphStatus MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info); + static graphStatus HandlePld(NodePtr &node); + static graphStatus HandleEnd(NodePtr &node); + static graphStatus ChangePld2Data(NodePtr &node, NodePtr &data_node); + static graphStatus ChangeEnd2NetOutput(NodePtr &node, NodePtr &out_node); + static graphStatus LinkEnd2NetOutput(NodePtr &node, NodePtr &out_node); + static graphStatus CreateDataNode(NodePtr &node, NodePtr &data_node); + static graphStatus CreateNetOutput(NodePtr &node, NodePtr &out_node); + static graphStatus AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node); + static graphStatus AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node); + static void DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path); + + static SubgraphCreateOutNode create_output_; + // part 2 + static graphStatus MergeAllSubGraph(std::vector &graphs, ComputeGraphPtr &graph); + static graphStatus MergeSubGraph(ComputeGraphPtr &graph); + // Deletes new data and output nodes added by call `MakeExeGraph()` func in part 1 + static graphStatus RemoveDataNetoutputEdge(ComputeGraphPtr &graph); + static graphStatus GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor, + AnchorPtr &src_out_anchor); + static NodeNametoNodeNameMap data_2_netoutput_; + static NodetoNodeNameMap data_node_2_netoutput_; + static NodetoNodeMap data_node_2_netoutput_node_; + static NodeSet netoutput_nodes_; + static NodeSet merged_graph_nodes_; + static std::mutex mutex_; + // for debug + static std::string PrintCheckLog(); + static std::string GetNodeNameByAnchor(const Anchor *anchor); +}; +} // namespace ge +#endif // MAIN_TUNING_UTILS_H diff --git a/inc/graph/usr_types.h b/inc/graph/usr_types.h new file mode 100644 index 00000000..90e02001 --- /dev/null +++ b/inc/graph/usr_types.h @@ -0,0 +1,133 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_USR_TYPES_H_ +#define INC_GRAPH_USR_TYPES_H_ + +#include +#include +#include +namespace ge { +#define USR_TYPE_DEC(type, name) \ + inline void set_##name(const type &value) { name = value; } \ + type *mutable_##name() { return &name; } + +#define USR_TYPE_HAS_DEC(type, name) \ + inline void set_##name(const type &value) { name = value; } \ + \ + private: \ + bool has_mutable_##name{false}; \ + \ + public: \ + bool has_##name() const { return (has_mutable_##name) || QuantizeFactorHasData(name); } \ + type *mutable_##name() { \ + has_mutable_##name = true; \ + return &name; \ + } + +#define USR_TYPE_BYTES_DEC(name) \ + inline void clear_##name() { name.clear(); } \ + inline void set_##name(const void *value, size_t size) { \ + name.assign(reinterpret_cast(const_cast(value)), \ + reinterpret_cast(const_cast(value)) + size); \ + } + +enum UsrQuantizeScaleType { USR_VECTOR_SCALE = 0, USR_SCALAR_SCALE = 1 }; +enum UsrQuantizeScaleMode { USR_NORMAL_MODE = 0, USR_SQRT_MODE = 1 }; +enum UsrQuantizeAlgorithm { + USR_NON_OFFSET_ALGO = 0, + USR_HALF_OFFSET_ALGO = 1, + USR_ALL_OFFSET_ALGO = 2, +}; + +struct UsrQuantizeFactor { + public: + // QuantizeScaleMode scale_mode; + UsrQuantizeScaleMode scale_mode{USR_NORMAL_MODE}; + std::vector scale_value; + int64_t scale_offset{0}; + std::vector offset_data_value; + int64_t offset_data_offset{0}; + std::vector offset_weight_value; + int64_t offset_weight_offset{0}; + std::vector offset_pad_value; + int64_t offset_pad_offset{0}; + + USR_TYPE_DEC(UsrQuantizeScaleMode, scale_mode); + USR_TYPE_BYTES_DEC(scale_value); + + USR_TYPE_DEC(int64_t, scale_offset); + USR_TYPE_BYTES_DEC(offset_data_value); + USR_TYPE_DEC(int64_t, offset_data_offset); + + USR_TYPE_BYTES_DEC(offset_weight_value); + USR_TYPE_DEC(int64_t, offset_weight_offset); + USR_TYPE_BYTES_DEC(offset_pad_value); + USR_TYPE_DEC(int64_t, offset_pad_offset); +}; + +static inline bool QuantizeFactorHasData(const UsrQuantizeFactor &factor) { + return factor.scale_value.size() > 0 || factor.offset_data_value.size() > 0 || + factor.offset_weight_value.size() > 0 || factor.offset_pad_value.size() > 0; +} + +struct UsrQuantizeCalcFactor { + public: + std::vector offsetw; + int64_t offsetw_offset{0}; + std::vector offsetd; + int64_t offsetd_offset{0}; + std::vector scalereq; + int64_t scaledreq_offset{0}; + std::vector offsetdnext; + int64_t offsetdnext_offset{0}; + + USR_TYPE_BYTES_DEC(offsetw); + USR_TYPE_DEC(int64_t, offsetw_offset); + USR_TYPE_BYTES_DEC(offsetd); + USR_TYPE_DEC(int64_t, offsetd_offset); + USR_TYPE_BYTES_DEC(scalereq); + USR_TYPE_DEC(int64_t, scaledreq_offset); + USR_TYPE_BYTES_DEC(offsetdnext); + USR_TYPE_DEC(int64_t, offsetdnext_offset); +}; + +static inline bool QuantizeFactorHasData(const UsrQuantizeCalcFactor &factor) { + return factor.offsetw.size() > 0 || factor.offsetd.size() > 0 || factor.scalereq.size() > 0 || + factor.offsetdnext.size() > 0; +} + +struct UsrQuantizeFactorParams { + UsrQuantizeAlgorithm quantize_algo{USR_NON_OFFSET_ALGO}; + UsrQuantizeScaleType scale_type{USR_VECTOR_SCALE}; + UsrQuantizeFactor quantize_param; + UsrQuantizeFactor dequantize_param; + UsrQuantizeFactor requantize_param; + UsrQuantizeCalcFactor quantizecalc_param; + USR_TYPE_DEC(UsrQuantizeAlgorithm, quantize_algo); + USR_TYPE_DEC(UsrQuantizeScaleType, scale_type); + USR_TYPE_HAS_DEC(UsrQuantizeFactor, quantize_param); + USR_TYPE_HAS_DEC(UsrQuantizeFactor, dequantize_param); + USR_TYPE_HAS_DEC(UsrQuantizeFactor, requantize_param); + USR_TYPE_HAS_DEC(UsrQuantizeCalcFactor, quantizecalc_param); +}; + +#undef USR_TYPE_DEC +#undef USR_TYPE_HAS_DEC +#undef USR_TYPE_BYTES_DEC +} // namespace ge + +#endif // INC_GRAPH_USR_TYPES_H_ diff --git a/inc/graph/utils/anchor_utils.h b/inc/graph/utils/anchor_utils.h new file mode 100644 index 00000000..35b3b035 --- /dev/null +++ b/inc/graph/utils/anchor_utils.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_ANCHOR_UTILS_H_ +#define INC_GRAPH_UTILS_ANCHOR_UTILS_H_ + +#include "graph/anchor.h" +#include "graph/node.h" + +namespace ge { +class AnchorUtils { + public: + // Get anchor format + static Format GetFormat(const DataAnchorPtr &dataAnchor); + + // Set anchor format + static graphStatus SetFormat(const DataAnchorPtr &dataAnchor, Format dataFormat); + + // Get anchor status + static AnchorStatus GetStatus(const DataAnchorPtr &dataAnchor); + + // Set anchor status + static graphStatus SetStatus(const DataAnchorPtr &dataAnchor, AnchorStatus anchorStatus); + + static bool HasControlEdge(const AnchorPtr &anchor); + + static bool IsControlEdge(const AnchorPtr &src, const AnchorPtr &dst); + + static int GetIdx(const AnchorPtr &anchor); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_ANCHOR_UTILS_H_ diff --git a/inc/graph/utils/attr_utils.h b/inc/graph/utils/attr_utils.h new file mode 100644 index 00000000..15a815d4 --- /dev/null +++ b/inc/graph/utils/attr_utils.h @@ -0,0 +1,150 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_ATTR_UTILS_H_ +#define INC_GRAPH_UTILS_ATTR_UTILS_H_ + +#include +#include +#include +#include "graph/detail/attributes_holder.h" +#include "graph/ge_attr_value.h" +#include "graph/types.h" + +namespace ge { +class OpDesc; +using OpDescPtr = std::shared_ptr; +using ConstOpDescPtr = std::shared_ptr; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { + public: + class ConstAttrHolderAdapter; + class AttrHolderAdapter; + // Set + static bool HasAttr(ConstAttrHolderAdapter &&obj, const string &name); + + static bool SetInt(AttrHolderAdapter &&obj, const string &name, const int64_t &value); + static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list &&value); + + static bool SetFloat(AttrHolderAdapter &&obj, const string &name, const float &value); + static bool SetListFloat(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetBool(AttrHolderAdapter &&obj, const string &name, const bool &value); + static bool SetListBool(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetStr(AttrHolderAdapter &&obj, const string &name, const string &value); + static bool SetListStr(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetTensorDesc(AttrHolderAdapter &&obj, const string &name, const GeTensorDesc &value); + static bool SetListTensorDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensorPtr &value); + static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const ConstGeTensorPtr &value); + static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensor &value); + static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, + std::initializer_list &&value); + static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetGraph(AttrHolderAdapter &&obj, const string &name, const ComputeGraphPtr &value); + static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value); + static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NAMED_ATTRS &value); + static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name, + const vector &value); + static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); + + // Get + static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int64_t &value); + static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int32_t &value); + static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, uint32_t &value); + static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetFloat(ConstAttrHolderAdapter &&obj, const string &name, float &value); + static bool GetListFloat(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetBool(ConstAttrHolderAdapter &&obj, const string &name, bool &value); + static bool GetListBool(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetStr(ConstAttrHolderAdapter &&obj, const string &name, string &value); + static bool GetListStr(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, GeTensorDesc &value); + static bool GetListTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value); + static bool MutableTensor(AttrHolderAdapter &&obj, const string &name, GeTensorPtr &value); + static bool GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetGraph(ConstAttrHolderAdapter &&obj, const string &name, ComputeGraphPtr &value); + static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value); + static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NAMED_ATTRS &value); + static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, + vector &value); + static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + // Value will be moved + static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); + static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer); + // Value will be moved + static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector &listBuffer); + static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &listBuffer); + + static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector> &value); + static bool GetListListInt(ConstAttrHolderAdapter &&obj, const string &name, vector> &value); + + static bool SetListDataType(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool GetListDataType(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + + static bool SetDataType(AttrHolderAdapter &&obj, const string &name, const ge::DataType &value); + static bool GetDataType(ConstAttrHolderAdapter &&obj, const string &name, ge::DataType &value); + + static OpDescPtr CloneOpDesc(const ConstOpDescPtr &orgOpDesc); + + static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc); + + static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj); + + class AttrHolderAdapter { + public: + AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {} + ~AttrHolderAdapter() {} + template + AttrHolderAdapter(const std::shared_ptr &obj) : obj_(obj.get()) {} + AttrHolderAdapter(AttrHolder &obj) : obj_(&obj) {} + operator bool() const { return obj_ != nullptr; } + AttrHolder *operator->() { return obj_; } + AttrHolder *get() { return obj_; } + + AttrHolder *obj_; + }; + + class ConstAttrHolderAdapter { + public: + ConstAttrHolderAdapter(const AttrHolder *obj) : obj_(obj) {} + ~ConstAttrHolderAdapter() {} + template + ConstAttrHolderAdapter(const std::shared_ptr obj) : obj_(obj.get()) {} + ConstAttrHolderAdapter(const AttrHolder &obj) : obj_(&obj) {} + operator bool() const { return obj_ != nullptr; } + const AttrHolder *operator->() const { return obj_; } + const AttrHolder *get() const { return obj_; } + + private: + const AttrHolder *obj_; + }; +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_ATTR_UTILS_H_ diff --git a/inc/graph/utils/graph_utils.h b/inc/graph/utils/graph_utils.h new file mode 100644 index 00000000..fdcbe1a9 --- /dev/null +++ b/inc/graph/utils/graph_utils.h @@ -0,0 +1,771 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_GRAPH_UTILS_H_ +#define INC_GRAPH_UTILS_GRAPH_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "graph/anchor.h" +#include "graph/node.h" +#include "graph/compute_graph.h" +#include "graph/utils/anchor_utils.h" +#include "graph/graph.h" +#include "graph/model.h" + +#define GE_DUMP(compute_graph, name) \ + do { \ + GraphUtils::DumpGEGraph(compute_graph, name); \ + GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); \ + uint64_t i = 0; \ + for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { \ + auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); \ + GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); \ + GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); \ + } \ + } while (0) + +#define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ + do { \ + DataType ret; \ + attr.GetValue(ret); \ + } while (0) + +#define PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \ + do { \ + if (value_type == VT_ENUM) { \ + REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ + stream << ret; \ + } \ + } while (0) + +#define PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \ + do { \ + if (value_type == VT_ENUM) { \ + REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ + stream << "["; \ + for (int i = 0; i < ret.size(); i++) { \ + stream << ret[i]; \ + if (i + 1 != ret.size()) stream << ", "; \ + } \ + stream << "]"; \ + } \ + } while (0) + +#define PRINT_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \ + else PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) + +#define PRINT_LIST_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \ + else PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) + +#define PRINT_SHAPE(i_o, n, idx, stream) \ + do { \ + auto op = n->GetOpDesc(); \ + GeTensorDesc td = i_o == "input" ? op->GetInputDesc(idx) : op->GetOutputDesc(idx); \ + auto shape = td.GetShape().GetDims(); \ + stream << "["; \ + for (int i = 0; i < shape.size(); i++) { \ + stream << shape[i]; \ + if (i + 1 < shape.size()) stream << ", "; \ + } \ + stream << "]"; \ + } while (0) + +#define PRINT_ATTR_FUNC(stream) \ + [&](GeAttrValue attr) { \ + auto type = attr.GetValueType(); \ + PRINT_ATTR_VALUE_IF(type, GeAttrValue::ValueType::VT_STRING, GeAttrValue::STR, attr, stream) \ + PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_FLOAT, GeAttrValue::FLOAT, attr, stream) \ + PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_BOOL, GeAttrValue::BOOL, attr, stream) \ + PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_INT, GeAttrValue::INT, attr, stream) \ + PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_STRING, GeAttrValue::LIST_STR, attr, stream) \ + PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_FLOAT, GeAttrValue::LIST_FLOAT, attr, stream) \ + PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_BOOL, GeAttrValue::LIST_BOOL, attr, stream) \ + PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_INT, GeAttrValue::LIST_INT, attr, stream) \ + else if (type == GeAttrValue::ValueType::VT_TENSOR_DESC) stream << "TENSOR_DESC"; \ + else if (type == GeAttrValue::ValueType::VT_TENSOR) stream << "TENSOR"; \ + else if (type == GeAttrValue::ValueType::VT_BYTES) stream << "BYTES"; \ + else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR_DESC) stream << "LIST_TENSOR_DESC"; \ + else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR) stream << "LIST_TENSOR"; \ + else if (type == GeAttrValue::ValueType::VT_LIST_BYTES) stream << "LIST_BYTES"; \ + }; + +namespace ge { +enum IOType { kIn, kOut }; + +struct NodeIndexIO { + NodeIndexIO(ge::NodePtr node, uint32_t index, IOType io_type) + : node_(std::move(node)), index_(index), io_type_(io_type) { + if (node_ != nullptr) { + value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_); + } + } + NodeIndexIO(ge::NodePtr node, int index, IOType io_type) + : node_(std::move(node)), index_(static_cast(index)), io_type_(io_type) { + if (node_ != nullptr) { + value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_); + } + } + ~NodeIndexIO() {} + + NodePtr node_ = nullptr; + uint32_t index_ = 0; + IOType io_type_ = kOut; + std::string value_; + + const std::string &ToString() const { return value_; } +}; + +class GraphUtils { + public: + static ComputeGraphPtr GetComputeGraph(const Graph &graph); + + static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); + + static graphStatus RecoverGraphOperators(const Graph &graph); + + static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector &inputs); + + static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); + + static graphStatus AddEdge(const OutDataAnchorPtr &src, const Format &src_format, const InDataAnchorPtr &dst, + const Format &dst_format); + + static graphStatus AddEdge(const AnchorPtr &src, const AnchorPtr &dst); + + static graphStatus AddEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst); + + static graphStatus AddEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst); + + // check whether src is link to dst and then remove + static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); + + static graphStatus RemoveEdge(const AnchorPtr &src, const AnchorPtr &dst); + + static graphStatus RemoveEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst); + + static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst); + + static graphStatus ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, + const InDataAnchorPtr &new_dst); + + static graphStatus ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst, + const InControlAnchorPtr &new_dst); + + static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, + const NodePtr &new_node); + + static graphStatus RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node); + + static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node); + + static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, + const std::vector &vec_op_desc); + + /// + /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst + /// @param [in] src + /// @param [in] dsts + /// @param [in] insert_node + /// @param [in] input_index + /// @param [in] output_index + /// @return graphStatus + /// + static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector &dsts, + const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); + + static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); + + static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node); + + static void RecordOriginalNames(std::vector original_nodes, const ge::NodePtr &node); + + static void RecordOriginalNames(std::vector names_tmp, const ge::NodePtr &node); + + static bool MatchDumpStr(const std::string &suffix); + + static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false, + const std::string &user_graph_name = ""); + + static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); + + static bool LoadGEGraph(const char *file, ge::ComputeGraphPtr &compute_graph); + + static void BreakConnect(const std::map &all_nodes_infos); + + static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); + + static bool LoadGEGraphFromOnnx(const char *file, ge::ComputeGraph &compute_graph); + + static bool ReadProtoFromTextFile(const char *file, google::protobuf::Message *message); + + static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *real_path); + + static graphStatus AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node); + + /// + /// Isolating `node`, relinking data links from the in-anchor peer nodes to + /// the out-anchor peer nodes according to `io_map`, relinking control links + /// to ensure that input nodes of `node` are before out nodes + /// + /// Link the `io_map[i]` input anchor peer node to `i` output anchor peer + /// nodes, then unlink all links connecting with `node`. If `io_map[i]` < 0, + /// unlink all links from `i` output anchor without any relinking. + /// + /// @param node + /// @param io_map + /// @return + /// + static graphStatus IsolateNode(const NodePtr &node, const std::initializer_list &io_map); + static graphStatus IsolateNode(const NodePtr &node, const std::vector &io_map); + + /// + /// Isolate `node` which must be one input one output, equivalent to + /// `IsolateNode(node, {0})` + /// @param node + /// @return + /// + static graphStatus IsolateNodeOneIO(const NodePtr &node); + + /// + /// The data anchors replacing behavior is the same with + /// `ReplaceNodeDataAnchors`. In addition, replace all `old_node` control + /// anchors with `new_node`'s. + /// @param new_node + /// @param old_node + /// @param inputs_map + /// @param outputs_map + /// @return + /// + static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, + std::initializer_list inputs_map, std::initializer_list outputs_map); + + static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, + const std::vector &inputs_map, const std::vector &outputs_map); + + /// + /// Replace `old_node` data anchors with `new_node`'s according to `inputs_map` and `outputs_map`. + /// Replace the `i` in/out data anchor on `old_node` with + /// `inputs_map[i]`/`outputs_map[i]` data anchor on `new_node`. + /// If `inputs_map[i]`/`outputs_map[i]` < 0 or the index not contained in + /// `inputs_map[i]`/`outputs_map[i]`, the `i` data anchor will remain + /// on `old_node`. + /// @param new_node + /// @param old_node + /// @param inputs_map + /// @param outputs_map + /// @return + /// + static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, + std::initializer_list inputs_map, + std::initializer_list outputs_map); + + static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, + const std::vector &inputs_map, const std::vector &outputs_map); + + /// + /// Copy all in-control edges from `src_node` to `dst_node` + /// @param src_node + /// @param dst_node + /// @return + /// + static graphStatus CopyInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); + + static graphStatus MoveInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); + + /// + /// Copy all out-control edges from `src_node` to `dst_node` + /// @param src_node + /// @param dst_node + /// @return success: GRAPH_SUCESS + /// + static graphStatus CopyOutCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); + + /// + /// Move all out-control edges from `src_node` to `dst_node` + /// @param src_node + /// @param dst_node + /// @return success: GRAPH_SUCESS + /// + static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); + + /// + /// Copy all in-data edges from `src_node` to `dst_node` + /// @param src_node + /// @param dst_node + /// @return + /// + static graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node); + + static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); + + /// + /// Make a copy of ComputeGraph. + /// @param graph: original graph. + /// @param prefix: node name prefix of new graph. + /// @return ComputeGraphPtr + /// + static ComputeGraphPtr CloneGraph(const ComputeGraphPtr &graph, const string &prefix, + std::vector &input_nodes, std::vector &output_nodes); + + /// + /// Copy tensor attribute to new node. + /// @param [in] dst_desc: cloned node. + /// @param [in] src_node: original node. + /// @return success: GRAPH_SUCESS + /// + static graphStatus CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node); + + static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector &node_vec); + + /// + /// Get reference-mapping of all data_anchors in graph + /// @param [in] graph + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus GetRefMapping(const ComputeGraphPtr &graph, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs + /// of the graph have UNKNOWN_SHAPE operators or not. + /// Note: This function will only look 'down' from the graph, not 'up'. For example, the following + /// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE + /// ROOT graph: A -----> B -----> C + /// K subgraph U + /// | + /// V + /// SUB graph: D --> E --> F + /// K K K + /// @param [in] graph + /// @return bool + /// + static bool IsUnknownShapeGraph(const ComputeGraphPtr &graph); + + static NodePtr FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::string &name); + + private: + /// + /// Get reference-mapping for in_data_anchors of node + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleInAnchorMapping(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Get reference-mapping for out_data_anchors of node + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleOutAnchorMapping(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Handle input of subgraph + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleSubgraphInput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Handle input of Merge op + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleMergeInput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Handle output of subgraph + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleSubgraphOutput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Relink all edges for cloned ComputeGraph. + /// @param [in] node: original node. + /// @param [in] prefix: node name prefix of new node. + /// @param [in] all_nodes: all nodes in new graph. + /// @return success: GRAPH_SUCESS + /// + static graphStatus RelinkGraphEdges(const NodePtr &node, const string &prefix, + const std::unordered_map &all_nodes); + + /// + /// Union ref-mapping + /// @param [in] exist_node_info1 + /// @param [in] exist_node_info2 + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @param [out] symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol, std::string &symbol); + + /// + /// Update symbol mapping with a new reference pair + /// @param [in] cur_node_info + /// @param [in] exist_node_info + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Check if out_data_anchor is reference of input + /// @param [in] out_data_anchor + /// @param [out] reuse_in_index + /// @return bool + /// + static bool IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index); +}; + +class ComputeGraphBuilder { + public: + ComputeGraphBuilder() : owner_graph_(nullptr) {} + ComputeGraphBuilder(const ComputeGraphBuilder &) = delete; + ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete; + ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete; + ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete; + ~ComputeGraphBuilder() = default; + + /// + /// @brief Add node to graph + /// @param [in] op_desc + /// @return ComputeGraphBuilder + /// + virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc); + + /// + /// @brief Add data-link among nodes in graph + /// @param [in] src_name + /// @param [in] out_anchor_ind + /// @param [in] dst_name + /// @param [in] in_anchor_ind + /// @return ComputeGraphBuilder + /// + virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, + const std::string &dst_name, uint32_t in_anchor_ind); + + /// + /// @brief Add ctrl-link among nodes in graph + /// @param [in] src_name + /// @param [in] dst_name + /// @return ComputeGraphBuilder + /// + virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name); + + /// + /// @brief Build graph + /// @param [out] error_code + /// @param [out] error_msg + /// @return ComputeGraphPtr + /// + virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0; + + /// @brief Get node with name + /// @param [in] name + /// @return NodePtr + /// + NodePtr GetNode(const std::string &name); + + /// @brief Get all nodes + /// @return std::vector + /// + std::vector GetAllNodes(); + + protected: + /// + /// @brief Build nodes + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void BuildNodes(graphStatus &error_code, std::string &error_msg); + + /// + /// @brief Build data-links + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void BuildDataLinks(graphStatus &error_code, std::string &error_msg); + + /// + /// @brief Build ctrl-links + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg); + + ComputeGraphPtr owner_graph_; + + // node_name -> node + std::map node_names_; + std::vector nodes_; + + // -> + std::vector, std::pair>> data_links_; + // src_node_name -> dst_node_name + std::vector> ctrl_links_; +}; + +class CompleteGraphBuilder : public ComputeGraphBuilder { + public: + explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {} + CompleteGraphBuilder(const CompleteGraphBuilder &) = delete; + CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete; + CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete; + CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete; + ~CompleteGraphBuilder() = default; + + /// + /// @brief Add node to graph + /// @param [in] op_desc + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override; + + /// + /// @brief Add data-link among nodes in graph + /// @param [in] src_name + /// @param [in] out_anchor_ind + /// @param [in] dst_name + /// @param [in] in_anchor_ind + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, + uint32_t in_anchor_ind) override; + + /// + /// @brief Add ctrl-link among nodes in graph + /// @param [in] src_name + /// @param [in] dst_name + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; + + /// + /// @brief Set index_th input anchor for graph + /// @param [in] index + /// @param [in] node_names + /// @param [in] anchor_inds + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &SetInput(uint32_t index, const std::vector &node_names, + const std::vector &anchor_inds); + + /// + /// @brief Set index_th input of graph as useless + /// @param [in] index + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &SetUselessInput(uint32_t index); + + /// + /// @brief Add output anchor for graph + /// @param [in] owner_node_name + /// @param [in] anchor_ind + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind); + + /// + /// @brief Add target for graph + /// @param [in] target_name + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &AddTarget(const std::string &target_name); + + /// + /// @brief Set parent-node of graph + /// @param [in] parent_node + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node); + + /// + /// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node + /// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &SetInputMapping(const std::map &input_mapping); + + /// + /// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind + /// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &SetOutputMapping(const std::map &output_mapping); + + /// + /// @brief Build graph + /// @param [out] error_code + /// @param [out] error_msg + /// @return ComputeGraphPtr + /// + ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; + + private: + /// + /// @brief Add data nodes + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void AddDataNodes(graphStatus &error_code, std::string &error_msg); + + /// + /// @brief Add data node + /// @param [in] index + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + NodePtr AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg); + + /// + /// @brief Add RetVal nodes + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void AddRetValNodes(graphStatus &error_code, std::string &error_msg); + + /// + /// @brief Build target-nodes for graph + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void BuildGraphTargets(graphStatus &error_code, std::string &error_msg); + + std::string name_; + NodePtr parent_node_; + std::map, std::vector>> graph_inputs_; + std::vector> graph_outputs_; + std::vector graph_targets_; + + // index_of_graph_input -> in_anchor_index_of_parent_node + std::map input_mapping_; + // index_of_graph_output -> out_anchor_index_of_parent_node + std::map output_mapping_; +}; + +class PartialGraphBuilder : public ComputeGraphBuilder { + public: + PartialGraphBuilder() = default; + PartialGraphBuilder(const PartialGraphBuilder &) = delete; + PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete; + PartialGraphBuilder(const PartialGraphBuilder &&) = delete; + PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete; + ~PartialGraphBuilder() = default; + + /// + /// @brief Add node to graph + /// @param [in] op_desc + /// @return PartialGraphBuilder + /// + PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override; + + /// + /// @brief Add data-link among nodes in graph + /// @param [in] src_name + /// @param [in] out_anchor_ind + /// @param [in] dst_name + /// @param [in] in_anchor_ind + /// @return PartialGraphBuilder + /// + PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, + uint32_t in_anchor_ind) override; + + /// + /// @brief Add ctrl-link among nodes in graph + /// @param [in] src_name + /// @param [in] dst_name + /// @return PartialGraphBuilder + /// + PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; + + /// + /// @brief Set owner graph + /// @param [in] graph + /// @return PartialGraphBuilder + /// + PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph); + + /// + /// @brief Add exist node + /// @param [in] node + /// @return PartialGraphBuilder + /// + PartialGraphBuilder &AddExistNode(const NodePtr &node); + + /// + /// @brief Build multi nodes with links + /// @param [out] error_code + /// @param [out] error_msg + /// @return ComputeGraphPtr + /// + ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; + + private: + /// + /// @brief Build exist nodes + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void BuildExistNodes(graphStatus &error_code, std::string &error_msg); + + std::vector exist_nodes_; +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_ diff --git a/inc/graph/utils/node_utils.h b/inc/graph/utils/node_utils.h new file mode 100644 index 00000000..bf57148d --- /dev/null +++ b/inc/graph/utils/node_utils.h @@ -0,0 +1,170 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_NODE_UTILS_H_ +#define INC_GRAPH_UTILS_NODE_UTILS_H_ + +#include +#include +#include +#include "external/graph/operator.h" +#include "graph/node.h" + +namespace ge { +// Op types of Const like Opps. +extern const std::set kConstOpTypes; +// Op types of If like Opps. +extern const std::set kIfOpTypes; +// Op types of While like Opps. +extern const std::set kWhileOpTypes; +// Op types of Case like Opps. +extern const std::set kCaseOpTypes; +// Op types of For like Opps. +extern const std::set kForOpTypes; + +class NodeUtils { + public: + static graphStatus AddSendEventId(const NodePtr &node, const uint32_t &event_id); + static graphStatus AddRecvEventId(const NodePtr &node, const uint32_t &event_id); + static graphStatus GetSendEventIdList(const NodePtr &node, std::vector &vec_send); + static graphStatus GetRecvEventIdList(const NodePtr &node, std::vector &vec_recv); + + static graphStatus ClearSendInfo(); + static graphStatus ClearRecvInfo(); + + static graphStatus GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst); + + static graphStatus GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data, + InControlAnchorPtr &in_control); + + static graphStatus ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor); + static graphStatus SetAllAnchorStatus(const NodePtr &nodePtr); + static graphStatus SetAllAnchorStatus(Node &node); + static bool IsAnchorStatusSet(const NodePtr &nodePtr); + static bool IsAnchorStatusSet(const Node &node); + + static graphStatus MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node); + + static void UpdateIsInputConst(const NodePtr &nodePtr); + static void UpdateIsInputConst(Node &node); + static bool IsConst(const Node &node); + static void UnlinkAll(const Node &node); + static graphStatus UpdatePeerNodeInputDesc(const NodePtr &node_ptr); + + static graphStatus AppendInputAnchor(const NodePtr &node, uint32_t num); + static graphStatus RemoveInputAnchor(const NodePtr &node, uint32_t num); + + static graphStatus AppendOutputAnchor(const NodePtr &node, uint32_t num); + static graphStatus RemoveOutputAnchor(const NodePtr &node, uint32_t num); + + static bool IsInNodesEmpty(const Node &node); + static GeTensorDesc GetOutputDesc(const Node &node, uint32_t index); + static GeTensorDesc GetInputDesc(const Node &node, uint32_t index); + static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); + static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); + // check node whether unknown shape.If node shape contain -1 or -2,out param "is_unknow" will be true; + // for func op, it will check subgraph yet, if some node shape of subgraph contain -1 or -2, + // the out param "is_unknow" will be true too + static graphStatus GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow); + + static std::string GetNodeType(const Node &node); + static std::string GetNodeType(const NodePtr &node); + + static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); + static graphStatus SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph); + + /// + /// Check if node is input of subgraph + /// @param [in] node + /// @return bool + /// + static bool IsSubgraphInput(const NodePtr &node); + + /// + /// Check if node is output of subgraph + /// @param [in] node + /// @return bool + /// + static bool IsSubgraphOutput(const NodePtr &node); + + /// + /// @brief Get subgraph original input node. + /// @param [in] node + /// @return Node + /// + static NodePtr GetParentInput(const Node &node); + static NodePtr GetParentInput(const NodePtr &node); + + /// + /// @brief Get is dynamic shape graph from node. + /// @param [in] node + /// @return bool + /// + static bool IsDynamicShape(const Node &node); + static bool IsDynamicShape(const NodePtr &node); + + /// + /// @brief Check is varying_input for while node + /// @param [in] node: Data node for subgraph + /// @return bool + /// + static bool IsWhileVaryingInput(const ge::NodePtr &node); + + /// + /// @brief Get subgraph input is constant. + /// @param [in] node + /// @param [out] string + /// @return bool + /// + static bool GetConstOpType(const NodePtr &node, std::string &type); + + /// + /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. + /// @param [in] node + /// @return return GRAPH_SUCCESS if remove successfully, other for failed. + /// + static graphStatus RemoveSubgraphsOnNode(const NodePtr &node); + + /// + /// @brief Get subgraph input data node by index. + /// @param [in] node + /// @return Node + /// + static vector GetSubgraphDataNodesByIndex(const Node &node, int index); + + /// + /// @brief Get subgraph input data node by index. + /// @param [in] node + /// @return Node + /// + static vector GetSubgraphOutputNodes(const Node &node); + + static NodePtr GetInDataNodeByIndex(const Node &node, const int index); + + static vector> GetOutDataNodesWithAnchorByIndex(const Node &node, const int index); + + static ge::ConstNodePtr GetNodeFromOperator(const Operator &oprt); + + static graphStatus GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor); + + static graphStatus GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor); + + private: + static std::map> map_send_info_; + static std::map> map_recv_info_; +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_NODE_UTILS_H_ diff --git a/inc/graph/utils/op_desc_utils.h b/inc/graph/utils/op_desc_utils.h new file mode 100644 index 00000000..daa95ebe --- /dev/null +++ b/inc/graph/utils/op_desc_utils.h @@ -0,0 +1,182 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_OP_DESC_UTILS_H_ +#define INC_GRAPH_UTILS_OP_DESC_UTILS_H_ + +#include +#include +#include +#include "graph/def_types.h" +#include "graph/node.h" +#include "graph/op_desc.h" +#include "graph/operator.h" +#include "graph/range_vistor.h" + +namespace ge { +class OpDesc; +using OpDescPtr = std::shared_ptr; + +class OpDescUtils { + public: + template + using Vistor = RangeVistor>; + + OpDescUtils() = default; + ~OpDescUtils() = default; + static bool HasQuantizeFactorParams(const OpDescPtr& op_desc); + static bool HasQuantizeFactorParams(const OpDesc& op_desc); + static graphStatus GetQuantizeFactorParams(const OpDescPtr& op_desc, QuantizeFactorParams& quant); + static graphStatus GetQuantizeFactorParams(const OpDesc& op_desc, QuantizeFactorParams& quant); + static graphStatus SetQuantizeFactorParams(const OpDescPtr& op_desc, const QuantizeFactorParams& quant); + static graphStatus SetQuantizeFactorParams(OpDesc& op_desc, const QuantizeFactorParams& quant); + + static vector GetConstInputNode(const ge::Node& node); + static vector GetInputData(const vector& input_nodes); + + static vector GetWeights(const ge::Node& node); + static vector GetWeights(const ge::ConstNodePtr& node); + static vector MutableWeights(const ge::Node& node); + static vector MutableWeights(const ge::NodePtr node); + static graphStatus SetWeights(ge::Node& node, const vector& weights); + static graphStatus SetWeights(ge::NodePtr node, const vector& weights); + static graphStatus SetWeights(ge::Node& node, const map& weights_map); + static graphStatus ClearWeights(ge::NodePtr node); + + static bool ClearInputDesc(ge::OpDescPtr op_desc, uint32_t index); + static bool ClearInputDesc(const ge::NodePtr& node); + static bool ClearOutputDesc(const ge::OpDescPtr& op_desc, uint32_t index); + static bool ClearOutputDesc(const ge::NodePtr& node); + static vector GetConstInputs(const ge::Node& node); + static vector GetConstInputs(const ge::ConstNodePtr& node); + static size_t GetNonConstInputsSize(const ge::Node& node); + static size_t GetNonConstInputsSize(ge::ConstNodePtr node); + // Index: Indicates the index of all non const inputs + static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node& node, size_t index_non_const = 0); + static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr& node, size_t index_non_const = 0); + static bool GetNonConstInputIndex(const ge::Node& node, size_t index_non_const, size_t& index); + static bool GetNonConstInputIndex(const ge::ConstNodePtr& node, size_t index_non_const, size_t& index); + // Index: Indicates the index of all inputs + static bool IsNonConstInput(const ge::Node& node, size_t index = 0); + static bool IsNonConstInput(const ge::ConstNodePtr& node, size_t index = 0); + + static vector GetNonConstTensorDesc(const ge::ConstNodePtr& node); + static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr& tensor_ptr); + + static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc); + static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr); + static OpDescPtr GetOpDescFromOperator(const Operator& oprt); + + static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); + + static graphStatus SetSubgraphInstanceName(const std::string& subgraph_name, + const std::string& subgraph_instance_name, OpDescPtr& op_desc); + + private: + static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); + static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); + static graphStatus SetWeights(ge::OpDesc& op_desc, const GeTensorPtr weight); + static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight); +}; + +class OpDescBuilder { + public: + OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {} + OpDescBuilder(const OpDescBuilder&) = delete; + OpDescBuilder& operator=(const OpDescBuilder&) = delete; + OpDescBuilder(const OpDescBuilder&&) = delete; + OpDescBuilder& operator=(const OpDescBuilder&&) = delete; + ~OpDescBuilder() = default; + + /// + /// @brief Add input + /// @param [in] name + /// @return OpDescBuilder + /// + OpDescBuilder& AddInput(const std::string& name); + + /// + /// @brief Add input + /// @param [in] name + /// @param [in] tensor + /// @return OpDescBuilder + /// + OpDescBuilder& AddInput(const std::string& name, const GeTensorDesc& tensor); + + /// + /// @brief Add dynamic input + /// @param [in] name + /// @param [in] num + /// @return OpDescBuilder + /// + OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num); + + /// + /// @brief Add dynamic input + /// @param [in] name + /// @param [in] num + /// @param [in] tensor + /// @return OpDescBuilder + /// + OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); + + /// + /// @brief Add output + /// @param [in] name + /// @return OpDescBuilder + /// + OpDescBuilder& AddOutput(const std::string& name); + + /// + /// @brief Add output + /// @param [in] name + /// @param [in] tensor + /// @return OpDescBuilder + /// + OpDescBuilder& AddOutput(const std::string& name, const GeTensorDesc& tensor); + + /// + /// @brief Add dynamic output + /// @param [in] name + /// @param [in] num + /// @return OpDescBuilder + /// + OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num); + + /// + /// @brief Add dynamic output + /// @param [in] name + /// @param [in] num + /// @param [in] tensor + /// @return OpDescBuilder + /// + OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); + + /// + /// @brief Build op_desc + /// @return OpDescPtr + /// + OpDescPtr Build(); + + private: + std::string name_; + std::string type_; + std::vector> inputs_; + std::vector> outputs_; +}; +} // namespace ge + +#endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_ diff --git a/inc/graph/utils/tensor_adapter.h b/inc/graph/utils/tensor_adapter.h new file mode 100644 index 00000000..a7355553 --- /dev/null +++ b/inc/graph/utils/tensor_adapter.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ +#define INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ + +#include +#include "graph/ge_tensor.h" +#include "graph/tensor.h" + +namespace ge { +using GeTensorPtr = std::shared_ptr; +using ConstGeTensorPtr = std::shared_ptr; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorAdapter { + public: + static GeTensorDesc TensorDesc2GeTensorDesc(const TensorDesc &tensorDesc); + static TensorDesc GeTensorDesc2TensorDesc(const GeTensorDesc &geTensorDesc); + static GeTensorPtr Tensor2GeTensor(const Tensor &tensor); + static Tensor GeTensor2Tensor(const ConstGeTensorPtr &geTensor); + + static ConstGeTensorPtr AsGeTensorPtr(const Tensor &tensor); // Share value + static GeTensorPtr AsGeTensorPtr(Tensor &tensor); // Share value + static const GeTensor AsGeTensor(const Tensor &tensor); // Share value + static GeTensor AsGeTensor(Tensor &tensor); // Share value + static const Tensor AsTensor(const GeTensor &tensor); // Share value + static Tensor AsTensor(GeTensor &tensor); // Share value +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ diff --git a/inc/graph/utils/tensor_utils.h b/inc/graph/utils/tensor_utils.h new file mode 100644 index 00000000..caa80dcf --- /dev/null +++ b/inc/graph/utils/tensor_utils.h @@ -0,0 +1,77 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_TENSOR_UTILS_H_ +#define INC_GRAPH_UTILS_TENSOR_UTILS_H_ + +#include +#include "graph/def_types.h" +#include "graph/ge_error_codes.h" +#include "graph/ge_tensor.h" + +namespace ge { +class TensorUtils { + public: + static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, int64_t &size); + static void SetSize(GeTensorDesc &tensorDesc, int64_t size); + static uint32_t GetWeightSize(const ConstGeTensorPtr &tensorPtr); + static uint32_t GetWeightSize(const GeTensor &tensor); + static uint32_t GetWeightSize(const GeTensorDesc &tensorDesc); + static uint8_t *GetWeightAddr(const ConstGeTensorPtr &tensorPtr, uint8_t *base); + static uint8_t *GetWeightAddr(const GeTensor &tensor, uint8_t *base); + static void SetWeightSize(GeTensorDesc &tensorDesc, uint32_t size); + static ge::graphStatus GetReuseInput(const GeTensorDesc &tensorDesc, bool &flag); + static void SetReuseInput(GeTensorDesc &tensorDesc, bool flag); + static ge::graphStatus GetOutputTensor(const GeTensorDesc &tensorDesc, bool &flag); + static void SetOutputTensor(GeTensorDesc &tensorDesc, bool flag); + static graphStatus GetDeviceType(const GeTensorDesc &tensorDesc, DeviceType &type); + static void SetDeviceType(GeTensorDesc &tensorDesc, DeviceType type); + static ge::graphStatus GetInputTensor(const GeTensorDesc &tensorDesc, bool &flag); + static void SetInputTensor(GeTensorDesc &tensorDesc, bool flag); + static ge::graphStatus GetRealDimCnt(const GeTensorDesc &tensorDesc, uint32_t &cnt); + static void SetRealDimCnt(GeTensorDesc &tensorDesc, uint32_t cnt); + static ge::graphStatus GetReuseInputIndex(const GeTensorDesc &tensorDesc, uint32_t &idx); + static void SetReuseInputIndex(GeTensorDesc &tensorDesc, uint32_t idx); + static ge::graphStatus GetDataOffset(const GeTensorDesc &tensorDesc, int64_t &offset); + static void SetDataOffset(GeTensorDesc &tensorDesc, int64_t offset); + static ge::graphStatus GetCmpsSize(const GeTensorDesc &tensorDesc, uint32_t &cmp_size); + static void SetCmpsSize(GeTensorDesc &tensorDesc, uint32_t cmp_size); + static ge::graphStatus GetCmpsTab(const GeTensorDesc &tensorDesc, vector &vec); + static void SetCmpsTab(GeTensorDesc &tensorDesc, const uint8_t *data, size_t size); + static ge::graphStatus GetCmpsTabOffset(const GeTensorDesc &tensorDesc, int64_t &tab_offset); + static void SetCmpsTabOffset(GeTensorDesc &tensorDesc, int64_t tab_offset); + static ge::graphStatus GetCmpsInfo(const GeTensorDesc &tensorDesc, CompressInfo &info); + static void SetCmpsInfo(GeTensorDesc &tensorDesc, const CompressInfo &info); + static bool HasAlloffsetQuantizeInfo(const GeTensorDesc &tensorDesc); + static ge::graphStatus GetAlloffsetQuantizeInfo(const GeTensorDesc &tensorDesc, AllOffsetQuantizeInfo &info); + static void SetAlloffsetQuantizeInfo(GeTensorDesc &tensorDesc, const AllOffsetQuantizeInfo &info); + static ge::graphStatus GetRC(const GeTensorDesc &tensorDesc, uint32_t &rc); + static void SetRC(GeTensorDesc &tensorDesc, uint32_t rc); + + /// + /// calculate tensor mem size. + /// @param shape tensor shape + /// @param format tensor format + /// @param data_type tensor data type + /// @param mem_size -1 means unknown shape,other means mem size + /// @return GRAPH_SUCCESS:success, other:failed + /// + static ge::graphStatus CalcTensorMemSize(const GeShape &shape, Format format, DataType data_type, int64_t &mem_size); + static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); + static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_ diff --git a/inc/graph/utils/type_utils.h b/inc/graph/utils/type_utils.h new file mode 100644 index 00000000..38509b9a --- /dev/null +++ b/inc/graph/utils/type_utils.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_TYPE_UTILS_H_ +#define INC_GRAPH_UTILS_TYPE_UTILS_H_ + +#include +#include +#include +#include "graph/def_types.h" +#include "graph/ge_error_codes.h" +#include "graph/types.h" +#include "graph/usr_types.h" +#include "register/register_types.h" +#include "external/register/register_fmk_types.h" + +namespace ge { +class TypeUtils { + public: + static bool IsDataTypeValid(DataType dt); + static bool IsFormatValid(Format format); + static bool IsInternalFormat(Format format); + + static std::string ImplyTypeToSerialString(domi::ImplyType imply_type); + static std::string DataTypeToSerialString(DataType data_type); + static DataType SerialStringToDataType(const std::string &str); + static std::string FormatToSerialString(Format format); + static Format SerialStringToFormat(const std::string &str); + static Format DataFormatToFormat(const std::string &str); + static Format DomiFormatToFormat(domi::domiTensorFormat_t domi_format); + static std::string FmkTypeToSerialString(domi::FrameworkType fmk_type); + + static graphStatus Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def); + static graphStatus Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr); + + static bool GetDataTypeLength(ge::DataType data_type, uint32_t &length); + static bool CheckUint64MulOverflow(uint64_t a, uint32_t b); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_TYPE_UTILS_H_ diff --git a/metadef b/metadef deleted file mode 160000 index 6681ff4b..00000000 --- a/metadef +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6681ff4b61da48441640602501e76dceb1bf1bb6 diff --git a/parser b/parser deleted file mode 160000 index c6b1f992..00000000 --- a/parser +++ /dev/null @@ -1 +0,0 @@ -Subproject commit c6b1f992dbcc73d8da2106975d4fcce29ff38a78 diff --git a/src/common/graph/CMakeLists.txt b/src/common/graph/CMakeLists.txt new file mode 100755 index 00000000..bb63eb81 --- /dev/null +++ b/src/common/graph/CMakeLists.txt @@ -0,0 +1,77 @@ +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# libgraph.so +# compiling proto files generates some warnings, use no-unused-variable to suppress them +set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") +# add all proto files, generate corresponding .h and .cc files +file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "../../proto/om.proto" + "../../proto/ge_ir.proto" + "../../proto/insert_op.proto" + "../../proto/task.proto" + "../../proto/fwk_adaper.proto" + "../../proto/op_mapping_info.proto" + "../../proto/dump_task.proto" + ) + +file(GLOB_RECURSE ONNX_PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "${onnx_INC}/onnx/onnx.proto" + ) + +ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +ge_protobuf_generate(ge PROTO_ONNX_SRCS PROTO_ONNX_HDRS ${ONNX_PROTO_LIST}) + +# need to remove dependencies on pb files later +file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "*.cc" + "utils/*.cc" + "opsproto/*.cc" + "detail/*.cc" + "debug/*.cc" + "option/*.cc" + ) + +# include directories +include_directories(${CMAKE_CURRENT_LIST_DIR}) +include_directories(${GE_SOURCE_DIR}) +include_directories(${GE_SOURCE_DIR}/src) +include_directories(${GE_SOURCE_DIR}/src/ge) +include_directories(${GE_SOURCE_DIR}/src/common) +include_directories(${GE_SOURCE_DIR}/src/common/graph) +include_directories(${GE_SOURCE_DIR}/inc) +include_directories(${GE_SOURCE_DIR}/inc/framework) +include_directories(${GE_SOURCE_DIR}/inc/external) +include_directories(${GE_SOURCE_DIR}/inc/external/graph) +include_directories(${GE_SOURCE_DIR}/inc/graph) +include_directories(${GE_SOURCE_DIR}/inc/common) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) +include_directories(${CMAKE_BINARY_DIR}) +include_directories(${CMAKE_BINARY_DIR}/proto/ge) +include_directories(${GE_SOURCE_DIR}/build) + +######### libgraph.so ############# +add_library(graph SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_ONNX_SRCS}) +target_compile_definitions(graph PRIVATE + DAVINCI_CLOUD + Werror) +target_link_libraries(graph PRIVATE + ${PROTOBUF_LIBRARY} + ${c_sec} + ${slog} + ${error_manager} + rt + dl) diff --git a/src/common/graph/anchor.cc b/src/common/graph/anchor.cc new file mode 100644 index 00000000..f02037e5 --- /dev/null +++ b/src/common/graph/anchor.cc @@ -0,0 +1,371 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/anchor.h" +#include +#include +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/node.h" + +namespace ge { +Anchor::Anchor(const NodePtr &owner_node, int idx) : owner_node_(owner_node), idx_(idx) {} + +bool Anchor::IsTypeOf(TYPE type) const { return strcmp(Anchor::TypeOf(), type) == 0; } + +size_t Anchor::GetPeerAnchorsSize() const { return peer_anchors_.size(); } + +Anchor::Vistor Anchor::GetPeerAnchors() const { + vector ret; + for (const auto &anchor : peer_anchors_) { + ret.push_back(anchor.lock()); + } + return Anchor::Vistor(shared_from_this(), ret); +} + +AnchorPtr Anchor::GetFirstPeerAnchor() const { + if (peer_anchors_.empty()) { + return nullptr; + } else { + return Anchor::DynamicAnchorCast(peer_anchors_.begin()->lock()); + } +} + +NodePtr Anchor::GetOwnerNode() const { return owner_node_.lock(); } + +void Anchor::UnlinkAll() noexcept { + if (!peer_anchors_.empty()) { + do { + auto peer_anchor_ptr = peer_anchors_.begin()->lock(); + if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { + GELOGW("unlink peer_anchor_ptr failed."); + } + } while (!peer_anchors_.empty()); + } +} + +graphStatus Anchor::Unlink(const AnchorPtr &peer) { + if (peer == nullptr) { + GELOGE(GRAPH_FAILED, "peer anchor is invalid."); + return GRAPH_FAILED; + } + auto it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [peer](const std::weak_ptr &an) { + auto anchor = an.lock(); + return peer->Equal(anchor); + }); + + GE_IF_BOOL_EXEC(it == peer_anchors_.end(), GELOGW("this anchor is not connected to peer"); return GRAPH_FAILED); + + auto it_peer = + std::find_if(peer->peer_anchors_.begin(), peer->peer_anchors_.end(), [this](const std::weak_ptr &an) { + auto anchor = an.lock(); + return Equal(anchor); + }); + + GE_CHK_BOOL_RET_STATUS(it_peer != peer->peer_anchors_.end(), GRAPH_FAILED, "peer is not connected to this anchor"); + + (void)peer_anchors_.erase(it); + (void)peer->peer_anchors_.erase(it_peer); + return GRAPH_SUCCESS; +} + +graphStatus Anchor::ReplacePeer(const AnchorPtr &old_peer, const AnchorPtr &first_peer, const AnchorPtr &second_peer) { + GE_CHK_BOOL_RET_STATUS(old_peer != nullptr, GRAPH_FAILED, "this old peer anchor is nullptr"); + GE_CHK_BOOL_RET_STATUS(first_peer != nullptr, GRAPH_FAILED, "this first peer anchor is nullptr"); + GE_CHK_BOOL_RET_STATUS(second_peer != nullptr, GRAPH_FAILED, "this second peer anchor is nullptr"); + auto this_it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [old_peer](const std::weak_ptr &an) { + auto anchor = an.lock(); + return old_peer->Equal(anchor); + }); + + GE_CHK_BOOL_RET_STATUS(this_it != peer_anchors_.end(), GRAPH_FAILED, "this anchor is not connected to old_peer"); + + auto old_it = std::find_if(old_peer->peer_anchors_.begin(), old_peer->peer_anchors_.end(), + [this](const std::weak_ptr &an) { + auto anchor = an.lock(); + return Equal(anchor); + }); + + GE_CHK_BOOL_RET_STATUS(old_it != old_peer->peer_anchors_.end(), GRAPH_FAILED, + "old_peer is not connected to this anchor"); + *this_it = first_peer; + first_peer->peer_anchors_.push_back(shared_from_this()); + *old_it = second_peer; + second_peer->peer_anchors_.push_back(old_peer); + return GRAPH_SUCCESS; +} + +bool Anchor::IsLinkedWith(const AnchorPtr &peer) { + auto it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [peer](const std::weak_ptr &an) { + auto anchor = an.lock(); + GE_CHK_BOOL_RET_STATUS(peer != nullptr, false, "this old peer anchor is nullptr"); + return peer->Equal(anchor); + }); + return (it != peer_anchors_.end()); +} + +int Anchor::GetIdx() const { return idx_; } + +void Anchor::SetIdx(int index) { idx_ = index; } + +DataAnchor::DataAnchor(const NodePtr &owner_node, int idx) : Anchor(owner_node, idx) {} + +bool DataAnchor::IsTypeOf(TYPE type) const { + if (strcmp(Anchor::TypeOf(), type) == 0) { + return true; + } + return Anchor::IsTypeOf(type); +} + +InDataAnchor::InDataAnchor(const NodePtr &owner_node, int idx) : DataAnchor(owner_node, idx) {} + +OutDataAnchorPtr InDataAnchor::GetPeerOutAnchor() const { + if (peer_anchors_.empty()) { + return nullptr; + } else { + return Anchor::DynamicAnchorCast(peer_anchors_.begin()->lock()); + } +} + +graphStatus InDataAnchor::LinkFrom(const OutDataAnchorPtr &src) { + // InDataAnchor must be only linkfrom once + if (src == nullptr || !peer_anchors_.empty()) { + GELOGE(GRAPH_FAILED, "src anchor is invalid or the peerAnchors is not empty."); + return GRAPH_FAILED; + } + peer_anchors_.push_back(src); + src->peer_anchors_.push_back(shared_from_this()); + return GRAPH_SUCCESS; +} + +bool InDataAnchor::Equal(AnchorPtr anchor) const { + auto in_data_anchor = Anchor::DynamicAnchorCast(anchor); + if (in_data_anchor != nullptr) { + if (GetOwnerNode() == in_data_anchor->GetOwnerNode() && GetIdx() == in_data_anchor->GetIdx()) { + return true; + } + } + return false; +} + +bool InDataAnchor::IsTypeOf(TYPE type) const { + if (strcmp(Anchor::TypeOf(), type) == 0) { + return true; + } + return DataAnchor::IsTypeOf(type); +} + +OutDataAnchor::OutDataAnchor(const NodePtr &owner_node, int idx) : DataAnchor(owner_node, idx) {} + +OutDataAnchor::Vistor OutDataAnchor::GetPeerInDataAnchors() const { + vector ret; + for (const auto &anchor : peer_anchors_) { + auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); + if (in_data_anchor != nullptr) { + ret.push_back(in_data_anchor); + } + } + return OutDataAnchor::Vistor(shared_from_this(), ret); +} + +uint32_t OutDataAnchor::GetPeerInDataNodesSize() const { + uint32_t out_nums = 0; + for (const auto &anchor : peer_anchors_) { + auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); + if (in_data_anchor != nullptr && in_data_anchor->GetOwnerNode() != nullptr) { + out_nums++; + } + } + return out_nums; +} + +OutDataAnchor::Vistor OutDataAnchor::GetPeerInControlAnchors() const { + vector ret; + for (const auto &anchor : peer_anchors_) { + auto in_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); + if (in_control_anchor != nullptr) { + ret.push_back(in_control_anchor); + } + } + return OutDataAnchor::Vistor(shared_from_this(), ret); +} + +graphStatus OutDataAnchor::LinkTo(const InDataAnchorPtr &dest) { + if (dest == nullptr || !dest->peer_anchors_.empty()) { + GELOGE(GRAPH_FAILED, "dest anchor is invalid or the peerAnchors is not empty."); + return GRAPH_FAILED; + } + peer_anchors_.push_back(dest); + dest->peer_anchors_.push_back(shared_from_this()); + return GRAPH_SUCCESS; +} + +graphStatus OutDataAnchor::LinkTo(const InControlAnchorPtr &dest) { + if (dest == nullptr) { + GELOGE(GRAPH_FAILED, "dest anchor is invalid."); + return GRAPH_FAILED; + } + peer_anchors_.push_back(dest); + dest->peer_anchors_.push_back(shared_from_this()); + return GRAPH_SUCCESS; +} + +graphStatus OutControlAnchor::LinkTo(const InDataAnchorPtr &dest) { + if (dest == nullptr) { + GELOGE(GRAPH_FAILED, "dest anchor is invalid."); + return GRAPH_FAILED; + } + peer_anchors_.push_back(dest); + dest->peer_anchors_.push_back(shared_from_this()); + return GRAPH_SUCCESS; +} + +bool OutDataAnchor::Equal(AnchorPtr anchor) const { + CHECK_FALSE_EXEC(anchor != nullptr, return false); + auto out_data_anchor = Anchor::DynamicAnchorCast(anchor); + if (out_data_anchor != nullptr) { + if (GetOwnerNode() == out_data_anchor->GetOwnerNode() && GetIdx() == out_data_anchor->GetIdx()) { + return true; + } + } + return false; +} + +bool OutDataAnchor::IsTypeOf(TYPE type) const { + if (strcmp(Anchor::TypeOf(), type) == 0) { + return true; + } + return DataAnchor::IsTypeOf(type); +} + +ControlAnchor::ControlAnchor(const NodePtr &owner_node) : Anchor(owner_node, -1) {} + +ControlAnchor::ControlAnchor(const NodePtr &owner_node, int idx) : Anchor(owner_node, idx) {} + +bool ControlAnchor::IsTypeOf(TYPE type) const { + if (strcmp(Anchor::TypeOf(), type) == 0) { + return true; + } + return Anchor::IsTypeOf(type); +} + +InControlAnchor::InControlAnchor(const NodePtr &owner_node) : ControlAnchor(owner_node) {} + +InControlAnchor::InControlAnchor(const NodePtr &owner_node, int idx) : ControlAnchor(owner_node, idx) {} + +InControlAnchor::Vistor InControlAnchor::GetPeerOutControlAnchors() const { + vector ret; + for (const auto &anchor : peer_anchors_) { + auto out_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); + if (out_control_anchor != nullptr) { + ret.push_back(out_control_anchor); + } + } + return InControlAnchor::Vistor(shared_from_this(), ret); +} + +InControlAnchor::Vistor InControlAnchor::GetPeerOutDataAnchors() const { + vector ret; + for (const auto &anchor : peer_anchors_) { + auto out_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); + if (out_data_anchor != nullptr) { + ret.push_back(out_data_anchor); + } + } + return InControlAnchor::Vistor(shared_from_this(), ret); +} + +graphStatus InControlAnchor::LinkFrom(const OutControlAnchorPtr &src) { + if (src == nullptr) { + GELOGE(GRAPH_FAILED, "src anchor is invalid."); + return GRAPH_FAILED; + } + peer_anchors_.push_back(src); + src->peer_anchors_.push_back(shared_from_this()); + return GRAPH_SUCCESS; +} + +bool InControlAnchor::Equal(AnchorPtr anchor) const { + CHECK_FALSE_EXEC(anchor != nullptr, return false); + auto in_control_anchor = Anchor::DynamicAnchorCast(anchor); + if (in_control_anchor != nullptr) { + if (GetOwnerNode() == in_control_anchor->GetOwnerNode()) { + return true; + } + } + return false; +} + +bool InControlAnchor::IsTypeOf(TYPE type) const { + if (strcmp(Anchor::TypeOf(), type) == 0) { + return true; + } + return ControlAnchor::IsTypeOf(type); +} + +OutControlAnchor::OutControlAnchor(const NodePtr &owner_node) : ControlAnchor(owner_node) {} + +OutControlAnchor::OutControlAnchor(const NodePtr &owner_node, int idx) : ControlAnchor(owner_node, idx) {} + +OutControlAnchor::Vistor OutControlAnchor::GetPeerInControlAnchors() const { + vector ret; + for (const auto &anchor : peer_anchors_) { + auto in_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); + if (in_control_anchor != nullptr) { + ret.push_back(in_control_anchor); + } + } + return OutControlAnchor::Vistor(shared_from_this(), ret); +} + +OutControlAnchor::Vistor OutControlAnchor::GetPeerInDataAnchors() const { + vector ret; + for (const auto &anchor : peer_anchors_) { + auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); + if (in_data_anchor != nullptr) { + ret.push_back(in_data_anchor); + } + } + return OutControlAnchor::Vistor(shared_from_this(), ret); +} + +graphStatus OutControlAnchor::LinkTo(const InControlAnchorPtr &dest) { + if (dest == nullptr) { + GELOGE(GRAPH_FAILED, "dest anchor is invalid."); + return GRAPH_FAILED; + } + peer_anchors_.push_back(dest); + dest->peer_anchors_.push_back(shared_from_this()); + return GRAPH_SUCCESS; +} + +bool OutControlAnchor::Equal(AnchorPtr anchor) const { + auto out_control_anchor = Anchor::DynamicAnchorCast(anchor); + if (out_control_anchor != nullptr) { + if (GetOwnerNode() == out_control_anchor->GetOwnerNode()) { + return true; + } + } + return false; +} + +bool OutControlAnchor::IsTypeOf(TYPE type) const { + if (strcmp(Anchor::TypeOf(), type) == 0) { + return true; + } + return ControlAnchor::IsTypeOf(type); +} +} // namespace ge diff --git a/src/common/graph/attr_value.cc b/src/common/graph/attr_value.cc new file mode 100644 index 00000000..066767c2 --- /dev/null +++ b/src/common/graph/attr_value.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "external/graph/attr_value.h" +#include "debug/ge_log.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/ge_attr_value.h" + +namespace ge { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue::AttrValue() { impl = ComGraphMakeShared(); } + +#define ATTR_VALUE_SET_GET_IMP(type) \ + graphStatus AttrValue::GetValue(type &val) const { \ + if (impl != nullptr) { \ + GELOGW("GetValue failed."); \ + return impl->geAttrValue_.GetValue(val); \ + } \ + return GRAPH_FAILED; \ + } + +ATTR_VALUE_SET_GET_IMP(AttrValue::STR) +ATTR_VALUE_SET_GET_IMP(AttrValue::INT) +ATTR_VALUE_SET_GET_IMP(AttrValue::FLOAT) +} // namespace ge diff --git a/src/common/graph/buffer.cc b/src/common/graph/buffer.cc new file mode 100644 index 00000000..48cdd397 --- /dev/null +++ b/src/common/graph/buffer.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/buffer.h" +#include "proto/ge_ir.pb.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +Buffer::Buffer() { + data_.InitDefault(); + if (data_.GetProtoMsg()) { + buffer_ = data_.GetProtoMsg()->mutable_bt(); + } +} + +Buffer::Buffer(const Buffer &other) { + // Share data + data_ = other.data_; + buffer_ = other.buffer_; +} + +Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { // default + auto proto_msg = data_.GetProtoMsg(); + if (proto_msg != nullptr) { + try { + proto_msg->set_bt(std::string(buffer_size, default_val)); + buffer_ = proto_msg->mutable_bt(); + } catch (std::bad_alloc &e) { + GELOGE(MEMALLOC_FAILED, "Failed to alloc buffer memory, buffer size %zu", buffer_size); + buffer_ = nullptr; + } + } +} + +Buffer Buffer::CopyFrom(const std::uint8_t *data, std::size_t buffer_size) { + Buffer buffer; + auto proto_msg = buffer.data_.GetProtoMsg(); + if (proto_msg != nullptr && data != nullptr) { + try { + proto_msg->set_bt(data, buffer_size); + buffer.buffer_ = proto_msg->mutable_bt(); + } catch (std::bad_alloc &e) { + GELOGE(MEMALLOC_FAILED, "Failed to alloc buffer memory, buffer size %zu", buffer_size); + buffer.buffer_ = nullptr; + } + } + return buffer; +} + +Buffer::Buffer(const std::shared_ptr &proto_owner, proto::AttrDef *buffer) + : data_(proto_owner, buffer) { + if (data_.GetProtoMsg() != nullptr) { + buffer_ = data_.GetProtoMsg()->mutable_bt(); + } +} + +Buffer::Buffer(const std::shared_ptr &proto_owner, std::string *buffer) + : data_(proto_owner, nullptr) { + buffer_ = buffer; +} + +Buffer &Buffer::operator=(const Buffer &other) { + if (&other != this) { + // Share data + data_ = other.data_; + buffer_ = other.buffer_; + } + return *this; +} + +const std::uint8_t *Buffer::GetData() const { + if (buffer_ != nullptr) { + return (const std::uint8_t *)buffer_->data(); + } + return nullptr; +} + +std::uint8_t *Buffer::GetData() { + if (buffer_ != nullptr && !buffer_->empty()) { + // Avoid copy on write + (void)(*buffer_)[0]; + return reinterpret_cast(const_cast(buffer_->data())); + } + return nullptr; +} + +std::size_t Buffer::GetSize() const { + if (buffer_ != nullptr) { + return buffer_->size(); + } + return 0; +} + +void Buffer::ClearBuffer() { + if (buffer_ != nullptr) { + buffer_->clear(); + } +} +} // namespace ge diff --git a/src/common/graph/compute_graph.cc b/src/common/graph/compute_graph.cc new file mode 100644 index 00000000..bae4d362 --- /dev/null +++ b/src/common/graph/compute_graph.cc @@ -0,0 +1,1314 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/compute_graph.h" +#include +#include "./format_refiner.h" +#include "./ge_context.h" +#include "debug/ge_attr_define.h" +#include "debug/ge_log.h" +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "ge/ge_api_types.h" +#include "graph/shape_refiner.h" +#include "proto/ge_ir.pb.h" +#include "utils/ge_ir_utils.h" +#include "utils/graph_utils.h" +#include "utils/node_utils.h" +#include "utils/op_desc_utils.h" +#include "utils/string_utils.h" +#include "utils/tensor_utils.h" + +namespace ge { +namespace { +const size_t OUTPUT_PARAM_SIZE = 2; +const std::string alias_name_attr = "_aliasName"; +bool IsUseBFS() { + string run_mode; + const int base = 10; + if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) { + if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) >= TRAIN) { + return true; + } + } else { + GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); + } + return false; +} +} // namespace + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name) + : name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) { + attrs_.InitDefault(); +} + +ComputeGraph::~ComputeGraph() {} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() const { return name_; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const { + return GetAllNodes().size(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetAllNodes() const { + std::vector> subgraphs; + return AllGraphNodes(subgraphs); +} + +ComputeGraph::Vistor ComputeGraph::AllGraphNodes(std::vector> &subgraphs) const { + std::vector all_nodes; + std::deque candidates; + + candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end()); + while (!candidates.empty()) { + NodePtr node = candidates.front(); + all_nodes.emplace_back(node); + candidates.pop_front(); + + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + + const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); + for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { + auto subgraph = GetSubgraph(*name_iter); + if (subgraph != nullptr) { + subgraphs.emplace_back(subgraph); + candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); + } + } + } + + return Vistor(shared_from_this(), all_nodes); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetNodes( + bool is_unknown_shape) const { + if (is_unknown_shape) { + return GetDirectNode(); + } else { + return GetAllNodes(); + } +} + +size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetDirectNode() const { + return Vistor(shared_from_this(), nodes_); +} + +ComputeGraph::Vistor ComputeGraph::GetInputNodes() const { + return Vistor(shared_from_this(), input_nodes_); +} + +ComputeGraph::Vistor ComputeGraph::GetOutputNodes() const { + std::vector result; + for (auto iter = output_nodes_info_.begin(); iter != output_nodes_info_.end(); ++iter) { + result.push_back(iter->first); + } + return Vistor(shared_from_this(), result); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(const std::string &name) const { + for (const auto &node : nodes_) { + if (node == nullptr) { + continue; + } + if (node->GetName() == name) { + return node; + } + std::vector out_alias_name; + if (AttrUtils::GetListStr(node->GetOpDesc(), alias_name_attr, out_alias_name)) { + for (const auto &alias_name : out_alias_name) { + if (alias_name == name) { + return node; + } + } + } + } + return nullptr; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr +ComputeGraph::FindFirstNodeMatchType(const std::string &name) const { + for (const auto &node : nodes_) { + if (node == nullptr) { + continue; + } + if (node->GetType() == name) { + return node; + } + } + return nullptr; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphAttrsAreEqual( + const ComputeGraph &r_graph) const { + // ProtoMsgOwner <::google::protobuf::Message> is temporarily ignored + if ((this->attrs_.protoMsg_ != nullptr) && (r_graph.attrs_.protoMsg_ != nullptr)) { + const auto &proto_attr_map = *(this->attrs_.protoMsg_); + const auto &r_proto_attr_map = *(r_graph.attrs_.protoMsg_); + // 1.Verify graph's ProtoAttrMap size + if (proto_attr_map.size() != r_proto_attr_map.size()) { + GELOGE(GRAPH_FAILED, "Size of compute graph's ProtoAttrMap verify failed, graph name: %s.", + this->GetName().c_str()); + return false; + } + // 2.Verify graph's ProtoAttrMap key, verify values is temporarily not implemented + for (const auto &it : proto_attr_map) { + if (r_proto_attr_map.count(it.first) == 0) { + GELOGE(GRAPH_FAILED, "Key of compute graph's ProtoAttrMap verify failed, graph name: %s key name: %s.", + this->GetName().c_str(), it.first.c_str()); + return false; + } + } + return true; + } + return ((this->attrs_.protoMsg_ == nullptr) && (r_graph.attrs_.protoMsg_ == nullptr)); +} + +/// Since there may be different input nodes +/// chosen by user in the same graph, special judgment is needed +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNodePtrIsEqual( + const std::vector &left_nodes, const std::vector &right_nodes) const { + const auto left_nodes_size = left_nodes.size(); + const auto right_nodes_size = right_nodes.size(); + if (left_nodes_size != right_nodes_size) { + GELOGE(GRAPH_FAILED, + "Check failed with graph input_nodes_: " + "left inputNodes size %zu is different with right inputNodes size %zu .", + left_nodes_size, right_nodes_size); + return false; + } + for (size_t j = 0; j < left_nodes_size; j++) { + if (left_nodes.at(j) == nullptr || right_nodes.at(j) == nullptr) { + GELOGE(GRAPH_FAILED, "left_nodes.at(%zu) or right_nodes.at(%zu) is nullptr", j, j); + return false; + } + const auto &left_input_name = left_nodes.at(j)->GetName(); + const auto &right_input_name = right_nodes.at(j)->GetName(); + if (left_input_name != right_input_name) { + GELOGE(GRAPH_FAILED, + "Check failed with graph input_nodes_: " + "left inputNode name %s is different with right inputNode name %s at inputNodes index %zu.", + left_input_name.c_str(), right_input_name.c_str(), j); + return false; + } + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual( + const ComputeGraph &r_graph) const { + return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.subgraphs_.size()") && + IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") && + VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) && + IsEqual(this->name_, r_graph.name_, "graph.name_") && + IsEqual(this->is_valid_flag_, r_graph.is_valid_flag_, "graph.is_valid_flag_") && + IsEqual(this->need_iteration_, r_graph.need_iteration_, "graph.need_iteration_") && + IsEqual(this->params_share_map_, r_graph.params_share_map_, "graph.params_share_map_") && + IsEqual(this->out_nodes_map_, r_graph.out_nodes_map_, "graph.out_nodes_map_") && + IsEqual(this->inputs_order_, r_graph.inputs_order_, "graph.inputs_order_") && + IsEqual(this->output_size_, r_graph.output_size_, "graph.output_size_") && + IsEqual(this->input_size_, r_graph.input_size_, "graph.input_size_") && + IsEqual(this->output_nodes_info_, r_graph.output_nodes_info_, "graph.output_nodes_info_")); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::operator==(const ComputeGraph &r_graph) const { + // Firstly: Graph's members equal + if ((!GraphMembersAreEqual(r_graph)) || (!GraphAttrsAreEqual(r_graph))) { + return false; + } + + // Secondly: Node equal means the link relationship between node and node itself equal + for (const auto &left_node : nodes_) { + if (left_node == nullptr) { + GELOGE(GRAPH_FAILED, "left_node is nullptr"); + return false; + } + const auto &node_name = left_node->GetName(); + // After TopologicalSorting, node order can change, so find node by name + const auto &right_node = r_graph.FindNode(node_name); + GE_IF_BOOL_EXEC(right_node == nullptr, GELOGE(GRAPH_FAILED, "right_node is NULL!!!"); return false); + if (!(*right_node == *left_node)) { + GELOGE(GRAPH_FAILED, "Compare graph failed, node name: %s.", node_name.c_str()); + return false; + } + } + + // Thirdly: Recursively determine whether the sub graphs are equal + for (size_t i = 0; i < this->sub_graph_.size(); i++) { + if (!(*((this->sub_graph_)[i]) == *((r_graph.sub_graph_)[i]))) { + return false; + } + } + return true; +} + +NodePtr ComputeGraph::AddNodeFront(NodePtr node) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null."); + return nullptr; + } + node->SetHostNode(is_valid_flag_); + node->GetOpDesc()->SetId(nodes_.size()); + if (nodes_.size() > 0 && nodes_[0]->GetType() == DATA) { + (void)nodes_.insert(nodes_.begin() + 1, node); + } else { + (void)nodes_.insert(nodes_.begin(), node); + } + return node; +} + +NodePtr ComputeGraph::AddNodeFront(const OpDescPtr &op) { + if (op == nullptr) { + GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); + return nullptr; + } + op->SetId(nodes_.size()); + NodePtr node_ptr = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); + GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); + GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); + return AddNodeFront(node_ptr); +} + +NodePtr ComputeGraph::AddNodeAfter(NodePtr node, const NodePtr &pre_node) { + if (node == nullptr || node->GetOpDesc() == nullptr || pre_node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null."); + return nullptr; + } + node->SetHostNode(is_valid_flag_); + node->GetOpDesc()->SetId(nodes_.size()); + auto node_iter = std::find(nodes_.begin(), nodes_.end(), pre_node); + if (node_iter != nodes_.end()) { + nodes_.insert(node_iter + 1, node); + } else { + GELOGE(GRAPH_FAILED, "Cannot find pre_node in nodes_."); + return nullptr; + } + + return node; +} + +NodePtr ComputeGraph::AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node) { + if (op == nullptr) { + GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); + return nullptr; + } + op->SetId(nodes_.size()); + NodePtr node_ptr = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); + GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); + GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init failed."); return nullptr); + return AddNodeAfter(node_ptr, pre_node); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(NodePtr node) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return nullptr; + } + node->SetHostNode(is_valid_flag_); + node->GetOpDesc()->SetId((int64_t)GetDirectNodesSize()); + nodes_.push_back(node); + return node; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpDescPtr op) { + if (op == nullptr) { + GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); + return nullptr; + } + op->SetId(GetDirectNodesSize()); + NodePtr node_ptr = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); + GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); + GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); + return AddNode(node_ptr); +} + +NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize. + if (op == nullptr) { + GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); + return nullptr; + } + op->SetId(id); + NodePtr node = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); + GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); + GE_IF_BOOL_EXEC(node->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); + node->SetHostNode(is_valid_flag_); + nodes_.push_back(node); + return node; +} + +NodePtr ComputeGraph::AddInputNode(NodePtr node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return nullptr; + } + input_nodes_.push_back(node); + if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) { + GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "add node failed"); + } + return node; +} + +NodePtr ComputeGraph::AddOutputNode(NodePtr node) { return AddOutputNodeByIndex(node, 0); } + +NodePtr ComputeGraph::AddOutputNodeByIndex(NodePtr node, int32_t index) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr or opdesc should not be null."); + return nullptr; + } + + bool already_have = false; + NodePtr result = node; + // [output_nodes_info_ : should not be null] + for (const auto &item : output_nodes_info_) { + if (item.first->GetName() == node->GetName() && item.second == index) { + already_have = true; + result = item.first; + break; + } + } + + if (!already_have) { + output_nodes_info_.emplace_back(std::make_pair(node, index)); + GELOGI("Push back node name:%s, index:%ld, into output_nodes_info_.", node->GetName().c_str(), index); + } + + if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) { + GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "add node failed"); + } + return result; +} + +graphStatus ComputeGraph::RemoveConstInput(const NodePtr &node) { + GE_CHECK_NOTNULL(node); + + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) { + continue; + } + if (out_anchor->GetOwnerNode()->GetType() == CONSTANT || out_anchor->GetOwnerNode()->GetType() == CONSTANTOP) { + GE_CHK_BOOL_RET_STATUS(GraphUtils::RemoveEdge(out_anchor, in_anchor) == GRAPH_SUCCESS, GRAPH_FAILED, + "Remove edge from const op failed."); + if (out_anchor->GetOwnerNode()->GetOutNodes().size() == 0) { + GELOGI("Remove const op %s.", out_anchor->GetOwnerNode()->GetName().c_str()); + auto iter = find(nodes_.begin(), nodes_.end(), out_anchor->GetOwnerNode()); + if (iter != nodes_.end()) { + (void)nodes_.erase(iter); + } + } + } + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveNode(const NodePtr &node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return GRAPH_FAILED; + } + + // delete const op for this node + (void)RemoveConstInput(node); + + // if the node save as input node, delete it + (void)RemoveInputNode(node); + + // if the node save as input node, delete it + (void)RemoveOutputNode(node); + + if (GRAPH_SUCCESS != IsolateNode(node)) { + GELOGE(GRAPH_FAILED, "Isolate node failed, node name: %s.", node->GetName().c_str()); + return GRAPH_FAILED; + } + + auto iter = find(nodes_.begin(), nodes_.end(), node); + if (iter != nodes_.end()) { + (void)nodes_.erase(iter); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +// Used in sub_graph scenes +graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return GRAPH_FAILED; + } + + auto iter = find(input_nodes_.begin(), input_nodes_.end(), node); + if (iter != input_nodes_.end()) { + (void)input_nodes_.erase(iter); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +// Used in sub_graph scenes +graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return GRAPH_FAILED; + } + + auto iter = output_nodes_info_.begin(); + bool find_node = false; + // [output_nodes_info_ : should not be null] + while (iter != output_nodes_info_.end()) { + if (node->GetName() == iter->first->GetName()) { + iter = output_nodes_info_.erase(iter); + find_node = true; + } else { + ++iter; + } + } + GE_IF_BOOL_EXEC(find_node == false, return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +std::shared_ptr ComputeGraph::AddSubGraph(std::shared_ptr sub_graph) { + if (sub_graph == nullptr) { + GELOGE(GRAPH_FAILED, "The graph ptr should not be null."); + return nullptr; + } + sub_graph_.push_back(sub_graph); + names_to_subgraph_[sub_graph->GetName()] = sub_graph; + return sub_graph; +} + +graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr &sub_graph) { + if (sub_graph == nullptr) { + GELOGE(GRAPH_FAILED, "The graph ptr should not be null."); + return GRAPH_FAILED; + } + + names_to_subgraph_.erase(sub_graph->GetName()); + auto iter = find(sub_graph_.begin(), sub_graph_.end(), sub_graph); + if (iter != sub_graph_.end()) { + (void)sub_graph_.erase(iter); + return GRAPH_SUCCESS; + } else { + GELOGW("find sub_graph failed"); + return GRAPH_SUCCESS; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr &subgraph) { + if (subgraph == nullptr) { + GE_LOGE("Try to add a null subgraph, name %s", name.c_str()); + return GRAPH_PARAM_INVALID; + } + auto parent_graph = subgraph->GetParentGraph(); + if (parent_graph == nullptr) { + GE_LOGE("Try to add subgraph without parent graph, name %s", name.c_str()); + return GRAPH_PARAM_INVALID; + } + auto parent_node = subgraph->GetParentNode(); + if (parent_node == nullptr) { + GE_LOGE("Try to add a subgraph without parent node, name %s", name.c_str()); + return GRAPH_PARAM_INVALID; + } + if (parent_node->GetOwnerComputeGraph() != parent_graph) { + GE_LOGE( + "Try to add a subgraph which parent node's parent graph is not equal to " + "the subgraph's parent graph, subgraph name %s, parent node name %s", + subgraph->GetName().c_str(), parent_graph->GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + if (!this->parent_graph_.expired()) { + GELOGW("The subgraphs should only be added to the root graph"); + } + if (name != subgraph->GetName()) { + GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str()); + } + if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) { + GE_LOGE("The subgraph %s existed", name.c_str()); + return GRAPH_PARAM_INVALID; + } + sub_graph_.push_back(subgraph); + names_to_subgraph_[name] = subgraph; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +ComputeGraph::AddSubgraph(const std::shared_ptr &subgraph) { + if (subgraph == nullptr) { + return GRAPH_PARAM_INVALID; + } + return AddSubgraph(subgraph->GetName(), subgraph); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(const std::string &name) { + auto iter = names_to_subgraph_.find(name); + if (iter == names_to_subgraph_.end()) { + return; + } + for (auto vec_iter = sub_graph_.begin(); vec_iter != sub_graph_.end(); ++vec_iter) { + if (*vec_iter == iter->second) { + sub_graph_.erase(vec_iter); + break; + } + } + names_to_subgraph_.erase(iter); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph( + const std::shared_ptr &subgraph) { + if (subgraph != nullptr) { + RemoveSubgraph(subgraph->GetName()); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr ComputeGraph::GetSubgraph( + const std::string &name) const { + std::shared_ptr parent = parent_graph_.lock(); + if (parent == nullptr) { + auto iter = names_to_subgraph_.find(name); + return iter == names_to_subgraph_.end() ? nullptr : iter->second; + } else { + return parent->GetSubgraph(name); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector> +ComputeGraph::GetAllSubgraphs() const { + return sub_graph_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr ComputeGraph::GetParentGraph() { + return parent_graph_.lock(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentGraph( + const shared_ptr &parent) { + parent_graph_ = parent; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr ComputeGraph::GetParentNode() { + return parent_node_.lock(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode(const shared_ptr &parent) { + parent_node_ = parent; +} + +/// +/// @brief Update input-mapping +/// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input +/// @return graphStatus +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +ComputeGraph::UpdateInputMapping(const std::map &input_mapping) { + for (auto &input : nodes_) { + if (input->GetType() == DATA) { + uint32_t cur_index = 0; + if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { + continue; + } + auto iter = input_mapping.find(cur_index); + if (iter == input_mapping.end()) { + continue; + } + if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { + GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); + return GRAPH_FAILED; + } + } + } + + return GRAPH_SUCCESS; +} + +/// +/// @brief Update output-mapping +/// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output +/// @return graphStatus +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +ComputeGraph::UpdateOutputMapping(const std::map &output_mapping) { + NodePtr net_output = FindFirstNodeMatchType(NETOUTPUT); + if (net_output == nullptr) { + GE_LOGE("UpdateOutputMapping failed: node type %s not exist in graph.", NETOUTPUT); + return GRAPH_FAILED; + } + OpDescPtr op_desc = net_output->GetOpDesc(); + if (op_desc == nullptr) { + GE_LOGE("UpdateOutputMapping failed: op_desc is NULL."); + return GRAPH_FAILED; + } + + size_t num = op_desc->GetAllInputsSize(); + for (size_t i = 0; i < num; i++) { + GeTensorDesc tensor = op_desc->GetInputDesc(i); + uint32_t cur_index = 0; + if (!ge::AttrUtils::GetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { + continue; + } + auto iter = output_mapping.find(cur_index); + if (iter == output_mapping.end()) { + continue; + } + if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { + GE_LOGE("UpdateOutputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); + return GRAPH_FAILED; + } + if (op_desc->UpdateInputDesc(i, tensor) != GRAPH_SUCCESS) { + GE_LOGE("UpdateOutputMapping failed: update %u input_tensor failed.", i); + return GRAPH_FAILED; + } + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { + std::vector node_vec = nodes_; + for (const auto &node : GetDirectNode()) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + GELOGW("node or OpDescPtr is nullptr."); + continue; + } + GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should not be null."); return GRAPH_FAILED); + if (node->GetOpDesc()->GetType() == RECV) { + auto iter = find(node_vec.begin(), node_vec.end(), node); + if (iter == node_vec.end()) { + GELOGW("no node found."); + } else { + (void)node_vec.erase(iter); + } + + auto dst_iter = find(node_vec.begin(), node_vec.end(), node->GetOutControlNodes().at(0)); + (void)node_vec.insert(dst_iter, node); + } + if (node->GetOpDesc()->GetType() == SEND) { + auto iter = find(node_vec.begin(), node_vec.end(), node); + if (iter == node_vec.end()) { + GELOGW("no node found."); + } else { + (void)node_vec.erase(iter); + } + + auto src_iter = find(node_vec.begin(), node_vec.end(), node->GetInControlNodes().at(0)); + (void)node_vec.insert(src_iter + 1, node); + } + } + nodes_.clear(); + for (size_t i = 0; i < node_vec.size(); ++i) { + NodePtr node = node_vec[i]; + if (node == nullptr || node->GetOpDesc() == nullptr) { + GELOGW("node or OpDescPtr is nullptr."); + } else { + node->GetOpDesc()->SetId((int64_t)i); + nodes_.push_back(node); + } + } + return GRAPH_SUCCESS; +} + +graphStatus ComputeGraph::DFSTopologicalSorting(std::vector &node_vec, + std::map &map_in_edge_num, + std::vector &stack) { + GELOGI("Runing_Dfs_Sort: %s", name_.c_str()); + // Record the number of non data nodes but no input nodes + GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); + + // Only data nodes here + while (!stack.empty()) { + NodePtr node = stack.back(); + stack.pop_back(); + node_vec.push_back(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); + for (const auto &anchor : node->GetAllOutDataAnchors()) { + GE_CHECK_NOTNULL(anchor); + for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { + GE_CHECK_NOTNULL(peer_in_anchor); + auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); + if (iter != map_in_edge_num.end() && --iter->second == 0) { + stack.push_back(peer_in_anchor->GetOwnerNode()); + } + } + for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { + GE_CHECK_NOTNULL(peer_in_anchor); + auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); + if (iter != map_in_edge_num.end() && --iter->second == 0) { + stack.push_back(peer_in_anchor->GetOwnerNode()); + } + } + } + GE_IF_BOOL_EXEC( + node->GetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor + : node->GetOutControlAnchor()->GetPeerAnchors()) { + GE_CHECK_NOTNULL(peer_in_anchor); + auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); + if (iter != map_in_edge_num.end() && --iter->second == 0) { + stack.push_back(peer_in_anchor->GetOwnerNode()); + } + }) + } + + return GRAPH_SUCCESS; +} + +graphStatus ComputeGraph::BFSTopologicalSorting(std::vector &node_vec, + std::map &map_in_edge_num, + std::deque &stack) { + GELOGI("Runing_Bfs_Sort: %s", name_.c_str()); + std::vector stack_input; + std::map breadth_node_map; + // Record the number of non data nodes but no input nodes + GE_CHK_BOOL_EXEC(SortNodes(stack_input, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); + + // Only data nodes here + while (!stack_input.empty() || !stack.empty()) { + NodePtr node = nullptr; + if (!stack.empty()) { + node = stack.back(); + stack.pop_back(); + } else { + node = stack_input.back(); + stack_input.pop_back(); + } + + node_vec.push_back(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); + CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map); + + for (const auto &name_node : breadth_node_map) { + (void)stack.push_front(name_node.second); + } + breadth_node_map.clear(); + } + return GRAPH_SUCCESS; +} + +graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, + std::map &breadth_node_map) { + for (const auto &anchor : node->GetAllOutDataAnchors()) { + for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { + auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); + if (iter != map_in_edge_num.end() && 0 == --iter->second) { + (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); + } + } + + for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { + auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); + if (iter != map_in_edge_num.end() && 0 == --iter->second) { + (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); + } + } + } + if (node->GetOutControlAnchor() != nullptr) { + for (AnchorPtr peer_in_anchor : node->GetOutControlAnchor()->GetPeerAnchors()) { + auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); + if (iter != map_in_edge_num.end() && 0 == --iter->second) { + (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); + } + } + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { + auto ret = TopologicalSortingGraph(); + if (ret != SUCCESS) { + GraphUtils::DumpGEGraphToOnnx(*this, "black_box"); + GELOGE(ret, "Graph [%s] topological sort failed, saved to file black_box", name_.c_str()); + return ret; + } + + if (sub_graph_.empty()) { + return SUCCESS; + } + + // partition sub graph + for (const auto &sub_graph : sub_graph_) { + ret = sub_graph->TopologicalSortingGraph(); + if (ret != SUCCESS) { + GELOGE(ret, "Sub graph topological sort Failed"); + return ret; + } + } + + std::vector> subgraphs; + auto nodes = AllGraphNodes(subgraphs); + for (size_t i = 0; i < nodes.size(); i++) { + NodePtr node = nodes.at(i); // [node: should not be null] + node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null] + } + if (sub_graph_.size() != subgraphs.size()) { // Graph Partition use subgraph, Keep original + GELOGW("Keep original subgraph for graph size %zu not equal %zu.", sub_graph_.size(), subgraphs.size()); + return SUCCESS; + } + sub_graph_.swap(subgraphs); + return SUCCESS; +} + +graphStatus ComputeGraph::TopologicalSortingGraph() { + std::vector node_vec; + std::map map_in_edge_num; + bool use_BFS = IsUseBFS(); + if (use_BFS) { + std::deque stack; + if (BFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } else { + std::vector stack; + if (DFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } + + // If they are not equal, there is a closed loop + if (node_vec.size() != nodes_.size()) { + std::set itered_nodes_set; + for (auto &node : node_vec) { + itered_nodes_set.insert(node.get()); + } + GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", nodes_.size(), + node_vec.size()); + for (auto &node : nodes_) { + if (itered_nodes_set.count(node.get()) == 0) { + GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); + } + } + return GRAPH_FAILED; + } + + nodes_.clear(); + for (size_t i = 0; i < node_vec.size(); i++) { + NodePtr node = node_vec[i]; // [node: should not be null] + node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null] + nodes_.push_back(node); + } + + is_valid_flag_ = true; + return GRAPH_SUCCESS; +} + +graphStatus ComputeGraph::SortNodes(std::vector &stack, std::map &map_in_edge_num) { + // Record the number of non data nodes but no input nodes + uint32_t spec_node_size = 0; + bool verify_isolated = false; + string run_mode; + const int base = 10; + // Need verify isolated point in PREDICTION mode. + if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) { + if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) < TRAIN) { + verify_isolated = true; + } + } + for (const auto &node : GetDirectNode()) { + GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); + map_in_edge_num[node] = static_cast(GetInEdgeSize(node)); + if (map_in_edge_num[node] == 0) { + if ((node->GetOpDesc()->GetType() != DATA) && (node->GetOpDesc()->GetType() != AIPPDATA) && + (node->GetOpDesc()->GetType() != INPUT_TYPE) && (node->GetOpDesc()->GetType() != ANN_DATA)) { + // At present, can only judge the isolated point without input and output. + // It is impossible to judge the situation with multiple output nodes. + if (verify_isolated && GetOutEdgeSize(node) == 0) { + GELOGE(GRAPH_FAILED, "May has isolated nodes in graph, node name: %s.", node->GetName().c_str()); + return GRAPH_FAILED; + } + (void)stack.insert(stack.begin(), node); + spec_node_size++; + continue; + } + // Need to insert the data nodes in reverse order + (void)stack.insert(stack.begin() + spec_node_size, node); + } + } + + /// Make sure the inputs order matches with user-designated + /// 1. Get the index of two input nodes in the user-inputs-order(inputs_order_) + /// 2. Compare two indices, if not match, swap the positions of two inputs + /// *: Remind: stack is reverse-order + for (size_t i = 0; i < stack.size(); ++i) { + // If not found in 'inputs_order_', skip it + auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); + GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue); + auto inx_i = it_i - inputs_order_.begin(); + for (size_t j = i + 1; j < stack.size(); ++j) { + // If not found in 'inputs_order_', skip it + auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName()); + GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue); + + // Compare index, swap them if it should be + auto inx_j = it_j - inputs_order_.begin(); + GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j])); + } + } + + return GRAPH_SUCCESS; +} + +size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { + size_t in_edge_size = 0; + if (node == nullptr) { + return in_edge_size; + } + for (const auto &anchor : node->GetAllInDataAnchors()) { + in_edge_size = in_edge_size + anchor->GetPeerAnchorsSize(); + // Break flow control data loop. + OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor(); + if ((out_anchor != nullptr) && (out_anchor->GetOwnerNode() != nullptr)) { + NodePtr out_node = out_anchor->GetOwnerNode(); + if (out_node == nullptr) { + GELOGW("out node is nullptr"); + continue; + } + if ((out_node->GetType() == NEXTITERATION) || (out_node->GetType() == REFNEXTITERATION)) { + GE_IF_BOOL_EXEC(in_edge_size == 0, GELOGE(GRAPH_FAILED, "If [in_edge_size = 0], the result will be reversed"); + return in_edge_size); + in_edge_size -= 1; + } + } + } + if (node->GetInControlAnchor() != nullptr) { + in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchorsSize(); + } + return in_edge_size; +} + +size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { + size_t out_edge_size = 0; + if (node == nullptr) { + return out_edge_size; + } + + // Break flow control data loop. + if ((node->GetType() != NEXTITERATION) && (node->GetType() != REFNEXTITERATION)) { + for (const auto &anchor : node->GetAllOutDataAnchors()) { + if (anchor != nullptr) { + out_edge_size = out_edge_size + anchor->GetPeerAnchors().size(); + } + } + } + if (node->GetOutControlAnchor() != nullptr) { + if (out_edge_size > (UINT64_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) { + return 0; + } + out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchors().size(); + } + return out_edge_size; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::IsValid() const { return is_valid_flag_; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { + GELOGI("graph name = %s.", GetName().c_str()); + for (const auto &node : GetAllNodes()) { + GELOGI("node name = %s.", node->GetName().c_str()); + for (const auto &anchor : node->GetAllOutDataAnchors()) { + for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { + GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, + GELOGI("node name = %s, out data node name = %s.", node->GetName().c_str(), + peer_in_anchor->GetOwnerNode()->GetName().c_str())); + } + for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { + GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, + GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), + peer_in_anchor->GetOwnerNode()->GetName().c_str())); + } + } + auto out_control_anchor = node->GetOutControlAnchor(); + if (out_control_anchor != nullptr) { + for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) { + GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, + GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), + peer_in_anchor->GetOwnerNode()->GetName().c_str())); + } + for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) { + GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, + GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), + peer_in_anchor->GetOwnerNode()->GetName().c_str())); + } + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Swap(ComputeGraph &graph) { + this->AttrHolder::Swap(graph); + + origGraph_.swap(graph.origGraph_); + + name_.swap(graph.name_); + std::swap(graph_id_, graph.graph_id_); + attrs_.Swap(graph.attrs_); + nodes_.swap(graph.nodes_); + all_nodes_infos_.swap(graph.all_nodes_infos_); + target_nodes_info_.swap(graph.target_nodes_info_); + + input_nodes_.swap(graph.input_nodes_); + inputs_order_.swap(graph.inputs_order_); + std::swap(input_size_, graph.input_size_); + out_nodes_map_.swap(graph.out_nodes_map_); + std::swap(output_size_, graph.output_size_); + output_nodes_info_.swap(graph.output_nodes_info_); + + sub_graph_.swap(graph.sub_graph_); + names_to_subgraph_.swap(graph.names_to_subgraph_); + parent_graph_.swap(graph.parent_graph_); + parent_node_.swap(graph.parent_node_); + + // the members followed should not in the ComputeGraph class + std::swap(is_valid_flag_, graph.is_valid_flag_); + std::swap(is_summary_graph_, graph.is_summary_graph_); + std::swap(need_iteration_, graph.need_iteration_); + params_share_map_.swap(graph.params_share_map_); + op_name_map_.swap(graph.op_name_map_); + std::swap(session_id_, graph.session_id_); + std::swap(data_format_, graph.data_format_); + std::swap(is_unknown_shape_graph_, graph.is_unknown_shape_graph_); + + // Update Node owner. + SetNodesOwner(); + graph.SetNodesOwner(); +} + +void ComputeGraph::SetNodesOwner() { + for (const auto &node : nodes_) { + if (node == nullptr) { + continue; + } + node->SetOwnerComputeGraph(shared_from_this()); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto next_nodes = node->GetOutAllNodes(); + // If there is input data side + for (size_t i = 0; i < node->GetAllInDataAnchors().size(); i++) { + auto in_data_anchor = node->GetInDataAnchor(static_cast(i)); + auto pre_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); + if (pre_out_data_anchor != nullptr) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_data_anchor, in_data_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + GE_IF_BOOL_EXEC(pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANT || + pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANTOP, + continue); + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "add edge failed"); + } + for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "add edge failed"); + } + } + auto out_ctrl_anchor = node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(out_ctrl_anchor); + auto pre_out_ctrl_anchor = pre_out_data_anchor->GetOwnerNode()->GetOutControlAnchor(); + GE_CHECK_NOTNULL(pre_out_ctrl_anchor); + for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "add edge failed"); + } + } + } + + // If there is an input control side + auto in_ctrl_anchor = node->GetInControlAnchor(); + GE_CHECK_NOTNULL(in_ctrl_anchor); + for (const auto &pre_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_ctrl_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, return GRAPH_FAILED, + "remove edge failed"); + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "add edge failed"); + } + } + auto out_ctrl_anchor = node->GetOutControlAnchor(); + if (out_ctrl_anchor != nullptr) { + for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "add edge failed"); + } + } + } + + for (const auto &out_peer_data_anchor : in_ctrl_anchor->GetPeerOutDataAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_peer_data_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, return GRAPH_FAILED, + "remove edge failed"); + for (const auto &next_node : next_nodes) { + auto next_in_control_anchor = next_node->GetInControlAnchor(); + GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(out_peer_data_anchor, next_in_control_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "add edge failed"); + } + } + + return RemoveExtraOutEdge(node); +} + +graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) { + GE_CHECK_NOTNULL(node); + // Remove redundant output edges + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + } + + for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + } + } + auto out_ctrl_anchor = node->GetOutControlAnchor(); + if (out_ctrl_anchor != nullptr) { + for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + } + } + return GRAPH_SUCCESS; +} + +graphStatus ComputeGraph::Verify() { + bool is_unknown_graph = GetGraphUnknownFlag(); + for (const auto &node_ptr : GetAllNodes()) { + GE_CHECK_NOTNULL(node_ptr); + GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); + GE_IF_BOOL_EXEC(is_unknown_graph, continue); + GE_CHK_BOOL_EXEC(node_ptr->GetOpDesc()->CommonVerify() == GRAPH_SUCCESS, return GRAPH_FAILED, + "Verifying %s failed.", node_ptr->GetName().c_str()); + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferOriginFormat() { + return ge::FormatRefiner::InferOrigineFormat(shared_from_this()); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferShapeInNeed() { + GE_CHK_BOOL_ONLY_LOG(TopologicalSorting() == GRAPH_SUCCESS, "Verifying failed."); + for (const auto &node_ptr : GetAllNodes()) { + GE_CHECK_NOTNULL(node_ptr); + auto op_desc = node_ptr->GetOpDesc(); + bool is_need_infer = false; + (void)ge::AttrUtils::GetBool(op_desc, NEED_INFER, is_need_infer); + if (is_need_infer) { + GE_CHK_BOOL_EXEC(node_ptr->Verify() == GRAPH_SUCCESS, return GRAPH_FAILED, "Verifying %s failed.", + node_ptr->GetName().c_str()); + + graphStatus status = node_ptr->InferShapeAndType(); + GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == DATA || GRAPH_PARAM_INVALID != status, break, + "Op %s does not have the IMPLEMT_INFERFUNC definition," + " and subsequent operators no longer perform shape inference.", + node_ptr->GetName().c_str()); + GE_CHK_BOOL_EXEC(status == GRAPH_SUCCESS, return GRAPH_FAILED, "Inferring %s failed.", + node_ptr->GetName().c_str()); + + for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { + GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc()); + auto output_tensor = out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()); + ge::TensorUtils::SetRealDimCnt(output_tensor, output_tensor.GetShape().GetDims().size()); + (void)out_anchor->GetOwnerNode()->GetOpDesc()->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor); + for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { + (void)peer_anchor->GetOwnerNode()->GetOpDesc()->UpdateInputDesc(peer_anchor->GetIdx(), output_tensor); + } + } + } + } + return GRAPH_SUCCESS; +} + +ProtoAttrMapHelper ComputeGraph::MutableAttrMap() { return attrs_; } + +ConstProtoAttrMapHelper ComputeGraph::GetAttrMap() const { + return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); +} + +const std::map &ComputeGraph::GetAllNodesInfo() const { return all_nodes_infos_; } + +void ComputeGraph::SetUserDefOutput(const std::string &output_name) { + if (output_name.empty()) { + return; + } + + vector nodes = StringUtils::Split(output_name, ';'); + for (string node : nodes) { + vector item = StringUtils::Split(node, ':'); + if (item.size() != OUTPUT_PARAM_SIZE) { + GELOGW("invalid output param!input:%s", output_name.c_str()); + continue; + } + + int32_t index; + try { + index = stoi(StringUtils::Trim(item[1])); + } catch (const std::out_of_range &) { + GELOGW("outputname cause out of range execption!output_name:%s", output_name.c_str()); + continue; + } catch (const std::invalid_argument &) { + GELOGW("outputname cause invalid argument!output_name:%s", output_name.c_str()); + continue; + } catch (...) { + GELOGW("stoi fail! output_name:%s", output_name.c_str()); + continue; + } + auto iter = out_nodes_map_.find(item[0]); + if (iter == out_nodes_map_.end()) { + out_nodes_map_[item[0]] = std::vector(1, index); + } else { + auto idx_iter = std::find(iter->second.begin(), iter->second.end(), index); + if (idx_iter == iter->second.end()) { + iter->second.push_back(index); + } + } + } +} + +const std::string ComputeGraph::GetOutput() { + static const int resultDefaultSize = 2048; + string result; + result.reserve(resultDefaultSize); + auto iter = out_nodes_map_.begin(); + while (iter != out_nodes_map_.end()) { + auto idxes = iter->second; + for (auto idx : idxes) { + (void)result.append(iter->first).append(":").append(std::to_string(idx)).append(";"); + } + ++iter; + } + + return result.substr(0, result.length() - 1); +} +} // namespace ge diff --git a/src/common/graph/debug/ge_log.h b/src/common/graph/debug/ge_log.h new file mode 100644 index 00000000..14a66709 --- /dev/null +++ b/src/common/graph/debug/ge_log.h @@ -0,0 +1,147 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_DEBUG_GE_LOG_H_ +#define COMMON_GRAPH_DEBUG_GE_LOG_H_ + +#include "graph/ge_error_codes.h" +#include "framework/common/debug/ge_log.h" + +#define GE_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) + +#define GE_LOGI_IF(condition, ...) \ + if ((condition)) { \ + GELOGI(__VA_ARGS__); \ + } + +#define GE_LOGW_IF(condition, ...) \ + if ((condition)) { \ + GELOGW(__VA_ARGS__); \ + } + +#define GE_LOGE_IF(condition, ...) \ + if ((condition)) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + } + +#define GE_CHK_STATUS_RET_NOLOG(expr) \ + do { \ + const ge::graphStatus _status = (expr); \ + if (ge::SUCCESS != _status) { \ + return _status; \ + } \ + } while (0) + +#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +#define GE_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ + { \ + bool b = (expr); \ + if (!b) { \ + exec_expr; \ + } \ + } + +#define GE_IF_BOOL_EXEC(expr, exec_expr) \ + { \ + if (expr) { \ + exec_expr; \ + } \ + } + +#define GE_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ + do { \ + const ge::graphStatus _status = (expr); \ + if (_status) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +// If expr is true, the log is printed and a custom statement is executed +#define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (b) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + exec_expr; \ + } \ + } + +// Only check error log +#define GE_CHK_BOOL_ONLY_LOG(expr, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + GELOGI(__VA_ARGS__); \ + } \ + } while (0) + +// If expr is not true, do not print the log and return the specified status +#define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + return _status; \ + } \ + } while (0) + +// If expr is not true, the log is printed and a custom statement is executed +#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (!b) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + exec_expr; \ + } \ + } + +// If expr is not true, the log is printed and a custom statement is executed +#define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (!b) { \ + GELOGI(__VA_ARGS__); \ + exec_expr; \ + } \ + } + +// If expr is not GRAPH_SUCCESS, print the log and return the same value +#define GE_CHK_STATUS_RET(expr, ...) \ + do { \ + const ge::graphStatus _status = (expr); \ + if (ge::SUCCESS != _status) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ + try { \ + exec_expr0; \ + } catch (...) { \ + GELOGE(ge::FAILED, "Make shared failed"); \ + exec_expr1; \ + } + +#endif // COMMON_GRAPH_DEBUG_GE_LOG_H_ diff --git a/src/common/graph/debug/ge_op_types.h b/src/common/graph/debug/ge_op_types.h new file mode 100644 index 00000000..dff87331 --- /dev/null +++ b/src/common/graph/debug/ge_op_types.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ +#define COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ + +namespace ge { +#define GE_REGISTER_OPTYPE(var_name, str_name) static const char *var_name __attribute__((unused)) = str_name + +GE_REGISTER_OPTYPE(DATA, "Data"); +GE_REGISTER_OPTYPE(AIPPDATA, "AippData"); +GE_REGISTER_OPTYPE(MATMUL, "MatMul"); +GE_REGISTER_OPTYPE(RESHAPE, "Reshape"); +GE_REGISTER_OPTYPE(PERMUTE, "Permute"); +GE_REGISTER_OPTYPE(NETOUTPUT, "NetOutput"); +GE_REGISTER_OPTYPE(_WHILE, "_While"); +GE_REGISTER_OPTYPE(WHILE, "While"); +GE_REGISTER_OPTYPE(STATELESSWHILE, "StatelessWhile"); +GE_REGISTER_OPTYPE(SQUEEZE, "Squeeze"); +GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); +GE_REGISTER_OPTYPE(SWITCH, "Switch"); +GE_REGISTER_OPTYPE(REFSWITCH, "RefSwitch"); +GE_REGISTER_OPTYPE(SWITCHN, "SwitchN"); +GE_REGISTER_OPTYPE(MERGE, "Merge"); +GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); +GE_REGISTER_OPTYPE(ENTER, "Enter"); +GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); +GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); +GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); +GE_REGISTER_OPTYPE(CONSTANT, "Const"); +GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); +GE_REGISTER_OPTYPE(END, "End"); +GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); +GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); +GE_REGISTER_OPTYPE(INITDATA, "InitData"); +GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity"); +GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); + +GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); +GE_REGISTER_OPTYPE(VARIABLE, "Variable"); +GE_REGISTER_OPTYPE(VARIABLEV2, "VariableV2"); + +GE_REGISTER_OPTYPE(INPUT_TYPE, "Input"); + +// Horovod operator +GE_REGISTER_OPTYPE(HVDCALLBACKALLREDUCE, "hvdCallbackAllreduce"); +GE_REGISTER_OPTYPE(HVDCALLBACKALLGATHER, "hvdCallbackAllgather"); +GE_REGISTER_OPTYPE(HVDCALLBACKBROADCAST, "hvdCallbackBroadcast"); +GE_REGISTER_OPTYPE(HVDWAIT, "hvdWait"); + +GE_REGISTER_OPTYPE(NODE_NAME_NET_OUTPUT, "Node_Output"); + +GE_REGISTER_OPTYPE(RECV, "Recv"); +GE_REGISTER_OPTYPE(SEND, "Send"); +}; // namespace ge +#endif // COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ diff --git a/src/common/graph/debug/ge_util.h b/src/common/graph/debug/ge_util.h new file mode 100644 index 00000000..4c6ae051 --- /dev/null +++ b/src/common/graph/debug/ge_util.h @@ -0,0 +1,274 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_DEBUG_GE_UTIL_H_ +#define COMMON_GRAPH_DEBUG_GE_UTIL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_log.h" +#include "graph/ge_error_codes.h" + +#if !defined(__ANDROID__) && !defined(ANDROID) +#define GE_DYNAMIC_CAST dynamic_cast +#define GE_DYNAMIC_POINTER_CAST std::dynamic_pointer_cast +#else +#define GE_DYNAMIC_CAST static_cast +#define GE_DYNAMIC_POINTER_CAST std::static_pointer_cast +#endif + +#define GE_RETURN_IF_ERROR(expr) \ + do { \ + const ::ge::optStatus _status = (expr); \ + if (_status) return _status; \ + } while (0) + +#define GE_RETURN_WITH_LOG_IF_INFO(expr, ...) \ + do { \ + const ::ge::optStatus _status = (expr); \ + if (_status) { \ + GELOGI(__VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +// Verify whether the parameter is true. If yes, return graph failed and record the error log +#define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ + do { \ + if (condition) { \ + GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +// Verify whether the parameter is false. If yes, return graph failed and record the error log +#define GE_RETURN_WITH_LOG_IF_FALSE(condition, ...) \ + do { \ + bool _condition = (condition); \ + if (!_condition) { \ + GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +// Verify whether the parameter is true. If yes, return GRAPH_PARAM_INVALID and record the error log +#define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ + do { \ + if (condition) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// Verify whether the parameter is false. If yes, return GRAPH_PARAM_INVALID and record the error log +#define GE_RT_PARAM_INVALID_WITH_LOG_IF_FALSE(condition, ...) \ + do { \ + bool _condition = (condition); \ + if (!_condition) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log +#define GE_CHECK_NOTNULL(val) \ + do { \ + if (val == nullptr) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log +#define GE_CHECK_NOTNULL_EXEC(val, expr) \ + do { \ + if (val == nullptr) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \ + expr; \ + } \ + } while (0) + +// Verify whether the parameter is null. If yes, return false and record the error log +#define GE_RT_FALSE_CHECK_NOTNULL(val) \ + do { \ + if (val == nullptr) { \ + GELOGE(ge::GRAPH_FAILED, "param[%s] must not be null.", #val); \ + return false; \ + } \ + } while (0) + +// Check whether the parameter is out of range +#define GE_CHECK_SIZE(size) \ + do { \ + if (size == 0) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +/// +/// @ingroup GE_common +/// eg:GE_DEFINE_BYTE_SIZE(filter_byte, filter.data().size(), sizeof(float)); +/// +#define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \ + uint32_t _var_name; \ + do { \ + uint32_t _expr_size = (_expr); \ + uint32_t _sizeof_size = (_sizeof); \ + if (_expr_size > (0xffffffff) / _sizeof_size) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "byte size : %s is out of range", #_var_name); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + _var_name = _sizeof_size * _expr_size; \ + } while (0); + +// Check whether the container is empty +#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ + do { \ + if (vector.empty()) { \ + GELOGE(ge::GRAPH_FAILED, "param[#vector] is empty", #vector); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +// Check whether the container is empty and return the specified status code +#define GE_CHECK_VECTOR_NOT_EMPTY_RET_STATUS(vector, _status) \ + do { \ + if (vector.empty()) { \ + GELOGE(_status, "param[%s] is empty", #vector); \ + return _status; \ + } \ + } while (0) + +/// +/// @ingroup GE_common +/// @brief This macro provides the ability to disable copying constructors and assignment operators. +/// It is usually placed under private +/// +#define GE_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName &) = delete; \ + void operator=(const TypeName &) = delete + +/// Check whether the size is 0 or out of range +/// @param:size:Size to be verified +#define GE_CHECK_SIZE_RANGE(size) \ + do { \ + if (size == 0 || size >= UINT_MAX / 4) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +#define GE_CHECK_SHORT_SIZE_RANGE(size) \ + do { \ + if (size == 0 || size >= UINT_MAX / 2) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ + do { \ + if (size <= 0) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not a positive number", #size); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +#define GE_CHECK_POSITIVE_SHORT_SIZE_RANGE(size) \ + do { \ + if (size <= 0 || size == 0 || size >= UINT_MAX / 4) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// Verify that the value on the left is greater than or equal to the value on the right +#define GE_CHECK_GE(lhs, rhs) \ + do { \ + if (lhs < rhs) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is less than[%s]", #lhs, #rhs); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// Check whether the parameters are equal +#define GE_CHECK_EQ(val1, val2) \ + do { \ + if (val1 != val2) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not equals to[%s]", #val1, #val2); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// Verify that the value on the left is less than or equal to the value on the right +#define GE_CHECK_LE(lhs, rhs) \ + do { \ + if (lhs > rhs) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is greater than[%s]", #lhs, #rhs); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// Check whether the parameters are equal +#define GE_CHECK_EQ_WITH_LOG(val1, val2, ...) \ + do { \ + if (val1 != val2) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// If expr is false, the custom statement is executed +#define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + exec_expr; \ + } \ + } while (0) + +#define GE_DELETE_NEW_SINGLE(var) \ + do { \ + if (var != nullptr) { \ + delete var; \ + var = nullptr; \ + } \ + } while (0) + +#define GE_DELETE_NEW_ARRAY(var) \ + do { \ + if (var != nullptr) { \ + delete[] var; \ + var = nullptr; \ + } \ + } while (0) + +template +static inline std::shared_ptr ComGraphMakeShared(Args &&... args) { + using T_nc = typename std::remove_const::type; + std::shared_ptr ret(new (std::nothrow) T_nc(std::forward(args)...)); + return ret; +} + +#endif // COMMON_GRAPH_DEBUG_GE_UTIL_H_ diff --git a/src/common/graph/debug/graph_debug.cc b/src/common/graph/debug/graph_debug.cc new file mode 100644 index 00000000..7ce9db37 --- /dev/null +++ b/src/common/graph/debug/graph_debug.cc @@ -0,0 +1,246 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/debug/graph_debug.h" +#include +#include +#include +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" + +#define TAB " " +#define STR_FMT(str) (" \"" + std::string(str) + "\" ") +#define INPUT_ANCHOR_PORT(name) ("__input__" + (name)) +#define OUTPUT_ANCHOR_PORT(name) ("__output__" + (name)) + +namespace ge { +std::unordered_set control_anchor; +std::vector types = { + "DT_FLOAT", "DT_FLOAT16", "DT_INT8", "DT_INT32", "DT_UINT8", "", + "DT_INT16", "DT_UINT16", "DT_UINT32", "DT_INT64", "DT_UINT64", "DT_DOUBLE", + "DT_BOOL", "DT_DUAL", "DT_DUAL_SUB_INT8", "DT_DUAL_SUB_UINT8", "DT_UNDEFINED"}; + +std::vector formats = {"FORMAT_NCHW", + "FORMAT_NHWC", + "FORMAT_ND", + "FORMAT_NC1HWC0", + "FORMAT_FRACTAL_Z", + "FORMAT_NC1C0HWPAD", + "FORMAT_NHWC1C0", + "FORMAT_FSR_NCHW", + "FORMAT_FRACTAL_DECONV", + "FORMAT_C1HWNC0", + "FORMAT_FRACTAL_DECONV_TRANSPOSE", + "FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS", + "FORMAT_NC1HWC0_C04", + "FORMAT_FRACTAL_Z_C04", + "FORMAT_CHWN", + "FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS", + "FORMAT_HWCN", + "FORMAT_NC1KHKWHWC0", + "FORMAT_BN_WEIGHT", + "FORMAT_FILTER_HWCK", + "FORMAT_HASHTABLE_LOOKUP_LOOKUPS", + "FORMAT_HASHTABLE_LOOKUP_KEYS", + "FORMAT_HASHTABLE_LOOKUP_VALUE", + "FORMAT_HASHTABLE_LOOKUP_OUTPUT", + "FORMAT_HASHTABLE_LOOKUP_HITS", + "FORMAT_RESERVED"}; + +std::vector data_nodes = {"Const", "Data"}; + +void GraphDebugPrinter::DumpNodeToDot(const NodePtr node, std::ostringstream &out_) { + if (node == nullptr) { + GELOGI("Some nodes are null."); + return; + } + + bool in_control = false; + auto name = node->GetName(); + out_ << TAB << STR_FMT(name); + auto input_cnt = std::max(static_cast(1), node->GetAllInDataAnchors().size()); + auto output_cnt = std::max(static_cast(1), node->GetAllOutDataAnchors().size()); + if (control_anchor.find(node->GetName()) != control_anchor.end()) { + input_cnt++; + in_control = true; + } + auto max_col = input_cnt * output_cnt; + out_ << "[\n"; + if (find(data_nodes.begin(), data_nodes.end(), node->GetType()) != data_nodes.end()) { + out_ << TAB << TAB << "shape=plaintext, color=goldenrod\n"; + } else { + out_ << TAB << TAB << "shape=plaintext, color=deepskyblue\n"; + } + out_ << TAB << TAB << "label=<\n"; + out_ << TAB << TAB << R"(" << std::endl; + + auto input_anchors = node->GetAllInDataAnchors(); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op_desc, return ); + if (!input_anchors.empty()) { + out_ << TAB << TAB << ""; + } + for (const auto &anchor : input_anchors) { + string anchor_text = op_desc->GetInputNameByIndex(anchor->GetIdx()); + + out_ << ""; + } + if (in_control) { + string anchor_text = "ctrl"; + out_ << ""; + } + if (!input_anchors.empty()) { + out_ << "\n"; + } + // Node type + out_ << TAB << TAB << "\n"; + // Output + auto output_anchors = node->GetAllOutDataAnchors(); + if (!output_anchors.empty()) { + out_ << TAB << TAB << ""; + } + for (const auto &anchor : output_anchors) { + string anchor_text = op_desc->GetOutputNameByIndex(anchor->GetIdx()); + + out_ << ""; + } + + if (!output_anchors.empty()) { + out_ << "\n"; + } + out_ << TAB << TAB << "
" + << anchor_text << "" + << anchor_text << "
" + << "" << node->GetType() << "
" + << anchor_text << "
\n" << TAB << ">];\n"; +} + +void GraphDebugPrinter::DumpEdgeToDot(const NodePtr node, std::ostringstream &out_, uint32_t flag) { + if (node == nullptr) { + GELOGI("Some nodes are null."); + return; + } + auto all_out_anchor = node->GetAllOutDataAnchors(); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op_desc, return ); + for (const auto &anchor : all_out_anchor) { + auto src_anchor = anchor; + auto src_node_name = node->GetName(); + auto src_anchor_index = op_desc->GetOutputNameByIndex(static_cast(src_anchor->GetIdx())); + auto des_anchors = anchor->GetPeerAnchors(); + for (const auto &peer_in_anchor : des_anchors) { + auto in_data_anchor = Anchor::DynamicAnchorCast(peer_in_anchor); + std::string dst_node_name; + out_ << TAB << STR_FMT(src_node_name); + out_ << ":" << OUTPUT_ANCHOR_PORT(src_anchor_index); + auto op = peer_in_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op, continue); + if (in_data_anchor != nullptr) { + dst_node_name = in_data_anchor->GetOwnerNode()->GetName(); + string des_anchor_index = op->GetInputNameByIndex(static_cast(in_data_anchor->GetIdx())); + out_ << " -> " << STR_FMT(dst_node_name); + out_ << ":" << INPUT_ANCHOR_PORT(des_anchor_index); + out_ << "["; + } + auto in_control_anchor = Anchor::DynamicAnchorCast(peer_in_anchor); + if (in_control_anchor != nullptr) { + dst_node_name = in_control_anchor->GetOwnerNode()->GetName(); + string des_anchor_index = "ctrl"; + out_ << " -> " << STR_FMT(dst_node_name); + out_ << ":" << INPUT_ANCHOR_PORT(des_anchor_index); + out_ << "["; + out_ << " style=dashed "; + } + if (flag != DOT_NOT_SHOW_EDGE_LABEL && in_data_anchor) { + string label; + auto src_ops = src_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(src_ops, return ); + auto src_shape = src_ops->GetOutputDesc(src_anchor->GetIdx()).GetShape(); + auto dim = src_shape.GetDims(); + std::ostringstream tensor_info; + if (dim.size() > 0) { + for (size_t i = 0; i < dim.size(); i++) { + if (i != dim.size() - 1) { + tensor_info << dim[i] << "x"; + } else { + tensor_info << dim[i]; + } + } + } else { + tensor_info << "?"; + } + auto src_tensor_desc = src_ops->GetOutputDescPtr(src_anchor->GetIdx()); + GE_CHECK_NOTNULL_EXEC(src_tensor_desc, return ); + auto format = src_tensor_desc->GetFormat(); + auto datatype = src_tensor_desc->GetDataType(); + tensor_info << " : " << formats[format] << " : " << types[datatype]; + label = tensor_info.str(); + out_ << "label=" << STR_FMT(label); + } + out_ << "]" << std::endl; + } + } +} + +graphStatus GraphDebugPrinter::DumpGraphDotFile(const Graph &graph, const std::string &output_dot_file_name, + uint32_t flag) { + auto compute_graph = GraphUtils::GetComputeGraph(graph); + if (compute_graph == nullptr) { + GELOGI("Compute graph is NULL ."); + return GRAPH_SUCCESS; + } + return DumpGraphDotFile(compute_graph, output_dot_file_name, flag); +} + +graphStatus GraphDebugPrinter::DumpGraphDotFile(const ComputeGraphPtr graph, const std::string &output_dot_file_name, + uint32_t flag) { + if (graph == nullptr) { + GELOGI("graph is null."); + return GRAPH_SUCCESS; + } + std::ostringstream out_; + out_ << "digraph G{\n"; + out_ << TAB << R"(ratio=compress;size="8, 100")" << std::endl; + out_ << TAB << R"(node[fontname="Consolas"])" << std::endl; + out_ << TAB << R"(edge[fontsize = "8" fontname = "Consolas" color="dimgray" ])" << std::endl; + auto all_nodes = graph->GetAllNodes(); + for (const auto &node : all_nodes) { + for (const auto &temp : node->GetAllOutDataAnchors()) { + for (const auto &peer : temp->GetPeerAnchors()) { + auto temp_control_anchor = Anchor::DynamicAnchorCast(peer); + if (temp_control_anchor) { + (void)control_anchor.insert(peer->GetOwnerNode()->GetName()); + } + } + } + } + for (const auto &node : all_nodes) { + DumpNodeToDot(node, out_); + } + for (const auto &node : all_nodes) { + DumpEdgeToDot(node, out_, flag); + } + out_ << "}"; + std::ofstream output_file(output_dot_file_name); + if (output_file.is_open()) { + output_file << out_.str(); + } else { + GELOGW("%s open error.", output_dot_file_name.c_str()); + } + return GRAPH_SUCCESS; +} +} // namespace ge diff --git a/src/common/graph/debug/graph_debug.h b/src/common/graph/debug/graph_debug.h new file mode 100644 index 00000000..29de632a --- /dev/null +++ b/src/common/graph/debug/graph_debug.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ +#define COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ +#include +#include +#include +#include +#include +#include "external/graph/graph.h" +#include "./ge_error_codes.h" +#include "graph/compute_graph.h" +#include "graph/debug/ge_log.h" +#include "graph/node.h" +#include "utils/graph_utils.h" + +namespace ge { +enum DotFileFlag { + // Show nodes, edges, size, type and format + DOT_FLAG_DEFAULT = 0, + DOT_NOT_SHOW_EDGE_LABEL = 1, +}; +class GraphDebugPrinter { + public: + static graphStatus DumpGraphDotFile(const Graph &graph, const std::string &output_dot_file_name, + uint32_t flag = DOT_FLAG_DEFAULT); + static graphStatus DumpGraphDotFile(const ComputeGraphPtr graph, const std::string &output_dot_file_name, + uint32_t flag = DOT_FLAG_DEFAULT); + static void DumpNodeToDot(const NodePtr node, std::ostringstream &out_); + static void DumpEdgeToDot(const NodePtr node, std::ostringstream &out_, uint32_t flag = DOT_FLAG_DEFAULT); +}; +} // namespace ge + +#endif // COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ diff --git a/src/common/graph/detail/attributes_holder.cc b/src/common/graph/detail/attributes_holder.cc new file mode 100644 index 00000000..7e3b6de9 --- /dev/null +++ b/src/common/graph/detail/attributes_holder.cc @@ -0,0 +1,241 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "detail/attributes_holder.h" +#include +#include "debug/ge_log.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/ge_attr_value.h" +#include "proto/ge_ir.pb.h" + +namespace ge { +using std::map; +using std::unordered_set; +void AttrHolder::CopyAttrsFrom(const AttrHolder &holder) { MutableAttrMap().CopyValueFrom(holder.GetAttrMap()); } +graphStatus AttrHolder::SetAttr(const std::string &name, const GeAttrValue &value) { + if (value.IsEmpty()) { + GELOGE(GRAPH_FAILED, "value is empty, key of the attr is %s", name.c_str()); + return GRAPH_FAILED; + } + auto proto_map = MutableAttrMap().GetProtoMsg(); + auto proto_val = value.value_.GetProtoMsg(); + if (proto_map == nullptr || proto_val == nullptr) { + return GRAPH_FAILED; + } + auto it = proto_map->find(name); + if (it != proto_map->end()) { + if (it->second.value_case() != proto::AttrDef::VALUE_NOT_SET && + it->second.value_case() != proto_val->value_case()) { + return GRAPH_FAILED; + } + } + (*proto_map)[name] = *proto_val; + return GRAPH_SUCCESS; +} + +graphStatus AttrHolder::AddRequiredAttr(const std::string &name) { + if (HasAttr(name)) { + return GRAPH_FAILED; + } + requiredAttrs_.push_back(name); + return GRAPH_SUCCESS; +} + +graphStatus AttrHolder::GetAttr(const std::string &name, GeAttrValue &value) const { + auto proto_map = GetAttrMap().GetProtoMsg(); + auto proto_val = value.value_.GetProtoMsg(); + if (proto_map == nullptr || proto_val == nullptr) { + return GRAPH_FAILED; + } + auto it = proto_map->find(name); + if (it != proto_map->end()) { + *proto_val = it->second; + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +bool AttrHolder::HasAttr(const std::string &name) const { + auto proto_map = GetAttrMap().GetProtoMsg(); + if (proto_map != nullptr) { + if (proto_map->find(name) != proto_map->end()) { + return true; + } + } + return std::find(requiredAttrs_.begin(), requiredAttrs_.end(), name) != requiredAttrs_.end(); +} + +graphStatus AttrHolder::DelAttr(const std::string &name) { + auto proto_map = MutableAttrMap().GetProtoMsg(); + if (proto_map == nullptr) { + return GRAPH_FAILED; + } + auto it = proto_map->find(name); + if (it != proto_map->end()) { + (void)proto_map->erase(it); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +const std::map AttrHolder::GetAllAttrs() const { + std::map attr_value_map; + auto proto_map = GetAttrMap().GetProtoMsg(); + if (proto_map != nullptr) { + auto proto_owner = GetAttrMap().GetProtoOwner(); + GE_CHK_BOOL_EXEC(proto_owner != nullptr, return attr_value_map, "proto_owner is nullptr"); + for (const auto &it : *proto_map) { + attr_value_map[it.first] = GeAttrValue(proto_owner, const_cast(&it.second)); + } + } + return attr_value_map; +} + +const std::unordered_set AttrHolder::GetAllAttrNames() const { + std::unordered_set names; + auto proto_map = GetAttrMap().GetProtoMsg(); + if (proto_map != nullptr) { + for (const auto &it : *proto_map) { + (void)names.insert(it.first); + } + } + for (const string &it : requiredAttrs_) { + (void)names.insert(it); + } + return names; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::AttrDef make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::TensorDef make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::ShapeDef make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::NamedAttrs make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); + return; + } + protoMsg_ = proto_owner->mutable_attr(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); + return; + } + protoMsg_ = &proto_owner->attr(); + protoOwner_ = proto_owner; +} +} // namespace ge diff --git a/src/common/graph/format_refiner.cc b/src/common/graph/format_refiner.cc new file mode 100644 index 00000000..c716825a --- /dev/null +++ b/src/common/graph/format_refiner.cc @@ -0,0 +1,508 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "format_refiner.h" + +#include +#include +#include +#include +#include + +#include "graph/ref_relation.h" +#include "./compute_graph.h" +#include "./ge_error_codes.h" +#include "./graph/ge_tensor.h" +#include "./operator.h" +#include "./operator_factory.h" +#include "debug/ge_log.h" +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "utils/node_utils.h" +#include "utils/op_desc_utils.h" +#include "utils/tensor_utils.h" +#include "utils/type_utils.h" + +using namespace ge; +using namespace std; +namespace ge { +namespace { +const std::unordered_set kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; +const string kIsGraphInferred = "_is_graph_inferred"; +thread_local RefRelations reflection_builder; +} // namespace + +graphStatus ReflectionProcess(const std::unordered_set &reflection, + std::deque &nodes, ge::Format to_be_set_format) { + for (const auto &cell : reflection) { + auto node = cell.node; + auto in_out_idx = cell.in_out_idx; + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + if (cell.in_out == ge::NODE_IN) { + auto desc = node->GetOpDesc()->GetInputDesc(static_cast(in_out_idx)); + desc.SetOriginFormat(to_be_set_format); + desc.SetFormat(to_be_set_format); + (void)node->GetOpDesc()->UpdateInputDesc(static_cast(in_out_idx), desc); + } else { + auto desc = node->GetOpDesc()->GetOutputDesc(static_cast(in_out_idx)); + desc.SetOriginFormat(to_be_set_format); + desc.SetFormat(to_be_set_format); + (void)node->GetOpDesc()->UpdateOutputDesc(static_cast(in_out_idx), desc); + } + nodes.push_back(cell.node); + } + + return GRAPH_SUCCESS; +} + +graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) { + // 5 meas dim num + if (node_ptr->GetType() != "BiasAdd") { + return GRAPH_SUCCESS; + } + std::unordered_map kTfFormatFix = {{"NHWC", FORMAT_NDHWC}, {"NCHW", FORMAT_NCDHW}}; + for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) { + auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i); + GE_CHECK_NOTNULL(in_desc); + if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num + continue; + } + auto format = in_desc->GetOriginFormat(); + auto key = TypeUtils::FormatToSerialString(format); + auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; + in_desc->SetOriginFormat(fixed_format); + in_desc->SetFormat(fixed_format); + GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", i, + node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::FormatToSerialString(fixed_format).c_str()); + } + for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) { + auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i); + GE_CHECK_NOTNULL(out_desc); + if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num + continue; + } + auto format = out_desc->GetOriginFormat(); + auto key = TypeUtils::FormatToSerialString(format); + auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; + out_desc->SetOriginFormat(fixed_format); + out_desc->SetFormat(fixed_format); + GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", i, + node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::FormatToSerialString(fixed_format).c_str()); + } + return GRAPH_SUCCESS; +} + +graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { + GE_CHECK_NOTNULL(graph); + GE_CHECK_NOTNULL(op_desc); + if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) { + ConstGeTensorPtr tensor_value; + if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) { + GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str()); + return GRAPH_FAILED; + } + GE_CHECK_NOTNULL(tensor_value); + (void)op_desc->UpdateOutputDesc(0, tensor_value->GetTensorDesc()); + } + return GRAPH_SUCCESS; +} + +graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector &anchor_points, + std::vector &data_nodes, + std::unordered_map &node_status) { + if (graph == nullptr) { + GELOGE(GRAPH_FAILED, "input graph is null"); + return GRAPH_FAILED; + } + anchor_points.clear(); + // Get all anchor point nodes and switch nodes + for (auto &node_ptr : graph->GetAllNodes()) { + if (node_ptr == nullptr) { + return GRAPH_FAILED; + } + auto op_desc = node_ptr->GetOpDesc(); + if (op_desc == nullptr) { + return GRAPH_FAILED; + } + graphStatus status = RefreshConstantOutProcess(graph, op_desc); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "refresh constant out process failed!"); + return GRAPH_FAILED; + } + // consider special node save process + // get all input desc format + bool node_is_all_nd = false; + auto input_size = static_cast(op_desc->GetAllInputsSize()); + for (uint32_t i = 0; i < input_size; i++) { + // Operator pre-set format but not origin format + GE_IF_BOOL_EXEC(op_desc->MutableInputDesc(i) == nullptr, continue); + auto input_format = op_desc->MutableInputDesc(i)->GetFormat(); + // Pre-save data node (only main graph data) and default infer fail + if (node_ptr->GetType() == DATA) { + data_nodes.push_back(node_ptr); + } + if (input_format != FORMAT_ND && input_format != FORMAT_RESERVED) { + node_is_all_nd = true; + } + } + // Get all output desc format + auto output_size = static_cast(op_desc->GetOutputsSize()); + for (uint32_t i = 0; i < output_size; i++) { + GE_IF_BOOL_EXEC(op_desc->MutableOutputDesc(i) == nullptr, continue); + auto output_format = op_desc->MutableOutputDesc(i)->GetFormat(); + if (output_format != FORMAT_ND && output_format != FORMAT_RESERVED) { + node_is_all_nd = true; + } + } + // check anchor point valid + if (!node_is_all_nd) { + continue; + } + // special process for biasAdd op + // In tensorflow, biasAdd's format is alwayse NHWC even though set the arg + // "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism + // so here do special process + status = BiasAddFormatFixProcess(node_ptr); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "fix biasAdd process failed!"); + return GRAPH_FAILED; + } + + GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str()); + anchor_points.push_back(node_ptr); + } + GELOGI("anchor_points number is %zu", anchor_points.size()); + return GRAPH_SUCCESS; +} +graphStatus FormatRefiner::AnchorProcess(const ge::NodePtr &anchor_node, + std::unordered_map &node_status) { + if (anchor_node == nullptr) { + GELOGE(GRAPH_FAILED, "anchor node is null!"); + return GRAPH_FAILED; + } + std::deque nodes; + nodes.push_back(anchor_node); + while (!nodes.empty()) { + ge::NodePtr node = nodes.front(); + nodes.pop_front(); + graphStatus status = BackInferProcess(nodes, node, node_status); + if (status != GRAPH_SUCCESS && node != nullptr) { + GELOGE(status, "BackInferProcess failed!node name [%s]", node->GetName().c_str()); + return status; + } + status = ForwardInferProcess(nodes, node, node_status); + if (status != GRAPH_SUCCESS && node != nullptr) { + GELOGE(status, "ForwardInferProcess failed!node name [%s]", node->GetName().c_str()); + return status; + } + } + return GRAPH_SUCCESS; +} +graphStatus FormatRefiner::BackInferProcess(std::deque &nodes, ge::NodePtr &node, + std::unordered_map &node_status) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + + GELOGD("Enter back infer process!Node is [%s]", (node->GetName()).c_str()); + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + GELOGD("Node is [%s] [B]", (node->GetName()).c_str()); + auto in_data_anchor_idx = in_anchor->GetIdx(); + auto input_desc = node->GetOpDesc()->MutableInputDesc(static_cast(in_data_anchor_idx)); + GE_IF_BOOL_EXEC(input_desc == nullptr, continue); + auto to_be_set_format = input_desc->GetOriginFormat(); + if (to_be_set_format == FORMAT_ND) { + GELOGD("Node [%s] [B], format is ND", (node->GetName()).c_str()); + continue; + } + auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_data_anchor == nullptr) { + GELOGW("Node[%s] %dth in data anchor's peer_out_anchor is null", (node->GetName()).c_str(), in_data_anchor_idx); + continue; + } + auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); + if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { + GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (node->GetName()).c_str()); + continue; + } + // Check format whether have been set + int idx = peer_out_data_anchor->GetIdx(); + // do peer_out_node name and index as key to lookup reflections + ge::RefCell key(peer_out_data_node->GetName(), peer_out_data_node, ge::NODE_OUT, idx); + std::unordered_set reflection; + auto status = reflection_builder.LookUpRefRelations(key, reflection); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d out edge", + (peer_out_data_node->GetName()).c_str(), idx); + return GRAPH_FAILED; + } + + auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast(idx)); + if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { + auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); + if (dim_num == 0) { + GELOGD("node name:%s idx:%d out is scalar. stop back infer!", peer_out_data_node->GetName().c_str(), idx); + continue; + } + /// Check whether node to change dims () + /// Because some node will calculate with 5D, C dim maybe multi meaning + auto peer_out_data_node_type = peer_out_data_node->GetType(); + auto iter1 = kChangeDimNodes.find(peer_out_data_node_type); + // 4 means dims num + if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) { + GELOGD("Node[%s] is change dim node and shape is smaller than 4. do not modify format", + (peer_out_data_node->GetName()).c_str()); + continue; + } + + if (reflection.empty()) { + ge_tensor_desc.SetOriginFormat(to_be_set_format); + ge_tensor_desc.SetFormat(to_be_set_format); + (void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast(idx), ge_tensor_desc); + + // Call operator infer format api (forward) to get out format + GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); + status = peer_out_data_node->InferOriginFormat(); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str()); + return GRAPH_FAILED; + } + nodes.push_back(peer_out_data_node); + } else { + auto status = ReflectionProcess(reflection, nodes, to_be_set_format); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "reflection process failed!"); + return GRAPH_FAILED; + } + } + } + } + return GRAPH_SUCCESS; +} +graphStatus FormatRefiner::ForwardInferProcess(std::deque &nodes, ge::NodePtr &node, + std::unordered_map &node_status) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + GELOGD("Enter forward infer process!Node is [%s]", (node->GetName()).c_str()); + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + GELOGD("Node is [%s] [F]", (node->GetName()).c_str()); + GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); + auto out_data_anchor_idx = out_data_anchor->GetIdx(); + auto to_be_set_format = + node->GetOpDesc()->MutableOutputDesc(static_cast(out_data_anchor_idx))->GetOriginFormat(); + if (to_be_set_format == FORMAT_ND) { + GELOGD("Node [%s] format is ND.[F]", (node->GetName()).c_str()); + continue; + } + for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_IF_BOOL_EXEC(peer_in_data_anchor == nullptr, continue); + + auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); + GE_IF_BOOL_EXEC(peer_in_data_node == nullptr, continue); + GE_IF_BOOL_EXEC(peer_in_data_node->GetOpDesc() == nullptr, continue); + + // Check format whether have been set + int idx = peer_in_data_anchor->GetIdx(); + // do peer_out_node name and index as key to lookup reflections + ge::RefCell key(peer_in_data_node->GetName(), peer_in_data_node, ge::NODE_IN, idx); + std::unordered_set reflection; + auto status = reflection_builder.LookUpRefRelations(key, reflection); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d input edge", + (peer_in_data_node->GetName()).c_str(), idx); + return GRAPH_FAILED; + } + auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast(idx)); + if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { + auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); + if (dim_num == 0) { + GELOGI("node name:%s idx:%d in is scalar. stop forward infer!", peer_in_data_node->GetName().c_str(), idx); + continue; + } + /// Check whether node to change dims () + /// Because some node will calculate with 5D, C dim maybe multi meaning + auto peer_in_data_node_type = peer_in_data_node->GetType(); + auto iter1 = kChangeDimNodes.find(peer_in_data_node_type); + // 4 means dims num + if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) { + GELOGD("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str()); + continue; + } + + if (reflection.empty()) { + ge_tensor_desc.SetOriginFormat(to_be_set_format); + ge_tensor_desc.SetFormat(to_be_set_format); + (void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(static_cast(idx), ge_tensor_desc); + + /// Because netoutput node added before infer format ,so netoutput is end condition + /// must set netoutput format , because saved result depend on format + if (peer_in_data_node_type == NETOUTPUT) { + continue; + } + + // Call operator infer format api (forward) to get out format + GELOGD("call infer format func[Back]!Node is [%s] ", (peer_in_data_node->GetName()).c_str()); + status = peer_in_data_node->InferOriginFormat(); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str()); + return GRAPH_FAILED; + } + nodes.push_back(peer_in_data_node); + } else { + auto status = ReflectionProcess(reflection, nodes, to_be_set_format); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "reflection process failed!"); + return GRAPH_FAILED; + } + } + } + } + } + return GRAPH_SUCCESS; +} + +void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector &anchor_points) { + for (const auto &node : anchor_points) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + continue; + } + for (const auto &input_desc : node->GetOpDesc()->GetAllInputsDescPtr()) { + if (input_desc != nullptr) { + input_desc->SetOriginFormat(input_desc->GetFormat()); + } + } + for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDescPtr()) { + if (output_desc != nullptr) { + output_desc->SetOriginFormat(output_desc->GetFormat()); + } + } + } +} + +graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector &data_nodes, + ge::Format data_format, + std::unordered_map &node_status) { + if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) { + GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph), + TypeUtils::FormatToSerialString(data_format).c_str()); + return GRAPH_SUCCESS; + } + GELOGD("Enter DataNodeFormatProcess"); + std::vector uninfered_data_nodes; + // Check and renew data nodes format + for (const auto &data_node : data_nodes) { + GE_CHECK_NOTNULL(data_node); + auto op_desc = data_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(0)); + auto curr_format = op_desc->GetOutputDescPtr(0)->GetOriginFormat(); + if (curr_format != FORMAT_ND) { + // Data format has been infered , continue + continue; + } + // Set format for un-infered data node + auto input_descs = op_desc->GetAllInputsDescPtr(); + auto output_descs = op_desc->GetAllOutputsDescPtr(); + + for (const auto &input_desc : input_descs) { + if (input_desc != nullptr) { + input_desc->SetOriginFormat(data_format); + input_desc->SetFormat(data_format); + } + } + for (const auto &output_desc : output_descs) { + if (output_desc != nullptr) { + output_desc->SetOriginFormat(data_format); + output_desc->SetFormat(data_format); + } + } + uninfered_data_nodes.push_back(data_node); + } + // Reinfer format from uninfered data nodes + for (const auto &node : uninfered_data_nodes) { + if (node == nullptr) { + continue; + } + GELOGD("data node [%s] start infer format process", node->GetName().c_str()); + auto status = AnchorProcess(node, node_status); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "data node [%s] infer format process failed!", node->GetName().c_str()); + return GRAPH_FAILED; + } + } + GELOGD("DataNodeFormatProcess success"); + return GRAPH_SUCCESS; +} + +graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) { + GELOGI("Enter InferOrigineFormat process!"); + + // True: infered false:no-infered + std::unordered_map node_status; + std::vector anchor_points; + std::vector data_nodes; + // global net format + + if (graph == nullptr) { + GELOGE(GRAPH_FAILED, "input graph is null"); + return GRAPH_FAILED; + } + // build reflection relations of boundary + (void)reflection_builder.Clear(); + auto status = reflection_builder.BuildRefRelations(*graph); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "build reflection relations failed for main and subgraph!"); + return GRAPH_FAILED; + } + // User set global net format + status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "GetAnchorPoints Process Faild!"); + return GRAPH_FAILED; + } + // Refresh origin format of anchor point + RefreshOriginFormatOfAnchor(anchor_points); + // Infer format process + for (const auto &anchor_node : anchor_points) { + if (anchor_node == nullptr) { + continue; + } + status = AnchorProcess(anchor_node, node_status); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Anchor node [%s] process failed!", anchor_node->GetName().c_str()); + return GRAPH_FAILED; + } + } + /// According to discuss with sys-enginer, data node default format is ND.Its format + /// should be set by infered.But if some data-node can not be got by infer, set context's + /// format for these data nodes. + /// Notice: ignore 5D formats + auto data_format = graph->GetDataFormat(); + status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status); + + (void)AttrUtils::SetBool(graph, kIsGraphInferred, true); + + return status; +} + +bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) { + bool is_graph_inferred = false; + return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred); +} +} // namespace ge diff --git a/src/common/graph/format_refiner.h b/src/common/graph/format_refiner.h new file mode 100644 index 00000000..eca93bae --- /dev/null +++ b/src/common/graph/format_refiner.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_FORMAT_REFINER_H_ +#define COMMON_GRAPH_FORMAT_REFINER_H_ + +#include +#include +#include +#include +#include "./compute_graph.h" +#include "./external/graph/types.h" +#include "./ge_error_codes.h" + +namespace ge { +// ShapeRefiner performs shape inference for compute graphs +class FormatRefiner { + public: + static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph); + + private: + static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); + static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector &anchor_points, + std::vector &data_nodes, + std::unordered_map &node_status); + static graphStatus AnchorProcess(const ge::NodePtr &anchor_node, std::unordered_map &node_status); + static void RefreshOriginFormatOfAnchor(std::vector &anchor_points); + static graphStatus BackInferProcess(std::deque &nodes, ge::NodePtr &node, + std::unordered_map &node_status); + static graphStatus ForwardInferProcess(std::deque &nodes, ge::NodePtr &node, + std::unordered_map &node_status); + static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector &data_nodes, + ge::Format data_format, std::unordered_map &node_status); + static bool IsGraphInferred(const ComputeGraphPtr &graph); +}; +} // namespace ge +#endif // COMMON_GRAPH_FORMAT_REFINER_H_ diff --git a/src/common/graph/ge_attr_define.cc b/src/common/graph/ge_attr_define.cc new file mode 100644 index 00000000..9b723bb3 --- /dev/null +++ b/src/common/graph/ge_attr_define.cc @@ -0,0 +1,1086 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace ge { +// Public attribute +const std::string ATTR_NAME_IS_UNKNOWN_SHAPE = "_is_unknown_shape"; + +const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED = "_dynamic_shape_partitioned"; + +const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE = "_unknown_shape_type"; + +const std::string ATTR_NAME_NAME = "name"; + +const std::string ATTR_NAME_TYPE = "type"; + +const std::string ATTR_NAME_WEIGHT_NAME = "weight_name"; + +const std::string ATTR_NAME_IS_QUANTIZE_FACTOR = "quantize_factor"; + +const std::string ATTR_NAME_ALPHA = "alpha"; + +const std::string ATTR_NAME_BETA = "beta"; + +const std::string ATTR_NAME_PADMODE = "pad_mode"; + +const std::string ATTR_NAME_PADMODES = "padding"; + +const std::string ATTR_NAME_MODE = "mode"; + +const std::string ATTR_NAME_FILTER = "filter"; + +const std::string ATTR_NAME_BIAS = "bias"; + +const std::string ATTR_NAME_BIAS_TERM = "bias_term"; + +const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; + +const std::string ATTR_NAME_PAD = "pad"; + +const std::string ATTR_NAME_PADS = "pad"; + +const std::string ATTR_NAME_PAD_SIZE = "pad size"; + +const std::string ATTR_NAME_PAD_MODE = "pad mode"; + +const std::string ATTR_NAME_SCALE = "scale"; + +const std::string ATTR_NAME_WINDOWS = "windows"; + +const std::string ATTR_NAME_GLOBAL_POOLING = "global_pooling"; + +const std::string ATTR_NAME_CEIL_MODE = "ceil_mode"; + +const std::string ATTR_NAME_RELUMODE = "relu_mode"; + +const std::string ATTR_NAME_STRIDE_SIZE = "stride size"; + +const std::string ATTR_NAME_RELU_FLAG = "relu_flag"; + +const std::string ATTR_NAME_ALGO = "algo"; + +const std::string ATTR_NAME_FORMAT = "format"; + +const std::string ATTR_NAME_STORAGE_FORMAT = "storage_format"; + +const std::string ATTR_NAME_STORAGE_SHAPE = "storage_shape"; + +const std::string ATTR_NAME_FILTER_FORMAT = "filter_format"; + +const std::string ATTR_NAME_LRN_K = "lrn_k"; + +const std::string ATTR_NAME_LRN_NORM_REGION = "lrn_normregion"; + +const std::string ATTR_NAME_LRN_LOCAL_SIZE = "lrn_localsize"; + +const std::string ATTR_NAME_LRN_ALPHA = "lrn_alpha"; + +const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; + +const std::string ATTR_NAME_AXIS = "axis"; +const std::string ATTR_NAME_BROADCAST = "broadcast"; + +const std::string ATTR_NAME_OUTPUT = "output"; +const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; +const std::string ATTR_NAME_TIDX = "t_idx"; + +const std::string ATTR_NAME_TPADDINGS = "t_paddings"; +const std::string ATTR_IMG_H = "img_h"; +const std::string ATTR_IMG_W = "img_w"; +const std::string ATTR_NET_H = "net_h"; +const std::string ATTR_NET_W = "net_w"; + +const std::string ATTR_NAME_TMULTIPLES = "t_multiples"; + +const std::string ATTR_NAME_MULTIPLES = "multiples"; + +const std::string ATTR_NAME_T = "T"; +const std::string ATTR_NAME_N = "N"; + +const std::string ATTR_NAME_TSHAPE = "Tshape"; +const std::string ATTR_NAME_NAN_OPT = "nan_opt"; + +const std::string ATTR_NAME_AIPP = "aipp"; +const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; + +const std::string ATTR_NAME_AIPP_INPUTS = "_aipp_inputs"; +const std::string ATTR_NAME_AIPP_OUTPUTS = "_aipp_outputs"; + +const std::string ATTR_NAME_INPUT_DIMS = "input_dims"; +const std::string ATTR_DYNAMIC_AIPP_INPUT_DIMS = "_dynamic_aipp_input_dims"; +const std::string ATTR_DATA_RELATED_AIPP_MODE = "_data_related_aipp_mode"; +const std::string ATTR_DATA_AIPP_DATA_NAME_MAP = "_data_aipp_data_name_map"; + +const std::string ATTR_NAME_GRAPH_HAS_BEEN_ADDED = "_graph_has_been_added"; + +const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; +const std::string ATTR_NAME_PARENT_GRAPH_NAME = "_parent_graph_name"; + +const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; +const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; +const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; + +const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; +const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; + +const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; +const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; +const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; +const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; +const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; + +const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; +const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; + +const std::string ATTR_NAME_INFERRED_FORMAT = "inferred_format"; +const std::string ATTR_NAME_PRED_PERMUTE_DELETED = "pred_permute_deleted"; +const std::string ATTR_NAME_IGNORE_PRED_FORMAT = "ignore_pred_format"; +const std::string ATTR_NAME_WEIGHTS = "value"; +const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; +const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; +const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; +const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; +const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL = "_continuous_stream_label"; +const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; +const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; +const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; +const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; +const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS = "_dynamic_output_dims"; +const std::string ATTR_NAME_INPUT_ORIGIN_SIZE = "input_origin_size"; + +// Identify node connecting to input and output +const std::string ATTR_NAME_NODE_CONNECT_INPUT = "_is_connected_to_data"; +const std::string ATTR_NAME_NODE_CONNECT_OUTPUT = "_is_connected_to_netoutput"; + +// To be deleted +const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; +const std::string PERMUTE_RESHAPE_FUSION = "permute_reshape_fusion"; +const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL = "fusion_conv_proposal"; +const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX = "fusion_conv_decodebbox"; +const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM = "box_type_num"; +const std::string SSD_MBOX_LOC_FUSION = "permute_flatten_fusion"; +const std::string SSD_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; +const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; +const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; +const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; + +// Refinedet +const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; + +const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; +const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; +const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; +const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; + +// _Arg +const std::string ATTR_NAME_INDEX = "index"; +// _RetVal +const std::string RETVAL_ATTR_NAME_INDEX = "retval_index"; +// Data +const std::string DATA_ATTR_NAME_DATA_TYPE = "data_type"; + +// Send +const std::string SEND_ATTR_EVENT_ID = "event_id"; + +// Recv +const std::string RECV_ATTR_EVENT_ID = "event_id"; + +// convolution +const std::string ATTR_NAME_COEF = "coef"; + +const std::string ATTR_NAME_STRIDE = "stride"; + +const std::string ATTR_NAME_STRIDES = "stride"; + +const std::string ATTR_NAME_DILATION = "dilation"; + +const std::string ATTR_NAME_DILATIONS = "dilation"; + +const std::string CONV_ATTR_NAME_MODE = "mode"; + +const std::string CONV_ATTR_NAME_ALGO = "algo"; + +const std::string CONV_ATTR_NAME_GROUP = "group"; + +const std::string CONV_ATTR_NAME_PAD_MODE = "pad_mode"; + +const std::string CONV_ATTR_NAME_PAD = "pad"; + +const std::string CONV_ATTR_NAME_STRIDE = "stride"; + +const std::string CONV_ATTR_NAME_DILATION = "dilation"; + +const std::string CONV_ATTR_NAME_NUM_OUTPUT = "num_output"; + +const std::string CONV_ATTR_NAME_KERNEL = "kernel"; + +const std::string CONV_ATTR_NAME_FILTER = "filter"; + +const std::string CONV_ATTR_NAME_BIAS = "bias"; + +const std::string CONV_ATTR_NAME_RELU_FLAG = "relu_flag"; + +const std::string CONV_ATTR_NAME_ADJ = "adj"; + +const std::string CONV_ATTR_NAME_TARGET_SHAPE = "target_shape"; + +const std::string CONV_ATTR_NAME_BEFORE_PAD = "before_pad"; + +const std::string CONV_ATTR_NAME_HAS_BIAS = "has_bias"; + +const std::string NEED_INFER = "isNeedInfer"; + +// Pooling +const std::string POOLING_ATTR_MODE = "mode"; +const std::string POOLING_ATTR_NAN_OPT = "nan_opt"; +const std::string POOLING_ATTR_PAD_MODE = "pad_mode"; +const std::string POOLING_ATTR_GLOBAL_POOLING = "global_pooling"; +const std::string POOLING_ATTR_WINDOW = "window"; +const std::string POOLING_ATTR_PAD = "pad"; +const std::string POOLING_ATTR_STRIDE = "stride"; +const std::string POOLING_ATTR_CEIL_MODE = "ceil_mode"; +const std::string POOLING_ATTR_DATA_MODE = "data_mode"; +const std::string POOLING_ATTR_BEFORE_PAD = "before_pad"; +const std::string POOLING_ATTR_NAME_ALGO = "algo"; + +// Eltwise +const std::string ELTWISE_ATTR_MODE = "mode"; +const std::string ELTWISE_ATTR_COEFF = "coeff"; +const std::string ELTWISE_ATTR_WEIGHT = "weight"; +const std::string ELTWISE_ATTR_RELU_FLAG = "relu_flag"; +const std::string ELTWISE_ATTR_ALPHA = "alpha"; +const std::string ELTWISE_ATTR_BETA = "beta"; + +// BatchNorm +const std::string BATCHNORM_ATTR_MODE = "mode"; +const std::string BATCHNORM_ATTR_EPSILON = "epsilon"; +const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS = "use_global_stats"; +const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION = "moving_average_fraction"; +const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; +const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; +const std::string BATCHNORM_ATTR_SCALE = "scale"; +const std::string BATCHNORM_ATTR_BIAS = "bias"; +const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; +const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; +const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; + +// huberloss +const std::string HUBER_LOSS_ATTR_DELTA = "delta"; + +// SSDRealDivTileMul +const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; + +// SSDSumMulRealDivMean +const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; +const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; +const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; +const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; + +// ConcatFive2Four +// ConcatFour2Five +const std::string SSD_BOX_TYPE_NUM = "box_type_num"; +const std::string SSD_CLASS_NUM = "class_num"; +const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; +const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; +const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; +const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; + +// Scale +const std::string SCALE_ATTR_SCALE = "scale"; +const std::string SCALE_ATTR_BIAS = "bias"; + +// FullConnection +const std::string FULL_CONNECTION_ATTR_FILTER = "filter"; +const std::string FULL_CONNECTION_ATTR_BIAS = "bias"; +const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT = "num_output"; +const std::string FULL_CONNECTION_ATTR_RELU_FLAG = "relu_flag"; +const std::string FULL_ATTR_NAME_ALGO = "algo"; + +// SoftmaxOpParams +const std::string SOFTMAX_ATTR_ALGO = "algo"; +const std::string SOFTMAX_ATTR_MODE = "mode"; + +// SparseSoftmaxCrossEntropy +const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE = "cross_entropy_mode"; +const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD = "cross_entropy_is_grad"; +// Attr labelSmoothing +const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING = "labelSmoothing"; + +// ApplyMomentum +const std::string APPLYMENTUM_ATTR_IS_GRAPH_FUSION = "applymomentum_is_graph_fusion"; + +// Activation +const std::string ACTIVATION_ATTR_MODE = "mode"; +const std::string ACTIVATION_ATTR_COEF = "coef"; + +// Concat +const std::string CONCAT_ATTR_NAME_AXIS = "axis"; + +// Const +const std::string CONST_ATTR_NAME_DATA_TRANSTYPE = "data_transtype"; +const std::string CONST_ATTR_NAME_OUTPUT_FORMAT = "output_format"; +const std::string CONST_ATTR_NAME_OUTPUT_TYPE = "output_type"; + +// Roipooling +const std::string ROIPOOLING_ATTR_NAME_POOLED_H = "pooled_h"; +const std::string ROIPOOLING_ATTR_NAME_POOLED_W = "pooled_w"; +const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE = "spatial_scale"; +const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE = "rio_pooling_mode"; +const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE = "pooling_mode"; +const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO = "sampling_ratio"; + +// DetectionOutput +const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES = "num_classes"; +const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES = "ocr_num_classes"; +const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD = "nms_threshold"; +const std::string DETECTIONOUTPUT_ATTR_TOP_K = "top_k"; +const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD = "confidence_threshold"; +const std::string DETECTIONOUTPUT_ATTR_IMG_H = "img_h"; +const std::string DETECTIONOUTPUT_ATTR_IMG_W = "img_w"; +const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE = "batch_size"; +// Ssd DetectionOutput +const std::string DETECTIONOUTPUT_ATTR_ETA = "eta"; +const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION = "shared_location"; +const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID = "background_label_id"; +const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE = "code_type"; +const std::string DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET = "variance_encoded_in_target"; +const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K = "keep_top_k"; +// Refinedet DetectionOutput +const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE = "objectness_score"; +// yolo DetectionOutput +const std::string DETECTIONOUTPUT_ATTR_ClASSES = "classes"; +const std::string DETECTIONOUTPUT_ATTR_BIASES = "biases"; +const std::string DETECTIONOUTPUT_ATTR_RELATIVE = "relative"; +const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD = "objectness_threshold"; +const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD = "class_threshold"; +const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K = "post_top_k"; +const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY = "iou_threshold_decay"; +const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR = "coor_scale_factor"; +const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION = "yolo_version"; + +// DetectionPostprocess +const std::string POSTPROCESS_ATTR_NAME_CLS_NUM = "cls_num"; +const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH = "conf_thresh"; +const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH = "nms_thresh"; +const std::string POSTPROCESS_ATTR_POST_NMS_TOPN = "post_nms_topn"; +const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT = "bbox_reg_weights"; + +// Spatialtransfrom +const std::string SPTIALTF_ATTR_NAME_OUTPUT_H = "output_h"; +const std::string SPTIALTF_ATTR_NAME_OUTPUT_W = "output_w"; +const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE = "border_value"; +const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM = "affine_transform"; + +// Proposa +const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE = "feat_stride"; +const std::string PROPOSAL_ATTR_NAME_BASE_SIZE = "base_size"; +const std::string PROPOSAL_ATTR_NAME_MIN_SIZE = "min_size"; +const std::string PROPOSAL_ATTR_NAME_RATIO = "ratio"; +const std::string PROPOSAL_ATTR_NAME_SCALE = "scale"; +const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN = "pre_nms_topn"; +const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN = "post_nms_topn"; +const std::string PROPOSAL_ATTR_NAME_NMS_THRESH = "nms_thresh"; +const std::string PROPOSAL_ATTR_NAME_TOP_SIZE = "top_size"; +const std::string PROPOSAL_ATTR_IMG_H = "img_h"; +const std::string PROPOSAL_ATTR_IMG_W = "img_w"; +// Softmax +const std::string SOFTMAX_ATTR_AXIS = "axis"; + +// Permute +const std::string PERMUTE_ATTR_ORDER = "order"; +const std::string PERMUTE_ATTR_PERM = "perm"; + +// SSD Normalize +const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; +const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED = "channel_shared"; +const std::string SSDNORMALIZE_ATTR_EPS = "eps"; + +// Flatten +const std::string FLATTEN_ATTR_AXIS = "axis"; +const std::string FLATTEN_ATTR_END_AXIS = "end_axis"; + +// SsdPRIORBOX +const std::string SSD_PRIOR_BOX_ATTR_FLIP = "flip"; +const std::string SSD_PRIOR_BOX_ATTR_CLIP = "clip"; +const std::string SSD_PRIOR_BOX_ATTR_IMG_H = "img_h"; +const std::string SSD_PRIOR_BOX_ATTR_IMG_W = "img_w"; +const std::string SSD_PRIOR_BOX_ATTR_STEP_H = "step_h"; +const std::string SSD_PRIOR_BOX_ATTR_STEP_W = "step_w"; +const std::string SSD_PRIOR_BOX_ATTR_OFFSET = "offset"; +const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE = "min_size"; +const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE = "max_size"; +const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM = "min_size_num"; +const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM = "max_size_num"; +const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO = "aspect_ratio"; +const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; +const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; +const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; + +// RefinedetDetectionOutput +const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; +const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; + +// PRelu +const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; + +// Psroi pooling +const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE = "spatial_scale"; +const std::string PSROIPOOLING_ATTR_OUTPUT_DIM = "output_dim"; +const std::string PSROIPOOLING_ATTR_GROUP_SIZE = "group_size"; + +// Power +const std::string POWER_ATTR_NAME_POWER = "power"; +const std::string POWER_ATTR_NAME_SCALE = "scale"; +const std::string POWER_ATTR_NAME_SHIFT = "shift"; + +// log +const std::string LOG_ATTR_NAME_SCALE = "scale"; +const std::string LOG_ATTR_NAME_SHIFT = "shift"; +const std::string LOG_ATTR_NAME_BASE = "base"; +// Pack +const std::string PACK_ATTR_NAME_NUM = "N"; + +// Unpack +const std::string UNPACK_ATTR_NAME_NUM = "num"; +const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; +// Gathernd +const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; +const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; + +// Argmax +const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; +const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; +const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; +const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; +const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; +const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; +const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; + +// upsample +const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; +const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; + +// Relu +const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; + +// FreeSpaceExtract +const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT = "org_height"; + +// Split +const std::string SPLIT_ATTR_NAME_SLICE_POINT = "slice_point"; +const std::string SPLIT_ATTR_NAME_SIZE_SPLIT = "size_split"; +const std::string SPLIT_ATTR_NAME_NUM_SPLIT = "num_split"; + +// Tvm +const std::string TVM_ATTR_NAME_MAGIC = "tvm_magic"; +const std::string TVM_ATTR_NAME_BLOCKDIM = "tvm_blockdim"; +const std::string TVM_ATTR_NAME_METADATA = "tvm_metadata"; +const std::string TVM_ATTR_NAME_WORKSPACE_TYPE = "tvm_workspace_type"; + +// Squeeze +const std::string SQUEEZE_ATTR_AXIS = "axis"; +const std::string SQUEEZE_ATTR_DIMS = "squeeze_dims"; +const std::string SQUEEZE_OP_NAME = "Squeeze"; + +// Stride slice +const std::string STRIDE_SLICE_ATTR_BEGIN_MASK = "begin_mask"; +const std::string STRIDE_SLICE_ATTR_END_MASK = "end_mask"; +const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK = "ellipsis_mask"; +const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK = "new_axis_mask"; +const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK = "shrink_axis_mask"; + +// Slice +const std::string SLICE_ATTR_NAME_BEGINS = "begins"; +const std::string SLICE_ATTR_NAME_SIZES = "sizes"; + +// Roialign +const std::string ROIALIGN_ATTR_SPATIAL_SCALE = "spatial_scale"; +const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; +const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; +const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; + +// Generate_rpn_proposal +const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; +const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK = "post_nms_topk"; +const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE = "rpn_mini_size"; +const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH = "rpn_proposal_nms_thresh"; +const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH = "rpn_proposal_filter_thresh"; +// Decode_bbox +const std::string DECODE_BBOX_ATTR_DECODECLIP = "decodeClip"; + +// Cast +const std::string CAST_ATTR_DSTT = "DstT"; +const std::string CAST_ATTR_SRCT = "SrcT"; +const std::string CAST_ATTR_DST_TYPE = "dst_type"; +const std::string CAST_ATTR_TRUNCATE = "truncate"; + +// Fastrcnnn predications +const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK = "fsr_topk"; +const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD = "fsr_score_thres"; +const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD = "fsr_nms_thres"; +const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES = "fsr_num_classes"; + +// REORG +const std::string REORG_ATTR_STRIDE = "stride"; +const std::string REORG_ATTR_REVERSE = "reverse"; + +// MERGE +const std::string MERGE_DEAD_INDEX = "merge_dead_index"; +const std::string MERGE_PRENODE_FLAG = "merge_prenode_flag"; +const std::string TO_BE_OUTPUT = "to_be_output"; + +// ENTER +const std::string ENTER_ATTR_FRAME_NAME = "frame_name"; +const std::string ENTER_ATTR_CONSTANT_FLAG = "is_constant"; + +// Concatv2 +const std::string CONCAT_V2_ATTR_TIDX = "Tidx"; +const std::string CONCAT_V2_ATTR_N = "N"; +// SUM +const std::string SUM_ATTR_TIDX = "Tidx"; +const std::string SUM_ATTR_AXIS = "axis"; +const std::string SUM_ATTR_KEEP_DIMS = "keep_dims"; + +// ResizeBilinear +const std::string RESIZE_BILINEAR_ATTR_MODE = "mode"; +const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS = "align_corners"; +const std::string RESIZE_BILINEAR_ATTR_HEIGHT = "height"; +const std::string RESIZE_BILINEAR_ATTR_WIDTH = "width"; +const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR = "zoom_factor"; +const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR = "shrink_factor"; +const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN = "pad_begin"; +const std::string RESIZE_BILINEAR_ATTR_PAD_END = "pad_end"; +const std::string RESIZE_BILINEAR_ATTR_ALPHA = "alpha"; +const std::string RESIZE_BILINEAR_ATTR_BETA = "beta"; + +// RetinaNet +const std::string RETINANET_FILTER_BACKGROUND_TRUE = "retina_conv_filter_background"; +const std::string RETINANET_ANCHOR_FUSION = "retina_anchor_fusion"; + +// MatMul +const std::string MATMUL_TRANSPOSE_X = "transposeX"; +const std::string MATMUL_TRANSPOSE_W = "transposeW"; +const std::string MATMUL_HAS_BIAS = "has_bias"; +const std::string MATMUL_ATTR_IS_TRAINING = "matmul_is_training"; + +// Flatten +const std::string FLATTEN_START_AXIS = "start_axis"; +const std::string FLATTEN_END_AXIS = "end_axis"; + +// Reshape +const std::string RESHAPE_ATTR_AXIS = "axis"; +const std::string RESHAPE_ATTR_NUM_AXES = "num_axes"; +const std::string RESHAPE_ATTR_FORMAT = "format"; +const std::string RESHAPE_ATTR_SHAPE = "shape"; +const std::string RESHAPE_ATTR_ALPHA = "alpha"; +const std::string RESHAPE_ATTR_BETA = "beta"; + +// Frameoworkop +const std::string T_IN_DATATYPE = "t_in_datatype"; +const std::string T_OUT_DATATYPE = "t_out_datatype"; +const std::string ATTR_NAME_OUT_N = "out_n"; +const std::string ATTR_NAME_OUT_C = "out_c"; +const std::string ATTR_NAME_OUT_H = "out_h"; +const std::string ATTR_NAME_OUT_W = "out_w"; +const std::string ATTR_PAD_DEPTH_CONV = "pad_depth_conv"; +const std::string ATTR_PAD_CONV = "pad_conv"; + +const std::string ATTR_NAME_BEFORE_PAD = "before_pad"; +const std::string ANN_MEAN_KEEPDIMS = "AnnMeanKeepDims"; +const std::string PAD_ATTR_PADDINGDS = "paddings"; +const std::string PAD_ATTR_CONSTANT_VALUE = "padvalue"; + +// ConvGradFilter +const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape"; +// ConvGradInput +const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; + +// Rnn +const std::string RNN_MODE_STATIC = "rnn_static"; +const std::string MUTI_RNN = "multi_rnn"; +const std::string CNN_RNN = "cnn_rnn"; +const std::string RNN_MODE_ = "rnn_"; + +const std::string CELL_MODE = "mode"; +const std::string LSTM_CELL = "lstm_cell"; +const std::string GRU_CELL = "gru_cell"; +const std::string RNN_HT = "ht"; +const std::string RNN_XT_HT = "xt_ht"; +const std::string RNN_BATCH_SIZE = "batch_size"; +const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; +const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; +const std::string LSTM_ACTIVATE = "lstm_activate"; +const std::string LSTM_OUT_MAP = "lstm_out_map"; +const std::string LSTM_OUT_MODE = "lstm_out_mode"; +const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; +const std::string LSTM_TIME_MAJOR = "lstm_time_major"; +const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; + +// Upsample +const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; + +// PadV2 +const std::string PADV2_ATTR_NAME_MODE = "mode"; +const std::string PADV2_ATTR_NAME_PADS = "paddings"; +const std::string PADV2_ATTR_NAME_T = "T"; +const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; +const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; + +// MirrorPad +const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; +const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; +const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; +const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; +// Filler +const std::string FILLER_TYPE = "filler_type"; +const std::string FILLER_VALUE = "filler_value"; + +// Shufflechannel +const std::string SHUFFLE_CHANNEL_GROUP = "group"; + +// TopKV2 +const std::string TOPKV2_ATTR_K = "k"; + +// Calibaration +const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; +const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; +const std::string PAD_TOP_INDEX = "PAD_TOP_INDEX"; +const std::string PAD_BOTTOM_INDEX = "PAD_BOTTOM_INDEX"; +const std::string PAD_RIGHT_INDEX = "PAD_RIGHT_INDEX"; +const std::string PAD_LEFT_INDEX = "PAD_LEFT_INDEX"; +const std::string QUANTIZE_ALGO_ATTR = "quantize_algo"; +const std::string SCALE_TYPE_ATTR = "scale_type"; + +const std::string QUANTIZE_SCALE_MODE = "quantize_scale_mode"; +const std::string QUANTIZE_SCALE_VALUE = "quantize_scale_value"; +const std::string QUANTIZE_SCALE_OFFSET = "quantize_scale_offset"; +const std::string QUANTIZE_OFFSET_DATA_VALUE = "quantize_offset_data_value"; +const std::string QUANTIZE_OFFSET_DATA_OFFSET = "quantize_offset_data_offset"; +const std::string QUANTIZE_OFFSET_WEIGHT_VALUE = "quantize_offset_weight_value"; +const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET = "quantize_offset_weight_offset"; +const std::string QUANTIZE_OFFSET_PAD_VALUE = "quantize_offset_pad_value"; +const std::string QUANTIZE_OFFSET_PAD_OFFSET = "quantize_offset_pad_offset"; + +const std::string DEQUANTIZE_SCALE_MODE = "dequantize_scale_mode"; +const std::string DEQUANTIZE_SCALE_VALUE = "dequantize_scale_value"; +const std::string DEQUANTIZE_SCALE_OFFSET = "dequantize_scale_offset"; +const std::string DEQUANTIZE_OFFSET_DATA_TYPE = "dequantize_offset_data_value"; +const std::string DEQUANTIZE_OFFSET_DATA_OFFSET = "dequantize_offset_data_offset"; +const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE = "dequantize_offset_weight_value"; +const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET = "dequantize_offset_weight_offset"; +const std::string DEQUANTIZE_OFFSET_PAD_VALUE = "dequantize_offset_pad_value"; +const std::string DEQUANTIZE_OFFSET_PAD_OFFSET = "dequantize_offset_pad_offset"; + +const std::string REQUANTIZE_SCALE_MODE = "requantize_scale_mode"; +const std::string REQUANTIZE_SCALE_VALUE = "requantize_scale_value"; +const std::string REQUANTIZE_SCALE_OFFSET = "requantize_scale_offset"; +const std::string REQUANTIZE_OFFSET_DATA_VALUE = "requantize_offset_data_value"; +const std::string REQUANTIZE_OFFSET_DATA_OFFSET = "requantize_offset_data_offset"; +const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE = "requantize_offset_weight_value"; +const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET = "requantize_offset_weight_offset"; +const std::string REQUANTIZE_OFFSET_PAD_VALUE = "requantize_offset_pad_value"; +const std::string REQUANTIZE_OFFSET_PAD_OFFSET = "requantize_offset_pad_offset"; + +const std::string ATTR_NAME_IS_CONST = "attr_name_is_const"; + +const std::string ATTR_NAME_GROUP = "group"; +const std::string ATTR_NAME_DILATION_SIZE = "dilation_size"; +const std::string ATTR_NAME_EPSILON = "epsilon"; +const std::string ATTR_NAME_POOLING_MODE = "mode"; +const std::string ATTR_NAME_CLASS_NUM = "class_num"; +// model +const std::string ATTR_MODEL_TARGET_TYPE = "target_type"; + +const std::string ATTR_MODEL_STREAM_NUM = "stream_num"; + +const std::string ATTR_MODEL_EVENT_NUM = "event_num"; + +const std::string ATTR_MODEL_HUGE_STREAM_LIST = "huge_stream_list"; + +const std::string ATTR_MODEL_LABEL_NUM = "label_num"; + +const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; + +const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size"; + +const std::string ATTR_MODEL_OUT_NODES_NAME = "attr_model_out_nodes_name"; + +const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; + +const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; + +const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR = "task_gen_weight_addr"; + +const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR = "task_gen_variable_addr"; + +const std::string ATTR_MODEL_VAR_SIZE = "variable_size"; + +const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; + +const std::string ATTR_MODEL_CORE_TYPE = "core_type"; + +const std::string ATTR_MODEL_ATC_VERSION = "atc_version"; + +const std::string ATTR_MODEL_OPP_VERSION = "opp_version"; + +// Public attribute +const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; + +const std::string ATTR_NAME_BYTE_SIZE = "op_byte_size"; + +const std::string ATTR_NAME_FUSION_INFERENCE_ID = "fusion_inference_id"; + +const std::string ATTR_NAME_FUSION_OPDEF = "fusion_opdef"; + +const std::string ATTR_NAME_IO_OP = "io_op"; + +const std::string ATTR_NAME_FUSION_SCOPE = "fusion_scope"; + +const std::string ATTR_NAME_OPATTR = "opattr"; + +const std::string ATTR_NAME_RELUFLAG = "relu_flag"; + +const std::string ATTR_NAME_SEQLEN_INDEX = "seqlen_index"; + +const std::string ATTR_NAME_X_INDEX = "x_index"; + +const std::string ATTR_NAME_CONT_INDEX = "cont_index"; + +const std::string ATTR_NAME_XSTATIC_INDEX = "xstatic_index"; + +const std::string TARGET_TYPE_MINI = "MINI"; + +const std::string TARGET_TYPE_TINY = "TINY"; + +const std::string TARGET_TYPE_LITE = "LITE"; + +// l2_normalize +const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; +const std::string L2_NORMALIZE_ATTR_EPS = "eps"; + +const std::string POOL_PARAMA_ATTR_WINDOW = "window"; +const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; +const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; +const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; +const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; +const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; + +// HCOM +const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; +const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; + +const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; +const std::string HCOM_ATTR_GROUP = "group"; +const std::string HCOM_ATTR_SR_TAG = "sr_tag"; +const std::string HCOM_ATTR_SRC_RANK = "src_rank"; +const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; +const std::string HCOM_ATTR_FUSION = "fusion"; +const std::string HCOM_ATTR_SHAPE = "shape"; +const std::string HCOM_ATTR_DATA_TYPE = "dtype"; + +// SpaceToDepth/DepthToSpace +const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; + +// SparseSoftmaxCrossEntropyWithLogits +const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; + +// MaxPoolGradWithArgmax +const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; + +// AvgPoolGrad +const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; + +// Pad +const std::string ATTR_PAD_FORMAT = "attr_pad_format"; + +// Varible +const std::string VAR_ATTR_FORMAT = "_var_format"; +const std::string VAR_ATTR_NAME = "var_name"; +const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; +const std::string VAR_ATTR_4D_FORMAT = "4D"; +const std::string VAR_ATTR_5D_FORMAT = "5D"; +const std::string VAR_ATTR_DATA_TYPE = "data_format"; +const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; +const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; +const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; +const std::string VAR_ATTR_SHAPE = "shape"; +const std::string HALF_VAR_NAME_END = "_fp16"; +const std::string VAR_ATTR_INITED = "var_is_inited"; + +const std::string VAR_ATTR_CONTAINER = "container"; +const std::string VAR_ATTR_SHARED_NAME = "shared_name"; +const std::string VAR_ATTR_DTYPE = "dtype"; + +const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; +const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; +const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; +const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; +const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; +const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; + +// Assign +const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; +const std::string ASSIGN_VAR_NAME = "_assign_var_name"; + +// space2bacth batch2space +const std::string BATCH_SPACE_ATTR_BLOCK = "block"; +const std::string BATCH_SPACE_ATTR_PADDING = "padding"; + +// depth_to_space space_to_depth +const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; + +// FakeQuantWithMinMaxVars +const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; +const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; + +// mobilenet_ssd_conv_fusion +const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; +const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; +const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; + +// lsh project +const std::string LSH_PROJ_TYPE = "lsh_project_type"; + +// log time stamp +const std::string LOG_TIME_STAMP_LOGID = "logid"; +const std::string LOG_TIME_STAMP_NOTIFY = "notify"; + +// ShapeN +const std::string SHAPEN_ATTR_N = "N"; +const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; +const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; + +// GatherV2 attr def +const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; +const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; +const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; + +// Reshape attr def +const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; +const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; + +// axis attr def +const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; + +const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; + +const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; +const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; + +// For constant folding +const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; + +const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; + +const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC = "continuous_input_alloc"; + +const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; + +const std::string ATTR_NAME_REFERENCE = "reference"; + +const std::string ATTR_NAME_NOTASK = "_no_task"; + +const std::string ATTR_NAME_OUTPUT_REUSE_INPUT = "_output_reuse_input"; + +const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX = "_reuse_input_on_dim_index"; + +const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT = "_no_padding_continuous_input"; + +const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT = "_no_padding_continuous_output"; + +const std::string ATTR_NAME_ATOMIC_INDEX = "atomic_index"; + +// Used for mark the active label list stream of activated node +const std::string ATTR_NAME_ACTIVE_LABEL_LIST = "_active_label_list"; + +// Used for l2cache, true: the memory of all inputs is used for the last time. +const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE = "is_end_of_inputmem_lifecycle"; + +// Multi batch +const std::string ATTR_NAME_PRED_VALUE = "_pred_value"; +const std::string ATTR_NAME_BATCH_NUM = "_batch_num"; +const std::string ATTR_NAME_BATCH_LABEL = "_batch_label"; +const std::string ATTR_NAME_COMBINED_BATCH = "_combined_batch"; + +// Control flow +const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; +const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; +const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; +const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; +const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; +const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; +const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE = "subgraph_first_active"; +const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS = "combined_dynamic_dims"; + +const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; +const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; +const std::string ATTR_NAME_SWITCH_DATA_TYPE = "_switch_data_type"; +const std::string ATTR_NAME_ORIG_NODE_NAME = "_original_node_name"; +const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG = "_cyclic_dependence_flag"; + +const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; + +// Function Op +const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; + +// Used for mark the active node is for loop, type:bool +const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; + +const std::string ATTR_NAME_MEMORY_TYPE_INPUT = "memory_type_input"; + +const std::string ATTR_NAME_MEMORY_TYPE_OUTPUT = "memory_type_output"; + +const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; + +const std::string ATTR_NAME_MEMORY_TYPE_RANGE = "_memory_type_range"; + +const std::string MODEL_ATTR_SESSION_ID = "session_id"; + +// lx fusion +const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; +const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; +const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; +const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; +const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; +const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST = "_input_memory_type"; +const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST = "_output_memory_type"; +const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR = "_l1_fusion_extend_content"; +const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE = "_tensor_actual_size"; +const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1_fuison"; +const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; +const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; +const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; +const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; +const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; +const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; +const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; +const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag"; +const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr"; +const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size"; +const std::string ATTR_NAME_ENGINE_NAME_FOR_LX = "_lxfusion_engine_name"; +const std::string ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX = "_lxfusion_op_kernel_lib_name"; +const std::string ATTR_NAME_NEED_LX_FUSION = "_lx_fusion"; +const std::string ATTR_NAME_OPTIMIZE_GROUP = "_optimize_group"; +const std::string ATTR_NAME_OP_COMPILE_STRATEGY = "_op_compile_strategy"; +const std::string ATTR_NAME_TBE_KERNEL_NAME = "_tbe_kernel_name"; +const std::string ATTR_NAME_TBE_KERNEL_BUFFER = "_tbe_kernel_buffer"; + +// Op debug attrs +const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; +const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode"; + +// Atomic addr clean attrs +const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; +const std::string ATOMIC_ATTR_OUTPUT_INDEX = "atomic_output_index"; +const std::string ATOMIC_ATTR_IS_FUSION_NODE = "is_fusion_node"; +const std::string EXT_ATTR_ATOMIC_WORKSPACE_INFO = "sub_node_workspace_info"; +const std::string EXT_ATTR_ATOMIC_WORKSPACE_OFFSET = "sub_node_workspace_offset"; +const std::string ATOMIC_ATTR_IS_ATOMIC_NODE = "is_atomic_node"; + +// Source/dst format for Op FormatTransfer +const std::string FORMAT_TRANSFER_SRC_FORMAT = "src_format"; +const std::string FORMAT_TRANSFER_DST_FORMAT = "dst_format"; + +// For compile op by ge call +const std::string ATTR_NEED_COMPILE = "_node_need_compile"; + +const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; + +const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS = "_mbatch_origin_input_dims"; + +const std::string ATTR_DYNAMIC_TYPE = "mbatch_dynamic_type"; + +const std::string ATTR_USER_DESIGNEATE_SHAPE_ORDER = "user_designate_shape_order"; + +// For inserted op +const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; + +// For compress weight +const std::string ATTR_NAME_COMPRESS_WEIGHT = "_is_compress_weight"; + +// For data dump +const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES = "_datadump_original_op_names"; +const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP = "_datadump_is_multiop"; +const std::string ATTR_NAME_DATA_DUMP_SUB_SPLITER_INDEX = "_datadump_sub_spliter_index"; +const std::string ATTR_NAME_DATA_DUMP_GROUP_OP_NAME = "_datadump_group_op_name"; +const std::string ATTR_NAME_DATA_DUMP_ORIGIN_NAME = "_datadump_origin_name"; +const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_output_index"; +const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; +const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; + +// functional ops attr +const std::string ATTR_NAME_IF_THEN_BRANCH = "then_branch"; +const std::string ATTR_NAME_IF_ELSE_BRANCH = "else_branch"; +const std::string ATTR_NAME_WHILE_COND = "cond"; +const std::string ATTR_NAME_WHILE_BODY = "body"; + +// used for label switch +const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; +const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; +const std::string ATTR_NAME_SUBGRAPH_END_NODE = "_subgraph_end_node"; + +const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; +const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; + +// used for LX tiling +const std::string ATTR_NAME_OP_L1_SPACE = "_l1_space"; +const std::string ATTR_NAME_FUSION_TYPE_LIST = "_fusion_type_list"; +const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST = "_valid_input_shape_list_list"; +const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_list_list"; +const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; +const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_output_offset_list_list"; + +// for unregistered op +const std::string ATTR_NAME_UNREGST_OPPATH = "_unregst_oppath"; +const std::string ATTR_NAME_UNREGST_ATTRLIST = "_unregst_attrlist"; + +// used for Horovod +const std::string ATTR_INTER_EVENT_IDENTIFY = "event_id"; +const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op"; +// used for allreduce tailing optimization +const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group"; +const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node"; + +// dynamic shape attr +const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR = "_alloc_fixed_addr"; +const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX = "_alloc_fixed_addr_index"; + +// op dynamic input +const std::string ATTR_NAME_DYNAMIC_INPUT_START = "_dynamic_input_index_start"; +const std::string ATTR_NAME_DYNAMIC_INPUT_END = "_dynamic_input_index_end"; + +// atc user def dtype&format +const std::string ATTR_ATC_USER_DEFINE_DATATYPE = "_user_defined_data_type"; +const std::string ATTR_ATC_USER_DEFINE_FORMAT = "_user_defined_format"; + +// for fusion op plugin +const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; + +// graph partition for aicpu +const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME = "pld_front_node_engine_name"; +const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME = "end_rear_node_engine_name"; + +// input and output memory type +const std::string ATTR_VARIABLE_PLACEMENT = "_variable_placement"; +const std::string ATTR_INPUT_MEMORY_TYPE = "_input_memory_type"; +const std::string ATTR_OUTPUT_MEMORY_TYPE = "_output_memory_type"; + +// input_output_offset +const std::string ATTR_ZERO_COPY_BASIC_OFFSET = "_zero_copy_basic_offset"; +const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET = "_zero_copy_relative_offset"; +} // namespace ge diff --git a/src/common/graph/ge_attr_value.cc b/src/common/graph/ge_attr_value.cc new file mode 100644 index 00000000..a8490470 --- /dev/null +++ b/src/common/graph/ge_attr_value.cc @@ -0,0 +1,1289 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/ge_attr_value.h" +#include "graph/ge_tensor.h" +#include "external/graph/graph.h" +#include "utils/attr_utils.h" +#include "framework/common/debug/ge_log.h" +#include "graph/model_serialize.h" +#include "proto/ge_ir.pb.h" +#include "detail/model_serialize_imp.h" +#include "debug/ge_attr_define.h" +#include "debug/ge_log.h" +#include "debug/ge_util.h" + +using std::map; +using std::string; +using std::vector; + +namespace ge { +NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } + +NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) + : named_attrs_(owner, proto_msg) {} // lint !e1744 + +void NamedAttrs::SetName(const std::string &name) { + auto proto_msg = named_attrs_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->set_name(name); + } +} + +string NamedAttrs::GetName() const { + auto proto_msg = named_attrs_.GetProtoMsg(); + if (proto_msg != nullptr) { + return proto_msg->name(); + } + return string(); +} + +GeAttrValue NamedAttrs::GetItem(const string &key) const { + GeAttrValue value; + (void)GetAttr(key, value); + return value; +} + +ProtoAttrMapHelper NamedAttrs::MutableAttrMap() { + auto proto_msg = named_attrs_.GetProtoMsg(); + if (proto_msg != nullptr) { + return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), proto_msg->mutable_attr()); + } + return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); +} + +ConstProtoAttrMapHelper NamedAttrs::GetAttrMap() const { + auto proto_msg = named_attrs_.GetProtoMsg(); + if (proto_msg != nullptr) { + return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), &proto_msg->attr()); + } + return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); +} + +class GeAttrValueImp { + public: + static map attr_val_one_type_map_; + static map attr_val_list_type_map_; + + static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::INT val); + static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::FLOAT val); + static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::BOOL val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::STR &val); + static bool SetValue(proto::AttrDef &attr_def, const ConstGeTensorPtr &val); + static bool SetValue(proto::AttrDef &attr_def, const GeTensor &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::TENSOR_DESC &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::BYTES &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::NAMED_ATTRS &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::GRAPH &val); + static bool SetValue(proto::AttrDef &attr_def, const vector &val); + static bool SetValue(proto::AttrDef &attr_def, const vector &val); + static bool SetValue(proto::AttrDef &attr_def, const vector &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_FLOAT &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BOOL &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_STR &val); + static bool SetValue(proto::AttrDef &proto_attr_val, const vector &value); + static bool SetValue(proto::AttrDef &proto_attr_val, const vector &value); + static bool SetValue(proto::AttrDef &attr_def, const vector &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_TENSOR_DESC &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BYTES &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_NAMED_ATTRS &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_GRAPH &val); + + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::INT &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::FLOAT &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BOOL &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::STR &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::TENSOR &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensor &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::TENSOR_DESC &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BYTES &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::NAMED_ATTRS &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::GRAPH &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_INT &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_FLOAT &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_BOOL &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_STR &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_TENSOR &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, vector &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_TENSOR_DESC &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_BYTES &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_NAMED_ATTRS &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_GRAPH &val); + // Value will be moved + static bool SetZeroCopyBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &&buffer); + static bool GetZeroCopyBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &buffer); + // Value will be moved + static bool SetZeroCopyListBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + vector &list_buffer); + static bool GetZeroCopyListBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + vector &list_buffer); + + static bool SetValue(proto::AttrDef &attr_def, const vector> &value); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + vector> &value); + static bool SetValue(proto::AttrDef &attr_def, const vector &value); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + vector &value); + + static bool SetValue(proto::AttrDef &attr_def, const ge::DataType &value); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, ge::DataType &value); +}; + +map GeAttrValueImp::attr_val_one_type_map_ = { + {proto::AttrDef::kI, GeAttrValue::VT_INT}, + {proto::AttrDef::kF, GeAttrValue::VT_FLOAT}, + {proto::AttrDef::kB, GeAttrValue::VT_BOOL}, + {proto::AttrDef::kS, GeAttrValue::VT_STRING}, + {proto::AttrDef::kT, GeAttrValue::VT_TENSOR}, + {proto::AttrDef::kTd, GeAttrValue::VT_TENSOR_DESC}, + {proto::AttrDef::kG, GeAttrValue::VT_GRAPH}, + {proto::AttrDef::kBt, GeAttrValue::VT_BYTES}, + {proto::AttrDef::kFunc, GeAttrValue::VT_NAMED_ATTRS}, + {proto::AttrDef::kListListInt, GeAttrValue::VT_LIST_LIST_INT}, + {proto::AttrDef::kDt, GeAttrValue::VT_DATA_TYPE}, +}; +map GeAttrValueImp::attr_val_list_type_map_ = { + {proto::AttrDef_ListValue_ListValueType_VT_LIST_INT, GeAttrValue::VT_LIST_INT}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT, GeAttrValue::VT_LIST_FLOAT}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_BOOL, GeAttrValue::VT_LIST_BOOL}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_STRING, GeAttrValue::VT_LIST_STRING}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR, GeAttrValue::VT_LIST_TENSOR}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, GeAttrValue::VT_LIST_TENSOR_DESC}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH, GeAttrValue::VT_LIST_GRAPH}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, GeAttrValue::VT_LIST_BYTES}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, GeAttrValue::VT_LIST_NAMED_ATTRS}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE, GeAttrValue::VT_LIST_DATA_TYPE}, +}; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue::GeAttrValue() { value_.InitDefault(); } + +GeAttrValue::GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val) : value_(proto_owner, val) {} + +GeAttrValue::ValueType GeAttrValue::GetValueType() const { + auto proto_msg = value_.GetProtoMsg(); + if (proto_msg != nullptr) { + auto val_case = proto_msg->value_case(); + if (val_case != proto::AttrDef::kList) { + auto it = GeAttrValueImp::attr_val_one_type_map_.find(val_case); + if (it != GeAttrValueImp::attr_val_one_type_map_.end()) { + return it->second; + } + } else { + auto it = GeAttrValueImp::attr_val_list_type_map_.find(proto_msg->list().val_type()); + if (it != GeAttrValueImp::attr_val_list_type_map_.end()) { + return it->second; + } + } + } + return GeAttrValue::VT_NONE; +} + +bool GeAttrValue::IsEmpty() const { return GetValueType() == VT_NONE; } + +GeAttrValue GeAttrValue::Copy() const { + GeAttrValue valueRet; + auto proto_msg = value_.GetProtoMsg(); + auto proto_msg_ret = valueRet.value_.GetProtoMsg(); + if (proto_msg != nullptr && proto_msg_ret != nullptr) { + *proto_msg_ret = *proto_msg; + } + return valueRet; +} + +#define ATTR_VALUE_SET_GET_IMP(type) \ + graphStatus GeAttrValue::SetValue(const type &val) { \ + auto proto_msg = value_.GetProtoMsg(); \ + if (proto_msg) { \ + if (GeAttrValueImp::SetValue(*proto_msg, val)) { \ + return GRAPH_SUCCESS; \ + } \ + } \ + return GRAPH_FAILED; \ + } \ + \ + graphStatus GeAttrValue::GetValue(type &val) const { \ + auto proto_msg = value_.GetProtoMsg(); \ + if (proto_msg) { \ + if (GeAttrValueImp::GetValue(*proto_msg, value_.GetProtoOwner(), val)) { \ + return GRAPH_SUCCESS; \ + } \ + } \ + return GRAPH_FAILED; \ + } + +ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524 +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR_DESC) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::GRAPH) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS) +ATTR_VALUE_SET_GET_IMP(vector) +/*lint -e665*/ +ATTR_VALUE_SET_GET_IMP(vector>) +/*lint +e665*/ +ATTR_VALUE_SET_GET_IMP(vector) // lint !e665 +ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665 + +#undef ATTR_VALUE_SET_GET_IMP + +graphStatus GeAttrValue::MutableTensor(GeTensorPtr &tensor) { return GetValue(tensor); } + +graphStatus GeAttrValue::MutableListTensor(vector &list_tensor) { return GetValue(list_tensor); } + +class AttrUtilsHelper { + public: + inline static bool GetValueCheckType(const proto::AttrDef &attr_def, proto::AttrDef::ValueCase proto_case) { + if (attr_def.value_case() != proto_case) { + GELOGW("Check Type Failed, proto case type %u, expected %u", attr_def.value_case(), proto_case); + return false; + } + return true; + } + + inline static bool GetValueCheckListType( + const proto::AttrDef &attr_def, proto::AttrDef_ListValue_ListValueType proto_list_case, + const std::function item_check_fun) { + if (attr_def.value_case() != proto::AttrDef::kList) { + GELOGW("Check ListType Failed, value_case %u", attr_def.value_case()); + return false; + } + auto &list = attr_def.list(); + if (list.val_type() == proto::AttrDef_ListValue_ListValueType_VT_LIST_NONE) { + return item_check_fun(attr_def); + } + if (list.val_type() != proto_list_case) { + GELOGW("Check ListType Failed, val_type %u, expected %u", list.val_type(), proto_list_case); + return false; + } + return true; + } + + inline static bool SetValueCheckType(proto::AttrDef &attr_def, proto::AttrDef::ValueCase proto_case) { + if (attr_def.value_case() != proto::AttrDef::VALUE_NOT_SET && attr_def.value_case() != proto_case) { + GELOGW("Check Type Failed, proto case type %u, expected %u", attr_def.value_case(), proto_case); + return false; + } + return true; + } + + inline static bool SetValueCheckAndSetListType(proto::AttrDef &attr_def, + proto::AttrDef_ListValue_ListValueType proto_list_case) { + if (attr_def.value_case() != proto::AttrDef::VALUE_NOT_SET && attr_def.value_case() != proto::AttrDef::kList) { + GELOGW("AttrUtils::Check Type Failed, value_case %u", attr_def.value_case()); + return false; + } + auto list = attr_def.mutable_list(); + if (list == nullptr) { + GELOGE(GRAPH_FAILED, "list is nullptr"); + return false; + } + if (list->val_type() != proto::AttrDef_ListValue_ListValueType_VT_LIST_NONE && + list->val_type() != proto_list_case) { + GELOGW("AttrUtils::Check ListType Type Failed, val_type %d, expected %d", static_cast(list->val_type()), + static_cast(proto_list_case)); + return false; + } + list->set_val_type(proto_list_case); + return true; + } + + static bool GetAttrMapItem(const AttrHolder *obj, const string &name, const proto::AttrDef *&attr_def) { + if (obj == nullptr) { + GELOGE(FAILED, "%s obj is nullptr", name.c_str()); + return false; + } + auto attr_map = obj->GetAttrMap().GetProtoMsg(); + if (attr_map == nullptr) { + GELOGE(FAILED, "%s attr map is nullptr", name.c_str()); + return false; + } + auto it = attr_map->find(name); + if (it == attr_map->end()) { + return false; + } + attr_def = &it->second; + return true; + } + + inline static bool MutableAttrMapItem(AttrHolder *obj, const string &name, proto::AttrDef *&attr_def) { + if (obj == nullptr) { + GELOGE(FAILED, " %s obj is nullptr", name.c_str()); + return false; + } + auto attr_map = obj->MutableAttrMap().GetProtoMsg(); + if (attr_map == nullptr) { + GELOGE(FAILED, "%s attr map is nullptr", name.c_str()); + return false; + } + // Get or add + attr_def = &((*attr_map)[name]); + return true; + } +}; + +#define ATTR_VALUE_IMP_SET_ONE(ValType, proto_case, protoItem) \ + bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, ValType value) { \ + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::proto_case)) { \ + return false; \ + } \ + proto_attr_val.set_##protoItem(value); \ + return true; \ + } + +#define ATTR_VALUE_IMP_SET_LIST(ValType, proto_list_case, protoItem) \ + bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, ValType value) { \ + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, \ + proto::AttrDef_ListValue_ListValueType_##proto_list_case)) { \ + return false; \ + } \ + auto list = proto_attr_val.mutable_list(); \ + list->clear_##protoItem(); \ + for (const auto &item : value) { \ + list->add_##protoItem(item); \ + } \ + return true; \ + } + +ATTR_VALUE_IMP_SET_ONE(int64_t, kI, i) +ATTR_VALUE_IMP_SET_ONE(float, kF, f) +ATTR_VALUE_IMP_SET_ONE(const string &, kS, s) +ATTR_VALUE_IMP_SET_ONE(bool, kB, b) + +ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) +ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) +ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) +ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_FLOAT, f) +ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_STRING, s) +ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_BOOL, b) + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeTensorDesc &value) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kTd)) { + return false; + } + auto proto_msg = value.tensor_descriptor_.GetProtoMsg(); + if (proto_msg == nullptr) { + return false; + } + *proto_attr_val.mutable_td() = *proto_msg; + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_td(); + for (const auto &item : value) { + auto proto_msg = item.tensor_descriptor_.GetProtoMsg(); + if (proto_msg == nullptr) { + proto_attr_val.clear_list(); + return false; + } + *list->add_td() = *proto_msg; + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ConstGeTensorPtr &value) { + if (value) { + return SetValue(proto_attr_val, *value); + } else { + return SetValue(proto_attr_val, GeTensor()); + } +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeTensor &val) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kT)) { + return false; + } + auto proto_msg = val.tensor_def_.GetProtoMsg(); + if (proto_msg == nullptr) { + GELOGE(FAILED, "Proto msg is nullptr"); + return false; + } + *proto_attr_val.mutable_t() = *proto_msg; + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + vector constList(value.size()); + std::copy(value.begin(), value.end(), constList.begin()); + return SetValue(proto_attr_val, constList); +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_t(); + for (const auto &item : value) { + if (item == nullptr) { + GELOGE(GRAPH_FAILED, "AttrUtils::SetListTensor item is nullptr"); + proto_attr_val.clear_list(); + return false; + } + auto proto_msg = item->tensor_def_.GetProtoMsg(); + if (proto_msg == nullptr) { + GELOGE(FAILED, "Proto msg is nullptr"); + proto_attr_val.clear_list(); + return false; + } + *list->add_t() = *proto_msg; + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_t(); + for (const auto &item : value) { + auto proto_msg = item.tensor_def_.GetProtoMsg(); + if (proto_msg == nullptr) { + GELOGE(FAILED, "Proto msg is nullptr"); + proto_attr_val.clear_list(); + return false; + } + *list->add_t() = *proto_msg; + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::BYTES &value) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { + return false; + } + size_t val_size = value.GetSize(); + proto_attr_val.set_bt(value.GetData(), val_size); + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_bt(); + for (const auto &item : value) { + list->add_bt(item.GetData(), item.GetSize()); + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NAMED_ATTRS &value) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { + return false; + } + auto proto_msg = value.named_attrs_.GetProtoMsg(); + if (proto_msg == nullptr) { + GELOGE(FAILED, "Proto msg is nullptr"); + return false; + } + *proto_attr_val.mutable_func() = *proto_msg; + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_na(); + for (const auto &item : value) { + auto proto_msg = item.named_attrs_.GetProtoMsg(); + if (proto_msg == nullptr) { + proto_attr_val.clear_list(); + return false; + } + *list->add_na() = *proto_msg; + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::ComputeGraphPtr &value) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kG)) { + return false; + } + ModelSerializeImp imp; + if (!imp.SerializeGraph(value, proto_attr_val.mutable_g())) { + GELOGE(GRAPH_FAILED, "AttrUtils::SetGraph SerializeGraph Failed"); + proto_attr_val.clear_g(); + return false; + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_g(); + + ModelSerializeImp imp; + for (const auto &item : value) { + if (!imp.SerializeGraph(item, list->add_g())) { + GELOGE(GRAPH_FAILED, "AttrUtils::SetListGraph SerializeGraph"); + proto_attr_val.clear_list(); + return false; + } + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector> &value) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kListListInt)) { + return false; + } + proto_attr_val.clear_list_list_int(); + auto list_list_int = proto_attr_val.mutable_list_list_int(); + GE_CHECK_NOTNULL_EXEC(list_list_int, return false); + for (auto &list_int : value) { + auto list_item = list_list_int->add_list_list_i(); + GE_CHECK_NOTNULL_EXEC(list_item, return false); + for (auto &int_item : list_int) { + list_item->add_list_i(int_item); + } + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_dt(); + for (const auto &item : value) { + list->add_dt(static_cast(item)); + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::DataType &value) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kDt)) { + return false; + } + proto_attr_val.set_dt(static_cast(value)); + + return true; +} + +#define ATTR_VALUE_IMP_GET_ONE(ValType, proto_case, protoItem) \ + bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ValType value) { \ + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::proto_case)) { \ + return false; \ + } \ + value = proto_attr_val.protoItem(); \ + return true; \ + } + +#define ListValueItemCheck(protoItem) \ + [](const proto::AttrDef &proto_attr_val) { return proto_attr_val.list().protoItem##_size() > 0; } + +#define ATTR_VALUE_IMP_GET_LIST(ValType, proto_list_case, protoItem) \ + bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, vector &value) { \ + value.clear(); \ + if (!AttrUtilsHelper::GetValueCheckListType( \ + proto_attr_val, proto::AttrDef_ListValue_ListValueType_##proto_list_case, ListValueItemCheck(protoItem))) { \ + return false; \ + } \ + auto &list = proto_attr_val.list(); \ + for (const auto &item : list.protoItem()) { \ + value.push_back(item); \ + } \ + return true; \ + } + +ATTR_VALUE_IMP_GET_ONE(int64_t &, kI, i) +ATTR_VALUE_IMP_GET_ONE(float &, kF, f) +ATTR_VALUE_IMP_GET_ONE(string &, kS, s) +ATTR_VALUE_IMP_GET_ONE(bool &, kB, b) + +ATTR_VALUE_IMP_GET_LIST(int64_t, VT_LIST_INT, i) +ATTR_VALUE_IMP_GET_LIST(float, VT_LIST_FLOAT, f) +ATTR_VALUE_IMP_GET_LIST(string, VT_LIST_STRING, s) +ATTR_VALUE_IMP_GET_LIST(bool, VT_LIST_BOOL, b) + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeTensorDesc &value) { + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kTd)) { + return false; + } + auto proto_msg = value.tensor_descriptor_.GetProtoMsg(); + if (proto_msg == nullptr) { + return false; + } + *proto_msg = proto_attr_val.td(); + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + vector &value) { + if (!AttrUtilsHelper::GetValueCheckListType( + proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, ListValueItemCheck(td))) { + return false; + } + auto &list = proto_attr_val.list(); + for (const auto &item : list.td()) { + value.emplace_back(GeTensorDesc()); + auto proto_msg = value.back().tensor_descriptor_.GetProtoMsg(); + if (proto_msg == nullptr) { + return false; + } + *proto_msg = item; + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, + GeTensorPtr &value) { + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kT)) { + return false; + } + value = std::shared_ptr(new (std::nothrow) + GeTensor(proto_owner, const_cast(proto_attr_val).mutable_t())); + GE_CHK_BOOL_RET_STATUS(value != nullptr, false, "value is nullptr"); + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, + vector &value) { + value.clear(); + if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR, + ListValueItemCheck(t))) { + return false; + } + auto list = const_cast(proto_attr_val).mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + for (auto &item : *(list->mutable_t())) { + std::shared_ptr temp_value = std::shared_ptr(new (std::nothrow) GeTensor(proto_owner, &item)); + GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); + value.push_back(temp_value); + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeAttrValue::BYTES &value) { + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { + return false; + } + auto &proto_val = proto_attr_val.bt(); + GE_LOGI_IF(proto_val.size() == 0, "size res is 0."); + value = Buffer::CopyFrom(reinterpret_cast(proto_val.data()), proto_val.size()); + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + vector &value) { + value.clear(); + if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, + ListValueItemCheck(bt))) { + return false; + } + auto &list = proto_attr_val.list(); + for (const auto &item : list.bt()) { + value.push_back(Buffer::CopyFrom((const uint8_t *)item.data(), item.size())); + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + GeAttrValue::NAMED_ATTRS &value) { + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { + return false; + } + auto proto_msg = value.named_attrs_.GetProtoMsg(); + if (proto_msg == nullptr) { + return false; + } + *proto_msg = proto_attr_val.func(); + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + vector &value) { + value.clear(); + if (!AttrUtilsHelper::GetValueCheckListType( + proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) { + return false; + } + auto &list = proto_attr_val.list(); + for (const auto &item : list.na()) { + value.emplace_back(GeAttrValue::NAMED_ATTRS()); + if (value.empty()) { + return false; + } + auto proto_msg = value.back().named_attrs_.GetProtoMsg(); + if (proto_msg == nullptr) { + return false; + } + *proto_msg = item; + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ComputeGraphPtr &value) { + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kG)) { + return false; + } + ComputeGraphPtr graph = nullptr; + std::shared_ptr graph_def; + graph_def = ComGraphMakeShared(proto_attr_val.g()); + if (graph_def == nullptr) { + GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); + graph_def = nullptr; + return false; // lint !e665 + } else { + ModelSerializeImp imp; + imp.SetProtobufOwner(graph_def); + if (!imp.UnserializeGraph(graph, *graph_def)) { + GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); + return false; + } // lint !e514 + value = graph; + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + vector &value) { + value.clear(); + if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH, + ListValueItemCheck(g))) { + return false; + } + auto &list = proto_attr_val.list(); + for (const auto &item : list.g()) { + std::shared_ptr graph_def; + graph_def = ComGraphMakeShared(item); + if (graph_def == nullptr) { + GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); + graph_def = nullptr; + return false; // lint !e665 + } else { + ComputeGraphPtr graph = nullptr; + ModelSerializeImp imp; + imp.SetProtobufOwner(graph_def); + if (!imp.UnserializeGraph(graph, *graph_def)) { + GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); + return false; + } // lint !e514 + value.push_back(graph); + } + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + vector> &value) { + value.clear(); + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kListListInt)) { + return false; + } + + auto &list_listint = proto_attr_val.list_list_int().list_list_i(); + for (auto &list_int : list_listint) { + vector list_item(list_int.list_i().size()); + if (!list_int.list_i().empty()) { + (void)std::copy(list_int.list_i().begin(), list_int.list_i().end(), list_item.begin()); + } + value.push_back(list_item); + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + vector &value) { + if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE, + ListValueItemCheck(dt))) { + return false; + } + auto &list = proto_attr_val.list(); + for (const auto &item : list.dt()) { + value.emplace_back(static_cast(item)); + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ge::DataType &value) { + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kDt)) { + return false; + } + value = static_cast(proto_attr_val.dt()); + return true; +} + +GE_FUNC_HOST_VISIBILITY bool GeAttrValueImp::SetZeroCopyBytes(proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + Buffer &&buffer) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { + return false; + } + auto proto_msg = buffer.data_.GetProtoMsg(); + if (proto_msg == nullptr) { + return false; + } + proto_attr_val.set_bt(std::move(*proto_msg->mutable_bt())); + return true; +} + +bool GeAttrValueImp::GetZeroCopyBytes(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, + Buffer &buffer) { + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { + return false; + } + buffer = Buffer(proto_owner, &const_cast(proto_attr_val)); + return true; +} + +bool GeAttrValueImp::SetZeroCopyListBytes(proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + vector &list_buffer) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_bt(); + for (auto &item : list_buffer) { + auto proto_msg = item.data_.GetProtoMsg(); + if (proto_msg == nullptr) { + return false; + } + list->add_bt(std::move(*proto_msg->mutable_bt())); + } + return true; +} + +bool GeAttrValueImp::GetZeroCopyListBytes(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, + vector &list_buffer) { + list_buffer.clear(); + if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, + ListValueItemCheck(bt))) { + return false; + } + auto list = const_cast(proto_attr_val).mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + for (auto &item : *(list->mutable_bt())) { + list_buffer.emplace_back(Buffer(proto_owner, &item)); + } + return true; +} + +bool AttrUtils::HasAttr(ConstAttrHolderAdapter &&obj, const string &name) { + if (!obj) { + return false; + } + return obj->HasAttr(name); +} + +#define ATTR_UTILS_SET_IMP(FuncName, Type) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Set##FuncName( \ + AttrHolderAdapter &&obj, const string &name, const Type &value) { \ + proto::AttrDef *proto_attr_val = nullptr; \ + if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \ + return false; \ + } \ + if (!GeAttrValueImp::SetValue(*proto_attr_val, value)) { \ + GELOGW("Set" #FuncName " failed key %s", name.c_str()); \ + return false; \ + } \ + return true; \ + } + +#define ATTR_UTILS_GET_IMP(FuncName, Type) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Get##FuncName(ConstAttrHolderAdapter &&obj, \ + const string &name, Type &value) { \ + const proto::AttrDef *proto_attr_val = nullptr; \ + if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \ + return false; \ + } \ + if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value)) { \ + GELOGW("Get" #FuncName " failed key %s", name.c_str()); \ + return false; \ + } \ + return true; \ + } + +#define ATTR_UTILS_SET_GET_IMP(FuncName, Type) \ + ATTR_UTILS_SET_IMP(FuncName, Type) \ + ATTR_UTILS_GET_IMP(FuncName, Type) + +ATTR_UTILS_SET_GET_IMP(Int, int64_t) +ATTR_UTILS_SET_GET_IMP(Float, float) +ATTR_UTILS_SET_GET_IMP(Bool, bool) +ATTR_UTILS_SET_GET_IMP(Str, string) +ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc) +ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr) +ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr) +ATTR_UTILS_SET_IMP(Tensor, GeTensor) +ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS) +ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) +ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) +/*lint -e665*/ +ATTR_UTILS_SET_GET_IMP(ListListInt, vector>) +/*lint +e665*/ + +ATTR_UTILS_SET_GET_IMP(ListInt, vector) +ATTR_UTILS_SET_IMP(ListInt, vector) +ATTR_UTILS_SET_IMP(ListInt, vector) +ATTR_UTILS_SET_GET_IMP(ListFloat, vector) +ATTR_UTILS_SET_GET_IMP(ListBool, vector) +ATTR_UTILS_SET_GET_IMP(ListStr, vector) +ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector) +ATTR_UTILS_SET_IMP(ListTensor, vector) +ATTR_UTILS_SET_IMP(ListTensor, vector) +ATTR_UTILS_SET_IMP(ListTensor, vector) +ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector) +ATTR_UTILS_SET_GET_IMP(ListBytes, vector) +ATTR_UTILS_SET_GET_IMP(ListGraph, vector) +ATTR_UTILS_SET_GET_IMP(ListDataType, vector) // lint !e665 +ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) // lint !e665 + +bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name, + std::initializer_list &&value) { + return SetListTensor(std::move(obj), name, vector(value)); +} + +bool AttrUtils::GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value) { + const proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + GeTensorPtr tensor; + if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), tensor)) { + return false; + } + value = tensor; + return true; +} + +bool AttrUtils::GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector &value) { + value.clear(); + const proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + vector tensor; + if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), tensor)) { + return false; + } + value.insert(value.begin(), tensor.begin(), tensor.end()); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::MutableTensor(AttrHolderAdapter &&obj, + const string &name, GeTensorPtr &value) { + const proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + return GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value); +} + +bool AttrUtils::MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector &value) { + value.clear(); + const proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + return GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value); +} + +bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list &&value) { + proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + return GeAttrValueImp::SetValue(*proto_attr_val, value); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name, + int32_t &value) { + int64_t int64_val = 0; + if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) { + return false; + } + if (int64_val > INT32_MAX) { + GELOGE(GRAPH_FAILED, "%ld int64_t value cannot cast to int32_t", int64_val); + return false; + } + value = static_cast(int64_val); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name, + uint32_t &value) { + int64_t int64_val = 0; + if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) { + return false; + } + if (int64_val > UINT32_MAX) { + GELOGE(GRAPH_FAILED, "%ld int64_t value cannot cast to uint32_t", int64_val); + return false; + } + value = static_cast(int64_val); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, + const string &name, vector &value) { + value.clear(); + vector int64_list; + if (!GetListInt(std::move(obj), name, int64_list)) { + return false; + } + + for (size_t i = 0; i < int64_list.size(); ++i) { + if (int64_list[i] > INT32_MAX) { + GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]); + return false; + } + } + value.insert(value.begin(), int64_list.begin(), int64_list.end()); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, + const string &name, vector &value) { + value.clear(); + vector int64_list; + if (!GetListInt(std::move(obj), name, int64_list)) { + return false; + } + + for (size_t i = 0; i < int64_list.size(); ++i) { + if (int64_list[i] > UINT32_MAX) { + GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]); + return false; + } + } + value.insert(value.begin(), int64_list.begin(), int64_list.end()); + return true; +} + +bool AttrUtils::SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value) { + if (obj) { + vector bytes_vals; + for (auto &item : value) { + ModelSerialize serialize; + auto buffer = serialize.SerializeOpDesc(item); + if (buffer.GetSize() == 0) { + return false; + } + bytes_vals.push_back(buffer); + } + return SetZeroCopyListBytes(std::move(obj), name, bytes_vals); + } + return false; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListOpDesc(AttrHolderAdapter &&obj, + const string &name, + const vector &value) { + if (obj) { + vector bytes_vals; + for (auto &item : value) { + ModelSerialize serialize; + auto buffer = serialize.SerializeOpDesc(item); + if (buffer.GetSize() == 0) { + return false; + } + bytes_vals.push_back(buffer); + } + return SetZeroCopyListBytes(std::move(obj), name, bytes_vals); + } + return false; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListOpDesc(ConstAttrHolderAdapter &&obj, + const string &name, + vector &value) { + value.clear(); + + vector bytes_vals; + if (!GetZeroCopyListBytes(std::move(obj), name, bytes_vals)) { + return false; + } + for (const auto &item : bytes_vals) { + ModelSerialize serialize; + auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); // lint !e732 + value.push_back(op_desc); + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetZeroCopyBytes(AttrHolderAdapter &&obj, + const string &name, Buffer &&buffer) { + // Value will be moved + proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + return GeAttrValueImp::SetZeroCopyBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), std::move(buffer)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, + const string &name, Buffer &buffer) { + const proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + return GeAttrValueImp::GetZeroCopyBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), buffer); +} + +bool AttrUtils::SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector &list_buffer) { + // Value will be moved + proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + return GeAttrValueImp::SetZeroCopyListBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), list_buffer); +} + +bool AttrUtils::GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &list_buffer) { + list_buffer.clear(); + const proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + return GeAttrValueImp::GetZeroCopyListBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), list_buffer); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) { + if (org_op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "org_op_desc is null"); + return nullptr; + } + std::shared_ptr op_def; + op_def = ComGraphMakeShared(); + if (op_def == nullptr) { + GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); + return nullptr; // lint !e665 + } + ModelSerializeImp imp; + (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); + + imp.SetProtobufOwner(op_def); + OpDescPtr op_desc = nullptr; + GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); + op_desc->extAttrs_ = org_op_desc->extAttrs_; + + // This function may be called by some passes of fusion engine, in this condition, do not need these attribute + if (!op_desc->input_name_idx_.empty()) { + op_desc->input_name_idx_.clear(); + } + if (!op_desc->output_name_idx_.empty()) { + op_desc->output_name_idx_.clear(); + } + if (!op_desc->optional_input_names_.empty()) { + op_desc->optional_input_names_.clear(); + } + + return op_desc; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) { + if (org_op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "org_op_desc is null"); + return nullptr; + } + std::shared_ptr op_def = ComGraphMakeShared(); + if (op_def == nullptr) { + GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); + return nullptr; + } + ModelSerializeImp imp; + (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); + + imp.SetProtobufOwner(op_def); + OpDescPtr op_desc = nullptr; + GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); + + op_desc->extAttrs_ = org_op_desc->extAttrs_; + + op_desc->input_name_idx_.insert(org_op_desc->input_name_idx_.begin(), org_op_desc->input_name_idx_.end()); + op_desc->optional_input_names_.insert(org_op_desc->optional_input_names_.begin(), + org_op_desc->optional_input_names_.end()); + op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); + + op_desc->infer_func_ = org_op_desc->infer_func_; + op_desc->infer_format_func_ = org_op_desc->infer_format_func_; + op_desc->verifier_func_ = org_op_desc->verifier_func_; + + return op_desc; +} +std::string AttrUtils::GetAllAttrsStr(AttrUtils::ConstAttrHolderAdapter &&obj) { + auto holder = obj.get(); + if (holder == nullptr) { + return ""; + } + auto attrs_map = holder->GetAttrMap(); + if (attrs_map.GetProtoMsg() == nullptr) { + return ""; + } + + std::map ordered_attrs; + for (auto &attr : *(attrs_map.GetProtoMsg())) { + ordered_attrs[attr.first] = attr.second.SerializeAsString(); + } + + std::stringstream ss; + for (auto &attr : ordered_attrs) { + ss << attr.first << ":" << attr.second << ";"; + } + return ss.str(); +} +} // namespace ge diff --git a/src/common/graph/ge_tensor.cc b/src/common/graph/ge_tensor.cc new file mode 100644 index 00000000..65881435 --- /dev/null +++ b/src/common/graph/ge_tensor.cc @@ -0,0 +1,1021 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/ge_tensor.h" +#include +#include +#include +#include +#include "debug/ge_attr_define.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/ge_attr_value.h" +#include "graph/model_serialize.h" +#include "proto/ge_ir.pb.h" +#include "utils/attr_utils.h" +#include "utils/ge_ir_utils.h" +#include "utils/tensor_utils.h" +#include "utils/type_utils.h" + +namespace ge { +static const char *const kKeyDataTypeSelfDefined = "__tensor_desc_data_type__"; + +static const std::map kDataTypeMap = { + {DT_UNDEFINED, proto::DT_UNDEFINED}, + {DT_FLOAT, proto::DT_FLOAT}, + {DT_FLOAT16, proto::DT_FLOAT16}, + {DT_INT8, proto::DT_INT8}, + {DT_UINT8, proto::DT_UINT8}, + {DT_INT16, proto::DT_INT16}, + {DT_UINT16, proto::DT_UINT16}, + {DT_INT32, proto::DT_INT32}, + {DT_INT64, proto::DT_INT64}, + {DT_UINT32, proto::DT_UINT32}, + {DT_UINT64, proto::DT_UINT64}, + {DT_BOOL, proto::DT_BOOL}, + {DT_DOUBLE, proto::DT_DOUBLE}, + {DT_DUAL, proto::DT_DUAL}, + {DT_DUAL_SUB_INT8, proto::DT_DUAL_SUB_INT8}, + {DT_DUAL_SUB_UINT8, proto::DT_DUAL_SUB_UINT8}, + {DT_COMPLEX64, proto::DT_COMPLEX64}, + {DT_COMPLEX128, proto::DT_COMPLEX128}, + {DT_QINT8, proto::DT_QINT8}, + {DT_QINT16, proto::DT_QINT16}, + {DT_QINT32, proto::DT_QINT32}, + {DT_QUINT8, proto::DT_QUINT8}, + {DT_QUINT16, proto::DT_QUINT16}, + {DT_RESOURCE, proto::DT_RESOURCE}, + {DT_STRING_REF, proto::DT_STRING_REF}, + {DT_STRING, proto::DT_STRING}, +}; + +static const std::map kDataTypeSelfDefinedMap = { + {DT_DUAL, 13}, {DT_DUAL_SUB_INT8, 14}, {DT_DUAL_SUB_UINT8, 15}, {DT_COMPLEX64, 16}, {DT_COMPLEX128, 17}, + {DT_QINT8, 18}, {DT_QINT16, 19}, {DT_QINT32, 20}, {DT_QUINT8, 21}, {DT_QUINT16, 22}, +}; + +GeShape::GeShape() { shape_def_.InitDefault(); } + +// Default +GeShape::GeShape(std::vector s) : GeShape() { + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto i : s) { + proto_msg->add_dim(i); + } + } +} + +size_t GeShape::GetDimNum() const { + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + if (proto_msg->dim_size() >= 0) { + // check whether contain -2, if true, return -1 + for (auto i : proto_msg->dim()) { + if (i == UNKNOWN_DIM_NUM) { + return 0; + } + } + return proto_msg->dim_size(); + } else { + return 0; + } + } + return 0; +} + +int64_t GeShape::GetDim(size_t idx) const { + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + if (proto_msg->dim_size() > static_cast(idx)) { + return proto_msg->dim(static_cast(idx)); + } + } + return 0; +} + +graphStatus GeShape::SetDim(size_t idx, int64_t value) { + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + auto dims = proto_msg->mutable_dim(); + GE_CHECK_NOTNULL(dims); + if (dims->empty()) { + GELOGE(GRAPH_FAILED, "shape is empty"); + return GRAPH_FAILED; + } + if (static_cast(idx) >= dims->size()) { + GELOGE(GRAPH_FAILED, "idx is out of range"); + return GRAPH_FAILED; + } + proto_msg->set_dim(static_cast(idx), value); + } + return GRAPH_SUCCESS; +} + +std::vector GeShape::GetDims() const { + vector dims; + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto i : proto_msg->dim()) { + dims.push_back(i); + } + } + return dims; +} + +std::string GeShape::ToString() const { + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg == nullptr) { + return ""; + } + + std::stringstream ss; + bool first = true; + for (auto i : proto_msg->dim()) { + if (first) { + first = false; + } else { + ss << ","; + } + ss << i; + } + return ss.str(); +} + +int64_t GeShape::GetShapeSize() const { + int64_t res = 1; + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + if (proto_msg->dim().empty()) { + return 0; + } + for (auto i : proto_msg->dim()) { + // if unknown shape, return -1 + if (i == UNKNOWN_DIM || i == UNKNOWN_DIM_NUM) { + return UNKNOWN_DIM; + } + res *= i; + } + } + return res; +} + +/// +/// @brief Check is unknown shape +/// @return bool +/// /// +bool GeShape::IsUnknownShape() const { + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto i : proto_msg->dim()) { + if (i < 0) { + return true; + } + } + } + return false; +} + +/// +/// @brief Check is a scalar +/// @return bool +/// +bool GeShape::IsScalar() const { + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + return proto_msg->dim().empty(); + } + return false; +} + +const string TENSOR_UTILS_SIZE = "size"; +const string TENSOR_UTILS_WEIGHT_SIZE = "weight_size"; +const string TENSOR_UTILS_REUSE_INPUT = "reuse_input"; +const string TENSOR_UTILS_OUTPUT_TENSOR = "output_tensor"; +const string TENSOR_UTILS_DEVICE_TYPE = "device_type"; +const string TENSOR_UTILS_INPUT_TENSOR = "input_tensor"; +const string TENSOR_UTILS_REAL_DIM_CNT = "real_dim_cnt"; +const string TENSOR_UTILS_REUSE_INPUT_INDEX = "reuse_input_index"; +const string TENSOR_UTILS_DATA_OFFSET = "data_offset"; +const string TENSOR_UTILS_CMPS_SIZE = "cmps_size"; +const string TENSOR_UTILS_CMPS_TAB = "cmps_tab"; +const string TENSOR_UTILS_CMPS_TAB_OFFSET = "cmps_tab_offset"; +const string TENSOR_UTILS_CMPSINFO = "cmps_info"; +const string TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO = "alloffset_quantize_info"; +const string TENSOR_UTILS_RC = "rc"; +const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; +const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; +const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; +const string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; +const string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index"; + +GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} + +GeShape::GeShape(const GeShape &other) : GeShape() { shape_def_.CopyValueFrom(other.shape_def_); } + +GeShape::GeShape(GeShape &&other) : GeShape() { shape_def_.MoveValueFrom(std::move(other.shape_def_)); } + +GeShape &GeShape::operator=(const GeShape &other) { + if (&other != this) { + shape_def_.CopyValueFrom(other.shape_def_); + } + return *this; +} + +GeShape &GeShape::operator=(GeShape &&other) { + if (&other != this) { + shape_def_.CopyValueFrom(std::move(other.shape_def_)); + } + return *this; +} + +GeTensorDesc::GeTensorDesc() { + tensor_descriptor_.InitDefault(); + SetDataType(DT_FLOAT); + Init(); +} + +// Default +GeTensorDesc::GeTensorDesc(GeShape shape, Format format, DataType dt) : GeTensorDesc() { + SetFormat(format); + SetDataType(dt); + ShapeReference() = std::move(shape); +} + +// Default +GeTensorDesc::GeTensorDesc(const GeTensorDesc &desc) : GeTensorDesc() { + tensor_descriptor_.CopyValueFrom(desc.tensor_descriptor_); +} + +// Default +GeTensorDesc::GeTensorDesc(GeTensorDesc &&desc) : GeTensorDesc() { + tensor_descriptor_.MoveValueFrom(std::move(desc.tensor_descriptor_)); +} + +GeTensorDesc::GeTensorDesc(const ProtoMsgOwner &proto_owner, proto::TensorDescriptor *proto_msg) + : tensor_descriptor_(proto_owner, proto_msg) { + if (proto_msg != nullptr && !proto_msg->has_out_attr()) { + proto_msg->set_has_out_attr(true); + + int64_t size = 0; + (void)AttrUtils::GetInt(this, TENSOR_UTILS_SIZE, size); + proto_msg->set_size(size); + + int64_t weight_size = 0; + (void)AttrUtils::GetInt(this, TENSOR_UTILS_WEIGHT_SIZE, weight_size); + proto_msg->set_weight_size(weight_size); + + bool reuse_input = false; + (void)AttrUtils::GetBool(this, TENSOR_UTILS_REUSE_INPUT, reuse_input); + proto_msg->set_reuse_input(reuse_input); + + bool output_tensor = false; + (void)AttrUtils::GetBool(this, TENSOR_UTILS_OUTPUT_TENSOR, output_tensor); + proto_msg->set_output_tensor(output_tensor); + + string device_type = "NPU"; + (void)AttrUtils::GetStr(this, TENSOR_UTILS_DEVICE_TYPE, device_type); + proto_msg->set_device_type(device_type); + + bool input_tensor = false; + (void)AttrUtils::GetBool(this, TENSOR_UTILS_INPUT_TENSOR, input_tensor); + proto_msg->set_input_tensor(input_tensor); + + int64_t real_dim_cnt = 0; + (void)AttrUtils::GetInt(this, TENSOR_UTILS_REAL_DIM_CNT, real_dim_cnt); + proto_msg->set_real_dim_cnt(real_dim_cnt); + + int64_t reuse_input_index = 0; + (void)AttrUtils::GetInt(this, TENSOR_UTILS_REUSE_INPUT_INDEX, reuse_input_index); + proto_msg->set_reuse_input_index(reuse_input_index); + + int64_t data_offset = 0; + (void)AttrUtils::GetInt(this, TENSOR_UTILS_DATA_OFFSET, data_offset); + proto_msg->set_data_offset(data_offset); + + int64_t cmps_size = 0; + (void)AttrUtils::GetInt(this, TENSOR_UTILS_CMPS_SIZE, cmps_size); + proto_msg->set_cmps_size(cmps_size); + + string cmps_tab; + (void)AttrUtils::GetStr(this, TENSOR_UTILS_CMPS_TAB, cmps_tab); + proto_msg->set_cmps_tab(cmps_tab); + + int64_t cmps_tab_offset = 0; + (void)AttrUtils::GetInt(this, TENSOR_UTILS_CMPS_TAB_OFFSET, cmps_tab_offset); + proto_msg->set_cmps_tab_offset(cmps_tab_offset); + } +} + +bool GeTensorDesc::GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const { + const auto &tensor_descriptor = this->tensor_descriptor_.GetProtoMsg(); + const auto &r_tensor_descriptor = r_ge_tensor_desc.tensor_descriptor_.GetProtoMsg(); + if ((tensor_descriptor != nullptr) && (r_tensor_descriptor != nullptr)) { + // Message TensorDescriptor in ge_ir.proto + return ( + IsEqual(tensor_descriptor->name(), r_tensor_descriptor->name(), "TensorDescriptor.name()") && + IsEqual(tensor_descriptor->dtype(), r_tensor_descriptor->dtype(), "TensorDescriptor.dtype()") && + // Message ShapeDef in ge_ir.proto + IsEqual(ToString(tensor_descriptor->shape().dim()), ToString(r_tensor_descriptor->shape().dim()), + "TensorDescriptor.shape().dim()") && + IsEqual(tensor_descriptor->layout(), r_tensor_descriptor->layout(), "TensorDescriptor.layout()") && + IsEqual(tensor_descriptor->has_out_attr(), r_tensor_descriptor->has_out_attr(), + "TensorDescriptor.has_out_attr()") && + IsEqual(tensor_descriptor->size(), r_tensor_descriptor->size(), "TensorDescriptor.size()") && + IsEqual(tensor_descriptor->weight_size(), r_tensor_descriptor->weight_size(), "TensorDescriptor.weight_size()") && + IsEqual(tensor_descriptor->reuse_input(), r_tensor_descriptor->reuse_input(), "TensorDescriptor.reuse_input()") && + IsEqual(tensor_descriptor->output_tensor(), r_tensor_descriptor->output_tensor(), + "TensorDescriptor.output_tensor()") && + IsEqual(tensor_descriptor->device_type(), r_tensor_descriptor->device_type(), "TensorDescriptor.device_type()") && + IsEqual(tensor_descriptor->input_tensor(), r_tensor_descriptor->input_tensor(), + "TensorDescriptor.input_tensor()") && + IsEqual(tensor_descriptor->real_dim_cnt(), r_tensor_descriptor->real_dim_cnt(), + "TensorDescriptor.real_dim_cnt()") && + IsEqual(tensor_descriptor->reuse_input_index(), r_tensor_descriptor->reuse_input_index(), + "TensorDescriptor.reuse_input_index()") && + IsEqual(tensor_descriptor->data_offset(), r_tensor_descriptor->data_offset(), "TensorDescriptor.data_offset()") && + IsEqual(tensor_descriptor->cmps_size(), r_tensor_descriptor->cmps_size(), "TensorDescriptor.cmps_size()") && + IsEqual(tensor_descriptor->cmps_tab(), r_tensor_descriptor->cmps_tab(), "TensorDescriptor.cmps_tab()") && + IsEqual(tensor_descriptor->cmps_tab_offset(), r_tensor_descriptor->cmps_tab_offset(), + "TensorDescriptor.cmps_tab_offset()")); + } else { + return ((tensor_descriptor == nullptr) && (r_tensor_descriptor == nullptr)); + } +} + +bool GeTensorDesc::operator==(const GeTensorDesc &r_ge_tensor_desc) const { + return GeTensorDescAttrsAreEqual(r_ge_tensor_desc); +} + +GeShape &GeTensorDesc::ShapeReference() const { + if (tensor_descriptor_.GetProtoMsg() != nullptr) { + GeShape refShape(tensor_descriptor_.GetProtoOwner(), tensor_descriptor_.GetProtoMsg()->mutable_shape()); + __shape_.RefTo(refShape); + } else { + GeShape refShape(tensor_descriptor_.GetProtoOwner(), nullptr); + __shape_.RefTo(refShape); + } + return __shape_; +} + +void GeTensorDesc::Init() { + SetFormat(FORMAT_ND); + SetOriginFormat(FORMAT_ND); + TensorUtils::SetDeviceType(*this, DeviceType::NPU); + if (tensor_descriptor_.GetProtoMsg() == nullptr) { + GELOGE(GRAPH_FAILED, "ProtoType nullptr."); + return; + } + tensor_descriptor_.GetProtoMsg()->set_has_out_attr(true); +} + +ProtoAttrMapHelper GeTensorDesc::MutableAttrMap() { + if (tensor_descriptor_.GetProtoMsg() != nullptr) { + return ProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), tensor_descriptor_.GetProtoMsg()->mutable_attr()); + } + return ProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), nullptr); +} + +ConstProtoAttrMapHelper GeTensorDesc::GetAttrMap() const { + if (tensor_descriptor_.GetProtoMsg() != nullptr) { + return ConstProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), + tensor_descriptor_.GetProtoMsg()->mutable_attr()); + } + return ConstProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), nullptr); +} + +void GeTensorDesc::Update(GeShape shape, Format format, DataType dt) { + ShapeReference() = std::move(shape); + SetFormat(format); + SetDataType(dt); +} +GeShape GeTensorDesc::GetShape() const { return ShapeReference(); } + +GeShape &GeTensorDesc::MutableShape() { return ShapeReference(); } + +void GeTensorDesc::SetShape(GeShape shape) { ShapeReference() = std::move(shape); } + +// set shape with -2, it stand for unknown shape +void GeTensorDesc::SetUnknownDimNumShape() { SetShape(GeShape({UNKNOWN_DIM_NUM})); } + +// for unknown shape +graphStatus GeTensorDesc::SetShapeRange(const std::vector> &range) { + std::vector> shape_range; + for (const auto &ele : range) { + shape_range.emplace_back(std::vector({ele.first, ele.second})); + } + auto ret = AttrUtils::SetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); + return ret ? GRAPH_SUCCESS : GRAPH_FAILED; +} +graphStatus GeTensorDesc::GetShapeRange(std::vector> &range) const { + std::vector> shape_range; + (void)AttrUtils::GetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); + + for (const auto &ele : shape_range) { + // here must be only two elemenet because pair + if (ele.size() != 2) { + GELOGE(GRAPH_FAILED, "shape_range must contain only 2 value but really is %lu", ele.size()); + return GRAPH_FAILED; + } + std::pair pair({ele[0], ele[1]}); + range.emplace_back(pair); + } + + return GRAPH_SUCCESS; +} + +GeShape GeTensorDesc::GetOriginShape() const { + vector origin_shape; + if (!AttrUtils::GetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape)) { + return GeShape(); + } + return GeShape(origin_shape); +} + +void GeTensorDesc::SetOriginShape(const GeShape &origin_shape) { + std::vector origin_shape_tmp = origin_shape.GetDims(); + (void)AttrUtils::SetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape_tmp); +} + +Format GeTensorDesc::GetFormat() const { + auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + return TypeUtils::SerialStringToFormat(tensor_descriptor_msg->layout()); + } + return FORMAT_RESERVED; +} + +void GeTensorDesc::SetFormat(Format format) { + auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_layout(TypeUtils::FormatToSerialString(format)); + } +} + +void GeTensorDesc::SetName(const std::string &name) { + auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_name(name); + return; + } + GELOGW("[SetName]tensor_descriptor_msg is null."); +} + +const std::string GeTensorDesc::GetName() const { + auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + return tensor_descriptor_msg->name(); + } + GELOGW("[GetName]tensor_descriptor_msg is null."); + return ""; +} + +Format GeTensorDesc::GetOriginFormat() const { + std::string origin_format_str; + if (!AttrUtils::GetStr(this, TENSOR_UTILS_ORIGIN_FORMAT, origin_format_str)) { + // Can not get the certificate and it's not set, return directly + return FORMAT_RESERVED; + } + if (origin_format_str == "RESERVED") { + return FORMAT_RESERVED; + } + return TypeUtils::SerialStringToFormat(origin_format_str); +} + +void GeTensorDesc::SetOriginFormat(Format origin_format) { + std::string origin_format_str = "RESERVED"; + if (origin_format != FORMAT_RESERVED) { + origin_format_str = TypeUtils::FormatToSerialString(origin_format); + } + (void)AttrUtils::SetStr(this, TENSOR_UTILS_ORIGIN_FORMAT, origin_format_str); +} + +DataType GeTensorDesc::GetDataType() const { + auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg == nullptr) { + return DT_UNDEFINED; + } + auto &attr_map = *(tensor_descriptor_msg->mutable_attr()); + // Data type + auto it_data_type = attr_map.find(kKeyDataTypeSelfDefined); + if (it_data_type != attr_map.end()) { + int64_t data_type_proto = it_data_type->second.i(); + for (auto it : kDataTypeSelfDefinedMap) { + if (it.second == data_type_proto) { + return it.first; + } + } + } else { + auto data_type_proto = tensor_descriptor_msg->dtype(); + for (auto it : kDataTypeMap) { + if (it.second == data_type_proto) { + return it.first; + } + } + } + return DT_UNDEFINED; +} + +void GeTensorDesc::SetDataType(DataType dataType) { + auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg == nullptr) { + return; + } + auto &attr_maps = *(tensor_descriptor_msg->mutable_attr()); + (void)attr_maps.erase(kKeyDataTypeSelfDefined); + + // Data type + auto it = kDataTypeMap.find(dataType); + if (it != kDataTypeMap.end()) { + tensor_descriptor_msg->set_dtype(it->second); + return; + } + auto it2 = kDataTypeSelfDefinedMap.find(dataType); + if (it2 != kDataTypeSelfDefinedMap.end()) { + attr_maps[kKeyDataTypeSelfDefined].set_i(it2->second); + } +} + +void GeTensorDesc::SetOriginDataType(DataType origin_data_type) { + std::string origin_data_type_str = "RESERVED"; + if (origin_data_type != DT_UNDEFINED) { + origin_data_type_str = TypeUtils::DataTypeToSerialString(origin_data_type); + } + (void)AttrUtils::SetStr(this, TENSOR_UTILS_ORIGIN_DATA_TYPE, origin_data_type_str); +} + +DataType GeTensorDesc::GetOriginDataType() const { + std::string origin_data_type_str; + if (!AttrUtils::GetStr(this, TENSOR_UTILS_ORIGIN_DATA_TYPE, origin_data_type_str)) { + return DT_UNDEFINED; + } + if (origin_data_type_str == "RESERVED") { + return DT_UNDEFINED; + } + return TypeUtils::SerialStringToDataType(origin_data_type_str); +} + +std::vector GeTensorDesc::GetRefPortIndex() const { + vector ref_port_index; + (void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index); + return ref_port_index; +} + +void GeTensorDesc::SetRefPortByIndex(const std::vector &index) { + (void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index); +} + +graphStatus GeTensorDesc::IsValid() const { + auto dtype = this->GetDataType(); + auto format = this->GetFormat(); + if (dtype == DT_UNDEFINED && format == FORMAT_RESERVED) { + return GRAPH_PARAM_INVALID; + } + return GRAPH_SUCCESS; +} + +GeTensorDesc GeTensorDesc::Clone() const { return *this; } + +GeTensorDesc &GeTensorDesc::operator=(const GeTensorDesc &desc) { + if (&desc != this) { + tensor_descriptor_.CopyValueFrom(desc.tensor_descriptor_); + } + return *this; +} + +GeTensorDesc &GeTensorDesc::operator=(GeTensorDesc &&desc) { + if (&desc != this) { + tensor_descriptor_.CopyValueFrom(std::move(desc.tensor_descriptor_)); + } + return *this; +} + +GeTensor::GeTensor::GeTensor() { + tensor_def_.InitDefault(); + // Default init desc + DescReference() = GeTensorDesc(); +} + +GeTensor::GeTensor(const GeTensorDesc &tensor_desc) : GeTensor() { DescReference() = tensor_desc; } + +GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const vector &data) : GeTensor() { + DescReference() = tensor_desc; + auto proto_msg = tensor_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->set_data(data.data(), data.size()); + } +} + +GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const uint8_t *data, size_t size) : GeTensor() { + DescReference() = tensor_desc; + auto proto_msg = tensor_def_.GetProtoMsg(); + if (proto_msg != nullptr && data != nullptr) { + proto_msg->set_data(data, size); + } +} + +GeTensor::GeTensor(GeTensorDesc &&tensor_desc, vector &&data) : GeTensor() { + DescReference() = std::move(tensor_desc); + auto proto_msg = tensor_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->set_data(data.data(), data.size()); + } +} + +GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const Buffer &data) : GeTensor() { + DescReference() = tensor_desc; + auto proto_msg = tensor_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + if (data.size() == 0) { + GELOGI("GetSize res is 0."); + } + if (data.data() == nullptr) { + GELOGI("data addr is null."); + } + proto_msg->set_data(data.GetData(), data.GetSize()); + } +} + +GeTensor::GeTensor(const ProtoMsgOwner &proto_owner, proto::TensorDef *proto_msg) + : tensor_def_(proto_owner, proto_msg) {} + +GeTensorDesc GeTensor::GetTensorDesc() const { return DescReference(); } + +GeTensorDesc &GeTensor::MutableTensorDesc() { return DescReference(); } + +GeTensorDesc &GeTensor::DescReference() const { + if (tensor_def_.GetProtoMsg() != nullptr) { + GeTensorDesc tensor_desc(tensor_def_.GetProtoOwner(), tensor_def_.GetProtoMsg()->mutable_desc()); + __desc_.RefTo(tensor_desc); + } else { + GeTensorDesc tensor_desc(tensor_def_.GetProtoOwner(), nullptr); + __desc_.RefTo(tensor_desc); + } + return __desc_; +} + +void GeTensor::SetTensorDesc(const GeTensorDesc &tensor_desc) { DescReference() = tensor_desc; } + +const Buffer GeTensor::GetData() const { + auto proto_msg = tensor_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + return Buffer(tensor_def_.GetProtoOwner(), proto_msg->mutable_data()); + } + return Buffer(); +} + +Buffer GeTensor::MutableData() { + auto proto_msg = tensor_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + return Buffer(tensor_def_.GetProtoOwner(), proto_msg->mutable_data()); + } + return Buffer(); +} + +graphStatus GeTensor::SetData(vector &&data) { + auto proto_msg = tensor_def_.GetProtoMsg(); + GE_CHECK_NOTNULL(proto_msg); + proto_msg->set_data(data.data(), data.size()); + return GRAPH_SUCCESS; +} + +graphStatus GeTensor::SetData(const vector &data) { + auto proto_msg = tensor_def_.GetProtoMsg(); + GE_CHECK_NOTNULL(proto_msg); + proto_msg->set_data(data.data(), data.size()); + return GRAPH_SUCCESS; +} + +graphStatus GeTensor::SetData(const uint8_t *data, size_t size) { + GE_CHECK_NOTNULL(data); + auto proto_msg = tensor_def_.GetProtoMsg(); + GE_CHECK_NOTNULL(proto_msg); + proto_msg->set_data(data, size); + return GRAPH_SUCCESS; +} + +graphStatus GeTensor::SetData(const Buffer &data) { + auto proto_msg = tensor_def_.GetProtoMsg(); + GE_CHECK_NOTNULL(proto_msg); + if (data.size() == 0) { + GELOGI("GetSize res is 0."); + } + if (data.data() == nullptr) { + GELOGI("data addr is null."); + } + proto_msg->set_data(data.data(), data.size()); + return GRAPH_SUCCESS; +} + +GeTensor GeTensor::Clone() const { + GeTensor tensor; + tensor.tensor_def_.CopyValueFrom(tensor_def_); + return tensor; +} + +GeTensor::GeTensor(const GeTensor &other) { tensor_def_ = other.tensor_def_; } + +GeTensor &GeTensor::operator=(const GeTensor &other) { + if (&other != this) { + tensor_def_ = other.tensor_def_; + } + return *this; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc, + int64_t &size) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + GE_CHECK_NOTNULL(tensor_descriptor_msg); + size = static_cast(tensor_descriptor_msg->size()); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, int64_t size) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_size(size); + } +} + +uint32_t TensorUtils::GetWeightSize(const GeTensorDesc &tensor_desc) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + return static_cast(tensor_descriptor_msg->weight_size()); + } + return 0; +} + +uint32_t TensorUtils::GetWeightSize(const GeTensor &tensor) { return GetWeightSize(tensor.GetTensorDesc()); } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t TensorUtils::GetWeightSize(const ConstGeTensorPtr &tensor_ptr) { + if (tensor_ptr == nullptr) { + return 0; + } + return GetWeightSize(*tensor_ptr); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint8_t *TensorUtils::GetWeightAddr(const ConstGeTensorPtr &tensor_ptr, + uint8_t *base) { + if (tensor_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "tensor_ptr is null."); + return nullptr; + } + return GetWeightAddr(*tensor_ptr, base); +} + +uint8_t *TensorUtils::GetWeightAddr(const GeTensor &tensor, uint8_t *base) { + if (base == nullptr) { + GELOGE(GRAPH_FAILED, "base is null."); + return nullptr; + } + int64_t weight_data_offset = 0; + if (GetDataOffset(tensor.GetTensorDesc(), weight_data_offset) != GRAPH_SUCCESS) return nullptr; + + if (weight_data_offset == 0) { + // The weight of offset 0 is still in const op, still get from ATTR_NAME_WEIGHTS. + return const_cast(tensor.GetData().data()); + } + + return base + weight_data_offset; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetWeightSize(GeTensorDesc &tensor_desc, + uint32_t size) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_weight_size(size); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetReuseInput(const GeTensorDesc &tensor_desc, + bool &flag) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + GE_CHECK_NOTNULL(tensor_descriptor_msg); + flag = tensor_descriptor_msg->reuse_input(); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInput(GeTensorDesc &tensor_desc, bool flag) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_reuse_input(flag); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetOutputTensor(const GeTensorDesc &tensor_desc, + bool &flag) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + GE_CHECK_NOTNULL(tensor_descriptor_msg); + flag = tensor_descriptor_msg->output_tensor(); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetOutputTensor(GeTensorDesc &tensor_desc, bool flag) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_output_tensor(flag); + } +} + +static map device_to_str_map{ + {0, "NPU"}, + {1, "CPU"}, +}; +static map str_to_device_map{ + {"NPU", 0}, + {"CPU", 1}, +}; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDeviceType(const GeTensorDesc &tensor_desc, + DeviceType &type) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + GE_CHECK_NOTNULL(tensor_descriptor_msg); + string type_str = tensor_descriptor_msg->device_type(); + type = DeviceType(str_to_device_map[type_str]); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDeviceType(GeTensorDesc &tensor_desc, + DeviceType type) { + auto type_str = device_to_str_map[type]; + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_device_type(type_str); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetInputTensor(const GeTensorDesc &tensor_desc, + bool &flag) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + GE_CHECK_NOTNULL(tensor_descriptor_msg); + flag = tensor_descriptor_msg->input_tensor(); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetInputTensor(GeTensorDesc &tensor_desc, bool flag) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_input_tensor(flag); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRealDimCnt(const GeTensorDesc &tensor_desc, + uint32_t &cnt) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + GE_CHECK_NOTNULL(tensor_descriptor_msg); + cnt = static_cast(tensor_descriptor_msg->real_dim_cnt()); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRealDimCnt(GeTensorDesc &tensor_desc, + uint32_t cnt) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_real_dim_cnt(cnt); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +TensorUtils::GetReuseInputIndex(const GeTensorDesc &tensor_desc, uint32_t &idx) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + GE_CHECK_NOTNULL(tensor_descriptor_msg); + + idx = static_cast(tensor_descriptor_msg->reuse_input_index()); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInputIndex(GeTensorDesc &tensor_desc, + uint32_t idx) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_reuse_input_index(idx); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDataOffset(const GeTensorDesc &tensor_desc, + int64_t &offset) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + offset = tensor_descriptor_msg->data_offset(); + return GRAPH_SUCCESS; + } else { + GELOGW("tensor_descriptor_msg is nullptr."); + return GRAPH_FAILED; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDataOffset(GeTensorDesc &tensor_desc, + int64_t offset) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_data_offset(offset); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetCmpsSize(const GeTensorDesc &tensor_desc, + uint32_t &cmp_size) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + cmp_size = static_cast(tensor_descriptor_msg->cmps_size()); + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsSize(GeTensorDesc &tensor_desc, + uint32_t cmp_size) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_cmps_size(cmp_size); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetCmpsTab(const GeTensorDesc &tensor_desc, + vector &vec) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + string str = tensor_descriptor_msg->cmps_tab(); + vec.assign(str.begin(), str.end()); + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsTab(GeTensorDesc &tensor_desc, + const uint8_t *data, size_t size) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + GE_CHK_BOOL_EXEC(data != nullptr, return, "data is null."); + string str((const char *)data, size); + tensor_descriptor_msg->set_cmps_tab(str); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +TensorUtils::GetCmpsTabOffset(const GeTensorDesc &tensor_desc, int64_t &tab_offset) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tab_offset = tensor_descriptor_msg->cmps_tab_offset(); + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsTabOffset(GeTensorDesc &tensor_desc, + int64_t tab_offset) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_cmps_tab_offset(tab_offset); + } +} + +graphStatus TensorUtils::GetCmpsInfo(const GeTensorDesc &tensor_desc, CompressInfo &info) { + GeAttrValue attr_value; + if (tensor_desc.GetAttr(TENSOR_UTILS_CMPSINFO, attr_value) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + return attr_value.GetValue(info); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsInfo(GeTensorDesc &tensor_desc, + const CompressInfo &info) { + (void)tensor_desc.SetAttr(TENSOR_UTILS_CMPSINFO, GeAttrValue::CreateFrom(info)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool TensorUtils::HasAlloffsetQuantizeInfo( + const GeTensorDesc &tensor_desc) { + return tensor_desc.HasAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +TensorUtils::GetAlloffsetQuantizeInfo(const GeTensorDesc &tensor_desc, AllOffsetQuantizeInfo &info) { + GeAttrValue attr_value; + if (tensor_desc.GetAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO, attr_value) != GRAPH_SUCCESS) { + GELOGW("get attr alloffset_quantize_info fail."); + } + return attr_value.GetValue(info); +} + +void TensorUtils::SetAlloffsetQuantizeInfo(GeTensorDesc &tensor_desc, const AllOffsetQuantizeInfo &info) { + (void)tensor_desc.SetAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO, GeAttrValue::CreateFrom(info)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRC(const GeTensorDesc &tensor_desc, + uint32_t &rc) { + return AttrUtils::GetInt(&tensor_desc, TENSOR_UTILS_RC, rc) ? GRAPH_SUCCESS : GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRC(GeTensorDesc &tensor_desc, uint32_t rc) { + (void)AttrUtils::SetInt(&tensor_desc, TENSOR_UTILS_RC, rc); +} +} // namespace ge diff --git a/src/common/graph/graph.cc b/src/common/graph/graph.cc new file mode 100644 index 00000000..fc30e9d6 --- /dev/null +++ b/src/common/graph/graph.cc @@ -0,0 +1,384 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "external/graph/graph.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/debug/ge_op_types.h" +#include "graph/model.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" + +using std::map; +using std::pair; +using std::string; +using std::vector; + +namespace ge { +class GraphImpl { + public: + friend class GraphUtils; + GraphImpl(const GraphImpl &) = delete; + GraphImpl &operator=(const GraphImpl &) = delete; + + explicit GraphImpl(const std::string &name) : name_(name) {} + + ~GraphImpl() { + if (IsValid()) { + if (compute_graph_ != nullptr) { + GraphUtils::BreakConnect(compute_graph_->GetAllNodesInfo()); + } + } + for (const auto &it : op_list_) { + Operator op = it.second; + op.BreakConnect(); + } + } + + graphStatus SetInputs(const std::vector &inputs) { + compute_graph_ = GraphUtils::CreateGraphFromOperator(name_, inputs); + GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "Build Graph failed."); + GE_CHK_BOOL_RET_STATUS(inputs.size() != 0, GRAPH_FAILED, "set input NULL."); + compute_graph_->SetInputSize(static_cast(inputs.size())); + return GRAPH_SUCCESS; + } + + graphStatus SetOutputs(const std::vector &outputs) { + if (compute_graph_ == nullptr) { + GELOGE(GRAPH_FAILED, "set ComputeGraph failed."); + return GRAPH_FAILED; + } + if (outputs.empty()) { + GELOGW("set outputs size is 0."); + return GRAPH_SUCCESS; + } + + // Construct special output node + std::vector>> output_indexs; + for (size_t i = 0; i < outputs.size(); ++i) { + output_indexs.emplace_back(outputs[i], std::vector{}); + } + + graphStatus ret = SetOutputs(output_indexs); + return ret; + } + + graphStatus SetOutputs(const std::vector>> &output_indexs) { + if (compute_graph_ == nullptr) { + GELOGE(GRAPH_FAILED, "set ComputeGraph failed."); + return GRAPH_FAILED; + } + if (output_indexs.empty()) { + GELOGW("set outputs size is 0."); + return GRAPH_SUCCESS; + } + + // Construct special output node + std::vector> output_nodes; + for (const auto &item : output_indexs) { + const Operator &output = item.first; + const vector &indexs = item.second; + ge::NodePtr node = compute_graph_->FindNode(output.GetName()); + if (node == nullptr) { + GELOGW("user designated out_node [%s] not exist in graph, will ignored!", output.GetName().c_str()); + continue; + } + + ge::OpDescPtr tmp_op_ptr = node->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue); + size_t out_size = tmp_op_ptr->GetOutputsSize(); + if (indexs.empty()) { + for (size_t i = 0; i < out_size; ++i) { + output_name_ += output.GetName() + ":" + std::to_string(i) + ";"; + output_nodes.emplace_back(node, i); + } + } else { + for (size_t i = 0; i < indexs.size(); ++i) { + if (indexs[i] >= out_size) { + GELOGW("index[%zu] is not belong to out_node[%s]", indexs[i], output.GetName().c_str()); + } else { + output_name_ += output.GetName() + ":" + std::to_string(i) + ";"; + output_nodes.emplace_back(node, indexs[i]); + } + } + } + } + + // Del last ";" + if (!output_name_.empty()) { + output_name_ = output_name_.substr(0, output_name_.length() - 1); + } + compute_graph_->SetUserDefOutput(output_name_); + compute_graph_->SetOutputSize(static_cast(output_indexs.size())); + compute_graph_->SetGraphOutNodesInfo(output_nodes); + return GRAPH_SUCCESS; + } + + graphStatus SetOutputs(const std::vector> &outputs) { + GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild."); + GE_CHK_BOOL_EXEC_INFO(outputs.size() != 0, return GRAPH_SUCCESS, "set outputs size is 0."); + + // Construct specified output + std::vector> output_nodes; + for (auto item : outputs) { + ge::NodePtr node = compute_graph_->FindNode(item.first.GetName()); + if (node == nullptr) { + GELOGE(GRAPH_FAILED, " Warning, user designated out_node (%s) not exist in graph, this out_node ignored!", + item.first.GetName().c_str()); + return GRAPH_FAILED; + } + ge::OpDescPtr tmp_op_ptr = node->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue); + size_t out_size = tmp_op_ptr->GetOutputsSize(); + + if (item.second.empty()) { + for (size_t i = 0; i < out_size; ++i) { + output_name_ += item.first.GetName() + ":" + std::to_string(i) + ";"; + output_nodes.push_back(std::make_pair(node, i)); + } + } else { + int32_t index = tmp_op_ptr->GetOutputIndexByName(item.second); + if (index < 0) { + GELOGE(GRAPH_FAILED, + " Warning, user designated out_node (%s):(%s) not exist in graph, this out_node ignored!", + item.first.GetName().c_str(), item.second.c_str()); + return GRAPH_FAILED; + } + output_name_ += item.first.GetName() + ":" + std::to_string(index) + ";"; + output_nodes.push_back(std::make_pair(node, index)); + } + } + // Del last ";" + if (!output_name_.empty()) { + output_name_ = output_name_.substr(0, output_name_.length() - 1); + } + compute_graph_->SetOutputSize(static_cast(outputs.size())); + compute_graph_->SetGraphOutNodesInfo(output_nodes); + GELOGI("********************SetOutputs Success***********************"); + GE_IF_BOOL_EXEC(!output_name_.empty(), GELOGI(" NetOutputs: (%s)", output_name_.c_str())); + + return GRAPH_SUCCESS; + } + + graphStatus SetTargets(const std::vector &targets) { + GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild."); + GE_CHK_BOOL_EXEC_INFO(targets.size() != 0, return GRAPH_SUCCESS, "set targets size is 0."); + + std::vector target_nodes; + for (auto item : targets) { + ge::NodePtr node = compute_graph_->FindNode(item.GetName()); + if (node == nullptr) { + GELOGW(" Warning, user designated target_node (%s) not exist in graph, this target_node ignored!", + item.GetName().c_str()); + continue; + } + target_nodes.push_back(node); + } + compute_graph_->SetGraphTargetNodesInfo(target_nodes); + return GRAPH_SUCCESS; + } + bool IsValid() const { return (compute_graph_ != nullptr); } + + graphStatus AddOp(const ge::Operator &op) { + std::pair::iterator, bool> ret; + ret = op_list_.emplace(std::pair(op.GetName(), op)); + GE_CHK_BOOL_RET_STATUS(ret.second != false, GRAPH_FAILED, "the op have added before, op name:%s.", + op.GetName().c_str()); + return GRAPH_SUCCESS; + } + + graphStatus GetAllOpName(std::vector &op_name) const { + for (const auto &it : op_list_) { + op_name.push_back(it.second.GetName()); + } + return GRAPH_SUCCESS; + } + + graphStatus FindOpByName(const string &name, ge::Operator &op) const { + auto it = op_list_.find(name); + GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str()); + op = it->second; + return GRAPH_SUCCESS; + } + + graphStatus FindOpByType(const string &type, std::vector &ops) const { + for (auto &op : op_list_) { + auto op_type = op.second.GetOpType(); + if (op_type == type) { + ops.push_back(op.second); + continue; + } + if (op_type == ge::FRAMEWORKOP) { + op.second.GetAttr(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, op_type); + if (op_type == type) { + ops.push_back(op.second); + } + } + } + return GRAPH_SUCCESS; + } + + void SetNeedIteration(bool need_iteration) { + if (compute_graph_ == nullptr) { + GELOGE(GRAPH_FAILED, "Set need iteration failed, as compute graph is null."); + return; + } + compute_graph_->SetNeedIteration(need_iteration); + } + + const std::string &GetName() const { return name_; } + + private: + std::string name_; + std::string output_name_; + std::map op_list_; + ComputeGraphPtr compute_graph_{nullptr}; +}; + +Graph::Graph(const std::string &name) { + impl_ = ComGraphMakeShared(name); + if (impl_ == nullptr) { + GELOGW("GraphImpl make shared failed, impl_ is nullptr"); + } +} + +graphStatus Graph::AddOp(const ge::Operator &op) { + GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, "AddOp failed: graph can not be used, impl is nullptr."); + return impl_->AddOp(op); +} + +graphStatus Graph::GetAllOpName(std::vector &op_name) const { + GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, + "GetAllOpName failed: graph can not be used, impl is nullptr."); + return impl_->GetAllOpName(op_name); +} + +graphStatus Graph::FindOpByName(const std::string &name, Operator &op) const { + Operator op_find_op_def("NULL"); + op = op_find_op_def; + GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, + "FindOpByName failed: graph can not be used, impl is nullptr."); + return impl_->FindOpByName(name, op); +} + +graphStatus Graph::FindOpByType(const string &type, std::vector &ops) const { + GE_CHECK_NOTNULL(impl_); + return impl_->FindOpByType(type, ops); +} + +Graph &Graph::SetInputs(const vector &inputs) { + GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetInputs failed: graph can not be used, impl is nullptr.") + GE_CHK_BOOL_EXEC(inputs.size() > 0, return *this, "SetInputs failed: input operator size can not be 0."); + (void)impl_->SetInputs(inputs); + return *this; +} + +Graph &Graph::SetOutputs(const vector &outputs) { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr."); + return *this; + } + (void)impl_->SetOutputs(outputs); + return *this; +} + +Graph &Graph::SetOutputs(const std::vector>> &output_indexs) { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr."); + return *this; + } + (void)impl_->SetOutputs(output_indexs); + return *this; +} + +Graph &Graph::SetOutputs(const std::vector> &outputs) { + GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetOutputs failed: graph can not be used, impl is nullptr.") + (void)impl_->SetOutputs(outputs); + return *this; +} + +Graph &Graph::SetTargets(const vector &targets) { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "SetTargets failed: graph can not be used, impl is nullptr."); + return *this; + } + (void)impl_->SetTargets(targets); + return *this; +} + +bool Graph::IsValid() const { + if (impl_ == nullptr) { + return false; + } + return impl_->IsValid(); +} + +void Graph::SetNeedIteration(bool need_iteration) { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "Set need iteration failed, as impl is null."); + return; + } + impl_->SetNeedIteration(need_iteration); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::GetComputeGraph(const ge::Graph &graph) { + GE_CHK_BOOL_EXEC_NOLOG(graph.IsValid(), return nullptr); + return graph.impl_->compute_graph_; +} + +graphStatus Graph::SaveToFile(const string &file_name) const { + Model model = Model(); + model.SetGraph(*this); + return model.SaveToFile(file_name); +} + +graphStatus Graph::LoadFromFile(const string &file_name) { + Model model = Model(); + graphStatus ret = model.LoadFromFile(file_name); + if (ret != GRAPH_SUCCESS) { + return ret; + } + *this = model.GetGraph(); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string &Graph::GetName() const { return impl_->GetName(); } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph +GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) { + GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return Graph("")); + + auto name = compute_graph->GetName(); + auto graph = Graph(name); + + GE_CHK_BOOL_EXEC_NOLOG(graph.impl_ != nullptr, return graph); + graph.impl_->compute_graph_ = compute_graph; + + return graph; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) { + GE_CHECK_NOTNULL(graph.impl_); + GE_CHECK_NOTNULL(graph.impl_->compute_graph_); + + graph.impl_->op_list_.clear(); + for (const auto &node : graph.impl_->compute_graph_->GetDirectNode()) { + graph.impl_->op_list_[node->GetName()] = OpDescUtils::CreateOperatorFromNode(node); + } + return SUCCESS; +} +} // namespace ge diff --git a/src/common/graph/graph.mk b/src/common/graph/graph.mk new file mode 100644 index 00000000..4ea84919 --- /dev/null +++ b/src/common/graph/graph.mk @@ -0,0 +1,294 @@ +LOCAL_PATH := $(call my-dir) +include $(LOCAL_PATH)/stub/Makefile +COMMON_LOCAL_SRC_FILES := \ + ./proto/om.proto \ + ./proto/ge_ir.proto \ + ./proto/ge_onnx.proto \ + ./proto/insert_op.proto \ + ./proto/task.proto \ + ./proto/fwk_adapter.proto \ + ./proto/op_mapping_info.proto \ + ./proto/dump_task.proto \ + ./anchor.cc \ + ./ge_attr_value.cc \ + ./attr_value.cc \ + ./buffer.cc \ + ./compute_graph.cc \ + ./graph.cc \ + ./inference_context.cc \ + ./shape_refiner.cc \ + ./format_refiner.cc \ + ./ref_relation.cc \ + ./model.cc \ + ./model_serialize.cc \ + ./node.cc \ + ./op_desc.cc \ + ./operator.cc \ + ./operator_factory.cc \ + ./operator_factory_impl.cc \ + ./ge_attr_define.cc \ + ./ge_tensor.cc \ + ./detail/attributes_holder.cc \ + ./utils/anchor_utils.cc \ + ./utils/tuning_utils.cc \ + ./utils/graph_utils.cc \ + ./utils/ge_ir_utils.cc \ + ./utils/op_desc_utils.cc \ + ./utils/type_utils.cc \ + ./utils/tensor_utils.cc \ + ./tensor.cc \ + ./debug/graph_debug.cc \ + ./opsproto/opsproto_manager.cc \ + ../ops/op_imp.cpp \ + option/ge_context.cc \ + option/ge_local_context.cc \ + ./runtime_inference_context.cc \ + ./utils/node_utils.cc \ + +COMMON_LOCAL_C_INCLUDES := \ + proto/om.proto \ + proto/ge_ir.proto \ + proto_inner/ge_onnx.proto \ + proto/insert_op.proto \ + proto/task.proto \ + proto/fwk_adapter.proto \ + proto/op_mapping_info.proto \ + proto/dump_task.proto \ + inc \ + inc/external \ + inc/external/graph \ + inc/graph \ + inc/common \ + common \ + common/graph \ + third_party/protobuf/include \ + libc_sec/include \ + ops/built-in/op_proto/inc \ + + +#compiler for host +include $(CLEAR_VARS) +LOCAL_MODULE := libgraph + +LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 +LOCAL_CPPFLAGS += -fexceptions + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libprotobuf \ + libslog \ + liberror_manager \ + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_HOST_SHARED_LIBRARY) + +#compiler for host +include $(CLEAR_VARS) +LOCAL_MODULE := stub/libgraph + +LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 +LOCAL_CPPFLAGS += -fexceptions + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_HOST_SHARED_LIBRARY) + +#compiler for host +include $(CLEAR_VARS) +LOCAL_MODULE := fwk_stub/libgraph + +LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 +LOCAL_CPPFLAGS += -fexceptions + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := \ + ../../out/graph/lib64/stub/attr_value.cc \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + ../../out/graph/lib64/stub/inference_context.cc \ + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_HOST_SHARED_LIBRARY) + +#compiler for device +include $(CLEAR_VARS) +LOCAL_MODULE := libgraph + +LOCAL_CFLAGS += -O2 + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libprotobuf \ + libslog \ + liberror_manager \ + +LOCAL_LDFLAGS := -lrt -ldl + +ifeq ($(device_os),android) +LOCAL_LDFLAGS := -ldl +endif + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_SHARED_LIBRARY) + +#compiler for device +include $(CLEAR_VARS) +LOCAL_MODULE := stub/libgraph + +LOCAL_CFLAGS += -O2 + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +ifeq ($(device_os),android) +LOCAL_LDFLAGS := -ldl +endif + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_SHARED_LIBRARY) + +#compiler for device +include $(CLEAR_VARS) +LOCAL_MODULE := fwk_stub/libgraph + +LOCAL_CFLAGS += -O2 + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := \ + ../../out/graph/lib64/stub/attr_value.cc \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + ../../out/graph/lib64/stub/inference_context.cc \ + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +ifeq ($(device_os),android) +LOCAL_LDFLAGS := -ldl +endif + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_SHARED_LIBRARY) + +# compile for ut/st +include $(CLEAR_VARS) +LOCAL_MODULE := libgraph + +LOCAL_CFLAGS += + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libprotobuf \ + libslog \ + liberror_manager \ + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_LLT_SHARED_LIBRARY) + + +#compiler for host static lib +include $(CLEAR_VARS) +LOCAL_MODULE := libgraph + +LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 +LOCAL_CPPFLAGS += -fexceptions + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) + +LOCAL_STATIC_LIBRARIES := \ + libprotobuf \ + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libslog \ + liberror_manager \ + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_HOST_STATIC_LIBRARY) + +#compiler for device static lib +include $(CLEAR_VARS) +LOCAL_MODULE := libgraph + +LOCAL_CFLAGS += -O2 + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) + +LOCAL_STATIC_LIBRARIES := \ + libprotobuf \ + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libslog \ + liberror_manager \ + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_STATIC_LIBRARY) diff --git a/src/common/graph/inference_context.cc b/src/common/graph/inference_context.cc new file mode 100644 index 00000000..ed8193dc --- /dev/null +++ b/src/common/graph/inference_context.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "external/graph/inference_context.h" +#include "debug/ge_util.h" + +namespace ge { +class ShapeAndTypeImpl { + public: + ShapeAndTypeImpl() = default; + ~ShapeAndTypeImpl() = default; + + ShapeAndTypeImpl(const Shape &shape, DataType data_type) : shape_(shape), data_type_(data_type) {} + + Shape shape_; + DataType data_type_ = DT_UNDEFINED; +}; + +class InferenceContextImpl { + public: + InferenceContextImpl() = default; + ~InferenceContextImpl() = default; + + // For deliver to op in pair, help to support dynamic shape + std::vector marks_; + std::vector> input_handle_shapes_and_types_; + std::vector> output_handle_shapes_and_types_; +}; + +ShapeAndType::ShapeAndType() { shape_and_type_impl_ = ComGraphMakeShared(); } + +ShapeAndType::ShapeAndType(const Shape &shape, DataType data_type) { + shape_and_type_impl_ = ComGraphMakeShared(shape, data_type); +} + +void ShapeAndType::SetShape(const Shape &shape) { + if (shape_and_type_impl_ != nullptr) { + shape_and_type_impl_->shape_ = shape; + } +} + +void ShapeAndType::SetType(DataType data_type) { + if (shape_and_type_impl_ != nullptr) { + shape_and_type_impl_->data_type_ = data_type; + } +} + +Shape ShapeAndType::GetShape() const { + if (shape_and_type_impl_ != nullptr) { + return shape_and_type_impl_->shape_; + } + return Shape(); +} + +DataType ShapeAndType::GetDataType() const { + if (shape_and_type_impl_ != nullptr) { + return shape_and_type_impl_->data_type_; + } + return DT_UNDEFINED; +} + +InferenceContext::InferenceContext(std::unique_ptr &impl) { + inference_context_impl_ = std::move(impl); +} + +std::unique_ptr InferenceContext::Create() { + std::unique_ptr impl = + std::unique_ptr(new (std::nothrow) InferenceContextImpl()); + if (impl == nullptr) { + return nullptr; + } + + return std::unique_ptr(new (std::nothrow) InferenceContext(impl)); +} + +void InferenceContext::SetInputHandleShapesAndTypes(std::vector> &&shapes_and_types) { + inference_context_impl_->input_handle_shapes_and_types_.swap(shapes_and_types); +} + +const std::vector> &InferenceContext::GetInputHandleShapesAndTypes() const { + return inference_context_impl_->input_handle_shapes_and_types_; +} + +const std::vector> &InferenceContext::GetOutputHandleShapesAndTypes() const { + return inference_context_impl_->output_handle_shapes_and_types_; +} + +void InferenceContext::SetOutputHandleShapesAndTypes(const std::vector> &shapes_and_types) { + inference_context_impl_->output_handle_shapes_and_types_ = shapes_and_types; +} + +void InferenceContext::SetOutputHandleShapesAndTypes(std::vector> &&shapes_and_types) { + inference_context_impl_->output_handle_shapes_and_types_.swap(shapes_and_types); +} + +void InferenceContext::SetMarks(const std::vector &marks) { inference_context_impl_->marks_ = marks; } + +const std::vector &InferenceContext::GetMarks() const { return inference_context_impl_->marks_; } +} // namespace ge diff --git a/src/common/graph/model.cc b/src/common/graph/model.cc new file mode 100644 index 00000000..a3628204 --- /dev/null +++ b/src/common/graph/model.cc @@ -0,0 +1,190 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/model.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "debug/ge_attr_define.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/model_serialize.h" +#include "proto/ge_ir.pb.h" +#include "utils/attr_utils.h" +#include "utils/ge_ir_utils.h" + +using google::protobuf::io::FileInputStream; +using google::protobuf::io::FileOutputStream; +using google::protobuf::io::ZeroCopyInputStream; + +namespace { +const int DEFAULT_VERSION = 1; +const int ACCESS_PERMISSION_BITS = 0400; +} // namespace + +namespace ge { +void Model::Init() { + (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0); + (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0); + (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0); + (void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0); + (void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0); + (void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI); + version_ = 0; +} + +Model::Model() { + attrs_.InitDefault(); + Init(); +} + +Model::Model(const string &name, const string &custom_version) + : name_(name), version_(DEFAULT_VERSION), platform_version_(custom_version) { + attrs_.InitDefault(); + Init(); +} + +string Model::GetName() const { return name_; } + +void Model::SetName(const string &name) { name_ = name; } + +uint32_t Model::GetVersion() const { return version_; } + +string Model::GetPlatformVersion() const { return platform_version_; } + +void Model::SetGraph(const ge::Graph &graph) { graph_ = graph; } + +Graph Model::GetGraph() const { return graph_; } + +graphStatus Model::Save(Buffer &buffer, bool is_dump) const { + ModelSerialize serialize; + buffer = serialize.SerializeModel(*this, is_dump); + return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED; +} + +void Model::SetAttr(const ProtoAttrMapHelper &attrs) { attrs_ = attrs; } + +graphStatus Model::Load(const uint8_t *data, size_t len, Model &model) { + ModelSerialize serialize; + model = serialize.UnserializeModel(data, len); + return model.IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED; +} + +graphStatus Model::SaveToFile(const string &file_name) const { + Buffer buffer; + if ((*this).Save(buffer) != GRAPH_SUCCESS) { + GE_LOGE("save to file fail."); + return GRAPH_FAILED; + } + // Write file + ge::proto::ModelDef ge_proto; + if (buffer.GetData() != nullptr) { + std::string str((const char *)buffer.GetData(), buffer.GetSize()); + if (!ge_proto.ParseFromString(str)) { + return GRAPH_FAILED; + } + char real_path[PATH_MAX] = {0x00}; + if (strlen(file_name.c_str()) >= PATH_MAX) { + return GRAPH_FAILED; + } + if (realpath(file_name.c_str(), real_path) == nullptr) { + GELOGI("file %s does not exit, it will be created.", file_name.c_str()); + } + int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS); + if (fd < 0) { + GELOGE(GRAPH_FAILED, "open file failed, file path [%s], %s ", real_path, strerror(errno)); + return GRAPH_FAILED; + } + bool ret = ge_proto.SerializeToFileDescriptor(fd); + if (!ret) { + GELOGE(GRAPH_FAILED, "SerializeToFileDescriptor failed"); + if (close(fd) != 0) { + GELOGE(GRAPH_FAILED, "close file descriptor fail."); + return GRAPH_FAILED; + } + return GRAPH_FAILED; + } + if (close(fd) != 0) { + GELOGE(GRAPH_FAILED, "close file descriptor fail."); + return GRAPH_FAILED; + } + if (!ret) { + GELOGE(GRAPH_FAILED, "function [SerializeToFileDescriptor] failed"); + return GRAPH_FAILED; + } + } + return GRAPH_SUCCESS; +} + +graphStatus Model::Load(ge::proto::ModelDef &model_def) { + ModelSerialize serialize; + *this = serialize.UnserializeModel(model_def); + return this->IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED; +} + +bool Model::IsValid() const { return graph_.IsValid(); } + +graphStatus Model::LoadFromFile(const string &file_name) { + char real_path[PATH_MAX] = {0x00}; + if (strlen(file_name.c_str()) >= PATH_MAX) { + return GRAPH_FAILED; + } + if (realpath(file_name.c_str(), real_path) == nullptr) { + GELOGE(GRAPH_FAILED, "file %s does not exit, can not load.", file_name.c_str()); + return GRAPH_FAILED; + } + int fd = open(real_path, O_RDONLY); + if (fd < 0) { + GELOGE(GRAPH_FAILED, "open file failed, %s", strerror(errno)); + return GRAPH_FAILED; + } + + ge::proto::ModelDef model_def; + bool ret = model_def.ParseFromFileDescriptor(fd); + if (!ret) { + GELOGE(GRAPH_FAILED, "ParseFromFileDescriptor failed"); + if (close(fd) != 0) { + GELOGE(GRAPH_FAILED, "close file descriptor fail."); + return GRAPH_FAILED; + } + return GRAPH_FAILED; + } + if (close(fd) != 0) { + GELOGE(GRAPH_FAILED, "close file descriptor fail."); + return GRAPH_FAILED; + } + if (!ret) { + GELOGE(GRAPH_FAILED, "function [ParseFromFileDescriptor] failed"); + return GRAPH_FAILED; + } + return Load(model_def); +} + +ProtoAttrMapHelper Model::MutableAttrMap() { return attrs_; } + +ConstProtoAttrMapHelper Model::GetAttrMap() const { + return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); +} +} // namespace ge diff --git a/src/common/graph/model_serialize.cc b/src/common/graph/model_serialize.cc new file mode 100644 index 00000000..16855fc5 --- /dev/null +++ b/src/common/graph/model_serialize.cc @@ -0,0 +1,763 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/model_serialize.h" +#include + +#include +#include + +#include "debug/ge_attr_define.h" +#include "debug/ge_log.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/detail/model_serialize_imp.h" +#include "proto/ge_ir.pb.h" +#include "utils/graph_utils.h" +#include "debug/ge_op_types.h" + +using std::map; +using std::string; + +namespace ge { +bool ModelSerializeImp::ParseNodeIndex(const string &node_index, string &node_name, int32_t &index) { + auto sep = node_index.rfind(":"); + if (sep == string::npos) { + GELOGW("separator is not found in node_index."); + return false; + } + node_name = node_index.substr(0, sep); + auto index_str = node_index.substr(sep + 1); + index = static_cast(std::strtol(index_str.c_str(), nullptr, 10)); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeTensor(const ConstGeTensorPtr &tensor, + proto::TensorDef *tensor_proto) { + GE_CHK_BOOL_EXEC(tensor != nullptr, return false, "tensor is null."); + GE_CHK_BOOL_EXEC(tensor_proto != nullptr, return false, "tensor_proto is null."); + + if (tensor->tensor_def_.GetProtoMsg() != nullptr) { + *tensor_proto = *tensor->tensor_def_.GetProtoMsg(); + return true; + } + return false; +} + +bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_proto) { + GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is null."); + GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); + + op_def_proto->clear_input(); + // Inputs + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + if (in_data_anchor != nullptr) { + auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) { + op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + + std::to_string(peer_out_anchor->GetIdx())); + } else { + op_def_proto->add_input(""); + } + } + } + // Control edge + auto control_anchor = node->GetInControlAnchor(); + if (control_anchor != nullptr) { + auto peer_out_anchors = control_anchor->GetPeerOutControlAnchors(); + for (const auto &peer_out_anchor : peer_out_anchors) { + if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) { + op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":-1"); + } + } + } + return true; +} + +bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { + GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null."); + GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); + if (op_desc->op_def_.GetProtoMsg() != nullptr) { + *op_def_proto = *op_desc->op_def_.GetProtoMsg(); + // Delete unnecessary attr + if (is_dump) { + auto attr = op_def_proto->mutable_attr(); + attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF); + attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF); + attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF); + GE_IF_BOOL_EXEC((op_def_proto->type() == CONSTANT || op_def_proto->type() == CONSTANTOP), + attr->erase(ATTR_NAME_WEIGHTS)); + } + op_def_proto->clear_input_desc(); + op_def_proto->clear_output_desc(); + // Input descs + if (op_desc->GetAllInputsSize() > 0) { + auto size = static_cast(op_desc->GetAllInputsSize()); + for (uint32_t i = 0; i < size; i++) { + auto tensor_desc = op_desc->GetInputDescPtrDfault(i); + if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { + *op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); + } + } + } + // Output descs + if (op_desc->GetOutputsSize() > 0) { + auto size = static_cast(op_desc->GetOutputsSize()); + for (uint32_t i = 0; i < size; i++) { + auto tensor_desc = op_desc->GetOutputDescPtr(i); + if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { + *op_def_proto->add_output_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); + } + } + } + + op_def_proto->set_id(op_desc->GetId()); + for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { + op_def_proto->add_subgraph_name(name); + } + OpDescToAttrDef(op_desc, op_def_proto); + } + return true; +} + +void ModelSerializeImp::OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) { + proto::AttrDef key_in; + proto::AttrDef value_in; + auto op_desc_attr = op_def_proto->mutable_attr(); + if (!op_desc->input_name_idx_.empty()) { + for (auto &item : op_desc->input_name_idx_) { + key_in.mutable_list()->add_s(item.first); + value_in.mutable_list()->add_i(item.second); + } + op_desc_attr->insert({"_input_name_key", key_in}); + op_desc_attr->insert({"_input_name_value", value_in}); + } + proto::AttrDef key_out; + proto::AttrDef value_out; + if (!op_desc->output_name_idx_.empty()) { + for (auto &item : op_desc->output_name_idx_) { + key_out.mutable_list()->add_s(item.first); + value_out.mutable_list()->add_i(item.second); + } + op_desc_attr->insert({"_output_name_key", key_out}); + op_desc_attr->insert({"_output_name_value", value_out}); + } + proto::AttrDef opt_input; + if (!op_desc->optional_input_names_.empty()) { + for (auto &item : op_desc->optional_input_names_) { + opt_input.mutable_list()->add_s(item); + } + op_desc_attr->insert({"_opt_input", opt_input}); + } +} + +bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) { + if (node == nullptr || op_def_proto == nullptr) { + GELOGE(GRAPH_FAILED, "Input Para Node Invalid"); + return false; + } + if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, is_dump)) { + GELOGE(GRAPH_FAILED, "Serialize OpDesc failed"); + return false; + } + if (SerializeEdge(node, op_def_proto)) { + return true; + } else { + return false; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph, + proto::GraphDef *graph_proto, + bool is_dump) { + if (graph == nullptr || graph_proto == nullptr) { + GELOGE(GRAPH_FAILED, "Input para Invalid"); + return false; + } + graph_proto->set_name(graph->GetName()); + // Inputs + for (const auto &input : graph->GetInputNodes()) { + if (input != nullptr) { + graph_proto->add_input(input->GetName() + ":0"); + } + } + // Outputs + for (const auto &output : graph->GetGraphOutNodesInfo()) { + if (output.first != nullptr) { + graph_proto->add_output(output.first->GetName() + ":" + std::to_string(output.second)); + GELOGI("Add output to graph proto, node name:%s, index:%ld", output.first->GetName().c_str(), output.second); + } + } + if (graph->attrs_.GetProtoMsg() != nullptr) { + *graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg(); + } + for (const auto &node : graph->GetDirectNode()) { + if (!SerializeNode(node, graph_proto->add_op(), is_dump)) { + if (node->GetOpDesc() != nullptr) { + GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str()); + } + return false; + } + } + return true; +} + +bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto, bool is_dump) { + if (model_proto == nullptr) { + GELOGE(GRAPH_FAILED, "model_proto para Invalid"); + return false; + } + model_proto->set_name(model.GetName()); + model_proto->set_custom_version(model.GetPlatformVersion()); + model_proto->set_version(model.GetVersion()); + if (model.attrs_.GetProtoMsg()) { + *model_proto->mutable_attr() = *model.attrs_.GetProtoMsg(); + } + auto &graph = model.graph_; + auto compute_graph = GraphUtils::GetComputeGraph(graph); + if (compute_graph == nullptr) { + GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr"); + return false; + } + if (!SerializeGraph(compute_graph, model_proto->add_graph(), is_dump)) { + GELOGE(GRAPH_FAILED, "SerializeGraph fail"); + return false; + } + + for (auto subgraph : compute_graph->GetAllSubgraphs()) { + if (!SerializeGraph(subgraph, model_proto->add_graph(), is_dump)) { + GELOGE(GRAPH_FAILED, "Serialize subgraph failed"); + return false; + } + } + + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeTensor( + GeTensorPtr &tensor, proto::TensorDef &tensor_proto) { + tensor = std::shared_ptr(new (std::nothrow) GeTensor(protobuf_owner_, &tensor_proto)); + if (tensor == nullptr) { + GELOGE(GRAPH_FAILED, "tensor is nullptr"); + return false; + } else { + return true; + } +} + +void ModelSerializeImp::AttrDefToOpDesc(OpDescPtr &op_desc, std::vector &key_in, std::vector &key_out, + std::vector &value_in, std::vector &value_out, + std::vector &opt_input) { + if (!key_in.empty()) { + if (key_in.size() != value_in.size()) { + GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(), + value_in.size()); + } else { + for (uint32_t i = 0; i < key_in.size(); ++i) { + op_desc->input_name_idx_.insert(std::pair(key_in.at(i), value_in.at(i))); + } + } + } + if (!key_out.empty()) { + if (key_out.size() != value_out.size()) { + GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(), + value_out.size()); + } else { + for (uint32_t i = 0; i < key_out.size(); ++i) { + op_desc->output_name_idx_.insert(std::pair(key_out.at(i), value_out.at(i))); + } + } + } + if (!opt_input.empty()) { + for (const auto &i : opt_input) { + op_desc->optional_input_names_.insert(i); + } + } +} + +bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) { + std::vector opt_input; + std::vector key_in; + std::vector value_in; + if (op_def_proto.attr().count("_opt_input") > 0) { + auto &name_list = op_def_proto.attr().at("_opt_input").list(); + for (const auto &item_s : name_list.s()) { + opt_input.push_back(item_s); + } + auto op_desc_attr = op_def_proto.mutable_attr(); + op_desc_attr->erase("_opt_input"); + } + if (op_def_proto.attr().count("_input_name_key") > 0) { + auto &output_name_key_list = op_def_proto.attr().at("_input_name_key").list(); + for (const auto &item_s : output_name_key_list.s()) { + key_in.push_back(item_s); + } + auto op_desc_attr = op_def_proto.mutable_attr(); + op_desc_attr->erase("_input_name_key"); + } + if (op_def_proto.attr().count("_input_name_value") > 0) { + auto &input_name_value_list = op_def_proto.attr().at("_input_name_value").list(); + for (const auto &item_i : input_name_value_list.i()) { + value_in.push_back(static_cast(item_i)); + } + auto op_desc_attr = op_def_proto.mutable_attr(); + op_desc_attr->erase("_input_name_value"); + } + std::vector key_out; + std::vector value_out; + if (op_def_proto.attr().count("_output_name_key") > 0) { + auto &output_name_key_list = op_def_proto.attr().at("_output_name_key").list(); + for (const auto &item_s : output_name_key_list.s()) { + key_out.push_back(item_s); + } + auto op_desc_attr = op_def_proto.mutable_attr(); + op_desc_attr->erase("_output_name_key"); + } + if (op_def_proto.attr().count("_output_name_value") > 0) { + auto &output_name_value_list = op_def_proto.attr().at("_output_name_value").list(); + for (const auto &item_i : output_name_value_list.i()) { + value_out.push_back(static_cast(item_i)); + } + auto op_desc_attr = op_def_proto.mutable_attr(); + op_desc_attr->erase("_output_name_value"); + } + + op_desc = std::shared_ptr(new (std::nothrow) OpDesc(protobuf_owner_, &op_def_proto)); + GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr."); + + // Input tensor + for (auto &input_desc : *op_def_proto.mutable_input_desc()) { + std::shared_ptr temp_value = + std::shared_ptr(new (std::nothrow) GeTensorDesc(protobuf_owner_, &input_desc)); + GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); + op_desc->inputs_desc_.push_back(temp_value); + } + // Output tensor + for (auto &output_desc : *op_def_proto.mutable_output_desc()) { + std::shared_ptr temp_value = + std::shared_ptr(new (std::nothrow) GeTensorDesc(protobuf_owner_, &output_desc)); + GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); + op_desc->outputs_desc_.push_back(temp_value); + } + + op_desc->SetId(op_def_proto.id()); + uint32_t graph_index = 0; + for (const std::string &name : op_def_proto.subgraph_name()) { + op_desc->AddSubgraphName(name); + op_desc->SetSubgraphInstanceName(graph_index++, name); + } + + // insert name index by key and value + AttrDefToOpDesc(op_desc, key_in, key_out, value_in, value_out, opt_input); + + return true; +} + +bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op_def_proto) { + GE_RT_FALSE_CHECK_NOTNULL(graph); + OpDescPtr op_desc = nullptr; + if (!UnserializeOpDesc(op_desc, op_def_proto)) { + GELOGW("UnserializeOpDesc error."); + } + + NodePtr node = graph->AddNode(op_desc, op_desc->GetId()); + GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr."); + + // Inputs + int dst_index = 0; + for (const auto &input : op_def_proto.input()) { + string node_name; + int32_t index = 0; + if (ParseNodeIndex(input, node_name, index)) { + node_input_node_names_.push_back(NodeNameNodeReq{node_name, index, node, dst_index, op_def_proto.name()}); + } + if (index >= 0) { + dst_index++; + } + } + node_map_[op_def_proto.name()] = node; + return true; +} + +bool ModelSerializeImp::HandleNodeNameRef() { + // Edges + for (auto &item : node_input_node_names_) { + auto src_node_it = node_map_.find(item.src_node_name); + if (src_node_it == node_map_.end()) { + GELOGE(GRAPH_FAILED, "cannot find node %s", item.src_node_name.c_str()); + return false; + } + GE_IF_BOOL_EXEC(src_node_it->second == nullptr || item.dst_node == nullptr, continue); + if (item.src_out_index >= 0) { + auto src_anchor = src_node_it->second->GetOutDataAnchor(item.src_out_index); + auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index); + if (src_anchor == nullptr || dst_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "get anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, + item.dst_node_name.c_str(), item.dst_in_index); + return false; + } + GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 + } else { + // Control edge + auto src_anchor = src_node_it->second->GetOutControlAnchor(); + auto dst_anchor = item.dst_node->GetInControlAnchor(); + if (src_anchor != nullptr && dst_anchor != nullptr) { + GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 + } + } + } + // Graph input + for (auto &item : graph_input_node_names_) { + auto node_it = node_map_.find(item.node_name); + if (node_it == node_map_.end()) { + GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str()); + return false; + } + GE_IF_BOOL_EXEC(item.graph == nullptr, continue); + auto ret = item.graph->AddInputNode(node_it->second); + if (ret == nullptr) { + return false; + } + } + // Graph output + for (auto &item : graph_output_node_names_) { + auto node_it = node_map_.find(item.node_name); + if (node_it == node_map_.end()) { + GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str()); + return false; + } + + GE_IF_BOOL_EXEC(item.graph == nullptr, continue); + auto ret = item.graph->AddOutputNodeByIndex(node_it->second, item.index); + GELOGI("node name:%s, item.index:%ld", node_it->second->GetName().c_str(), item.index); + if (ret == nullptr) { + GELOGE(GRAPH_FAILED, "AddOutputNode failed."); + return false; + } + } + node_input_node_names_.clear(); + graph_input_node_names_.clear(); + graph_output_node_names_.clear(); + node_map_.clear(); + return true; +} + +bool ModelSerializeImp::RebuildOwnership(ComputeGraphPtr &compute_graph, map &subgraphs) { + std::queue all_graphs; + all_graphs.emplace(compute_graph); + while (!all_graphs.empty()) { + ComputeGraphPtr graph = all_graphs.front(); + all_graphs.pop(); + + for (const NodePtr &node : graph->GetDirectNode()) { + const OpDescPtr op_desc = node->GetOpDesc(); + for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { + auto it = subgraphs.find(name); + if (it == subgraphs.end()) { + GELOGE(GRAPH_FAILED, "Node:%s, Subgraph:%s not found, num:%zu.", op_desc->GetName().c_str(), name.c_str(), + subgraphs.size()); + return false; + } + + ComputeGraphPtr &subgraph = it->second; + subgraph->SetParentGraph(graph); + subgraph->SetParentNode(node); + compute_graph->AddSubgraph(subgraph->GetName(), subgraph); + all_graphs.emplace(subgraph); + } + } + } + + return true; +} + +bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) { + model.name_ = model_proto.name(); + model.version_ = model_proto.version(); + model.platform_version_ = model_proto.custom_version(); + model.attrs_ = ProtoAttrMapHelper(protobuf_owner_, model_proto.mutable_attr()); + + auto &graphs_proto = *model_proto.mutable_graph(); + if (!graphs_proto.empty()) { + auto &graph_proto = graphs_proto[0]; + ComputeGraphPtr compute_graph_ptr; + if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) { + model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr); + } + + // 0 is main graph, following is subgraph. + map subgraphs; + for (int idx = 1; idx < graphs_proto.size(); ++idx) { + ComputeGraphPtr subgraph; + ModelSerializeImp impl; + if (!impl.UnserializeGraphWithoutEdge(subgraph, graphs_proto[idx])) { + GELOGE(GRAPH_FAILED, "UnserializeGraphWithoutEdge failed"); + return false; + } + + if (!impl.HandleNodeNameRef()) { + GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); + return false; + } + + subgraphs[subgraph->GetName()] = subgraph; + } + + if (!RebuildOwnership(compute_graph_ptr, subgraphs)) { + GELOGE(GRAPH_FAILED, "Rebuild graph ownership failed"); + return false; + } + } + + if (!HandleNodeNameRef()) { + GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); + return false; + } + return true; +} + +bool ModelSerializeImp::UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graph_proto) { + graph = ComGraphMakeShared(graph_proto.name()); + if (graph == nullptr) { + GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); + return false; + } + + // Inputs + for (auto input : graph_proto.input()) { + string node_name; + int32_t index; + if (ParseNodeIndex(input, node_name, index)) { + graph_input_node_names_.push_back(NodeNameGraphReq{node_name, index, graph}); + } + } + // Outputs + for (auto output : graph_proto.output()) { + string node_name; + int32_t index; + if (ParseNodeIndex(output, node_name, index)) { + graph_output_node_names_.push_back(NodeNameGraphReq{node_name, index, graph}); + } + } + graph->attrs_ = ProtoAttrMapHelper(protobuf_owner_, graph_proto.mutable_attr()); + for (auto &op_def_proto : *graph_proto.mutable_op()) { + if (!UnserializeNode(graph, op_def_proto)) { + GELOGE(GRAPH_FAILED, "UnserializeNode fail"); + return false; + } + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeGraph(ComputeGraphPtr &graph, + proto::GraphDef &graph_proto) { + if (!UnserializeGraphWithoutEdge(graph, graph_proto)) { + GELOGW("UnserializeGraphWithoutEdge fail"); + } + if (!HandleNodeNameRef()) { + GELOGE(GRAPH_FAILED, "Link Anchor or set graph input or output fail"); + return false; + } + return true; +} + +bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf::Message *proto) { + GE_CHK_BOOL_EXEC(data != nullptr, return false, "data is null."); + GE_CHK_BOOL_EXEC(proto != nullptr, return false, "proto is null."); + + google::protobuf::io::CodedInputStream coded_stream(data, len); + // 2048M -1 + coded_stream.SetTotalBytesLimit(INT32_MAX, -1); + if (!proto->ParseFromCodedStream(&coded_stream)) { + GELOGE(GRAPH_FAILED, "ReadProtoFromBinaryFile failed len %zu", len); + return false; + } + return true; +} + +Buffer ModelSerialize::SerializeModel(const Model &model, bool is_dump) { + proto::ModelDef model_def; + ModelSerializeImp imp; + if (!imp.SerializeModel(model, &model_def, is_dump)) { + return Buffer(); + } +#if !defined(__ANDROID__) && !defined(ANDROID) + Buffer buffer(model_def.ByteSizeLong()); +#else + Buffer buffer(model_def.ByteSize()); +#endif + GE_CHK_BOOL_ONLY_LOG(buffer.GetSize() != 0, "get size failed"); + GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); + auto ret = model_def.SerializeToArray(buffer.GetData(), static_cast(buffer.GetSize())); + if (ret != true) { + GELOGW("serialize to array fail."); + } + return buffer; +} + +size_t ModelSerialize::GetSerializeModelSize(const Model &model) { + proto::ModelDef model_def; + ModelSerializeImp imp; + if (!imp.SerializeModel(model, &model_def)) { + return 0; + } +#if !defined(__ANDROID__) && !defined(ANDROID) + return model_def.ByteSizeLong(); +#else + return model_def.ByteSize(); +#endif +} + +Model ModelSerialize::UnserializeModel(const uint8_t *data, size_t len) { + if (data == nullptr) { + GELOGE(GRAPH_FAILED, "data is nullptr"); + return Model(); + } + + std::shared_ptr model_proto_ptr; + model_proto_ptr = ComGraphMakeShared(); + if (model_proto_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed"); + return Model(); + } + + auto &model_proto = *model_proto_ptr; + if (!ReadProtoFromBinaryFile(data, len, &model_proto)) { + GELOGE(GRAPH_FAILED, "ParseFromArray fail"); + return Model(); + } + + Model model; + ModelSerializeImp imp; + imp.SetProtobufOwner(model_proto_ptr); + if (!imp.UnserializeModel(model, model_proto)) { + GELOGE(GRAPH_FAILED, "Unserialize Model fail"); + return Model(); + } + return model; +} + +Model ModelSerialize::UnserializeModel(ge::proto::ModelDef &model_def) { + std::shared_ptr model_def_ptr = ComGraphMakeShared(model_def); + GE_CHK_BOOL_EXEC(model_def_ptr != nullptr, return Model(), "mode_def make shared failed"); + + ModelSerializeImp imp; + imp.SetProtobufOwner(model_def_ptr); + Model model; + if (!imp.UnserializeModel(model, *model_def_ptr)) { + GELOGE(GRAPH_FAILED, "Unserialize Model fail"); + return Model(); + } + return model; +} + +Buffer ModelSerialize::SerializeGraph(const ComputeGraphPtr &graph) { + proto::GraphDef graph_def; + ModelSerializeImp imp; + if (!imp.SerializeGraph(graph, &graph_def)) { + return Buffer(); + } +#if !defined(__ANDROID__) && !defined(ANDROID) + Buffer buffer(graph_def.ByteSizeLong()); +#else + Buffer buffer(graph_def.ByteSize()); +#endif + GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed"); + GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); + auto ret = graph_def.SerializeToArray(buffer.GetData(), static_cast(buffer.GetSize())); + if (ret != true) { + GE_LOGE("serialize to array fail."); + } + + return buffer; +} + +ComputeGraphPtr ModelSerialize::UnserializeGraph(const uint8_t *data, size_t len) { + if (data == nullptr) { + GELOGE(GRAPH_FAILED, "data is nullptr"); + return nullptr; + } + + std::shared_ptr graph_proto_ptr; + graph_proto_ptr = ComGraphMakeShared(); + if (graph_proto_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); + return nullptr; + } + proto::GraphDef &graph_proto = *graph_proto_ptr; + if (!ReadProtoFromBinaryFile(data, len, &graph_proto)) { + GELOGE(GRAPH_FAILED, "ParseFromArray fail"); + return nullptr; + } + + ComputeGraphPtr graph; + ModelSerializeImp imp; + imp.SetProtobufOwner(graph_proto_ptr); + if (!imp.UnserializeGraph(graph, graph_proto)) { + return nullptr; + } + return graph; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer ModelSerialize::SerializeOpDesc(const ConstOpDescPtr &op_desc) { + proto::OpDef op_def; + ModelSerializeImp imp; + if (!imp.SerializeOpDesc(op_desc, &op_def)) { + return Buffer(); + } +#if !defined(__ANDROID__) && !defined(ANDROID) + Buffer buffer(op_def.ByteSizeLong()); +#else + Buffer buffer(op_def.ByteSize()); +#endif + GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed"); + GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); + auto ret = op_def.SerializeToArray(buffer.GetData(), static_cast(buffer.GetSize())); + if (ret != true) { + GE_LOGE("serialize to array fail."); + } + + return buffer; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr ModelSerialize::UnserializeOpDesc(const uint8_t *data, + size_t len) { + if (data == nullptr) { + GELOGE(GRAPH_FAILED, "data is nullptr"); + return nullptr; + } + + std::shared_ptr op_def_ptr; + op_def_ptr = ComGraphMakeShared(); + if (op_def_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); + return nullptr; + } + proto::OpDef &op_def = *op_def_ptr; + if (!ReadProtoFromBinaryFile(data, len, &op_def)) { + GELOGE(GRAPH_FAILED, "ParseFromArray fail"); + return nullptr; + } + + OpDescPtr op_desc; + ModelSerializeImp imp; + imp.SetProtobufOwner(op_def_ptr); + if (!imp.UnserializeOpDesc(op_desc, op_def)) { + GELOGW("UnserializeOpDesc error."); + } + return op_desc; +} +} // namespace ge diff --git a/src/common/graph/module.mk b/src/common/graph/module.mk new file mode 100644 index 00000000..1e00b7fc --- /dev/null +++ b/src/common/graph/module.mk @@ -0,0 +1,3 @@ +LOCAL_PATH := $(call my-dir) + +include $(LOCAL_PATH)/graph.mk diff --git a/src/common/graph/node.cc b/src/common/graph/node.cc new file mode 100644 index 00000000..d33c6008 --- /dev/null +++ b/src/common/graph/node.cc @@ -0,0 +1,878 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/node.h" +#include +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "external/graph/operator_factory.h" +#include "framework/common/debug/ge_log.h" +#include "graph/ge_tensor.h" +#include "graph/operator_factory_impl.h" +#include "graph/shape_refiner.h" +#include "utils/ge_ir_utils.h" +#include "utils/node_utils.h" +#include "utils/op_desc_utils.h" +#include "common/util/error_manager/error_manager.h" + +using std::string; +using std::vector; + +namespace ge { +Node::Node(const OpDescPtr &op, const ComputeGraphPtr &owner_graph) + : op_(op), + owner_graph_(owner_graph), + in_data_anchors_(), + out_data_anchors_(), + in_control_anchor_(nullptr), + out_control_anchor_(nullptr), + attrs_(), + has_init_(false) { + anchor_status_updated_ = false; +} + +Node::~Node() { + for (const auto &in_data_anchor : in_data_anchors_) { + if (in_data_anchor != nullptr) { + in_data_anchor->UnlinkAll(); + } + } + for (const auto &out_data_anchor : out_data_anchors_) { + if (out_data_anchor != nullptr) { + out_data_anchor->UnlinkAll(); + } + } + if (in_control_anchor_ != nullptr) { + in_control_anchor_->UnlinkAll(); + } + if (out_control_anchor_ != nullptr) { + out_control_anchor_->UnlinkAll(); + } +} + +graphStatus Node::Init() { + if (has_init_) { + return GRAPH_SUCCESS; + } + GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); + size_t size = op_->GetAllInputsSize(); + for (size_t i = 0; i < size; i++) { + std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), i); + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Current in_data_anchor is null, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + in_data_anchors_.push_back(anchor); + } + size = op_->GetOutputsSize(); + for (size_t i = 0; i < size; i++) { + std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), i); + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Current out_data_anchor is null, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + out_data_anchors_.push_back(anchor); + } + in_control_anchor_ = ComGraphMakeShared(shared_from_this(), -1); + out_control_anchor_ = ComGraphMakeShared(shared_from_this(), -1); + if (in_control_anchor_ == nullptr || out_control_anchor_ == nullptr) { + GELOGE(GRAPH_FAILED, "Current in_control_anchor or out_control_anchor is null, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + has_init_ = true; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string Node::GetName() const { + GE_CHK_BOOL_EXEC(op_ != nullptr, return string(), "original OpDesc is nullptr"); + return op_->GetName(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string Node::GetType() const { + GE_CHK_BOOL_EXEC(op_ != nullptr, return string(), "original OpDesc is nullptr"); + return op_->GetType(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAttrsAreEqual(const Node &r_node) const { + const auto &attr_map = this->attrs_; + const auto &r_attr_map = r_node.attrs_; + // 1.Verify node's map size + if (attr_map.size() != r_attr_map.size()) { + GELOGE(GRAPH_FAILED, "Size of node's attr map verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + // 2.Verify node's map key, verify values is temporarily not implemented + for (const auto &it : attr_map) { + if (r_attr_map.count(it.first) == 0) { + GELOGE(GRAPH_FAILED, "Key of node's attr map verify failed, node name: %s key name: %s.", this->GetName().c_str(), + it.first.c_str()); + return false; + } + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeMembersAreEqual(const Node &r_node) const { + return ((((this->op_ != nullptr) && (r_node.op_ != nullptr) && (IsEqual(*(this->op_), *(r_node.op_), "node.op_"))) || + ((this->op_ == nullptr) && (r_node.op_ == nullptr))) && + IsEqual(this->has_init_, r_node.has_init_, "node.has_init_") && + IsEqual(this->anchor_status_updated_, r_node.anchor_status_updated_, "node.anchor_status_updated_") && + IsEqual(this->send_event_id_list_, r_node.send_event_id_list_, "node.send_event_id_list_") && + IsEqual(this->recv_event_id_list_, r_node.recv_event_id_list_, "node.recv_event_id_list_")); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(const AnchorPtr &left_anchor, + const AnchorPtr &right_anchor, + size_t i) const { + GE_IF_BOOL_EXEC(left_anchor == nullptr, GELOGE(GRAPH_FAILED, "left_anchor is null."); return false); + GE_IF_BOOL_EXEC(right_anchor == nullptr, GELOGE(GRAPH_FAILED, "right_anchor is null."); return false); + + const auto anchor_peer_size = left_anchor->GetPeerAnchors().size(); + const auto right_anchor_peer_size = right_anchor->GetPeerAnchors().size(); + // Firstly, verify anchor's peer anchors size equal or not + if (anchor_peer_size != right_anchor_peer_size) { + GELOGE(GRAPH_FAILED, + "Size of anchor's peer anchors verify failed, node name: %s " + "anchor_peer_size [%zu] is different form [%zu] at index [%zu].", + this->GetName().c_str(), anchor_peer_size, right_anchor_peer_size, i); + return false; + } + // Secondly, verify anchor's peer anchor owner node equal or not + for (size_t j = 0; j < anchor_peer_size; j++) { + const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); + const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); + if (peer_node == nullptr || r_peer_node == nullptr) { + GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", + this->GetName().c_str(), i, j); + return false; + } + // Determine the connection relationship by linking the node's name + if (peer_node->GetName() != r_peer_node->GetName()) { + GELOGE(GRAPH_FAILED, + "anchor's peer node name verify failed, node name: %s index[%zu]" + "peer node name %s is different from %s at index [%zu].", + this->GetName().c_str(), i, peer_node->GetName().c_str(), r_peer_node->GetName().c_str(), j); + return false; + } + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeInConnectsAreEqual(const Node &r_node) const { + // 1.Verify all in data and control anchors size + const auto in_data_anchor_size = this->GetAllInDataAnchors().size(); + const auto r_in_data_anchor_size = r_node.GetAllInDataAnchors().size(); + if (in_data_anchor_size != r_in_data_anchor_size) { + GELOGE(GRAPH_FAILED, "Size of node's in data anchors verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + const auto l_in_anchors = this->GetAllInAnchors(); + const auto r_in_anchors = r_node.GetAllInAnchors(); + // Data anchors size equal, all anchors size not equal, means control anchor size not equal + const auto in_control_anchor_size = l_in_anchors.size() - in_data_anchor_size; + const auto r_in_control_anchor_size = r_in_anchors.size() - r_in_data_anchor_size; + if (in_control_anchor_size != r_in_control_anchor_size) { + GELOGE(GRAPH_FAILED, "Size of node's in control anchors verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + // 2.Verify all in data and control anchors connect info + for (size_t i = 0; i < this->GetAllInAnchors().size(); i++) { + // Verify data anchors + if (i < in_data_anchor_size) { + const auto &in_anchor = l_in_anchors.at(i); + const auto &r_in_anchor = r_in_anchors.at(i); + if (!(NodeAnchorIsEqual(in_anchor, r_in_anchor, i))) { + GELOGE(GRAPH_FAILED, "Node's in data control anchor verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + } else { + // Verify control anchors + const auto &in_control_anchor = l_in_anchors.at(i); + const auto &r_in_control_anchor = r_in_anchors.at(i); + if (!(NodeAnchorIsEqual(in_control_anchor, r_in_control_anchor, i - in_data_anchor_size))) { + GELOGE(GRAPH_FAILED, "Node's in control anchor verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + } + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeOutConnectsAreEqual(const Node &r_node) const { + // 1.Verify all out data and control anchors size + const auto l_out_data_anchors = this->GetAllOutDataAnchors(); + const auto r_out_data_anchors = r_node.GetAllOutDataAnchors(); + const auto out_data_anchor_size = l_out_data_anchors.size(); + const auto r_out_data_anchor_size = r_out_data_anchors.size(); + if (out_data_anchor_size != r_out_data_anchor_size) { + GELOGE(GRAPH_FAILED, "Size of node's out data anchors verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + const auto l_out_anchors = this->GetAllOutAnchors(); + const auto r_out_anchors = r_node.GetAllOutAnchors(); + // Data anchors size equal, all anchors size not equal, means control anchor size not equal + const auto out_control_anchor_size = l_out_anchors.size() - out_data_anchor_size; + const auto r_out_control_anchor_size = r_out_anchors.size() - r_out_data_anchor_size; + if (out_control_anchor_size != r_out_control_anchor_size) { + GELOGE(GRAPH_FAILED, "Size of node's out control anchors verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + + // 2.Verify all out data and control anchors connect info + for (size_t i = 0; i < this->GetAllOutAnchors().size(); i++) { + // Verify data anchors + if (i < out_data_anchor_size) { + const auto &out_anchor = l_out_data_anchors.at(i); + const auto &r_out_anchor = r_out_data_anchors.at(i); + if (!(NodeAnchorIsEqual(out_anchor, r_out_anchor, i))) { + GELOGE(GRAPH_FAILED, "Node's out data control anchor verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + } else { + // Verify control anchors + const auto &out_control_anchor = l_out_anchors.at(i); + const auto &r_out_control_anchor = r_out_anchors.at(i); + if (!(NodeAnchorIsEqual(out_control_anchor, r_out_control_anchor, i - out_data_anchor_size))) { + GELOGE(GRAPH_FAILED, "Node's out control anchor verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + } + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::operator==(const Node &r_node) const { + return (NodeMembersAreEqual(r_node) && NodeAttrsAreEqual(r_node) && NodeInConnectsAreEqual(r_node) && + NodeOutConnectsAreEqual(r_node)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const NodePtr &input_node) { + // This function is deprecated, please use other two overloaded functions + GE_CHECK_NOTNULL(input_node); + // Input_node ---> this + auto out_anchors = input_node->GetAllOutDataAnchors(); + if (out_anchors.size() != 1) { + GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, only support 1", out_anchors.size()); + return GRAPH_PARAM_INVALID; + } + GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); + auto op_desc = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + if (op_->AddInputDesc(op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "add input desc failed."); + return GRAPH_FAILED; + } + std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, malloc shared_ptr failed.", out_anchors.size()); + return GRAPH_FAILED; + } + in_data_anchors_.push_back(anchor); + (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const uint32_t &index, + NodePtr input_node) { + GE_CHECK_NOTNULL(input_node); + // Input_node ---> this + auto out_anchors = input_node->GetAllOutDataAnchors(); + if (out_anchors.size() != 1) { + GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, only support 1", out_anchors.size()); + return GRAPH_PARAM_INVALID; + } + + GE_CHECK_NOTNULL(op_); + auto op_desc = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + if (op_->AddInputDesc(index, op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "add input desc failed."); + return GRAPH_FAILED; + } + + if (index < GetAllInDataAnchors().size()) { + (void)out_anchors.at(0)->LinkTo(in_data_anchors_[index]); + } else { + std::shared_ptr anchor = + ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, malloc shared_ptr failed.", out_anchors.size()); + return GRAPH_FAILED; + } + in_data_anchors_.push_back(anchor); + (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFromForParse(const NodePtr &input_node) { + // This function is used for ParseWeights. + GE_CHECK_NOTNULL(input_node); + // Input_node ---> this + auto out_anchors = input_node->GetAllOutDataAnchors(); + if (out_anchors.size() != 1) { + GELOGE(GRAPH_PARAM_INVALID, "out_anchor size is:%zu, only support 1", out_anchors.size()); + return GRAPH_PARAM_INVALID; + } + + std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, make anchor failed", out_anchors.size()); + return GRAPH_FAILED; + } + in_data_anchors_.push_back(anchor); + (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const string &name, NodePtr input_node) { + GE_CHECK_NOTNULL(input_node); + // Input_node ---> this + auto out_anchors = input_node->GetAllOutDataAnchors(); + if (out_anchors.size() != 1) { + GELOGE(GRAPH_PARAM_INVALID, "out_anchor size is:%zu, only support 1", out_anchors.size()); + return GRAPH_PARAM_INVALID; + } + + GE_CHECK_NOTNULL(op_); + auto input_op_desc = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(input_op_desc); + auto index = op_->GetInputIndexByName(name); + if (index != -1) { + if (index >= static_cast(in_data_anchors_.size())) { + GELOGE(GRAPH_FAILED, "op %s get input name %s 's index %d is illegal.", op_->GetName().c_str(), name.c_str(), + index); + return GRAPH_FAILED; + } + (void)out_anchors.at(0)->LinkTo(in_data_anchors_[index]); + } else { + std::shared_ptr anchor = + ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "in_data_anchors_size is:%zu, malloc shared_ptr failed.", in_data_anchors_.size()); + return GRAPH_FAILED; + } + in_data_anchors_.push_back(anchor); + (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); + } + if (op_->AddInputDesc(name, input_op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "add input desc failed."); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr Node::GetOwnerComputeGraph() const { + return owner_graph_.lock(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::SetOwnerComputeGraph(const ComputeGraphPtr &graph) { + if (graph == nullptr) { + return GRAPH_PARAM_INVALID; + } + owner_graph_ = graph; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllInDataAnchors() const { + return Vistor(shared_from_this(), in_data_anchors_); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllOutDataAnchors() const { + return Vistor(shared_from_this(), out_data_anchors_); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetAllInDataAnchorsSize() const { + return in_data_anchors_.size(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetAllOutDataAnchorsSize() const { + return out_data_anchors_.size(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllInAnchors() const { + std::vector vec; + // Push back in_data_anchors_ + for (const auto &in_anchor_iter : Vistor(shared_from_this(), in_data_anchors_)) { + auto in_anchor = Anchor::DynamicAnchorCast(in_anchor_iter); + if (in_anchor != nullptr) { + vec.push_back(in_anchor); + } + } + // Push back in_control_anchor_ + if ((in_control_anchor_->GetPeerOutControlAnchors().size() > 0) || + (in_control_anchor_->GetPeerOutDataAnchors().size() > 0)) { + auto in_anchor = Anchor::DynamicAnchorCast(in_control_anchor_); + if (in_anchor != nullptr) { + vec.push_back(in_anchor); + } + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllOutAnchors() const { + std::vector vec; + // Push back out_data_anchors_ + for (const auto &out_anchor_iter : Vistor(shared_from_this(), out_data_anchors_)) { + auto out_anchor = Anchor::DynamicAnchorCast(out_anchor_iter); + if (out_anchor != nullptr) { + vec.push_back(out_anchor); + } + } + // Push back out_control_anchor_ + if (out_control_anchor_->GetPeerInControlAnchors().size() > 0 || + out_control_anchor_->GetPeerInDataAnchors().size() > 0) { + auto out_anchor = Anchor::DynamicAnchorCast(out_control_anchor_); + if (out_anchor != nullptr) { + vec.push_back(out_anchor); + } + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const { + if (idx < 0 || idx >= static_cast(in_data_anchors_.size())) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19019", {"opname", "index", "anchorname", "optype"}, + {GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()}); + GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", GetName().c_str(), idx, + GetType().c_str()); + return nullptr; + } else { + return in_data_anchors_[idx]; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const { + // Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_ + if (idx < -1 || idx >= static_cast(in_data_anchors_.size())) { + GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str()); + return nullptr; + } else { + // Return control anchor + if (idx == -1) { + auto in_anchor = Anchor::DynamicAnchorCast(in_control_anchor_); + return in_anchor; + } + // Return data anchor + return in_data_anchors_[idx]; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const { + // Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_ + if (idx < -1 || idx >= static_cast(out_data_anchors_.size())) { + ErrorManager::GetInstance().ATCReportErrMessage("E19019", {"opname", "index", "anchorname", "optype"}, + { + GetName().c_str(), + std::to_string(idx), + "out_anchor", + GetType().c_str(), + }); + GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx, + GetType().c_str()); + return nullptr; + } else { + // Return control anchor + if (idx == -1) { + auto out_anchor = Anchor::DynamicAnchorCast(out_control_anchor_); + return out_anchor; + } + // Return data anchor + return out_data_anchors_[idx]; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const { + if (idx < 0 || idx >= static_cast(out_data_anchors_.size())) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19019", {"opname", "index", "anchorname", "optype"}, + {GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()}); + GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", GetName().c_str(), idx, + GetType().c_str()); + return nullptr; + } else { + return out_data_anchors_[idx]; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InControlAnchorPtr Node::GetInControlAnchor() const { + return in_control_anchor_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutControlAnchorPtr Node::GetOutControlAnchor() const { + return out_control_anchor_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInNodes() const { + std::vector vec; + for (const auto &in_anchor : in_data_anchors_) { + GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + continue; + } + auto node = out_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + if (in_control_anchor_ != nullptr) { + if (in_control_anchor_->IsPeerOutAnchorsEmpty()) { + return Node::Vistor(shared_from_this(), vec); + } + + auto peer_out_anchors = in_control_anchor_->GetPeerOutDataAnchors(); + for (const auto &out_anchor : peer_out_anchors) { + GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "in_control_anchor_ peer out data anchors is nullptr"); + auto node = out_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + + auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); + for (const auto &out_control_anchor : peer_out_control_anchors) { + GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, + "in_control_anchor_ peer out control anchors is nullptr"); + auto node = out_control_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::IsAllInNodesSeen( + std::unordered_set &nodes_seen) const { + for (const auto &in_anchor : in_data_anchors_) { + GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + continue; + } + auto node = out_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + if ((node->GetType() == NEXTITERATION) || (node->GetType() == REFNEXTITERATION)) { + continue; + } + if (nodes_seen.count(node.get()) == 0) { + return false; + } + } + + if (in_control_anchor_ != nullptr) { + if (in_control_anchor_->IsPeerOutAnchorsEmpty()) { + return true; + } + auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); + for (const auto &out_control_anchor : peer_out_control_anchors) { + GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, "out_control_anchor is nullptr"); + auto node = out_control_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + if ((node->GetType() == NEXTITERATION) || (node->GetType() == REFNEXTITERATION)) { + continue; + } + if (nodes_seen.count(node.get()) == 0) { + return false; + } + } + } + + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInDataNodes() const { + std::vector vec; + for (const auto &in_anchor : in_data_anchors_) { + GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); + auto anchor_ptr = in_anchor->GetPeerOutAnchor(); + if (anchor_ptr == nullptr) { + continue; + } + auto node = anchor_ptr->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInControlNodes() const { + std::vector vec; + if (in_control_anchor_ != nullptr) { + for (const auto &in_anchor : in_control_anchor_->GetPeerOutControlAnchors()) { + GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerOutControlAnchors is nullptr"); + auto node = in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutNodes() const { + std::vector vec; + for (const auto &out_anchor : out_data_anchors_) { + GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); + for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_CHK_BOOL_EXEC((peer_in_anchor != nullptr), continue, "GetPeerInDataAnchors is nullptr"); + auto node = peer_in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + if (out_control_anchor_ != nullptr) { + auto peer_in_control_anchors = out_control_anchor_->GetPeerInControlAnchors(); + for (const auto &in_control_anchor : peer_in_control_anchors) { + GE_CHK_BOOL_EXEC(in_control_anchor != nullptr, continue, + "out_control_anchor_ peer in control anchors is nullptr"); + auto node = in_control_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInAllNodes() const { + std::vector vec; + for (const auto &in_node : GetInDataNodes()) { + vec.push_back(in_node); + } + for (const auto &in_control_node : GetInControlNodes()) { + vec.push_back(in_control_node); + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutDataNodes() const { + std::vector vec; + for (const auto &out_anchor : out_data_anchors_) { + GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); + for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "GetPeerInDataAnchors is nullptr"); + auto node = in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetOutDataNodesSize() const { + uint32_t out_nums = 0; + for (const auto &out_anchor : out_data_anchors_) { + GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); + out_nums += out_anchor->GetPeerInDataNodesSize(); + } + return out_nums; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutControlNodes() const { + std::vector vec; + + for (const auto &out_anchor : out_data_anchors_) { + GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); + for (const auto &in_anchor : out_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "GetPeerInControlAnchors is nullptr"); + auto node = in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + + if (out_control_anchor_ != nullptr) { + for (const auto &in_anchor : out_control_anchor_->GetPeerAnchors()) { + GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); + auto node = in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutAllNodes() const { + std::vector vec; + for (const auto &out_anchor : out_data_anchors_) { + GE_CHK_BOOL_EXEC((out_anchor != nullptr), { continue; }, "out_data_anchors_ is nullptr"); + for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_CHK_BOOL_EXEC((in_anchor != nullptr), { continue; }, "GetPeerInDataAnchors is nullptr"); + auto node = in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + for (const auto &in_anchor : out_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); + auto node = in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + + if (out_control_anchor_ != nullptr) { + for (const auto &in_anchor : out_control_anchor_->GetPeerAnchors()) { + GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); + auto node = in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + return Node::Vistor(shared_from_this(), vec); +} + +graphStatus Node::InferShapeAndType() const { + Operator op = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); + graphStatus ret = ShapeRefiner::InferShapeAndType(shared_from_this(), op); + return ret; +} + +graphStatus Node::InferOriginFormat() const { + Operator op = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); + // Get infer func and execute + GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); + return op_->CallInferFormatFunc(op); +} +graphStatus Node::Verify() const { + const string data_type = "Data"; + const string aipp_data_type = "AippData"; + const string const_type = "Const"; + const string variable_type = "Variable"; + bool is_unknown_graph = GetOwnerComputeGraph()->GetGraphUnknownFlag(); + GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); + + if (!is_unknown_graph) { + for (const auto &in_anchor_ptr : GetAllInDataAnchors()) { + GE_IF_BOOL_EXEC(in_anchor_ptr == nullptr, GELOGW("in anchor ptr is null"); continue); + bool valid_anchor = + op_->GetType() == data_type || op_->GetType() == aipp_data_type || op_->GetType() == const_type || + op_->GetType() == variable_type || op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || + op_->MutableInputDesc(in_anchor_ptr->GetIdx()) == nullptr || in_anchor_ptr->GetPeerAnchors().size() > 0; + if (!valid_anchor) { + ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"opname", "index"}, + {GetName(), std::to_string(in_anchor_ptr->GetIdx())}); + GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); + return GRAPH_FAILED; + } + } + } + + string frameworkop_type = "FrameworkOp"; + bool need_update_name = op_->GetType() != frameworkop_type && !is_unknown_graph; + if (need_update_name) { + auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_->GetType()); + if (node_op.IsEmpty()) { + GELOGW("get op from OperatorFactory fail. opType: %s", op_->GetType().c_str()); + } else { + GELOGD("get op from OperatorFactory success. opType: %s", op_->GetType().c_str()); + auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); + if (temp_op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "temp op desc is null"); + return GRAPH_FAILED; + } + if (!op_->UpdateInputName(temp_op_desc->GetAllInputName())) { + GELOGW("Verify UpdateInputName failed"); + } + if (!op_->UpdateOutputName(temp_op_desc->GetAllOutputName())) { + GELOGW("Verify UpdateOutputName failed"); + } + } + node_op.BreakConnect(); + } + GE_IF_BOOL_EXEC(is_unknown_graph, return GRAPH_SUCCESS;); + if (op_->CommonVerify() == GRAPH_SUCCESS) { + Operator op_proxy = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); + auto verify_func = op_->GetVerifyFunc(); + if (verify_func == nullptr) { + verify_func = OperatorFactoryImpl::GetVerifyFunc(GetType()); + } + if (verify_func != nullptr) { + return (graphStatus)verify_func(op_proxy); + } + return GRAPH_SUCCESS; + } else { + GELOGE(GRAPH_FAILED, "%s Verify failed.", op_->GetType().c_str()); + return GRAPH_FAILED; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr Node::GetOpDesc() const { return op_; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::UpdateOpDesc(const OpDescPtr &op_desc) { + GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); + GE_CHK_BOOL_EXEC(op_desc != nullptr, return GRAPH_PARAM_INVALID, "Param OpDesc is nullptr"); + GE_CHK_BOOL_EXEC(op_->GetInputsSize() == op_desc->GetInputsSize(), return GRAPH_PARAM_INVALID, + "Inputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetInputsSize(), + op_desc->GetInputsSize()); + + GE_CHK_BOOL_EXEC(op_->GetOutputsSize() == op_desc->GetOutputsSize(), return GRAPH_PARAM_INVALID, + "Outputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetOutputsSize(), + op_desc->GetOutputsSize()); + op_ = op_desc; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor> +Node::GetInDataNodesAndAnchors() const { + std::vector> vec; + for (const auto &p : in_data_anchors_) { + if (p == nullptr) { + GELOGW("indata anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); + continue; + } + auto anchor_ptr = p->GetPeerOutAnchor(); + if (anchor_ptr == nullptr) { + continue; + } + auto node = anchor_ptr->GetOwnerNode(); + if (node == nullptr) { + GELOGW("src node is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); + continue; + } + vec.push_back(std::make_pair(node, anchor_ptr)); + } + return Node::Vistor>(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor> +Node::GetOutDataNodesAndAnchors() const { + std::vector> vec; + for (const auto &p : out_data_anchors_) { + if (p == nullptr) { + GELOGW("out data anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); + continue; + } + for (const auto &in_anchor : p->GetPeerInDataAnchors()) { + if (in_anchor == nullptr) { + GELOGW("dst in data anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); + continue; + } + auto node = in_anchor->GetOwnerNode(); + if (node == nullptr) { + GELOGW("dst node is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); + continue; + } + vec.push_back(std::make_pair(node, in_anchor)); + } + } + return Node::Vistor>(shared_from_this(), vec); +} +} // namespace ge diff --git a/src/common/graph/op_desc.cc b/src/common/graph/op_desc.cc new file mode 100644 index 00000000..dee0aece --- /dev/null +++ b/src/common/graph/op_desc.cc @@ -0,0 +1,1410 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/op_desc.h" +#include "debug/ge_attr_define.h" +#include "debug/ge_util.h" +#include "external/graph/operator.h" +#include "framework/common/debug/ge_log.h" +#include "common/util/error_manager/error_manager.h" +#include "graph/ge_attr_value.h" +#include "graph/ge_tensor.h" +#include "graph/operator_factory_impl.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/ge_ir_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "proto/ge_ir.pb.h" + +using std::make_pair; +using std::shared_ptr; +using std::string; +using std::vector; + +/*lint -save -e521 -e681 -e732 -e737*/ +namespace ge { +const std::string ATTR_NAME_ID = "id"; + +const std::string ATTR_NAME_STREAM_ID = "stream_id"; + +const std::string ATTR_NAME_INPUT_NAME = "input_name"; + +const std::string ATTR_NAME_SRC_NAME = "src_name"; + +const std::string ATTR_NAME_SRC_INDEX = "src_index"; + +const std::string ATTR_NAME_INPUT = "input"; + +const std::string ATTR_NAME_OUTPUT = "output"; + +const std::string ATTR_NAME_INPUT_DESC = "input_desc"; + +const std::string ATTR_NAME_OUTPUT_DESC = "output_desc"; + +const std::string ATTR_NAME_DST_NAME = "dst_name"; + +const std::string ATTR_NAME_DST_INDEX = "dst_index"; + +const std::string ATTR_NAME_WORKSPACE = "workspace"; + +const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes"; + +const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; + +const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends"; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() { + op_def_.InitDefault(); + if (op_def_.GetProtoMsg() != nullptr) { + op_def_.GetProtoMsg()->set_has_out_attr(true); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::~OpDesc() {} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const std::string &name, const std::string &type) { + op_def_.InitDefault(); + if (op_def_.GetProtoMsg() != nullptr) { + op_def_.GetProtoMsg()->set_has_out_attr(true); + } + SetName(name); + SetType(type); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const ProtoMsgOwner &proto_msg_owner, + ge::proto::OpDef *op_def) + : op_def_(proto_msg_owner, op_def) { + if (op_def != nullptr && !op_def->has_out_attr()) { + op_def->set_has_out_attr(true); + + int64_t id = 0; + (void)AttrUtils::GetInt(this, ATTR_NAME_ID, id); + op_def->set_id(id); + + int64_t stream_id = 0; + (void)AttrUtils::GetInt(this, ATTR_NAME_STREAM_ID, stream_id); + op_def->set_stream_id(stream_id); + + vector input_name; + (void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME, input_name); + for (auto &item : input_name) { + op_def->add_input_name(item); + } + vector src_name; + (void)AttrUtils::GetListStr(this, ATTR_NAME_SRC_NAME, src_name); + for (auto &item : src_name) { + op_def->add_src_name(item); + } + vector src_index; + (void)AttrUtils::GetListInt(this, ATTR_NAME_SRC_INDEX, src_index); + for (auto &item : src_index) { + op_def->add_src_index(item); + } + vector input; + (void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT, input); + for (auto &item : input) { + op_def->add_input_i(item); + } + vector output; + (void)AttrUtils::GetListInt(this, ATTR_NAME_OUTPUT, output); + for (auto &item : output) { + op_def->add_output_i(item); + } + vector dst_name; + (void)AttrUtils::GetListStr(this, ATTR_NAME_DST_NAME, dst_name); + for (auto &item : dst_name) { + op_def->add_dst_name(item); + } + vector dst_index; + (void)AttrUtils::GetListInt(this, ATTR_NAME_DST_INDEX, dst_index); + for (auto &item : dst_index) { + op_def->add_dst_index(item); + } + vector workspace; + (void)AttrUtils::GetListInt(this, ATTR_NAME_WORKSPACE, workspace); + for (auto &item : workspace) { + op_def->add_workspace(item); + } + vector workspace_bytes; + (void)AttrUtils::GetListInt(this, ATTR_NAME_WORKSPACE_BYTES, workspace_bytes); + for (auto &item : workspace_bytes) { + op_def->add_workspace_bytes(item); + } + vector is_input_const; + (void)AttrUtils::GetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const); + for (auto item : is_input_const) { + op_def->add_is_input_const(item); + } + auto input_desc_mutable_list = (*op_def->mutable_attr())[ATTR_NAME_INPUT_DESC].mutable_list(); + if (input_desc_mutable_list != nullptr) { + *op_def->mutable_input_desc() = *(input_desc_mutable_list->mutable_td()); + } + auto output_desc_mutable_list = (*op_def->mutable_attr())[ATTR_NAME_OUTPUT_DESC].mutable_list(); + if (output_desc_mutable_list != nullptr) { + *op_def->mutable_output_desc() = *(output_desc_mutable_list->mutable_td()); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetName() const { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + return proto_msg->name(); + } + return ""; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetName(const std::string &name) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->set_name(name); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetType() const { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + return proto_msg->type(); + } + return ""; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetType(const string &type) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->set_type(type); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddInputDesc(const ge::GeTensorDesc &input_desc) { + int index = static_cast(inputs_desc_.size()); + return AddInputDesc("__input" + std::to_string(index), input_desc); +} + +graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_desc) { + graphStatus ret = GRAPH_SUCCESS; + if (index < inputs_desc_.size()) { + // InputsDesc[index] is exist, then update it + ret = UpdateInputDesc(index, input_desc); + } else { + // InputDesc[index] is not exist, then add it + ret = AddInputDesc(input_desc); + } + return ret; +} + +graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { + if (input_name_idx_.find(name) != input_name_idx_.end()) { + GELOGI("input %s is exist, update it", name.c_str()); + graphStatus ret = UpdateInputDesc(name, input_desc); + return ret; + } else { + int index = static_cast(inputs_desc_.size()); + std::shared_ptr in_desc = ComGraphMakeShared(input_desc); + if (in_desc == nullptr) { + GELOGE(GRAPH_FAILED, "AddInputDesc failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + inputs_desc_.push_back(in_desc); + (void)input_name_idx_.insert(make_pair(name, index)); + if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) { + register_input_name_.push_back(name); + } + + return GRAPH_SUCCESS; + } +} + +graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int num, size_t index) { + for (unsigned int i = 0; i < num; i++) { + string input_name = name + std::to_string(i); + GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, + "Add input tensor_desc is existed. name[%s]", input_name.c_str()); + + std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); + if (in_desc == nullptr) { + GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + + if (index > inputs_desc_.size()) { + GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); + return GRAPH_FAILED; + } + + (void)inputs_desc_.insert(inputs_desc_.begin() + index + i, in_desc); + + // Update index in input_name_idx + for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { + if (it->second >= (index + i)) { + it->second += 1; + } + } + + (void)input_name_idx_.insert(make_pair(input_name, i + index)); + } + + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::AddOutputDescMiddle(const string &name, const unsigned int num, size_t index) { + for (unsigned int i = 0; i < num; i++) { + string output_name = name + std::to_string(i); + GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED, + "Add input tensor_desc is existed. name[%s]", output_name.c_str()); + + std::shared_ptr out_desc = ComGraphMakeShared(GeTensorDesc()); + if (out_desc == nullptr) { + GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + + if (index > outputs_desc_.size()) { + GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); + return GRAPH_FAILED; + } + + (void)outputs_desc_.insert(outputs_desc_.begin() + index + i, out_desc); + + // Update index in input_name_idx + for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) { + if (it->second >= (index + i)) { + it->second += 1; + } + } + + (void)output_name_idx_.insert(make_pair(output_name, i + index)); + } + + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { + for (unsigned int i = 0; i < num; i++) { + string input_name = name + std::to_string(i); + GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, + "Add input tensor_desc is existed. name[%s]", input_name.c_str()); + + std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); + if (in_desc == nullptr) { + GELOGE(GRAPH_FAILED, "AddInputDescForward failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); + + // Update index in input_name_idx + for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { + it->second += 1; + } + + (void)input_name_idx_.insert(make_pair(input_name, 0)); + } + + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int num) { + for (unsigned int i = 0; i < num; i++) { + string output_name = name + std::to_string(i); + GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED, + "Add output tensor_desc is existed. name[%s]", output_name.c_str()); + + std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); + if (in_desc == nullptr) { + GELOGE(GRAPH_FAILED, "AddOutputDescForward failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + + (void)outputs_desc_.insert(outputs_desc_.begin(), in_desc); + + // Update index in output_name_idx + for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) { + it->second += 1; + } + (void)output_name_idx_.insert(make_pair(output_name, 0)); + } + + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { + if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED; + (void)optional_input_names_.insert(name); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { + if (index >= inputs_desc_.size()) { + GELOGW("The index is invalid. index[%u]", index); + return GRAPH_FAILED; + } + + inputs_desc_[index] = ComGraphMakeShared(tensor_Desc); + if (inputs_desc_[index] == nullptr) { + GELOGE(GRAPH_FAILED, "UpdateInputDesc failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const { + return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") && + IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && + IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") && + IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && + IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const { + const auto &op_def = this->op_def_.GetProtoMsg(); + const auto &r_op_def = r_op_desc.op_def_.GetProtoMsg(); + if ((op_def != nullptr) && (r_op_def != nullptr)) { + // Message OpDef in ge_ir.proto + return ( + IsEqual(op_def->name(), r_op_def->name(), "OpDef_.name()") && + IsEqual(op_def->type(), r_op_def->type(), "OpDef_.type()") && + IsEqual(ToString(op_def->input()), ToString(r_op_def->input()), "OpDef_.input()") && + IsEqual(op_def->has_out_attr(), r_op_def->has_out_attr(), "OpDef_.has_out_attr()") && + IsEqual(op_def->stream_id(), r_op_def->stream_id(), "OpDef_.stream_id()") && + IsEqual(ToString(op_def->input_name()), ToString(r_op_def->input_name()), "OpDef_.input_name()") && + IsEqual(ToString(op_def->src_name()), ToString(r_op_def->src_name()), "OpDef_.src_name()") && + IsEqual(ToString(op_def->dst_name()), ToString(r_op_def->dst_name()), "OpDef_.dst_name()") && + IsEqual(ToString(op_def->src_index()), ToString(r_op_def->src_index()), "OpDef_.src_index()") && + IsEqual(ToString(op_def->dst_index()), ToString(r_op_def->dst_index()), "OpDef_.dst_index()") && + IsEqual(ToString(op_def->input_i()), ToString(r_op_def->input_i()), "OpDef_.input_i()") && + IsEqual(ToString(op_def->output_i()), ToString(r_op_def->output_i()), "OpDef_.output_i()") && + IsEqual(ToString(op_def->workspace()), ToString(r_op_def->workspace()), "OpDef_.workspace()") && + IsEqual(ToString(op_def->workspace_bytes()), ToString(r_op_def->workspace_bytes()), "OpDef_.workspace_bytes()") && + IsEqual(ToString(op_def->is_input_const()), ToString(r_op_def->is_input_const()), "OpDef_.is_input_const()")); + } else { + return ((op_def == nullptr) && (r_op_def == nullptr)); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescGenTensorDescsAreEqual( + const OpDesc &r_op_desc) const { + // 1.Verify inputs and outputs desc size + const auto inputs_desc_size = this->inputs_desc_.size(); + const auto r_inputs_desc_size = r_op_desc.inputs_desc_.size(); + if (inputs_desc_size != r_inputs_desc_size) { + GELOGE(GRAPH_FAILED, "Size of OpDesc's inputs desc verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + const auto outputs_desc_size = this->outputs_desc_.size(); + const auto r_outputs_desc_size = r_op_desc.outputs_desc_.size(); + if (outputs_desc_size != r_outputs_desc_size) { + GELOGE(GRAPH_FAILED, "Size of OpDesc's outputs desc verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + // 2.Verify all inputs desc equal + for (uint32_t i = 0; i < inputs_desc_size; i++) { + const auto &in_ge_tensor_desc = this->GetInputDesc(i); + const auto &r_in_ge_tensor_desc = r_op_desc.GetInputDesc(i); + // Determine the connection relationship by GeTensorDesc + if (!(in_ge_tensor_desc == r_in_ge_tensor_desc)) { + GELOGE(GRAPH_FAILED, "Link info of OpDesc's inputs desc verify failed, OpDesc name: %s.", + this->GetName().c_str()); + return false; + } + } + // 3.Verify all outputs desc equal + for (uint32_t i = 0; i < outputs_desc_size; i++) { + const auto &out_ge_tensor_desc = this->GetOutputDesc(i); + const auto &r_out_ge_tensor_desc = r_op_desc.GetOutputDesc(i); + if (!(out_ge_tensor_desc == r_out_ge_tensor_desc)) { + GELOGE(GRAPH_FAILED, "Link info of OpDesc's outputs desc verify failed, OpDesc name: %s.", + this->GetName().c_str()); + return false; + } + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpDesc &r_op_desc) const { + return (OpDescAttrsAreEqual(r_op_desc) && OpDescMembersAreEqual(r_op_desc) && + OpDescGenTensorDescsAreEqual(r_op_desc)); +} + +graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { + auto it = input_name_idx_.find(name); + if (it == input_name_idx_.end()) { + GELOGW("Cann't find the input desc. name[%s]", name.c_str()); + return GRAPH_FAILED; + } + if (it->second >= inputs_desc_.size()) { + GELOGE(GRAPH_FAILED, "[%d] more than size of inputs_desc_", it->second); + return GRAPH_FAILED; + } + GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); + return GRAPH_FAILED); + inputs_desc_[it->second] = ComGraphMakeShared(tensor_Desc); + if (inputs_desc_[it->second] == nullptr) { + GELOGE(GRAPH_FAILED, "UpdateInputDesc failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +bool OpDesc::InputIsSet(const string &name) const { + auto it = input_name_idx_.find(name); + if (it != input_name_idx_.end()) { + GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false); + auto tensor_desc = inputs_desc_[it->second]; + GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false); + auto dims = tensor_desc->GetShape().GetDims(); + if (dims.size() > 0) { + return true; + } + } + return false; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS_NOLOG(index < inputs_desc_.size(), GeTensorDesc()); + return *(inputs_desc_[index].get()); +} + +GeTensorDesc OpDesc::GetInputDesc(const string &name) const { + auto it = input_name_idx_.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc()); + GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc()); + return *(inputs_desc_[it->second].get()); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); + if (inputs_desc_[index] == nullptr) { + return nullptr; + } + if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) { + GELOGW("input desc is invalid"); + return nullptr; + } + return inputs_desc_[index]; +} + +GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const { + auto input_name_idx = GetAllInputName(); + auto it = input_name_idx.find(name); + if (it == input_name_idx.end()) { + GELOGW("Failed to get [%s] input desc", name.c_str()); + return nullptr; + } + return MutableInputDesc(it->second); +} + +GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputNames() const { + vector names; + if (input_name_idx_.empty()) { + return OpDesc::Vistor(shared_from_this(), names); + } + for (std::pair input : input_name_idx_) { + names.push_back(input.first); + } + return OpDesc::Vistor(shared_from_this(), names); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpKernelLibName(const std::string &name) { + op_kernel_lib_name_ = name; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpKernelLibName() const { + return op_kernel_lib_name_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(const std::string &name) { + engine_name_ = name; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDesc() const { + vector temp{}; + for (const auto &it : inputs_desc_) { + if (it->IsValid() == GRAPH_SUCCESS) { + temp.push_back(*it); + } else { + GELOGW("this inputDesc is InValid, it won't be return"); + continue; + } + } + return OpDesc::Vistor(shared_from_this(), temp); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDescPtr() const { + vector temp{}; + for (const auto &it : inputs_desc_) { + if (it->IsValid() == GRAPH_SUCCESS) { + temp.push_back(it); + } else { + GELOGW("this inputDesc is InValid, it won't be return"); + continue; + } + } + return OpDesc::Vistor(shared_from_this(), temp); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetInputsSize() const { + // Just return valid inputs size.InValid desc is set in default OPTION_INPUT register. + size_t size = 0; + for (const auto &it : inputs_desc_) { + if (it->IsValid() == GRAPH_SUCCESS) { + size++; + } + } + return size; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetAllInputsSize() const { return inputs_desc_.size(); } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) { + int index = static_cast(outputs_desc_.size()); + return AddOutputDesc("__output" + std::to_string(index), output_desc); +} + +graphStatus OpDesc::AddOutputDesc(const string &name, const ge::GeTensorDesc &output_desc) { + GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(name) == output_name_idx_.end()), GRAPH_FAILED, + "Add output tensor_Desc is existed. name[%s]", name.c_str()); + int index = static_cast(outputs_desc_.size()); + + std::shared_ptr tensor = ComGraphMakeShared(output_desc); + if (tensor == nullptr) { + GELOGE(GRAPH_FAILED, "AddOutputDesc failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + outputs_desc_.push_back(tensor); + (void)output_name_idx_.insert(make_pair(name, index)); + if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) { + register_output_name_.push_back(name); + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDesc::UpdateOutputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { + GE_CHK_BOOL_RET_STATUS((index < outputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); + + outputs_desc_[index] = ComGraphMakeShared(tensor_Desc); + if (outputs_desc_[index] == nullptr) { + GELOGE(GRAPH_FAILED, "UpdateOutputDesc failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::UpdateOutputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { + auto it = output_name_idx_.find(name); + if (it == output_name_idx_.end()) { + GELOGW("Cann't find the output desc. name[%s]", name.c_str()); + return GRAPH_FAILED; + } + GE_IF_BOOL_EXEC(it->second >= outputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); + return GRAPH_FAILED); + outputs_desc_[it->second] = ComGraphMakeShared(tensor_Desc); + if (outputs_desc_[it->second] == nullptr) { + GELOGE(GRAPH_FAILED, "UpdateOutputDesc failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetOutputDesc(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS_NOLOG(index < outputs_desc_.size(), GeTensorDesc()); + return *(outputs_desc_[index].get()); +} + +GeTensorDesc OpDesc::GetOutputDesc(const string &name) const { + auto it = output_name_idx_.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), GeTensorDesc()); + GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < outputs_desc_.size(), GeTensorDesc()); + return *(outputs_desc_[it->second].get()); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS(index < outputs_desc_.size(), nullptr, "Cann't find the output desc %u", index); + return outputs_desc_[index]; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const { + auto it = output_name_idx_.find(name); + if (it == output_name_idx_.end()) { + GELOGW("Failed to get [%s] output desc", name.c_str()); + return nullptr; + } + return MutableOutputDesc(it->second); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { + return static_cast(outputs_desc_.size()); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllOutputsDesc() const { + vector temp{}; + for (const auto &it : outputs_desc_) { + temp.push_back(*it); + } + return OpDesc::Vistor(shared_from_this(), temp); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllOutputsDescPtr() const { + return OpDesc::Vistor(shared_from_this(), outputs_desc_); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetOutputsSize() const { return outputs_desc_.size(); } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetOutputDescPtr(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast(outputs_desc_.size()), nullptr); + return outputs_desc_[index]; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast(inputs_desc_.size()), nullptr); + if (inputs_desc_[index] == nullptr) { + return nullptr; + } + if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) { + GELOGW("inputsDesc[%u] is InValid", index); + return nullptr; + } else { + return inputs_desc_[static_cast(index)]; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr +OpDesc::GetInputDescPtrDfault(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS_NOLOG((index) < (uint32_t)(inputs_desc_.size()), nullptr); + return inputs_desc_[(int32_t)index]; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const { + auto it = input_name_idx_.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), shared_ptr()); + return inputs_desc_[it->second]; +} + +graphStatus OpDesc::AddRegisterInputName(const std::string &name) { + if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) { + register_input_name_.push_back(name); + } + + return GRAPH_SUCCESS; +} + +vector OpDesc::GetRegisterInputName() const { return register_input_name_; } + +graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) { + if (is_push_back) { + for (unsigned int i = 0; i < num; i++) { + if (AddInputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) return GRAPH_FAILED; + } + } else { + if (AddInputDescForward(name, num) != GRAPH_SUCCESS) return GRAPH_FAILED; + } + if (AddRegisterInputName(name) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index) { + if (AddInputDescMiddle(name, num, index) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::AddRegisterOutputName(const string &name) { + if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) { + register_output_name_.push_back(name); + } + + return GRAPH_SUCCESS; +} + +vector OpDesc::GetRegisterOutputName() const { return register_output_name_; } + +graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int num, bool is_push_back) { + if (is_push_back) { + for (unsigned int i = 0; i < num; i++) { + if (AddOutputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) return GRAPH_FAILED; + } + } else { + if (AddOutputDescForward(name, num) != GRAPH_SUCCESS) return GRAPH_FAILED; + } + + if (AddRegisterOutputName(name) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +bool OpDesc::IsOptionalInput(const string &name) const { + return optional_input_names_.find(name) != optional_input_names_.end(); +} + +bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } + +std::map OpDesc::GetAllInputName() const { return input_name_idx_; } + +std::map OpDesc::GetAllOutputName() { return output_name_idx_; } + +bool OpDesc::UpdateInputName(std::map input_name_idx) { + bool ret = true; + // Use inputDesc_.size() to contain the InValid OptionInput.GetInputsSize() will remove default OptionInput name. + auto input_map_size = inputs_desc_.size(); + auto factory_map_size = input_name_idx.size(); + // It indicates that some inputs have no optionalname. + // The redundant optionalname of factory needs to be deleted and then assigned + if (input_map_size < factory_map_size) { + GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size, + factory_map_size); + for (auto it = input_name_idx.begin(); it != input_name_idx.end();) { + if (it->second >= input_map_size) { + it = input_name_idx.erase(it); + } else { + ++it; + } + } + if (input_name_idx.size() == input_map_size) { + GELOGI("UpdateInputName"); + input_name_idx_ = input_name_idx; + } else { + ret = false; + GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size()); + } + } else if (input_map_size == factory_map_size) { + input_name_idx_ = input_name_idx; + } else { + ret = false; + GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); + } + return ret; +} + +bool OpDesc::UpdateOutputName(std::map output_name_idx) { + size_t output_map_size = GetAllOutputsDescSize(); + size_t factory_map_size = output_name_idx.size(); + if (output_map_size < factory_map_size) { + GELOGI("UpdateOutputName org outputname map size: %zu, factory outputname map size: %zu", output_map_size, + factory_map_size); + for (auto it = output_name_idx.begin(); it != output_name_idx.end();) { + if (it->second >= output_map_size) { + it = output_name_idx.erase(it); + } else { + ++it; + } + } + if (output_name_idx.size() == output_map_size) { + GELOGI("UpdateoutputName"); + output_name_idx_ = output_name_idx; + return true; + } + } else if (output_map_size == factory_map_size) { + output_name_idx_ = output_name_idx; + return true; + } else { + GELOGW("UpdateOutputName org name map size: %zu, factory map size: %zu", output_map_size, factory_map_size); + return false; + } + GELOGW("UpdateOutputName org name map size: %zu, factory map size: %zu", output_map_size, factory_map_size); + return false; +} + +std::function OpDesc::GetInferFunc() const { return infer_func_; } + +std::function OpDesc::GetVerifyFunc() const { return verifier_func_; } + +void OpDesc::AddInferFunc(const std::function &func) { infer_func_ = func; } + +std::function OpDesc::GetInferFormatFunc() const { return infer_format_func_; } + +void OpDesc::AddInferFormatFunc(const std::function &func) { infer_format_func_ = func; } + +void OpDesc::AddVerifierFunc(const std::function &func) { verifier_func_ = func; } + +graphStatus OpDesc::InferShapeAndType() { + if (infer_func_ == nullptr) { + infer_func_ = OperatorFactoryImpl::GetInferShapeFunc(GetType()); + if (infer_func_ == nullptr) { + GELOGW("%s does not have inferfunc_.", GetName().c_str()); + /// The infoshape function has not been added for each operator in the current operator information library. + /// No infoshape added operator skips the call + /// and directly uses the shape information passed down by the upper framework + return GRAPH_SUCCESS; + } + } + Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this()); + graphStatus ret = (graphStatus)infer_func_(op_proxy); + op_proxy.BreakConnect(); + return ret; +} + +graphStatus OpDesc::DefaultInferFormat() { + ge::Format first_none_nd_format = FORMAT_ND; + auto input_descs = GetAllInputsDescPtr(); + auto output_descs = GetAllOutputsDescPtr(); + // Overall input and output,get the first non-nd format + for (const auto &input_desc : input_descs) { + Format origin_format = input_desc->GetOriginFormat(); + if (origin_format != FORMAT_ND) { + first_none_nd_format = origin_format; + break; + } + } + for (const auto &output_desc : output_descs) { + Format origin_format = output_desc->GetOriginFormat(); + if (origin_format != FORMAT_ND) { + first_none_nd_format = origin_format; + break; + } + } + // Refresh all input output format + GELOGD("Default infer format.node[%s], first none nod format is:%d", GetName().c_str(), first_none_nd_format); + + for (const auto &input_desc : input_descs) { + Format origin_format = input_desc->GetOriginFormat(); + GELOGD("Default infer format[in].node[%s].origin format is:%d", GetName().c_str(), origin_format); + if (origin_format == FORMAT_ND) { + input_desc->SetOriginFormat(first_none_nd_format); + input_desc->SetFormat(first_none_nd_format); + } + } + for (const auto &output_desc : output_descs) { + Format origin_format = output_desc->GetOriginFormat(); + GELOGD("Default infer format[out].node[%s].origin format is:%d", GetName().c_str(), origin_format); + if (origin_format == FORMAT_ND) { + output_desc->SetOriginFormat(first_none_nd_format); + output_desc->SetFormat(first_none_nd_format); + } + } + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::OpVerify() { + if (verifier_func_ == nullptr) { + verifier_func_ = OperatorFactoryImpl::GetVerifyFunc(GetType()); + } + if (verifier_func_ != nullptr) { + Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this()); + graphStatus ret = (graphStatus)verifier_func_(op_proxy); + op_proxy.BreakConnect(); + return ret; + } + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::CommonVerify() const { + for (const string &iname : GetAllInputNames()) { + // Checking shape of all inputs + vector ishape = GetInputDescPtr(iname)->GetShape().GetDims(); + for (int64_t dim : ishape) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + dim < -2, ErrorManager::GetInstance().ATCReportErrMessage( + "E19014", {"opname", "value", "reason"}, + {GetName(), "input " + iname + " shape", "contains negative or zero dimension"}); + return GRAPH_FAILED, "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(), + iname.c_str()); + } + } + // Check all attributes defined + const auto &all_attributes = GetAllAttrs(); + for (const auto &name : GetAllAttrNames()) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + all_attributes.find(name) == all_attributes.end(), + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {GetName(), "attribute " + name, "is empty"}); + return GRAPH_FAILED, "operator attribute %s is empty.", name.c_str()); + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const { + auto it = input_name_idx_.begin(); + for (; it != input_name_idx_.end(); ++it) { + if (it->second == index) { + break; + } + } + GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), ""); + return it->first; +} + +int OpDesc::GetInputIndexByName(const string &name) const { + auto it_find = input_name_idx_.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1); + return static_cast(it_find->second); +} + +int OpDesc::GetValidInputIndexByName(const string &name) const { + map valid_input_name_idx{}; + uint32_t j = 0; + for (size_t i = 0; i < GetAllInputsSize(); i++) { + if (MutableInputDesc(static_cast(i)) != nullptr) { + auto valid_name = GetInputNameByIndex(static_cast(i)); + GE_CHK_BOOL_RET_STATUS_NOLOG(!valid_name.empty(), -1); + valid_input_name_idx.insert({valid_name, j}); + j++; + } + } + auto it_find = valid_input_name_idx.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != valid_input_name_idx.end(), -1); + return static_cast(it_find->second); +} + +string OpDesc::GetValidInputNameByIndex(uint32_t index) const { + map valid_input_name_idx{}; + uint32_t j = 0; + for (size_t i = 0; i < GetAllInputsSize(); i++) { + if (MutableInputDesc(static_cast(i)) != nullptr) { + auto valid_name = GetInputNameByIndex(static_cast(i)); + GE_CHK_BOOL_RET_STATUS_NOLOG(!valid_name.empty(), ""); + valid_input_name_idx.insert({valid_name, j}); + j++; + } + } + auto it = valid_input_name_idx.begin(); + for (; it != valid_input_name_idx.end(); ++it) { + if (it->second == index) { + break; + } + } + GE_CHK_BOOL_RET_STATUS_NOLOG(it != valid_input_name_idx.end(), ""); + return it->first; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetOutputNameByIndex(uint32_t index) const { + auto it = output_name_idx_.begin(); + for (; it != output_name_idx_.end(); ++it) { + if (it->second == index) { + break; + } + } + GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), ""); + return it->first; +} + +int OpDesc::GetOutputIndexByName(const string &name) const { + auto it_find = output_name_idx_.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != output_name_idx_.end(), -1); + return static_cast(it_find->second); +} + +ProtoAttrMapHelper OpDesc::MutableAttrMap() { + if (op_def_.GetProtoMsg() == nullptr) { + GELOGE(GRAPH_FAILED, "op def get proto msg failed"); + return GeIrProtoHelper(); + } + return ProtoAttrMapHelper(op_def_.GetProtoOwner(), op_def_.GetProtoMsg()->mutable_attr()); +} + +ConstProtoAttrMapHelper OpDesc::GetAttrMap() const { + return ConstProtoAttrMapHelper(op_def_.GetProtoOwner(), &op_def_.GetProtoMsg()->attr()); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetId(int64_t id) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->set_id(id); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetId() const { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + return proto_msg->id(); + } + return 0; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetStreamId(int64_t stream_id) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->set_stream_id(stream_id); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetStreamId() const { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + return proto_msg->stream_id(); + } + return 0; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputName(const vector &input_name) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_input_name(); + for (auto &item : input_name) { + proto_msg->add_input_name(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetInputName() const { + vector input_name; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->input_name()) { + input_name.push_back(item); + } + } + return input_name; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcName(const vector &src_name) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_src_name(); + for (auto &item : src_name) { + proto_msg->add_src_name(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetSrcName() const { + vector src_name; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->src_name()) { + src_name.push_back(item); + } + } + return src_name; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcIndex(const vector &src_index) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_src_index(); + for (auto &item : src_index) { + proto_msg->add_src_index(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetSrcIndex() const { + vector src_index; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->src_index()) { + src_index.push_back(item); + } + } + return src_index; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputOffset(const vector &input) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_input_i(); + for (auto &item : input) { + proto_msg->add_input_i(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetInputOffset() const { + vector input; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->input_i()) { + input.push_back(item); + } + } + return input; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOutputOffset(const vector &output) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_output_i(); + for (auto &item : output) { + proto_msg->add_output_i(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetOutputOffset() const { + vector output; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->output_i()) { + output.push_back(item); + } + } + return output; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstName(const vector &dst_name) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_dst_name(); + for (auto &item : dst_name) { + proto_msg->add_dst_name(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetDstName() const { + vector dst_name; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->dst_name()) { + dst_name.push_back(item); + } + } + return dst_name; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpInferDepends(const vector &depend_names) { + auto ret = AttrUtils::SetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); + if (ret != true) { + GELOGE(GRAPH_FAILED, "set op_infer_depends fail."); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetOpInferDepends() const { + vector depend_names; + (void)AttrUtils::GetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); + return depend_names; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstIndex(const vector &dst_index) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_dst_index(); + for (auto &item : dst_index) { + proto_msg->add_dst_index(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetDstIndex() const { + vector dst_index; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->dst_index()) { + dst_index.push_back(item); + } + } + return dst_index; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspace(const vector &workspace) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_workspace(); + for (auto &item : workspace) { + proto_msg->add_workspace(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetWorkspace() const { + vector workspace; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->workspace()) { + workspace.push_back(item); + } + } + return workspace; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspaceBytes(const vector &workspace_bytes) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_workspace_bytes(); + for (auto &item : workspace_bytes) { + proto_msg->add_workspace_bytes(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetWorkspaceBytes() const { + vector workspace_bytes; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->workspace_bytes()) { + workspace_bytes.push_back(item); + } + } + return workspace_bytes; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetIsInputConst(const vector &is_input_const) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_is_input_const(); + for (auto item : is_input_const) { + proto_msg->add_is_input_const(item); + } + } + // If comes from ME,which is_input_const exist as attrs, outside no need to check GE_TRAIN flag + auto ret = AttrUtils::SetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const); + if (ret != true) { + GELOGE(GRAPH_FAILED, "set is_input_const fail."); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetIsInputConst() const { + vector is_input_const; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto item : proto_msg->is_input_const()) { + is_input_const.push_back(item); + } + } + return is_input_const; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name, + const int &index) { + if (input_name_idx_.find(name) != input_name_idx_.end()) { + GELOGI("Restore input name index is existed. name[%s]", name.c_str()); + } + (void)input_name_idx_.insert(make_pair(name, index)); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreOutputNameIdx(const string &name, + const int &index) { + if (output_name_idx_.find(name) != output_name_idx_.end()) { + GELOGI("Restore output name index is existed. name[%s]", name.c_str()); + } + (void)output_name_idx_.insert(make_pair(name, index)); + return GRAPH_SUCCESS; +} +graphStatus OpDesc::CallInferFunc(Operator &op) { + if (infer_func_ == nullptr) { + infer_func_ = OperatorFactoryImpl::GetInferShapeFunc(GetType()); + if (infer_func_ == nullptr) { + GELOGW("%s does not have infer func.", GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + } + graphStatus graph_status = (graphStatus)infer_func_(op); + if (graph_status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "%s call infer func. ret: %u", GetName().c_str(), graph_status); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} +graphStatus OpDesc::CallInferFormatFunc(Operator &op) { + if (infer_format_func_ == nullptr) { + infer_format_func_ = OperatorFactoryImpl::GetInferFormatFunc(GetType()); + if (infer_format_func_ == nullptr) { + return DefaultInferFormat(); + } + } + return (graphStatus)infer_format_func_(op); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetSubgraphInstanceName(uint32_t index) const { + if (static_cast(index) >= subgraph_instance_names_.size()) { + return ""; + } + return subgraph_instance_names_.at(index); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector &OpDesc::GetSubgraphInstanceNames() + const { + return subgraph_instance_names_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { + for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { + if (*iter == name) { + *iter = ""; + return; + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { + GELOGI("Add subgraph name is %s", name.c_str()); + auto iter = subgraph_names_to_index_.find(name); + if (iter != subgraph_names_to_index_.end()) { + GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second); + return GRAPH_FAILED; + } + auto size = subgraph_names_to_index_.size(); + subgraph_names_to_index_[name] = size; + subgraph_instance_names_.resize(size + 1); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map &OpDesc::GetSubgraphNameIndexes() + const { + return subgraph_names_to_index_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::SetSubgraphInstanceName(uint32_t index, + const std::string &name) { + GELOGI("Add sub graph instans name is %s, index is %u", name.c_str(), index); + if (index >= subgraph_instance_names_.size()) { + GE_LOGE("The index %u exceeds the max instance coutn %zu", index, subgraph_instance_names_.size()); + return GRAPH_PARAM_INVALID; + } + subgraph_instance_names_[index] = name; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RegisterSubgraphIrName(const string &name, + SubgraphType type) { + subgraph_ir_names_to_type_[name] = type; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map &OpDesc::GetSubgraphIrNames() + const { + return subgraph_ir_names_to_type_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY SubgraphType +OpDesc::GetSubgraphTypeByIrName(const std::string &name) const { + auto iter = subgraph_ir_names_to_type_.find(name); + if (iter == subgraph_ir_names_to_type_.end()) { + return kSubgraphTypeEnd; + } + return iter->second; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDesc::GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const { + for (size_t idx = 0; idx < subgraph_instance_names_.size(); ++idx) { + if (subgraph_instance_names_[idx] != instance_name) { // find subgraph index. + continue; + } + + for (auto name_to_index : subgraph_names_to_index_) { + if (name_to_index.second != idx) { // find subgraph name. + continue; + } + + subgraph_name = name_to_index.first; + return GRAPH_SUCCESS; + } + } + + return GRAPH_PARAM_INVALID; +} + +} // namespace ge diff --git a/src/common/graph/op_imp.cc b/src/common/graph/op_imp.cc new file mode 100644 index 00000000..9abf242b --- /dev/null +++ b/src/common/graph/op_imp.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "debug/ge_log.h" +#include "debug/ge_util.h" + +using namespace std; + +namespace ge { + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +BroadCastInfer(const function()>& get_in1_shape, const function()>& get_in2_shape, + const function& outShape)>& set_out_shape) { + auto x1_shape = get_in1_shape(); + auto x2_shape = get_in2_shape(); + vector y_shape; + + if (x1_shape.empty()) { + y_shape = x2_shape; + set_out_shape(y_shape); + return GRAPH_SUCCESS; + } + if (x2_shape.empty()) { + y_shape = x1_shape; + set_out_shape(y_shape); + return GRAPH_SUCCESS; + } + + int len_diff = static_cast(x1_shape.size() - x2_shape.size()); + if (len_diff >= 0) { + for (int i = 0; i < len_diff; i++) { + y_shape.push_back(x1_shape[i]); + } + int x2_shape_size = static_cast(x2_shape.size()); + for (int i = 0; i < x2_shape_size; i++) { + bool shapeFlag = + ((x1_shape[i + len_diff] != x2_shape[i]) && (std::min(x1_shape[i + len_diff], x2_shape[i]) != 1)); + if (shapeFlag) { + GE_LOGE("operands could not be broadcast together"); + return GRAPH_FAILED; + } + y_shape.push_back(std::max(x1_shape[i + len_diff], x2_shape[i])); + } + } else { + for (int i = 0; i < -len_diff; i++) { + y_shape.push_back(x2_shape[i]); + } + int x1_shape_size = static_cast(x1_shape.size()); + for (int i = 0; i < x1_shape_size; i++) { + bool shapeFlag = + ((x1_shape[i] != x2_shape[i - len_diff]) && (std::min(x1_shape[i], x2_shape[i - len_diff]) != 1)); + if (shapeFlag) { + GE_LOGE("operands could not be broadcast together"); + return GRAPH_FAILED; + } + y_shape.push_back(std::max(x1_shape[i], x2_shape[i - len_diff])); + } + } + set_out_shape(y_shape); + return GRAPH_SUCCESS; +} + +} // namespace ge diff --git a/src/common/graph/operator.cc b/src/common/graph/operator.cc new file mode 100644 index 00000000..21554fa1 --- /dev/null +++ b/src/common/graph/operator.cc @@ -0,0 +1,1587 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "external/graph/operator.h" +#include "external/graph/operator_factory.h" +#include +#include +#include +#include +#include +#include "./array_ops.h" +#include "debug/ge_log.h" +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "external/graph/attr_value.h" +#include "external/graph/types.h" +#include "framework/common/debug/ge_log.h" +#include "graph/compute_graph.h" +#include "graph/ge_attr_value.h" +#include "graph/ge_context.h" +#include "graph/ge_tensor.h" +#include "graph/node.h" +#include "graph/op_desc.h" +#include "graph/runtime_inference_context.h" +#include "graph/usr_types.h" +#include "graph/utils/node_utils.h" +#include "graph/debug/ge_attr_define.h" +#include "utils/graph_utils.h" +#include "utils/op_desc_utils.h" +#include "utils/tensor_adapter.h" +#include "utils/tensor_utils.h" +#include "utils/type_utils.h" +#include +#include +#include +#include +#include + +using std::enable_shared_from_this; +using std::make_pair; +using std::shared_ptr; +using std::string; +using std::to_string; +using std::vector; + +/*lint -save -e529 -e728*/ +/*lint -e446 -e732*/ +/*lint -e665*/ +namespace ge { +class OpIO { + public: + OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {} + + ~OpIO() = default; + + string GetName() const { return name_; } + + int GetIndex() const { return index_; } + + OperatorImplPtr GetOwner() const { return owner_; } + + bool operator==(const OpIO &r_value) const { + return (this->name_ == r_value.GetName()) && (this->index_ == r_value.GetIndex()) && + (this->GetOwner() == r_value.GetOwner()); + } + + private: + string name_; + int index_; + std::shared_ptr owner_; +}; + +class TensorTypeImpl { + public: + TensorTypeImpl() = default; + ~TensorTypeImpl() = default; + + std::vector dt_vec_; +}; + +TensorType::TensorType(DataType dt) { + tensor_type_impl_ = ComGraphMakeShared(); + if (tensor_type_impl_ != nullptr) { + tensor_type_impl_->dt_vec_.push_back(dt); + } +} + +TensorType::TensorType(const std::initializer_list &types) { + tensor_type_impl_ = ComGraphMakeShared(); + if (tensor_type_impl_ != nullptr) { + tensor_type_impl_->dt_vec_ = types; + } +} + +class OperatorImpl : public std::enable_shared_from_this { + friend class GraphBuilderImpl; + friend class OpDescUtils; + + public: + explicit OperatorImpl(const string &name, const string &type) : op_desc_(ComGraphMakeShared(name, type)) { + if (op_desc_ == nullptr) { + GELOGW("OpDesc make shared failed"); + } + } + explicit OperatorImpl(const OpDescPtr &op_desc) : op_desc_(op_desc) {} + explicit OperatorImpl(ge::ConstNodePtr node) : node_(std::move(node)) { + if (node_ != nullptr && node_->GetOpDesc() != nullptr) { + op_desc_ = node_->GetOpDesc(); + } + } + ~OperatorImpl() {} + void SetInputImpl(const string &dst_name, const ge::Operator &src_oprt) { + GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty"); + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr."); + GE_CHK_BOOL_EXEC(src_oprt.operator_impl_ != nullptr, return, "operator_impl_ is nullptr."); + GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_ != nullptr, return, "op_desc_ is nullptr."); + + auto src_op_impl = src_oprt.GetOperatorImplPtr(); + GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return, "Src impl is null."); + GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return, "Src impl's opdesc is null."); + GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_->GetOutputsSize() == 1, return, + "The source operator[%s] must has one output", + src_oprt.operator_impl_->op_desc_->GetName().c_str()) + + uint32_t src_index = 0; + string src_name = src_op_impl->op_desc_->GetOutputNameByIndex(src_index); + GE_CHK_BOOL_EXEC(!src_name.empty(), return, "Src output's name is empty."); + + OpIO out_handler(src_name, src_index, src_op_impl); + input_link_.insert(std::make_pair(dst_name, out_handler)); + + int dst_index = op_desc_->GetInputIndexByName(dst_name); + GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), + op_desc_->GetName().c_str()); + + bool is_const = false; + if (src_oprt.GetOpType() == CONSTANT) { + is_const = true; + } + auto is_input_const = op_desc_->GetIsInputConst(); + for (int i = static_cast(is_input_const.size()); i <= dst_index; ++i) { + is_input_const.push_back(false); + } + + is_input_const[dst_index] = is_const; + op_desc_->SetIsInputConst(is_input_const); + + OpIO op_dst(dst_name, dst_index, shared_from_this()); + src_op_impl->UpdateLinkMapImpl(src_name, op_dst); + auto output_desc = src_op_impl->GetOutputDesc(src_name); + auto input_desc = op_desc_->GetInputDesc(dst_name); + if (input_desc.GetFormat() == FORMAT_RESERVED) { + output_desc.SetFormat(FORMAT_ND); + } else { + output_desc.SetFormat(input_desc.GetFormat()); + } + // Fix for linking opdesc + if (op_desc_->UpdateInputDesc(dst_name, output_desc) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Update inputdesc failed,dst name is %s, src name is %s", dst_name.c_str(), + src_name.c_str()); + return; + } + } + + void SetInputImpl(const string &dst_name, const ge::OutHandler &out_handler) { + GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty"); + GE_CHK_BOOL_EXEC(out_handler != nullptr, return, "SetInputImpl faild, out_handler is nullptr."); + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr."); + input_link_.insert(std::make_pair(dst_name, *out_handler)); + + string src_name = out_handler->GetName(); + int dst_index = op_desc_->GetInputIndexByName(dst_name); + GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), + op_desc_->GetName().c_str()); + auto out_op_impl = out_handler->GetOwner(); + GE_CHK_BOOL_EXEC(out_op_impl != nullptr && out_op_impl->GetOpDescImpl() != nullptr, return, + "out_handler invalid. name[%s]", dst_name.c_str()); + bool is_const = false; + if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) { + is_const = true; + } + auto is_input_const = op_desc_->GetIsInputConst(); + for (int i = static_cast(is_input_const.size()); i <= dst_index; ++i) { + is_input_const.push_back(false); + } + is_input_const[dst_index] = is_const; + op_desc_->SetIsInputConst(is_input_const); + + OpIO in_handler(dst_name, dst_index, shared_from_this()); + GE_CHK_BOOL_EXEC(out_op_impl != nullptr, return, "Get out_handler's impl failed."); + + out_op_impl->UpdateLinkMapImpl(src_name, in_handler); + auto src_output_desc = out_op_impl->GetOutputDesc(src_name); + auto dst_input_desc = op_desc_->GetInputDesc(dst_name); + if (dst_input_desc.GetFormat() == FORMAT_RESERVED) { + src_output_desc.SetFormat(FORMAT_ND); + } else { + src_output_desc.SetFormat(dst_input_desc.GetFormat()); + } + GE_CHK_BOOL_EXEC(op_desc_->UpdateInputDesc(dst_name, src_output_desc) == GRAPH_SUCCESS, return, + "Update input desc failed,dst name is %s,src name is %s", dst_name.c_str(), + src_name.c_str()); // fix for linking opdesc + } + + void AddControlInputImp(const ge::Operator &src_oprt) { + if (src_oprt.operator_impl_ == nullptr) { + GELOGE(FAILED, "Src operator impl is nullptr"); + return; + } + for (auto &input : control_input_link_) { + if (input.lock() == src_oprt.operator_impl_) { + return; + } + } + control_input_link_.push_back(src_oprt.operator_impl_); + src_oprt.operator_impl_->control_output_link_.push_back(shared_from_this()); + } + + graphStatus GetInputImpl(const string &dst_name, ge::OpIO &out_handler) { + auto out = input_link_.find(dst_name); + if (out == input_link_.end()) { + return GRAPH_FAILED; + } + out_handler = out->second; + return GRAPH_SUCCESS; + } + + bool InputIsSet(const string &name) { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return false, "op_desc_ is nullptr."); + return op_desc_->InputIsSet(name); + } + + string GetName() const { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return string(), "op_desc_ is nullptr."); + return op_desc_->GetName(); + } + + GeTensorDesc GetInputDesc(const string &name) const { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); + return op_desc_->GetInputDesc(name); + } + + GeTensorDesc GetInputDesc(uint32_t index) const { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); + return op_desc_->GetInputDesc(index); + } + + graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc) { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GRAPH_FAILED, "op_desc_ is nullptr."); + + return op_desc_->UpdateInputDesc(name, tensor_desc); + } + + OutHandler GetOutput(const string &name) { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr."); + + int src_index = op_desc_->GetOutputIndexByName(name); + GE_CHK_BOOL_EXEC(src_index >= 0, return nullptr, "Find src index by name failed. name[%s]", name.c_str()); + shared_ptr output_ptr = ComGraphMakeShared(name, src_index, shared_from_this()); + if (output_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "OpIO make shared failed"); + return nullptr; + } + return output_ptr; + } + + OutHandler GetOutput(uint32_t index) { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr."); + + string name = op_desc_->GetOutputNameByIndex(index); + if (name.empty()) { + GELOGE(GRAPH_FAILED, "Find src name by index failed. index[%u]", index); + return nullptr; + } + shared_ptr output_ptr = ComGraphMakeShared(name, index, shared_from_this()); + if (output_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "OpIO make shared failed"); + return nullptr; + } + return output_ptr; + } + + GeTensorDesc GetOutputDesc(const string &name) const { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); + + return op_desc_->GetOutputDesc(name); + } + + GeTensorDesc GetOutputDesc(uint32_t index) const { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); + + return op_desc_->GetOutputDesc(index); + } + + graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc) { + GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); + + auto res = op_desc_->UpdateOutputDesc(name, tensor_desc); + if (res == GRAPH_SUCCESS) { + for (auto ol : output_links_[name]) { + if (ol.GetOwner() == nullptr) { + GELOGW("%s get owner is nullptr", ol.GetName().c_str()); + continue; + } + GE_CHK_BOOL_RET_STATUS(ol.GetOwner()->UpdateInputDesc(ol.GetName(), tensor_desc) == GRAPH_SUCCESS, GRAPH_FAILED, + "Could not update next operator's input %s.", ol.GetName().c_str()); + } + } + return res; + } + + size_t GetInputsSize() const { + GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0); + return op_desc_->GetInputsSize(); + } + + size_t GetOutputsSize() const { + GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0); + return op_desc_->GetOutputsSize(); + } + + graphStatus SetAttr(const string &name, GeAttrValue &&attr_value) { + GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); + return op_desc_->SetAttr(name, std::move(attr_value)); + } + + graphStatus GetAttr(const string &name, GeAttrValue &attr_value) const { + GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); + return op_desc_->GetAttr(name, attr_value); + } + + OpDescPtr GetOpDescImpl() const { return op_desc_; } + + void UpdateLinkMapImpl(const string &src_name, OpIO &op_dst) { + auto it_find = output_links_.find(src_name); + if (it_find == output_links_.end()) { + std::vector dsts{op_dst}; + output_links_.insert(std::make_pair(src_name, dsts)); + } else { + it_find->second.push_back(op_dst); + } + } + + Operator ToOperator() { return Operator(shared_from_this()); } + + static OpDescPtr GetOpDesc(const Operator &oprt) { + GE_IF_BOOL_EXEC(oprt.operator_impl_ == nullptr, return nullptr); + return oprt.operator_impl_->op_desc_; + } + + void ClearOutputLinks() noexcept { output_links_.clear(); } + + void ClearInputLinks() noexcept { input_link_.clear(); } + + ge::ConstNodePtr GetNode() { return node_; } + + void SetInferenceContext(const InferenceContextPtr &inference_context) { inference_context_ = inference_context; } + + InferenceContextPtr GetInferenceContext() const { return inference_context_; } + + void SubgraphRegister(const std::string &ir_name, bool dynamic) { + op_desc_->RegisterSubgraphIrName(ir_name, dynamic ? kDynamic : kStatic); + } + + void SubgraphCountRegister(const std::string &ir_name, uint32_t count) { + if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kStatic) { + op_desc_->AddSubgraphName(ir_name); + subgraph_names_to_builders_[ir_name] = nullptr; + } else { + for (uint32_t i = 0; i < count; ++i) { + string key_name = ir_name + std::to_string(i); + op_desc_->AddSubgraphName(key_name); + subgraph_names_to_builders_[key_name] = nullptr; + } + } + } + + void SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { + string key_name = ir_name; + if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { + key_name += std::to_string(index); + } + + auto it = subgraph_names_to_builders_.find(key_name); + if (it == subgraph_names_to_builders_.end()) { + GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u.", ir_name.c_str(), index); + return; + } + it->second = builder; + } + + SubgraphBuilder GetSubgraphBuilder(const std::string &ir_name, uint32_t index) const { + string key_name = ir_name; + if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { + key_name += std::to_string(index); + } + + return GetSubgraphBuilder(key_name); + } + + SubgraphBuilder GetSubgraphBuilder(const std::string &name) const { + auto iter = subgraph_names_to_builders_.find(name); + if (iter == subgraph_names_to_builders_.end()) { + GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s", name.c_str()); + return nullptr; + } + + return iter->second; + } + + std::vector GetSubgraphNames() const { + std::vector names; + for (const auto &subgraph_name_to_type : op_desc_->GetSubgraphIrNames()) { + names.emplace_back(subgraph_name_to_type.first); + } + return names; + } + + size_t GetSubgraphNamesCount() const { return op_desc_->GetSubgraphIrNames().size(); } + + OpDescPtr op_desc_ = nullptr; + + private: + ge::ConstNodePtr node_{nullptr}; + ge::InferenceContextPtr inference_context_; + std::map> output_links_{}; + std::map input_link_{}; + std::vector> control_input_link_{}; + std::vector> control_output_link_{}; + std::map subgraph_names_to_builders_; +}; + +// Used to manage OperatorImpl instances created by ge api. +class OperatorKeeper { + private: + OperatorKeeper() = default; + ~OperatorKeeper() { + for (const auto &iter : operators_) { + if (iter) { + iter->ClearInputLinks(); + iter->ClearOutputLinks(); + } + } + } + std::set operators_; + std::mutex mutex_; + + public: + static OperatorKeeper &GetInstance() { + static OperatorKeeper instance; + return instance; + } + void CheckInOperator(const OperatorImplPtr &op_impl) { + if (op_impl) { + std::lock_guard lock(mutex_); + operators_.insert(op_impl); + } + } + void CheckOutOperator(const OperatorImplPtr &op_impl) { + if (op_impl) { + std::lock_guard lock(mutex_); + operators_.erase(op_impl); + } + } +}; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromNode(ge::ConstNodePtr node_ptr) { + ge::OperatorImplPtr operator_impl_ptr = ComGraphMakeShared(node_ptr); + if (operator_impl_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); + return Operator("default"); + } + return operator_impl_ptr->ToOperator(); +} + +Operator::Operator(const std::string &type) { + static uint32_t index = 0; + string name = type + "_" + std::to_string(index++); + operator_impl_ = ComGraphMakeShared(name, type); + if (operator_impl_ == nullptr) { + GELOGW("OperatorImpl make shared failed"); + } + OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromOpDesc(OpDescPtr op_desc) { + shared_ptr operator_impl_ptr; + operator_impl_ptr = ComGraphMakeShared(op_desc); + if (operator_impl_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); + return Operator("default"); + } + OperatorKeeper::GetInstance().CheckInOperator(operator_impl_ptr); + return operator_impl_ptr->ToOperator(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::GetOpDescFromOperator(const Operator &oprt) { + return OperatorImpl::GetOpDesc(oprt); +} + +GE_FUNC_HOST_VISIBILITY Operator::Operator(const string &name, const string &type) { + operator_impl_ = ComGraphMakeShared(name, type); + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); + return; + } + OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); +} + +Operator::Operator(ge::OperatorImplPtr &&op_impl) { operator_impl_ = std::move(op_impl); } + +bool Operator::IsEmpty() const { + if (operator_impl_ == nullptr) { + return true; + } + return false; +} + +string Operator::GetName() const { + if (operator_impl_ != nullptr) { + return operator_impl_->GetName(); + } + return ""; +} + +GE_FUNC_HOST_VISIBILITY Operator &Operator::SetInput(const string &dst_name, const ge::Operator &src_oprt) { + // Describe the connection relationship between operators, no create action + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); + operator_impl_->SetInputImpl(dst_name, src_oprt); + return *this; +} + +Operator &Operator::SetInput(const string &dst_name, const ge::OutHandler &out_handler) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); + operator_impl_->SetInputImpl(dst_name, out_handler); + return *this; +} + +Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, const std::string &name) { + auto out_handler = src_oprt.GetOutput(name); + GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr."); + (void)SetInput(dst_name, out_handler); + return *this; +} + +Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, uint32_t index) { + auto out_handler = src_oprt.GetOutput(index); + GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr."); + (void)SetInput(dst_name, out_handler); + return *this; +} + +Operator &Operator::AddControlInput(const Operator &src_oprt) { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr."); + return *this; + } + operator_impl_->AddControlInputImp(src_oprt); + return *this; +} + +graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const { + GE_CHECK_NOTNULL(operator_impl_); + auto node_ptr = operator_impl_->GetNode(); + if (node_ptr != nullptr) { + // For inner compute graph + auto op_desc = node_ptr->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto index = op_desc->GetInputIndexByName(dst_name); + auto in_data_anchor = node_ptr->GetInDataAnchor(index); + GE_CHECK_NOTNULL(in_data_anchor); + auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(out_data_anchor); + auto peer_node = out_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_node); + auto peer_op_desc = peer_node->GetOpDesc(); + GE_CHECK_NOTNULL(peer_op_desc); + auto peer_op_type = peer_op_desc->GetType(); + if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) { + auto const_op_impl = ComGraphMakeShared(peer_node); + GE_CHECK_NOTNULL(const_op_impl); + Operator const_op(std::move(const_op_impl)); + return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); + } else if (peer_op_type == DATA) { + auto parent_node = NodeUtils::GetParentInput(peer_node); + while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { + parent_node = NodeUtils::GetParentInput(parent_node); + } + if ((parent_node != nullptr) && + ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { + auto const_op_impl = ComGraphMakeShared(parent_node); + GE_CHECK_NOTNULL(const_op_impl); + Operator const_op(std::move(const_op_impl)); + return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); + } + } + // Try get from runtime inference context + auto session_id = std::to_string(GetContext().SessionId()); + RuntimeInferenceContext *runtime_infer_ctx = nullptr; + if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { + GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); + auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); + if (ret == GRAPH_SUCCESS) { + return GRAPH_SUCCESS; + } + } + } else { + // For outer graph + return GetInputConstDataOut(dst_name, data); + } + auto op_name = operator_impl_->GetName(); + GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str()); + return GRAPH_FAILED; +} +graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const { + ge::OpIO out_handle("", 0, nullptr); + GE_CHECK_NOTNULL(operator_impl_); + if (operator_impl_->GetInputImpl(dst_name, out_handle) != GRAPH_SUCCESS) { + GELOGE(FAILED, "%s get input impl failed", dst_name.c_str()); + return GRAPH_FAILED; + } + if (out_handle.GetOwner() != nullptr && out_handle.GetOwner()->GetOpDescImpl() != nullptr) { + Operator const_op(out_handle.GetOwner()); + const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType(); + if (op_desc_impl_type == CONSTANTOP) { + return const_op.GetAttr(op::Constant::name_attr_value(), data); + } else if (op_desc_impl_type == CONSTANT) { + return const_op.GetAttr(op::Const::name_attr_value(), data); + } + } + return GRAPH_FAILED; +} + +std::shared_ptr Operator::GetNode() const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); + return operator_impl_->GetNode(); +} + +TensorDesc Operator::GetInputDesc(const std::string &name) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name)); +} + +void Operator::SetInferenceContext(const InferenceContextPtr &inference_context) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + operator_impl_->SetInferenceContext(inference_context); +} + +InferenceContextPtr Operator::GetInferenceContext() const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); + return operator_impl_->GetInferenceContext(); +} +TensorDesc Operator::GetInputDesc(uint32_t index) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(index)); +} + +graphStatus Operator::TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + auto check = operator_impl_->InputIsSet(name); + if (check) tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name)); + return check ? GRAPH_SUCCESS : GRAPH_FAILED; +} + +graphStatus Operator::UpdateInputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + return operator_impl_->UpdateInputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); +} + +OutHandler Operator::GetOutput(const string &name) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); + return operator_impl_->GetOutput(name); +} + +OutHandler Operator::GetOutput(uint32_t index) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); + return operator_impl_->GetOutput(index); +} + +TensorDesc Operator::GetOutputDesc(const std::string &name) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name)); +} + +TensorDesc Operator::GetOutputDesc(uint32_t index) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(index)); +} + +graphStatus Operator::UpdateOutputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + return operator_impl_->UpdateOutputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); +} + +TensorDesc Operator::GetDynamicInputDesc(const string &name, uint32_t index) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name + std::to_string(index))); +} + +graphStatus Operator::UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + return operator_impl_->UpdateInputDesc(name + std::to_string(index), + TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); +} + +TensorDesc Operator::GetDynamicOutputDesc(const string &name, uint32_t index) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name + std::to_string(index))); +} + +graphStatus Operator::UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + return operator_impl_->UpdateOutputDesc(name + std::to_string(index), + TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); +} + +graphStatus Operator::InferShapeAndType() { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); + + return operator_impl_->GetOpDescImpl()->CallInferFunc(*this); +} + +graphStatus Operator::VerifyAllAttr(bool disable_common_verifier) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); + + if (!disable_common_verifier && (graphStatus)Operator::VerifyAll() == GRAPH_FAILED) { + return GRAPH_FAILED; + } else { + return (graphStatus)operator_impl_->GetOpDescImpl()->OpVerify(); + } +} + +GE_FUNC_HOST_VISIBILITY size_t Operator::GetInputsSize() const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr"); + return operator_impl_->GetInputsSize(); +} + +GE_FUNC_HOST_VISIBILITY size_t Operator::GetOutputsSize() const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr"); + return operator_impl_->GetOutputsSize(); +} + +// According to op get the attrs name and type +namespace { +const std::map kAttrTypesMap = { + {GeAttrValue::VT_NONE, "VT_STRING"}, + {GeAttrValue::VT_STRING, "VT_STRING"}, + {GeAttrValue::VT_FLOAT, "VT_FLOAT"}, + {GeAttrValue::VT_BOOL, "VT_BOOL"}, + {GeAttrValue::VT_INT, "VT_INT"}, + {GeAttrValue::VT_TENSOR_DESC, "VT_TENSOR_DESC"}, + {GeAttrValue::VT_TENSOR, "VT_TENSOR"}, + {GeAttrValue::VT_BYTES, "VT_BYTES"}, + {GeAttrValue::VT_GRAPH, "VT_GRAPH"}, + {GeAttrValue::VT_NAMED_ATTRS, "VT_NAMED_ATTRS"}, + {GeAttrValue::VT_LIST_BASE, "VT_LIST_BASE"}, + {GeAttrValue::VT_LIST_STRING, "VT_LIST_STRING"}, + {GeAttrValue::VT_LIST_FLOAT, "VT_LIST_FLOAT"}, + {GeAttrValue::VT_LIST_BOOL, "VT_LIST_BOOL"}, + {GeAttrValue::VT_LIST_INT, "VT_LIST_INT"}, + {GeAttrValue::VT_LIST_TENSOR_DESC, "VT_LIST_TENSOR_DESC"}, + {GeAttrValue::VT_LIST_TENSOR, "VT_LIST_TENSOR"}, + {GeAttrValue::VT_LIST_BYTES, "VT_LIST_BYTES"}, + {GeAttrValue::VT_GRAPH, "VT_GRAPH"}, + {GeAttrValue::VT_LIST_NAMED_ATTRS, "VT_LIST_NAMED_ATTRS"}, +}; +} // namespace +const std::map Operator::GetAllAttrNamesAndTypes() const { + std::map attr_types; + + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return attr_types, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return attr_types, "GetOpDescImpl is nullptr."); + std::map attr_map = operator_impl_->GetOpDescImpl()->GetAllAttrs(); + + map::iterator iter; + for (iter = attr_map.begin(); iter != attr_map.end(); ++iter) { + string name = iter->first; + GeAttrValue attr_value = iter->second; + + GeAttrValue::ValueType type = attr_value.GetValueType(); + + auto iter2 = kAttrTypesMap.find(type); + if (iter2 != kAttrTypesMap.end()) { + attr_types[name] = iter2->second; + } + } + + return attr_types; +} + +void Operator::InputRegister(const string &name) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + (void)operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); +} + +void Operator::OptionalInputRegister(const string &name) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + // [No need to verify return value] + (void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name, + GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED)); +} + +void Operator::InferFuncRegister(const std::function &func) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + // [No need to verify return value] + (void)operator_impl_->GetOpDescImpl()->AddInferFunc(func); +} + +void Operator::InferFormatFuncRegister(const std::function &func) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + // [No need to verify return value] + (void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func); +} + +void Operator::VerifierFuncRegister(const std::function &func) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + // [No need to verify return value] + (void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func); +} + +void Operator::OutputRegister(const string &name) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + // [No need to verify return value] + (void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc()); +} + +void Operator::DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return, + "set int failed"); + (void)operator_impl_->GetOpDescImpl()->AddDynamicInputDesc(name, num, is_push_back); +} + +void Operator::DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index) { + GE_CHK_BOOL_EXEC(!!operator_impl_, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(nullptr != operator_impl_->GetOpDescImpl(), return, "GetOpDescImpl is nullptr."); + operator_impl_->GetOpDescImpl()->AddDynamicInputDescByIndex(name, num, index); +} + +int Operator::GetDynamicInputNum(const string &name) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); + int num = 0; + GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return num, + "Get %s int failed", name.c_str()); + return num; +} + +void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return, + "Set %s int failed", name.c_str()); + (void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back); +} + +int Operator::GetDynamicOutputNum(const string &name) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); + int num = 0; + GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return num, + "Get %s int failed", name.c_str()); + return num; +} + +void Operator::RequiredAttrRegister(const string &name) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + operator_impl_->GetOpDescImpl()->AddRequiredAttr(name); +} + +graphStatus Operator::VerifyAll() { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); + + // Check all inputs defined + for (const string &iname : operator_impl_->GetOpDescImpl()->GetAllInputNames()) { + GE_CHK_BOOL_RET_STATUS(operator_impl_->GetOpDescImpl()->IsOptionalInput(iname) || operator_impl_->InputIsSet(iname), + GRAPH_FAILED, "operator input %s is not linked.", iname.c_str()); + vector ishape = operator_impl_->GetOpDescImpl()->GetInputDesc(iname).GetShape().GetDims(); + for (int64_t dim : ishape) { + GE_CHK_BOOL_RET_STATUS(dim > 0, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", + iname.c_str()); + } + } + // Check all attributes defined + const auto all_attributes = operator_impl_->GetOpDescImpl()->GetAllAttrs(); + for (const auto &name : operator_impl_->GetOpDescImpl()->GetAllAttrNames()) { + GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, + "operator attribute %s is empty.", name.c_str()); + } + + return GRAPH_SUCCESS; +} + +string Operator::GetOpType() const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return "Data", "operator impl is nullptr."); + return OperatorImpl::GetOpDesc(*this)->GetType(); +} + +Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt) { + string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index); + return SetInput(dynamic_dst_name, src_oprt); +} + +Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt, + const std::string &name) { + string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index); + return SetInput(dynamic_dst_name, src_oprt, name); +} + +OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; } + +#define OP_ATTR_SET_IMP(ArgType, AttrUtilsFun) \ + Operator &Operator::SetAttr(const string &name, ArgType attr_value) { \ + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ + return *this; \ + } \ + if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ + GELOGW("set attr name %s failed.", name.c_str()); \ + } \ + return *this; \ + } // lint !e665 + +#define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \ + graphStatus Operator::GetAttr(const string &name, ArgType attr_value) const { \ + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ + return GRAPH_FAILED; \ + } \ + if (!AttrUtils::Get##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ + GELOGW("get attr name %s failed.", name.c_str()); \ + return GRAPH_FAILED; \ + } \ + return GRAPH_SUCCESS; \ + } // lint !e665 + +void Operator::BreakConnect() const { + if (operator_impl_ == nullptr) { + GELOGW("operator impl is nullptr."); + return; + } + operator_impl_->ClearInputLinks(); + operator_impl_->ClearOutputLinks(); + OperatorKeeper::GetInstance().CheckOutOperator(operator_impl_); +} + +#define OP_ATTR_REG_IMP(ArgType, AttrUtilsFun) \ + void Operator::AttrRegister(const string &name, ArgType attr_value) { \ + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ + return; \ + } \ + if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ + GELOGW("reg attr name %s failed.", name.c_str()); \ + } \ + } // lint !e665 + +OP_ATTR_SET_IMP(int64_t, Int) +OP_ATTR_SET_IMP(int32_t, Int) +OP_ATTR_SET_IMP(uint32_t, Int) +OP_ATTR_GET_IMP(int64_t &, Int) +OP_ATTR_GET_IMP(int32_t &, Int) +OP_ATTR_GET_IMP(uint32_t &, Int) +OP_ATTR_SET_IMP(const vector &, ListInt) +OP_ATTR_SET_IMP(const vector &, ListInt) +OP_ATTR_SET_IMP(const vector &, ListInt) +OP_ATTR_SET_IMP(std::initializer_list &&, ListInt) +OP_ATTR_GET_IMP(vector &, ListInt) +OP_ATTR_GET_IMP(vector &, ListInt) +OP_ATTR_GET_IMP(vector &, ListInt) +OP_ATTR_GET_IMP(vector> &, ListListInt) +OP_ATTR_SET_IMP(const vector> &, ListListInt) + +OP_ATTR_SET_IMP(float, Float) +OP_ATTR_GET_IMP(float &, Float) +OP_ATTR_SET_IMP(const vector &, ListFloat) +OP_ATTR_GET_IMP(vector &, ListFloat) // lint !e665 + +OP_ATTR_SET_IMP(bool, Bool) +OP_ATTR_GET_IMP(bool &, Bool) +OP_ATTR_SET_IMP(const vector &, ListBool) +OP_ATTR_GET_IMP(vector &, ListBool) // lint !e665 + +OP_ATTR_SET_IMP(const string &, Str) +OP_ATTR_GET_IMP(string &, Str) +OP_ATTR_SET_IMP(const vector &, ListStr) +OP_ATTR_GET_IMP(vector &, ListStr) // lint !e665 + +OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) +OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs) +OP_ATTR_SET_IMP(const vector &, ListNamedAttrs) +OP_ATTR_GET_IMP(vector &, ListNamedAttrs) // lint !e665 + +OP_ATTR_REG_IMP(int64_t, Int) +OP_ATTR_REG_IMP(const vector &, ListInt) +OP_ATTR_REG_IMP(float, Float) +OP_ATTR_REG_IMP(const vector &, ListFloat) +OP_ATTR_REG_IMP(const string &, Str) +OP_ATTR_REG_IMP(const vector &, ListStr) +OP_ATTR_REG_IMP(bool, Bool) +OP_ATTR_REG_IMP(const vector &, ListBool) +OP_ATTR_REG_IMP(const vector> &, ListListInt) +OP_ATTR_REG_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) +OP_ATTR_REG_IMP(const vector &, ListNamedAttrs) + +#undef OP_ATTR_SET_IMP +#undef OP_ATTR_GET_IMP +#undef OP_ATTR_REG_IMP + +Operator &Operator::SetAttr(const string &name, const Tensor &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return *this; + } + GeTensor tensor = TensorAdapter::AsGeTensor(attr_value); + if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { + GELOGW("set attr name %s failed.", name.c_str()); + } + return *this; +} + +Operator &Operator::SetAttr(const string &name, const vector &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return *this; + } + vector val_list; + for (const auto &item : attr_value) { + auto tensor = TensorAdapter::AsGeTensor(item); + val_list.push_back(tensor); + } + if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { + GELOGW("set attr name %s failed.", name.c_str()); + } + return *this; +} + +graphStatus Operator::GetAttr(const string &name, Tensor &attr_value) const { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return GRAPH_FAILED; + } + ConstGeTensorPtr tensor; + if (!AttrUtils::GetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { + GELOGW("get attr name %s failed.", name.c_str()); + return GRAPH_FAILED; + } + attr_value = TensorAdapter::GeTensor2Tensor(tensor); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const string &name, vector &attr_value) const { + attr_value.clear(); + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return GRAPH_FAILED; + } + vector val_list; + if (!AttrUtils::GetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { + GELOGW("get attr name %s failed.", name.c_str()); + return GRAPH_FAILED; + } + for (auto &tensor : val_list) { + attr_value.push_back(TensorAdapter::GeTensor2Tensor(tensor)); + } + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const string &name, const OpBytes &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return *this; + } + if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, + Buffer::CopyFrom(attr_value.data(), attr_value.size()))) { + GELOGW("set attr name %s failed.", name.c_str()); + } + return *this; +} + +graphStatus Operator::GetAttr(const string &name, OpBytes &attr_value) const { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return GRAPH_FAILED; + } + Buffer buffer; + if (!AttrUtils::GetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, buffer)) { + GELOGW("get attr name %s failed.", name.c_str()); + return GRAPH_FAILED; + } + attr_value.clear(); + if (buffer.data() == nullptr) { + GELOGE(GRAPH_FAILED, "buffer data is null."); + return GRAPH_FAILED; + } + attr_value.assign(buffer.data(), buffer.data() + buffer.size()); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const string &name, ge::AttrValue &&attrValue) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); + (void)operator_impl_->SetAttr(name, std::move(attrValue.impl->geAttrValue_)); + return *this; +} + +graphStatus Operator::GetAttr(const string &name, ge::AttrValue &attrValue) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + return operator_impl_->GetAttr(name, attrValue.impl->geAttrValue_); +} + +Operator &Operator::SetAttr(const string &name, const std::vector &attr_value) { + if (operator_impl_ == nullptr || !operator_impl_->GetOpDescImpl()) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return *this; + } + if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("set attr name %s failed.", name.c_str()); + } + return *this; +} + +graphStatus Operator::GetAttr(const string &name, std::vector &attr_value) const { + attr_value.clear(); + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return GRAPH_FAILED; + } + if (!AttrUtils::GetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("get attr name %s failed.", name.c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const string &name, const ge::DataType &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return *this; + } + if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("set attr name %s failed.", name.c_str()); + } + return *this; +} + +graphStatus Operator::GetAttr(const string &name, ge::DataType &attr_value) const { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return GRAPH_FAILED; + } + if (!AttrUtils::GetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("get attr name %s failed.", name.c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +void Operator::AttrRegister(const string &name, const std::vector &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("set attr name %s failed.", name.c_str()); + } +} + +void Operator::AttrRegister(const string &name, const ge::DataType &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("set attr name %s failed.", name.c_str()); + } +} + +void Operator::AttrRegister(const string &name, const Tensor &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + auto tensor = TensorAdapter::AsGeTensor(attr_value); + if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { + GELOGW("reg attr name %s failed.", name.c_str()); + } +} + +void Operator::AttrRegister(const string &name, const vector &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + vector val_list; + for (const auto &item : attr_value) { + val_list.push_back(TensorAdapter::AsGeTensor(item)); + } + if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { + GELOGW("reg attr name %s failed.", name.c_str()); + } +} + +void Operator::AttrRegister(const string &name, const OpBytes &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, + Buffer::CopyFrom(attr_value.data(), attr_value.size()))) { + GELOGW("reg attr name %s failed.", name.c_str()); + } +} + +void Operator::SubgraphRegister(const std::string &name, bool dynamic) { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + operator_impl_->SubgraphRegister(name, dynamic ? kDynamic : kStatic); +} + +void Operator::SubgraphCountRegister(const std::string &name, uint32_t count) { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + operator_impl_->SubgraphCountRegister(name, count); +} + +void Operator::SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", ir_name.c_str()); + return; + } + operator_impl_->SetSubgraphBuilder(ir_name, index, builder); +} + +std::vector Operator::GetSubgraphNames() const { return operator_impl_->GetSubgraphNames(); } + +SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &ir_name, uint32_t index) const { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr."); + return nullptr; + } + return operator_impl_->GetSubgraphBuilder(ir_name, index); +} + +SubgraphBuilder Operator::GetSubgraphBuilder(const string &ir_name) const { + return GetDynamicSubgraphBuilder(ir_name, 0); +} + +Graph Operator::GetSubgraph(const string &name) const { + if (operator_impl_ == nullptr) { + GE_LOGE("Failed to get subgraph %s, the operator impl is null", name.c_str()); + return Graph(""); + } + auto op_desc = OpDescUtils::GetOpDescFromOperator(*this); + if (op_desc == nullptr) { + GE_LOGE("Failed to get subgraph %s, the op_desc is null", name.c_str()); + return Graph(""); + } + const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); + auto iter = subgraph_names_to_index.find(name); + if (iter == subgraph_names_to_index.end()) { + GE_LOGE("Failed to get subgraph %s, the name may be invalid", name.c_str()); + return Graph(""); + } + auto subgraph_instance_name = op_desc->GetSubgraphInstanceName(iter->second); + if (subgraph_instance_name.empty()) { + GE_LOGE("Failed to get subgraph %s index %u, the subgraph may not be added", name.c_str(), iter->second); + return Graph(""); + } + + auto node = operator_impl_->GetNode(); + if (node == nullptr) { + GE_LOGE("Failed to get subgraph %s, the node is null", name.c_str()); + return Graph(""); + } + auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); + if (root_graph == nullptr) { + GE_LOGE("Failed to get subgraph %s, can not find the root graph", name.c_str()); + return Graph(""); + } + auto subgraph = root_graph->GetSubgraph(subgraph_instance_name); + if (subgraph == nullptr) { + GE_LOGE("Failed to get subgraph %s index %u, can not find the instance %s from the root graph", name.c_str(), + iter->second, subgraph_instance_name.c_str()); + return Graph(""); + } + return GraphUtils::CreateGraphFromComputeGraph(subgraph); +} + +Graph Operator::GetDynamicSubgraph(const string &name, uint32_t index) const { + return GetSubgraph(name + std::to_string(index)); +} + +size_t Operator::GetSubgraphNamesCount() const { + if (operator_impl_ == nullptr) { + GE_LOGE("Failed to get subgraph names count, the operator impl is null"); + return 0; + } + return operator_impl_->GetSubgraphNamesCount(); +} + +class GraphBuilderImpl { + public: + explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared(name)) { + if (graph_ == nullptr) { + GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); + return; + } + } + + ~GraphBuilderImpl() {} + + ComputeGraphPtr BuildGraph(const std::vector &inputs) { + std::vector vec_inputs; + for (auto &it : inputs) { + auto src_op_impl = it.operator_impl_; + GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return nullptr, "Operator Impl is null."); + GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return nullptr, "Operator impl's opdesc is null."); + + string type = src_op_impl->op_desc_->GetType(); + auto node_op = ge::OperatorFactory::CreateOperator("node_op", type); + auto tensor_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); + node_op.BreakConnect(); + + GE_CHK_BOOL_EXEC(tensor_desc != nullptr, continue, "tensor_desc is null."); + if ((tensor_desc->GetInputsSize() == 0 && tensor_desc->GetOutputsSize() > 0) || type == DATA || + type == VARIABLE || type == INITDATA || type == GETNEXT) { + vec_inputs.push_back(it.operator_impl_); + } else { + GELOGW("Input operator should be Data, Variable operator or operator that has output but no input."); + } + } + GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr, + "User Input do not include operator such as " + "Data, Variable operator or operator that has output but no input."); + auto ret = WalkAllOperators(vec_inputs); + GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed."); + + ret = AddEdge(); + GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "AddEdge failed."); + + return graph_; + } + + const std::map &GetAllNodesInfo() const { return all_nodes_info_; } + + private: + graphStatus WalkAllOperators(const std::vector &vec_ops) { + GE_CHK_BOOL_EXEC(graph_ != nullptr, return GRAPH_FAILED, "graph_ is null.") + std::queue> que; + que.push(vec_ops); + while (!que.empty()) { + auto vec_tem = que.front(); + que.pop(); + for (const auto &op_impl : vec_tem) { + GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.") + GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue, + "This node %s has created.", op_impl->GetName().c_str()) + auto node_ptr = graph_->AddNode(op_impl->op_desc_); + GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed."); + all_nodes_info_.insert(std::make_pair(op_impl, node_ptr)); + + auto &out_links = op_impl->output_links_; + std::vector vec_op_forward{}; + for (const auto &out_link : out_links) { + for (const auto &op_forward : out_link.second) { + vec_op_forward.push_back(op_forward.GetOwner()); + } + } + + auto &out_control_links = op_impl->control_output_link_; + for (const auto &out_link : out_control_links) { + vec_op_forward.push_back(out_link.lock()); + } + que.push(vec_op_forward); + + auto &in_links = op_impl->input_link_; + std::vector vec_op_back_forward{}; + for (const auto &in_link : in_links) { + vec_op_back_forward.push_back(in_link.second.GetOwner()); + } + + auto &in_control_links = op_impl->control_input_link_; + for (const auto &in_link : in_control_links) { + vec_op_back_forward.push_back(in_link.lock()); + } + que.push(vec_op_back_forward); + + if (WalkAllSubgraphs(node_ptr, op_impl) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } + } + return MoveSubgraphToRoot(graph_); + } + + graphStatus WalkAllSubgraphs(const NodePtr &node, const OperatorImplPtr &op_impl) { + const string name = node->GetName(); + for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) { + const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first); + GE_CHK_BOOL_EXEC(builder != nullptr, return GRAPH_FAILED, "Node: %s, Get builder failed.", name.c_str()); + + Graph graph = builder(); // Build subgraph from user define builder. + const ComputeGraphPtr &subgraph = GraphUtils::GetComputeGraph(graph); + GE_CHK_BOOL_EXEC(subgraph != nullptr, return GRAPH_FAILED, "Node: %s, Build graph failed.", name.c_str()); + + subgraph->SetParentNode(node); + subgraph->SetParentGraph(graph_); + if (graph_->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + + if (op_impl->op_desc_->SetSubgraphInstanceName(name_idx.second, subgraph->GetName()) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to set subgraph %s index %u", subgraph->GetName().c_str(), name_idx.second); + return GRAPH_FAILED; + } + } + + return GRAPH_SUCCESS; + } + + graphStatus MoveSubgraphToRoot(const ComputeGraphPtr &graph) { + const ComputeGraphPtr &root_graph = GraphUtils::FindRootGraph(graph); + if (root_graph == nullptr) { + GELOGE(GRAPH_FAILED, "Graph: %s, Find root graph failed.", graph->GetName().c_str()); + return GRAPH_FAILED; + } + + if (root_graph == graph) { + auto subgraphs = graph->GetAllSubgraphs(); + for (auto &subgraph : subgraphs) { + if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } + } else { + auto subgraphs = graph->GetAllSubgraphs(); + for (auto &subgraph : subgraphs) { + if (root_graph->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + graph->RemoveSubgraph(subgraph->GetName()); + if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } + } + + return GRAPH_SUCCESS; + } + + graphStatus AddEdge() { + for (const auto &node_info : all_nodes_info_) { + auto src_op_impl_ptr = node_info.first; + auto src_node_ptr = node_info.second; + + GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue); + auto out_links = src_op_impl_ptr->output_links_; + GE_CHK_BOOL_EXEC(src_op_impl_ptr->op_desc_ != nullptr, return GRAPH_FAILED, + "Src operator impl's op_desc is null."); + auto &op_desc = src_op_impl_ptr->op_desc_; + GE_IF_BOOL_EXEC(op_desc == nullptr, continue); + for (const auto &out : out_links) { + auto src_idx = op_desc->GetOutputIndexByName(out.first); + GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed"); + + auto src_anchor = src_node_ptr->GetOutDataAnchor(src_idx); + GE_CHK_BOOL_EXEC(src_anchor != nullptr, return GRAPH_FAILED, "GetOutDataAnchor failed."); + + for (const auto &dst_opio : out.second) { + auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner()); + GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed."); + + GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); + + auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex()); + GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed."); + + auto ret = GraphUtils::AddEdge(src_anchor, dst_anchor); + GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED, + "from node[%s][%d] to node[%s][%d]AddEdge failed.", src_node_ptr->GetName().c_str(), + src_anchor->GetIdx(), dst_node_info->second->GetName().c_str(), dst_anchor->GetIdx()); + } + } + auto out_control_anchor = src_node_ptr->GetOutControlAnchor(); + for (const auto &control_out : src_op_impl_ptr->control_output_link_) { + auto dst_node_info = all_nodes_info_.find(control_out.lock()); + if (dst_node_info == all_nodes_info_.end()) { + GELOGE(GRAPH_FAILED, "Find Dst node failed."); + return GRAPH_FAILED; + } + GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); + auto in_control_anchor = dst_node_info->second->GetInControlAnchor(); + auto ret = GraphUtils::AddEdge(out_control_anchor, in_control_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "AddEdge failed. srcNode %s:%s, dstNode %s:%s", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), dst_node_info->second->GetName().c_str(), + dst_node_info->second->GetType().c_str()); + return ret; + } + } + } + return GRAPH_SUCCESS; + } + + ComputeGraphPtr graph_ = nullptr; + std::map all_nodes_info_{}; +}; + +inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) { + for (const auto &graph : compute_graph->GetAllSubgraphs()) { + std::set node_names; + for (auto const &node : graph->GetDirectNode()) { + auto result = node_names.insert(node->GetName()); + if (!result.second) { + GELOGE(GRAPH_FAILED, "graph %s has same name node%s", graph->GetName().c_str(), node->GetName().c_str()); + return true; + } + } + } + + std::set node_names; + for (auto const &node : compute_graph->GetDirectNode()) { + auto result = node_names.insert(node->GetName()); + if (!result.second) { + GELOGE(GRAPH_FAILED, "graph %s has same name node%s", compute_graph->GetName().c_str(), node->GetName().c_str()); + return true; + } + } + return false; +} + +ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector &inputs) { + auto graph_builder_impl = GraphBuilderImpl(name); + ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs); + GE_CHK_BOOL_EXEC(compute_graph != nullptr, return compute_graph, "Computer graph is nullptr"); + compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo()); + if (HasSameNameNode(compute_graph)) { + GELOGW("Compute do not allow has same name nodes."); + compute_graph = nullptr; + } + + return compute_graph; +} + +void GraphUtils::BreakConnect(const std::map &all_nodes_infos) { + for (const auto &it : all_nodes_infos) { + OperatorImplPtr op_impl = it.first; + if (op_impl == nullptr) { + GELOGW("operator impl is nullptr."); + continue; + } + op_impl->ClearOutputLinks(); + op_impl->ClearInputLinks(); + OperatorKeeper::GetInstance().CheckOutOperator(op_impl); + } +} +} // namespace ge +/*lint +e446 +e732*/ +/*lint +e665*/ diff --git a/src/common/graph/operator_factory.cc b/src/common/graph/operator_factory.cc new file mode 100644 index 00000000..43d61a7c --- /dev/null +++ b/src/common/graph/operator_factory.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/operator_factory_impl.h" +#include "debug/ge_log.h" + +namespace ge { +Operator OperatorFactory::CreateOperator(const std::string &operator_name, const std::string &operator_type) { + return OperatorFactoryImpl::CreateOperator(operator_name, operator_type); +} + +graphStatus OperatorFactory::GetOpsTypeList(std::vector &all_ops) { + return OperatorFactoryImpl::GetOpsTypeList(all_ops); +} + +bool OperatorFactory::IsExistOp(const string &operator_type) { return OperatorFactoryImpl::IsExistOp(operator_type); } + +OperatorCreatorRegister::OperatorCreatorRegister(const string &operator_type, OpCreator const &op_creator) { + (void)OperatorFactoryImpl::RegisterOperatorCreator(operator_type, op_creator); +} + +InferShapeFuncRegister::InferShapeFuncRegister(const std::string &operator_type, + const InferShapeFunc &infer_shape_func) { + (void)OperatorFactoryImpl::RegisterInferShapeFunc(operator_type, infer_shape_func); +} + +InferFormatFuncRegister::InferFormatFuncRegister(const std::string &operator_type, + const InferFormatFunc &infer_format_func) { + (void)OperatorFactoryImpl::RegisterInferFormatFunc(operator_type, infer_format_func); +} + +VerifyFuncRegister::VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func) { + (void)OperatorFactoryImpl::RegisterVerifyFunc(operator_type, verify_func); +} +} // namespace ge diff --git a/src/common/graph/operator_factory_impl.cc b/src/common/graph/operator_factory_impl.cc new file mode 100644 index 00000000..026a85bc --- /dev/null +++ b/src/common/graph/operator_factory_impl.cc @@ -0,0 +1,149 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/operator_factory_impl.h" +#include "debug/ge_log.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +shared_ptr> OperatorFactoryImpl::operator_creators_; +shared_ptr> OperatorFactoryImpl::operator_infershape_funcs_; +shared_ptr> OperatorFactoryImpl::operator_inferformat_funcs_; +shared_ptr> OperatorFactoryImpl::operator_verify_funcs_; + +Operator OperatorFactoryImpl::CreateOperator(const std::string &operator_name, const std::string &operator_type) { + if (operator_creators_ == nullptr) { + return Operator(); + } + auto it = operator_creators_->find(operator_type); + if (it == operator_creators_->end()) { + GELOGW("no OpProto of [%s] registered", operator_type.c_str()); + return Operator(); + } + return it->second(operator_name); +} + +graphStatus OperatorFactoryImpl::GetOpsTypeList(std::vector &all_ops) { + all_ops.clear(); + if (operator_creators_ != nullptr) { + for (auto it = operator_creators_->begin(); it != operator_creators_->end(); ++it) { + all_ops.emplace_back(it->first); + } + } else { + GELOGE(GRAPH_FAILED, "no operator creators found"); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +bool OperatorFactoryImpl::IsExistOp(const string &operator_type) { + if (operator_creators_ == nullptr) { + return false; + } + auto it = operator_creators_->find(operator_type); + if (it == operator_creators_->end()) { + return false; + } + return true; +} + +InferShapeFunc OperatorFactoryImpl::GetInferShapeFunc(const std::string &operator_type) { + if (operator_infershape_funcs_ == nullptr) { + return nullptr; + } + auto it = operator_infershape_funcs_->find(operator_type); + if (it == operator_infershape_funcs_->end()) { + return nullptr; + } + return it->second; +} + +InferFormatFunc OperatorFactoryImpl::GetInferFormatFunc(const std::string &operator_type) { + if (operator_inferformat_funcs_ == nullptr) { + GELOGI("operator_inferformat_funcs_ is null"); + return nullptr; + } + auto it = operator_inferformat_funcs_->find(operator_type); + if (it == operator_inferformat_funcs_->end()) { + return nullptr; + } + return it->second; +} + +VerifyFunc OperatorFactoryImpl::GetVerifyFunc(const std::string &operator_type) { + if (operator_verify_funcs_ == nullptr) { + return nullptr; + } + auto it = operator_verify_funcs_->find(operator_type); + if (it == operator_verify_funcs_->end()) { + return nullptr; + } + return it->second; +} + +graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) { + if (operator_creators_ == nullptr) { + operator_creators_.reset(new (std::nothrow) std::map()); + } + auto it = operator_creators_->find(operator_type); + if (it != operator_creators_->end()) { + return GRAPH_FAILED; + } + (void)operator_creators_->emplace(operator_type, op_creator); + return GRAPH_SUCCESS; +} + +graphStatus OperatorFactoryImpl::RegisterInferShapeFunc(const std::string &operator_type, + InferShapeFunc const infer_shape_func) { + if (operator_infershape_funcs_ == nullptr) { + GELOGI("operator_infershape_funcs_ init"); + operator_infershape_funcs_.reset(new (std::nothrow) std::map()); + } + auto it = operator_infershape_funcs_->find(operator_type); + if (it != operator_infershape_funcs_->end()) { + return GRAPH_FAILED; + } + (void)operator_infershape_funcs_->emplace(operator_type, infer_shape_func); + return GRAPH_SUCCESS; +} + +graphStatus OperatorFactoryImpl::RegisterInferFormatFunc(const std::string &operator_type, + InferFormatFunc const infer_format_func) { + if (operator_inferformat_funcs_ == nullptr) { + GELOGI("operator_inferformat_funcs_ init"); + operator_inferformat_funcs_.reset(new (std::nothrow) std::map()); + } + auto it = operator_inferformat_funcs_->find(operator_type); + if (it != operator_inferformat_funcs_->end()) { + return GRAPH_FAILED; + } + (void)operator_inferformat_funcs_->emplace(operator_type, infer_format_func); + return GRAPH_SUCCESS; +} + +graphStatus OperatorFactoryImpl::RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func) { + if (operator_verify_funcs_ == nullptr) { + GELOGI("operator_verify_funcs_ init"); + operator_verify_funcs_.reset(new (std::nothrow) std::map()); + } + auto it = operator_verify_funcs_->find(operator_type); + if (it != operator_verify_funcs_->end()) { + return GRAPH_FAILED; + } + (void)operator_verify_funcs_->emplace(operator_type, verify_func); + return GRAPH_SUCCESS; +} +} // namespace ge diff --git a/src/common/graph/opsproto/opsproto_manager.cc b/src/common/graph/opsproto/opsproto_manager.cc new file mode 100644 index 00000000..d482715b --- /dev/null +++ b/src/common/graph/opsproto/opsproto_manager.cc @@ -0,0 +1,187 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/opsproto_manager.h" +#include +#include +#include +#include +#include +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_log.h" + +namespace ge { +OpsProtoManager *OpsProtoManager::Instance() { + static OpsProtoManager instance; + return &instance; +} + +bool OpsProtoManager::Initialize(const std::map &options) { + std::lock_guard lock(mutex_); + + if (is_init_) { + GELOGI("OpsProtoManager is already initialized."); + return true; + } + + /*lint -e1561*/ + auto proto_iter = options.find("ge.opsProtoLibPath"); + /*lint +e1561*/ + if (proto_iter == options.end()) { + GELOGW("ge.opsProtoLibPath option not set, return."); + return false; + } + + pluginPath_ = proto_iter->second; + LoadOpsProtoPluginSo(pluginPath_); + + is_init_ = true; + + return true; +} + +void OpsProtoManager::Finalize() { + std::lock_guard lock(mutex_); + + if (!is_init_) { + GELOGI("OpsProtoManager is not initialized."); + return; + } + + for (auto handle : handles_) { + if (handle != nullptr) { + if (dlclose(handle) != 0) { + GELOGW("failed to close handle, message: %s", dlerror()); + continue; + } + GELOGI("close opsprotomanager handler success"); + } else { + GELOGW("close opsprotomanager handler failure, handler is nullptr"); + } + } + + is_init_ = false; +} + +static std::vector Split(const std::string &str, char delim) { + std::vector elems; + if (str.empty()) { + elems.emplace_back(""); + return elems; + } + + std::stringstream ss(str); + std::string item; + + while (getline(ss, item, delim)) { + elems.push_back(item); + } + + auto str_size = str.size(); + if (str_size > 0 && str[str_size - 1] == delim) { + elems.emplace_back(""); + } + + return elems; +} + +static void FindParserSo(const std::string &path, std::vector &file_list) { + // Lib plugin path not exist + if (path.empty()) { + GELOGI("realPath is empty"); + return; + } + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path.size() >= PATH_MAX, return, "path is invalid"); + + char resolved_path[PATH_MAX] = {0}; + + // Nullptr is returned when the path does not exist or there is no permission + // Return absolute path when path is accessible + if (realpath(path.c_str(), resolved_path) == nullptr) { + GELOGW("the path [%s] not exsit.", path.c_str()); + return; + } + + struct dirent *dent = nullptr; + DIR *dir = opendir(resolved_path); + // Lib plugin path not exist + if (dir == nullptr) { + GELOGW("Open directory %s failed,maybe it is not exit or not a dir", resolved_path); + return; + } + + while ((dent = readdir(dir)) != nullptr) { + if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) { + continue; + } + std::string name = dent->d_name; + std::string full_name = path + "/" + name; + const std::string so_suff = ".so"; + + if (dent->d_type != DT_DIR && name.size() >= so_suff.size() && + name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) { + file_list.push_back(full_name); + GELOGI("OpsProtoManager Parse full name = %s \n", full_name.c_str()); + } + } + if (closedir(dir) != 0) { + GELOGW("close dir fail."); + } +} + +static void GetPluginSoFileList(const std::string &path, std::vector &file_list) { + // Support multi lib directory with ":" as delimiter + std::vector v_path = Split(path, ':'); + + for (size_t i = 0; i < v_path.size(); ++i) { + FindParserSo(v_path[i], file_list); + GELOGI("OpsProtoManager full name = %s", v_path[i].c_str()); + } +} + +void OpsProtoManager::LoadOpsProtoPluginSo(std::string &path) { + if (path.empty()) { + GELOGE(GRAPH_FAILED, "filePath is invalid. please check your text file %s.", path.c_str()); + return; + } + std::vector file_list; + + // If there is .so file in the lib path + GetPluginSoFileList(path, file_list); + + // Not found any .so file in the lib path + if (file_list.empty()) { + GELOGE(GRAPH_FAILED, "OpsProtoManager can not find any plugin file in pluginPath: %s \n", path.c_str()); + return; + } + // Warning message + GELOGW("The shared library will not be checked. Please ensure that the source of the shared library is trusted."); + + // Load .so file + for (auto elem : file_list) { + void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle == nullptr) { + GELOGW("OpsProtoManager dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); + continue; + } else { + // Close dl when the program exist, not close here + GELOGI("OpsProtoManager plugin load %s success.", elem.c_str()); + handles_.push_back(handle); + } + } +} +} // namespace ge diff --git a/src/common/graph/option/ge_context.cc b/src/common/graph/option/ge_context.cc new file mode 100644 index 00000000..421e0aff --- /dev/null +++ b/src/common/graph/option/ge_context.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "./ge_context.h" +#include "./ge_global_options.h" +#include "./ge_local_context.h" +#include "framework/common/ge_types.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +namespace { +const int64_t kMinTrainingTraceJobId = 256; +const int kDecimal = 10; +const char *kHostExecPlacement = "HOST"; +} // namespace +GEContext &GetContext() { + static GEContext ge_context{}; + return ge_context; +} + +graphStatus GEContext::GetOption(const std::string &key, std::string &option) { + return GetThreadLocalContext().GetOption(key, option); +} + +bool GEContext::GetHostExecFlag() { + std::string exec_placement; + if (GetThreadLocalContext().GetOption(GE_OPTION_EXEC_PLACEMENT, exec_placement) != GRAPH_SUCCESS) { + GELOGW("get option OPTION_EXEC_PLACEMENT failed."); + return false; + } + GELOGD("Option ge.exec.placement is %s.", exec_placement.c_str()); + return exec_placement == kHostExecPlacement; +} + +std::map &GetMutableGlobalOptions() { + static std::map global_options{}; + return global_options; +} + +void GEContext::Init() { + string session_id; + (void)GetOption("ge.exec.sessionId", session_id); + try { + session_id_ = static_cast(std::stoi(session_id.c_str())); + } catch (std::invalid_argument &) { + GELOGW("%s transform to int failed.", session_id.c_str()); + } catch (std::out_of_range &) { + GELOGW("%s transform to int failed.", session_id.c_str()); + } + + string device_id; + (void)GetOption("ge.exec.deviceId", device_id); + try { + device_id_ = static_cast(std::stoi(device_id.c_str())); + } catch (std::invalid_argument &) { + GELOGW("%s transform to int failed.", device_id.c_str()); + } catch (std::out_of_range &) { + GELOGW("%s transform to int failed.", device_id.c_str()); + } + + string job_id; + (void)GetOption("ge.exec.jobId", job_id); + std::string s_job_id = ""; + for (auto c : job_id) { + if (c >= '0' && c <= '9') { + s_job_id += c; + } + } + if (s_job_id == "") { + trace_id_ = kMinTrainingTraceJobId; + return; + } + int64_t d_job_id = std::strtoll(s_job_id.c_str(), nullptr, kDecimal); + if (d_job_id < kMinTrainingTraceJobId) { + trace_id_ = d_job_id + kMinTrainingTraceJobId; + } else { + trace_id_ = d_job_id; + } +} + +uint64_t GEContext::SessionId() { return session_id_; } + +uint32_t GEContext::DeviceId() { return device_id_; } + +uint64_t GEContext::TraceId() { return trace_id_; } + +void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; } + +void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } + +} // namespace ge diff --git a/src/common/graph/option/ge_local_context.cc b/src/common/graph/option/ge_local_context.cc new file mode 100644 index 00000000..82b1cb01 --- /dev/null +++ b/src/common/graph/option/ge_local_context.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "./ge_local_context.h" +#include + +namespace ge { +namespace { +thread_local GEThreadLocalContext thread_context; +} + +GEThreadLocalContext &GetThreadLocalContext() { return thread_context; } + +graphStatus GEThreadLocalContext::GetOption(const string &key, string &option) { + auto graph_iter = graph_options_.find(key); + if (graph_iter != graph_options_.end()) { + option = graph_iter->second; + return GRAPH_SUCCESS; + } + auto session_iter = session_options_.find(key); + if (session_iter != session_options_.end()) { + option = session_iter->second; + return GRAPH_SUCCESS; + } + auto global_iter = global_options_.find(key); + if (global_iter != global_options_.end()) { + option = global_iter->second; + return GRAPH_SUCCESS; + } + return GRAPH_PARAM_INVALID; +} + +void GEThreadLocalContext::SetGlobalOption(map options_map) { + global_options_.clear(); + global_options_ = std::move(options_map); +} + +void GEThreadLocalContext::SetSessionOption(map options_map) { + session_options_.clear(); + session_options_ = std::move(options_map); +} + +void GEThreadLocalContext::SetGraphOption(map options_map) { + graph_options_.clear(); + graph_options_ = std::move(options_map); +} +} // namespace ge diff --git a/src/common/graph/ref_relation.cc b/src/common/graph/ref_relation.cc new file mode 100644 index 00000000..48e136fb --- /dev/null +++ b/src/common/graph/ref_relation.cc @@ -0,0 +1,455 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/ref_relation.h" + +#include +#include + +#include "utils/mem_utils.h" +#include "debug/ge_log.h" +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "debug/ge_attr_define.h" +#include "graph/ge_error_codes.h" +#include "graph/utils/graph_utils.h" +#include "framework/common/debug/ge_log.h" + +using namespace std; +using namespace ge; +namespace ge { +namespace { +const char *kRefIndex = "_parent_node_index"; +const string kWhile = "While"; +const string kIf = "If"; +const string kCase = "Case"; + +const uint16_t kMaxElementNum = 100; + +std::unordered_set function_op = {kWhile, kIf, kCase}; +} // namespace + +/* Impl */ +class RefRelations::Impl { + public: + graphStatus LookUpRefRelations(const RefCell &key, unordered_set &result) { + unsigned long number = static_cast(reinterpret_cast(key.node.get())); + std::string lookup_key = + key.node_name + std::to_string(key.in_out) + std::to_string(key.in_out_idx) + std::to_string(number); + auto iter = look_up_table_.find(lookup_key); + if (iter != look_up_table_.end()) { + for (auto &c : iter->second) { + result.insert(c); + } + return GRAPH_SUCCESS; + } + GELOGW("can not find any relations! key value of dest relation is %s", lookup_key.c_str()); + return GRAPH_SUCCESS; + }; + graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); + graphStatus Clear() { + GELOGD("Start clear boundary reflections between main graph and sub graph!"); + look_up_table_.clear(); + values_.clear(); + return GRAPH_SUCCESS; + }; + + private: + graphStatus BuildLookUpTables(); + graphStatus BuildRefRelationsForBranch(const NodePtr &root_node, const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, + vector> &node_refs); + graphStatus BuildRefRelationsForWhile(const NodePtr &root_node, const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, + vector> &node_refs); + graphStatus BuildRelationsWithFuncNodeType(const NodePtr &root_node, + const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, + vector> &node_refs); + void GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector &data_nodes, + vector &netoutput_nodes, const std::vector &sub_graph_names, + const std::string &node_type); + + graphStatus GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph); + graphStatus ProcessSubgraphDataNodes(vector &data_nodes, vector> &classed_data_nodes); + graphStatus ProcessSubgraphNetoutput(const vector &netoutput_nodes, + vector>> &classed_netoutput_nodes); + + std::unordered_map> look_up_table_; + std::vector>> values_; +}; + +// Node Level +graphStatus RefRelations::Impl::BuildRefRelationsForBranch( + const NodePtr &root_node, const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, vector> &node_refs) { + GELOGD("Enter BuildRefRelationsForBranch!"); + + size_t ref_i = 0; + for (const auto &ref_i_data_nodes : classed_data_nodes) { + vector in_ref_i_all_refs; + RefCell cell_root; + cell_root.node_name = root_node->GetName(); + cell_root.node = root_node; + cell_root.in_out = NODE_IN; + cell_root.in_out_idx = ref_i; + in_ref_i_all_refs.emplace_back(cell_root); + for (const auto &data : ref_i_data_nodes) { + RefCell cell_in; + RefCell cell_out; + cell_in.node_name = data->GetName(); + cell_in.node = data; + cell_in.in_out = NODE_IN; + cell_in.in_out_idx = 0; + cell_out.node_name = data->GetName(); + cell_out.node = data; + cell_out.in_out = NODE_OUT; + cell_out.in_out_idx = 0; + in_ref_i_all_refs.emplace_back(cell_in); + in_ref_i_all_refs.emplace_back(cell_out); + } + node_refs.emplace_back(in_ref_i_all_refs); + ref_i++; + } + + size_t ref_o = 0; + for (const auto &ref_o_net_nodes : classed_netoutput_nodes) { + vector out_ref_i_all_refs; + RefCell cell_root; + cell_root.node_name = root_node->GetName(); + cell_root.node = root_node; + cell_root.in_out = NODE_OUT; + cell_root.in_out_idx = ref_o; + out_ref_i_all_refs.emplace_back(cell_root); + for (const auto &ele : ref_o_net_nodes) { + RefCell cell_netoutput_in; + cell_netoutput_in.node_name = (ele.first)->GetName(); + cell_netoutput_in.node = ele.first; + cell_netoutput_in.in_out = NODE_IN; + cell_netoutput_in.in_out_idx = ele.second; + out_ref_i_all_refs.emplace_back(cell_netoutput_in); + } + node_refs.emplace_back(out_ref_i_all_refs); + ref_o++; + } + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::BuildLookUpTables() { + GELOGD("start to build look up table!"); + for (size_t i = 0; i < values_.size(); i++) { + vector> &val = values_[i]; + for (const auto &ele : val) { + for (const auto &ref_cell : ele) { + string key = ref_cell.node_name + std::to_string(ref_cell.in_out) + std::to_string(ref_cell.in_out_idx) + + std::to_string(static_cast(reinterpret_cast(ref_cell.node.get()))); + look_up_table_[key] = ele; + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::BuildRefRelationsForWhile( + const NodePtr &root_node, const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, vector> &node_refs) { + GELOGD("Enter BuildRefRelations for while op!"); + // data_nodes has been sorted + // for while, input num must be same as output num + auto input_num = root_node->GetAllInDataAnchorsSize(); + NodePtr netoutput = nullptr; + + size_t ref_i = 0; + while (ref_i < input_num) { + auto &ref_i_data_nodes = classed_data_nodes[ref_i]; + auto &ref_i_net_nodes = classed_netoutput_nodes[ref_i]; + + vector ref_i_all_refs; + RefCell cell_root_i; + RefCell cell_root_o; + cell_root_i.node_name = root_node->GetName(); + cell_root_i.node = root_node; + cell_root_i.in_out = NODE_IN; + cell_root_i.in_out_idx = ref_i; + ref_i_all_refs.emplace_back(cell_root_i); + cell_root_o.node_name = root_node->GetName(); + cell_root_o.node = root_node; + cell_root_o.in_out = NODE_OUT; + cell_root_o.in_out_idx = ref_i; + ref_i_all_refs.emplace_back(cell_root_o); + for (const auto &data : ref_i_data_nodes) { + RefCell cell_in; + RefCell cell_out; + cell_in.node_name = data->GetName(); + cell_in.node = data; + cell_in.in_out = NODE_IN; + cell_in.in_out_idx = 0; + cell_out.node_name = data->GetName(); + cell_out.node = data; + cell_out.in_out = NODE_OUT; + cell_out.in_out_idx = 0; + ref_i_all_refs.emplace_back(cell_in); + ref_i_all_refs.emplace_back(cell_out); + } + + for (const auto &ele : ref_i_net_nodes) { + RefCell cell_netoutput_in; + RefCell cell_netoutput_out; + cell_netoutput_in.node_name = (ele.first)->GetName(); + cell_netoutput_in.node = ele.first; + cell_netoutput_in.in_out = NODE_IN; + cell_netoutput_in.in_out_idx = ele.second; + ref_i_all_refs.emplace_back(cell_netoutput_in); + netoutput = ele.first; + } + node_refs.emplace_back(ref_i_all_refs); + ref_i++; + } + /* There exist scene like the follows, it means data0 data1 netoutput 0'th + * and 1'th tensor should be the same addr. + * Data0 Data1 + * \/ + * /\ + * netoutput + */ + if (netoutput == nullptr) { + return GRAPH_SUCCESS; + } + for (const auto &in_anchor : netoutput->GetAllInDataAnchors()) { + auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_data_anchor == nullptr) { + continue; + } + auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); + if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { + GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (netoutput->GetName()).c_str()); + continue; + } + if (peer_out_data_node->GetType() != DATA) { + continue; + } + auto in_data_anchor_idx = in_anchor->GetIdx(); + auto net_in_desc = netoutput->GetOpDesc()->MutableInputDesc(static_cast(in_data_anchor_idx)); + int ref_d = 0; + int ref_n = 0; + (void)AttrUtils::GetInt(peer_out_data_node->GetOpDesc(), kRefIndex, ref_d); + (void)AttrUtils::GetInt(net_in_desc, kRefIndex, ref_n); + + node_refs[ref_d].insert(node_refs[ref_d].end(), node_refs[ref_n].begin(), node_refs[ref_n].end()); + node_refs[ref_n].insert(node_refs[ref_n].end(), node_refs[ref_d].begin(), node_refs[ref_d].end()); + } + + return GRAPH_SUCCESS; +} +// build ref relations according to diff func op type +graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType( + const NodePtr &root_node, const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, vector> &node_refs) { + // data_nodes has been sorted + auto node_type = root_node->GetType(); + + auto status = GRAPH_SUCCESS; + if (node_type != kWhile) { + status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); + } else { + status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); + } + return status; +} + +void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector &data_nodes, + vector &netoutput_nodes, + const std::vector &sub_graph_names, + const std::string &node_type) { + int sub_graph_idx = 0; + for (const auto &name : sub_graph_names) { + auto sub_graph = root_graph.GetSubgraph(name); + if (sub_graph == nullptr) { + GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str()); + continue; + } + for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { + auto sub_graph_node_type = sub_graph_node->GetType(); + + if (sub_graph_node_type == DATA) { + data_nodes.emplace_back(sub_graph_node); + } else if (sub_graph_node_type == NETOUTPUT) { + // if while, the first subgraph must be cond subgraph. + // There is no meaning for refs ,so continue + if (node_type == kWhile && sub_graph_idx == 0) { + continue; + } + netoutput_nodes.emplace_back(sub_graph_node); + } + continue; + } + sub_graph_idx++; + } +} + +graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) { + auto parent_graph_ptr = graph.GetParentGraph(); + if (parent_graph_ptr == nullptr) { + root_graph = graph; + return GRAPH_SUCCESS; + } + auto root_graph_ptr = GraphUtils::FindRootGraph(parent_graph_ptr); + if (root_graph_ptr == nullptr) { + GE_LOGE("Get null root graph"); + return GRAPH_PARAM_INVALID; + } + root_graph = *root_graph_ptr; + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector &data_nodes, + vector> &classed_data_nodes) { + GELOGD("start to process subgraph data nodes!"); + int max_ref_idx = 0; + for (const auto &e : data_nodes) { + int i; + bool is_exist = true; + is_exist = AttrUtils::GetInt(e->GetOpDesc(), kRefIndex, i); + if (!is_exist) { + GELOGE(GRAPH_FAILED, "Invalid SubGraph NetOutput node[%s].no attr %s", e->GetName().c_str(), kRefIndex); + return GRAPH_FAILED; + } + max_ref_idx = (i > max_ref_idx) ? i : max_ref_idx; + } + + while (!data_nodes.empty()) { + auto data = data_nodes.back(); + data_nodes.pop_back(); + int ref_idx = 0; + (void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx); + if (ref_idx >= static_cast(classed_data_nodes.size())) { + return GRAPH_FAILED; + } + classed_data_nodes[ref_idx].emplace_back(data); + } + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( + const vector &netoutput_nodes, vector>> &classed_netoutput_nodes) { + GELOGD("[RefRelations]Start to process subgraph netoutput!"); + for (const auto &sub_netoutput_node : netoutput_nodes) { + auto op_desc = sub_netoutput_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchors()) { + auto in_desc = op_desc->MutableInputDesc(in_data_anchor->GetIdx()); + if (in_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Invalid NetOutput node [%s] idx [%lu], no tensor on it", + sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx()); + return GRAPH_FAILED; + } + int ref_o; + if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) { + if (ref_o >= static_cast(classed_netoutput_nodes.size())) { + return GRAPH_FAILED; + } + classed_netoutput_nodes[ref_o].emplace_back( + std::pair({sub_netoutput_node, static_cast(in_data_anchor->GetIdx())})); + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { + GELOGD("Start to build ref relations!"); + /* First Step: Get root graph */ + ge::ComputeGraph &root_graph = graph; + auto status = GetRootGraph(graph, root_graph); + if (status != GRAPH_SUCCESS) { + return status; + } + + for (const auto &node : graph.GetAllNodes()) { + auto node_type = node->GetType(); + std::vector ref_nodes; + auto op_desc = node->GetOpDesc(); + auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); + if (sub_graph_names.empty()) { + continue; + } + vector data_nodes; + vector netoutput_nodes; + // Get data and netoutput of sub_graph + GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type); + size_t max_elem_num = (data_nodes.size() > kMaxElementNum) ? data_nodes.size() : kMaxElementNum; + vector> classed_data_nodes(max_elem_num); // according to ref_idx + vector>> classed_netoutput_nodes(max_elem_num); // according to ref_idx + status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "classfy data nodes failed!"); + return status; + } + + // for netoutput + // check netoutput + // here main graph output number must be the same as every sub_graph netoutput node + // key: netoutput node_ptr , + status = ProcessSubgraphNetoutput(netoutput_nodes, classed_netoutput_nodes); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "process netoutput failed!"); + return status; + } + + vector> node_refs; + status = BuildRelationsWithFuncNodeType(node, classed_data_nodes, classed_netoutput_nodes, node_refs); + if (status != GRAPH_SUCCESS) { + GELOGE(status, "BuildRelationsWithFuncNodeType Failed! Node is [%s]!", node->GetName().c_str()); + return status; + } + if (!node_refs.empty()) { + values_.push_back(node_refs); + } + } + /* Seconde Step: generate map */ + status = BuildLookUpTables(); + if (status != GRAPH_SUCCESS) { + GELOGE(status, "Build look up tables failed!"); + return status; + } + return GRAPH_SUCCESS; +} + +/* Ref Relations Interface */ +RefRelations::RefRelations() { + impl_ = MakeShared(); + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "MakeShared failed!"); + return; + } +} + +graphStatus RefRelations::LookUpRefRelations(const RefCell &key, unordered_set &result) { + GE_CHECK_NOTNULL(impl_); + return impl_->LookUpRefRelations(key, result); +} + +graphStatus RefRelations::BuildRefRelations(ge::ComputeGraph &root_graph) { + GE_CHECK_NOTNULL(impl_); + return impl_->BuildRefRelations(root_graph); +} + +graphStatus RefRelations::Clear() { + GE_CHECK_NOTNULL(impl_); + return impl_->Clear(); +} +} // namespace ge \ No newline at end of file diff --git a/src/common/graph/runtime_inference_context.cc b/src/common/graph/runtime_inference_context.cc new file mode 100644 index 00000000..361d893c --- /dev/null +++ b/src/common/graph/runtime_inference_context.cc @@ -0,0 +1,129 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/runtime_inference_context.h" +#include "graph/utils/tensor_adapter.h" +#include +#include "framework/common/debug/ge_log.h" + +namespace ge { +std::map> RuntimeInferenceContext::contexts_; +std::mutex RuntimeInferenceContext::ctx_mu_; + +graphStatus RuntimeInferenceContext::CreateContext(const std::string &context_id) { + GELOGI("To create context. session id = %s", context_id.c_str()); + auto ctx = std::unique_ptr(new (std::nothrow) RuntimeInferenceContext()); + if (ctx == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to create instance of RuntimeInferenceContext. context_id = %s", context_id.c_str()); + return GRAPH_FAILED; + } + + std::lock_guard lk(ctx_mu_); + auto emplace_ret = contexts_.emplace(context_id, std::move(ctx)); + if (!emplace_ret.second) { + GELOGE(GRAPH_FAILED, "Old context not destroyed"); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +void RuntimeInferenceContext::DestroyContext(const std::string &context_id) { + GELOGI("To destroy context. session id = %s", context_id.c_str()); + std::lock_guard lk(ctx_mu_); + contexts_.erase(context_id); +} + +graphStatus RuntimeInferenceContext::GetContext(const std::string &context_id, RuntimeInferenceContext **ctx) { + std::lock_guard lk(ctx_mu_); + auto it = contexts_.find(context_id); + if (it != contexts_.end()) { + *ctx = it->second.get(); + return GRAPH_SUCCESS; + } + + GELOGD("Runtime inference context not created. session id = %s", context_id.c_str()); + return GRAPH_FAILED; +} + +graphStatus RuntimeInferenceContext::SetTensor(int64_t node_id, int output_id, Tensor &&tensor) { + std::lock_guard lk(mu_); + auto &output_tensors = tensors_[node_id]; + if (static_cast(output_id) >= output_tensors.size()) { + output_tensors.resize(output_id + 1); + } + + GELOGD("Set tensor for node_id = %ld, output_id = %d", node_id, output_id); + output_tensors[output_id] = std::move(tensor); + + auto &output_ge_tensors = ge_tensors_[node_id]; + if (static_cast(output_id) >= output_ge_tensors.size()) { + output_ge_tensors.resize(output_id + 1); + } + + GELOGD("Set ge tensor for node_id = %ld, output_id = %d", node_id, output_id); + output_ge_tensors[output_id] = TensorAdapter::AsGeTensorPtr(tensor); + return GRAPH_SUCCESS; +} + +graphStatus RuntimeInferenceContext::GetTensor(int64_t node_id, int output_id, Tensor &tensor) { + if (output_id < 0) { + GELOGE(GRAPH_PARAM_INVALID, "Invalid output index: %d", output_id); + return GRAPH_PARAM_INVALID; + } + + std::lock_guard lk(mu_); + auto iter = tensors_.find(node_id); + if (iter == tensors_.end()) { + GELOGE(INTERNAL_ERROR, "Node not register. Id = %ld", node_id); + return INTERNAL_ERROR; + } + + auto &output_tensors = iter->second; + if (static_cast(output_id) >= output_tensors.size()) { + GELOGE(GRAPH_FAILED, "Node output is not registered. node_id = %ld, output index = %d", node_id, output_id); + return GRAPH_FAILED; + } + + GELOGD("Get tensor for node_id = %ld, output_id = %d", node_id, output_id); + tensor = output_tensors[output_id]; + return GRAPH_SUCCESS; +} + +graphStatus RuntimeInferenceContext::GetTensor(int64_t node_id, int output_id, GeTensorPtr &tensor) { + if (output_id < 0) { + GELOGE(GRAPH_PARAM_INVALID, "Invalid output index: %d", output_id); + return GRAPH_PARAM_INVALID; + } + + std::lock_guard lk(mu_); + auto iter = ge_tensors_.find(node_id); + if (iter == ge_tensors_.end()) { + GELOGE(INTERNAL_ERROR, "Node not register. Id = %ld", node_id); + return INTERNAL_ERROR; + } + + auto &output_tensors = iter->second; + if (static_cast(output_id) >= output_tensors.size()) { + GELOGE(GRAPH_FAILED, "Node output is not registered. node_id = %ld, output index = %d", node_id, output_id); + return GRAPH_FAILED; + } + + GELOGD("Get ge tensor for node_id = %ld, output_id = %d", node_id, output_id); + tensor = output_tensors[output_id]; + return GRAPH_SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/src/common/graph/shape_refiner.cc b/src/common/graph/shape_refiner.cc new file mode 100644 index 00000000..17423da4 --- /dev/null +++ b/src/common/graph/shape_refiner.cc @@ -0,0 +1,688 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/shape_refiner.h" + +#include +#include +#include +#include +#include +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" + +#include "debug/ge_log.h" +#include "debug/ge_op_types.h" +#include "external/graph/operator.h" +#include "external/graph/operator_factory.h" +#include "framework/common/debug/ge_log.h" +#include "graph/compute_graph.h" +#include "utils/node_utils.h" +#include "utils/op_desc_utils.h" +#include "utils/tensor_utils.h" +#include "utils/type_utils.h" + +namespace ge { +namespace { +const uint32_t kWhileBodySubGraphIdx = 1; + +graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) { + GELOGD("Enter reverse brush while body subgraph process!"); + + auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx); + if (sub_graph_body == nullptr) { + GELOGE(GRAPH_FAILED, "Get while body graph failed!"); + return GRAPH_FAILED; + } + + for (const auto &node_sub : sub_graph_body->GetAllNodes()) { + for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) { + auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i); + GE_IF_BOOL_EXEC(input_desc == nullptr, + GELOGW("Get null input by index %zu from node %s ", i, node_sub->GetName().c_str()); + continue); + (void)input_desc->SetUnknownDimNumShape(); + } + for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) { + auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i); + (void)output_desc->SetUnknownDimNumShape(); + } + } + + return GRAPH_SUCCESS; +} + +graphStatus UpdataOutputForMultiBatcch(const ConstNodePtr &node, + std::vector> &ref_out_tensors) { + // check sub_graph shape. Get max for update. + for (size_t i = 0; i < ref_out_tensors.size(); ++i) { + if (ref_out_tensors[i].empty()) { + continue; + } + + int64_t max_size = 0; + size_t max_shape_index = 0; + auto &ref_out_tensor = ref_out_tensors[i].at(0); + const auto &ref_out_tensor_shape = ref_out_tensor.MutableShape(); + for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) { + auto &tensor = ref_out_tensors[i].at(j); + if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { + GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); + return GRAPH_FAILED; + } + + auto shape = tensor.MutableShape(); + if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { + GELOGE(GRAPH_FAILED, "node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", + node->GetName().c_str(), i, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); + return GRAPH_FAILED; + } + + int64_t size = 1; + for (auto dim : shape.GetDims()) { + if (INT64_MAX / dim < size) { + GELOGE(PARAM_INVALID, "The shape size overflow"); + return PARAM_INVALID; + } + size *= dim; + } + + if (size > max_size) { + max_size = size; + max_shape_index = j; + } + } + + (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index)); + } + + return GRAPH_SUCCESS; +} + +graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node, + std::vector> &ref_out_tensors) { + GELOGD("Enter update parent node shape for class branch op process"); + if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { + return UpdataOutputForMultiBatcch(node, ref_out_tensors); + } + + // check sub_graph shape.If not same ,do unknown shape process + for (size_t i = 0; i < ref_out_tensors.size(); i++) { + if (ref_out_tensors[i].empty()) { + continue; + } + auto ref_out_tensor = ref_out_tensors[i].at(0); + ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); + for (auto &tensor : ref_out_tensors[i]) { + if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { + GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); + return GRAPH_FAILED; + } + auto shape = tensor.MutableShape(); + if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { + GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, + shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); + ref_out_tensor_shape = GeShape(UNKNOWN_RANK); + break; + } + for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { + if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { + continue; + } + GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, + j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); + (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); + } + } + (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); + } + return GRAPH_SUCCESS; +} + +graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector> &ref_data_tensors, + std::vector> &ref_out_tensors) { + GELOGD("Enter update parent node shape for class while op process"); + if (ref_data_tensors.size() != ref_out_tensors.size()) { + GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(), + ref_data_tensors.size(), ref_out_tensors.size()); + return GRAPH_FAILED; + } + for (size_t i = 0; i < ref_data_tensors.size(); i++) { + if (ref_out_tensors[i].size() != 1) { + GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!"); + return GRAPH_FAILED; + } + } + bool is_need_reverse_brush = false; + // check input and output + for (size_t i = 0; i < ref_out_tensors.size(); i++) { + if (ref_out_tensors[i].empty()) { + continue; + } + auto ref_out_tensor = ref_out_tensors[i].at(0); + auto tmp_shape = ref_out_tensor.MutableShape(); + // ref_i's data and output tensor shape should be same + for (auto &tensor : ref_data_tensors[i]) { + if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { + GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str()); + return GRAPH_FAILED; + } + auto shape = tensor.MutableShape(); + if (shape.GetDims() != tmp_shape.GetDims()) { + ref_out_tensor.SetUnknownDimNumShape(); + is_need_reverse_brush = true; + break; + } + } + (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); + } + // reverse refresh while body shape + if (is_need_reverse_brush) { + return ReverseBrushWhileBodySubGraph(node); + } + return GRAPH_SUCCESS; +} + +graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { + auto op_desc = node->GetOpDesc(); + auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); + if (sub_graph_names.empty()) { + return GRAPH_SUCCESS; + } + + auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); + for (const auto &name : sub_graph_names) { + if (name.empty()) { + GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); + continue; + } + auto sub_graph = root_graph->GetSubgraph(name); + if (sub_graph == nullptr) { + GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); + return GRAPH_FAILED; + } + for (const auto &node_sub : sub_graph->GetDirectNode()) { + if (node_sub->GetType() != DATA) { + continue; + } + int ref_i; + auto data_opdesc = node_sub->GetOpDesc(); + if (data_opdesc == nullptr) { + GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), + node->GetName().c_str()); + return GRAPH_FAILED; + } + if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { + GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), + node->GetName().c_str()); + return GRAPH_FAILED; + } + if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { + continue; + } + auto input_desc = op_desc->MutableInputDesc(ref_i); + if (input_desc == nullptr) { + GE_LOGE( + "The ref index(%d) on the data %s on the sub graph %s " + "parent node %s are incompatible, inputs num %u", + ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); + return GRAPH_FAILED; + } + GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), + node->GetName().c_str()); + auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); + + if (ret != GRAPH_SUCCESS) { + GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s", + node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); + return ret; + } + ret = data_opdesc->UpdateOutputDesc(0, *input_desc); + if (ret != GRAPH_SUCCESS) { + GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s", + node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); + return ret; + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr &sub_graph, NodePtr &netoutput, + const ConstNodePtr &node, + std::vector> &ref_data_tensors) { + auto sub_nodes = sub_graph->GetDirectNode(); + for (size_t i = sub_nodes.size(); i > 0; --i) { + auto sub_node = sub_nodes.at(i - 1); + if (sub_node->GetType() == NETOUTPUT) { + netoutput = sub_node; + } + if (sub_node->GetType() == DATA) { + if (sub_node->GetOpDesc() == nullptr) { + return GRAPH_FAILED; + } + + int ref_i; + if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { + GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); + return GRAPH_FAILED; + } + if (ref_i < 0 || static_cast(ref_i) >= node->GetAllInDataAnchorsSize()) { + GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(), + ref_i, node->GetAllInDataAnchorsSize()); + return GRAPH_FAILED; + } + ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0)); + } + } + return GRAPH_SUCCESS; +} + +graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { + auto op_desc = node->GetOpDesc(); + auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); + if (sub_graph_names.empty()) { + return GRAPH_SUCCESS; + } + + std::vector> ref_data_tensors(node->GetAllInDataAnchorsSize()); + std::vector> ref_out_tensors(node->GetAllOutDataAnchorsSize()); + auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); + + for (const auto &name : sub_graph_names) { + if (name.empty()) { + GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); + continue; + } + auto sub_graph = root_graph->GetSubgraph(name); + if (sub_graph == nullptr) { + GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); + return GRAPH_FAILED; + } + NodePtr netoutput = nullptr; + auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); + if (ret != GRAPH_SUCCESS) { + return ret; + } + if (netoutput == nullptr) { + GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); + return GRAPH_FAILED; + } + auto netoutput_opdesc = netoutput->GetOpDesc(); + if (netoutput_opdesc == nullptr) { + GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(), + node->GetName().c_str()); + return GRAPH_FAILED; + } + for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) { + auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx()); + if (edge_desc == nullptr) { + GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", name.c_str(), + node->GetName().c_str(), edge_anchor->GetIdx()); + return GRAPH_FAILED; + } + GELOGI("Netoutput in anchor index is %zu, input tensor dim is %zu", edge_anchor->GetIdx(), + edge_desc->GetShape().GetDimNum()); + int ref_i; + if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { + // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. + continue; + } + GELOGI("Parent node index of edge desc is %d", ref_i); + if (ref_i < 0 || static_cast(ref_i) >= node->GetAllOutDataAnchorsSize()) { + return GRAPH_FAILED; + } + ref_out_tensors[ref_i].emplace_back(*edge_desc); + } + } + + if (node->GetType() == WHILE) { + return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors); + } + return UpdateParentNodeForBranch(node, ref_out_tensors); +} + +string Serial(const vector &dims) { + string serial_string; + serial_string += "["; + for (int64_t dim : dims) { + serial_string += std::to_string(dim) + " "; + } + serial_string += "]"; + return serial_string; +} + +graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) { + GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); + GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); + for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) { + auto in_idx = in_anchor->GetIdx(); + auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_data_anchor == nullptr) { + continue; + } + auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); + if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { + continue; + } + int peer_out_idx = peer_out_data_anchor->GetIdx(); + auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast(peer_out_idx)); + + // check shape and dtype continuity. do not stop process + auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast(in_idx)); + if (in_desc == nullptr) { + continue; + } + auto in_shape = in_desc->GetShape().GetDims(); + auto in_dtype = in_desc->GetDataType(); + auto peer_out_shape = peer_out_desc->GetShape().GetDims(); + auto peer_out_dtype = peer_out_desc->GetDataType(); + if (peer_out_dtype != in_dtype) { + GELOGW( + "current node [%s] [%d]\'th out_dtype is [%s].peer output node [%s] [%d]\'th " + "output_dtype is [%s].The two dtype should be same! Please check graph and fix it", + node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(), + peer_out_data_node->GetName().c_str(), peer_out_idx, TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str()); + } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) { + string in_shape_str = Serial(in_shape); + string peer_out_shape_str = Serial(peer_out_shape); + GELOGW( + "current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th " + "input_shape is [%s].The two shape should be same! Please check graph and fix it", + node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), peer_out_data_node->GetName().c_str(), peer_out_idx, + peer_out_shape_str.c_str()); + } + // refresh current node input desc + in_desc->SetOriginShape(peer_out_desc->GetOriginShape()); + in_desc->SetShape(peer_out_desc->GetShape()); + in_desc->SetDataType(peer_out_desc->GetDataType()); + in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType()); + std::vector> shape_range; + (void)peer_out_desc->GetShapeRange(shape_range); + in_desc->SetShapeRange(shape_range); + ge::TensorUtils::SetRealDimCnt(*in_desc, static_cast(peer_out_desc->GetShape().GetDims().size())); + } + return GRAPH_SUCCESS; +} +} // namespace +void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { + if (!IsLogEnable(GE, DLOG_DEBUG)) { + return; + } + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "node is null"); + return; + } + ge::OpDescPtr op_desc = node->GetOpDesc(); + GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); + std::string str; + if (op_desc->GetInputsSize() != 0) { + std::string input_desc_str = "input shape: "; + for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { + input_desc_str += "["; + for (int64_t dim : input_desc->GetShape().GetDims()) { + input_desc_str += std::to_string(dim) + " "; + } + input_desc_str += "]"; + input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) + ":" + + TypeUtils::FormatToSerialString(input_desc->GetFormat()) + " "; + } + str += input_desc_str; + + input_desc_str = "input origin shape: "; + for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { + input_desc_str += "["; + for (int64_t dim : input_desc->GetOriginShape().GetDims()) { + input_desc_str += std::to_string(dim) + " "; + } + input_desc_str += "]"; + input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) + ":" + + TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) + " "; + } + str += input_desc_str; + } + + if (op_desc->GetAllOutputsDescSize() != 0) { + std::string output_desc_str = "output shape: "; + for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { + if (output_desc == nullptr) { + continue; + } + output_desc_str += "["; + for (int64_t dim : output_desc->GetShape().GetDims()) { + output_desc_str += std::to_string(dim) + " "; + } + output_desc_str += "]"; + output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) + ":" + + TypeUtils::FormatToSerialString(output_desc->GetFormat()) + " "; + } + str += output_desc_str; + + output_desc_str = "output origin shape: "; + for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { + if (output_desc == nullptr) { + continue; + } + output_desc_str += "["; + for (int64_t dim : output_desc->GetOriginShape().GetDims()) { + output_desc_str += std::to_string(dim) + " "; + } + output_desc_str += "]"; + output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) + ":" + + TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) + " "; + } + str += output_desc_str; + } + GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), str.c_str()); +} + +graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) { + return InferShapeAndType(node, op, true); +} +graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) { + auto op_desc = node->GetOpDesc(); + const auto &op_type = op_desc->GetType(); + + graphStatus ret; + if (before_subgraph) { + ret = UpdateSubGraphDataNodes(node); + if (ret != GRAPH_SUCCESS) { + return ret; + } + } + // Get infer func and execute + ret = op_desc->CallInferFunc(op); + if (ret == GRAPH_PARAM_INVALID) { + // Op ir no infer func, try to get infer func from operator factory + auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); + if (node_op.IsEmpty()) { + GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); + return ret; + } + + GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); + auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); + node_op.BreakConnect(); + if (temp_op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "temp op desc is null"); + return GRAPH_FAILED; + } + if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { + GELOGW("InferShapeAndType UpdateInputName failed"); + for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { + if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { + break; + } + return GRAPH_SUCCESS; + } + } + if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { + GELOGW("InferShapeAndType UpdateOutputName failed"); + } + op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); + ret = op_desc->CallInferFunc(op); + GELOGI("op CallInferFunc second. ret: %u", ret); + } + if (ret != GRAPH_SUCCESS) { + return ret; + } + + if (!before_subgraph) { + return UpdateParentNodeOutTensor(node); + } + return GRAPH_SUCCESS; +} + +InferenceContextPtr CreateInferenceContext(const std::unordered_map &context_map, + const NodePtr &node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "node is null"); + return nullptr; + } + InferenceContextPtr inference_context = std::shared_ptr(InferenceContext::Create()); + if (inference_context == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to alloc InferenceContext"); + return nullptr; + } + + auto all_in_data_anchors = node->GetAllInDataAnchors(); + std::vector> input_shapes_and_types(all_in_data_anchors.size()); + std::vector marks; + + bool has_input_shapes_and_types = false; + for (const auto &in_anchor : all_in_data_anchors) { + const auto &out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + continue; + } + + auto input_node = out_anchor->GetOwnerNode(); + if (input_node == nullptr) { + continue; + } + + auto iter = context_map.find(input_node); + if (iter != context_map.end()) { + const auto &src_context = iter->second; + GE_IF_BOOL_EXEC(src_context == nullptr, GELOGE(GRAPH_FAILED, "src_context is null."); return nullptr); + GELOGD("node:%s get %ld marks from node:%s", node->GetName().c_str(), src_context->GetMarks().size(), + input_node->GetName().c_str()); + for (auto mark : src_context->GetMarks()) { + marks.push_back(mark); + } + auto output_idx = out_anchor->GetIdx(); + auto input_idx = in_anchor->GetIdx(); + auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes(); + if (output_idx < static_cast(output_shape_and_type.size())) { + GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx, + node->GetName().c_str(), input_idx); + input_shapes_and_types[input_idx] = output_shape_and_type[output_idx]; + has_input_shapes_and_types = true; + } else { + GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx, + output_shape_and_type.size()); + } + } + } + + if (has_input_shapes_and_types) { + inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types)); + } + inference_context->SetMarks(marks); + + return inference_context; +} + +namespace { +thread_local std::unordered_map context_map; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ShapeRefiner::ClearContextMap() { context_map.clear(); } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) { + return InferShapeAndType(node, true); +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node, + bool before_subgraph) { + GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); + bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); + auto opdesc = node->GetOpDesc(); + GE_IF_BOOL_EXEC(opdesc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); + // some op can not infershape twice such as aipp + bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified"); + if (need_update_input) { + auto status = UpdateOpInputDesc(node); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "update op input_desc failed!"); + return status; + } + } + + if (node->Verify() != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + PrintInOutTensorShape(node, "before_infershape"); + Operator op = OpDescUtils::CreateOperatorFromNode(node); + + if (!is_unknown_graph) { + auto inference_context = CreateInferenceContext(context_map, node); + if (inference_context == nullptr) { + GELOGE(GRAPH_FAILED, "inference context is null"); + return GRAPH_FAILED; + } + GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); + op.SetInferenceContext(inference_context); + } + + graphStatus status = InferShapeAndType(node, op, before_subgraph); + if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { + if (is_unknown_graph) { + PrintInOutTensorShape(node, "after_infershape when running"); + return GRAPH_SUCCESS; + } + auto op_desc = node->GetOpDesc(); + for (const auto &out_anchor : node->GetAllOutDataAnchors()) { + auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetShape().GetDims().size())); + output_tensor->SetOriginShape(output_tensor->GetShape()); + output_tensor->SetOriginDataType(output_tensor->GetDataType()); + + GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", + node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), + TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), + TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); + } + } else { + GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + if (!is_unknown_graph) { + auto ctx_after_infer = op.GetInferenceContext(); + if (ctx_after_infer != nullptr) { + GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); + if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { + GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), + ctx_after_infer->GetMarks().size()); + (void)context_map.emplace(node, ctx_after_infer); + } + } + } + PrintInOutTensorShape(node, "after_infershape"); + + return GRAPH_SUCCESS; +} +} // namespace ge diff --git a/src/common/graph/stub/Makefile b/src/common/graph/stub/Makefile new file mode 100644 index 00000000..f339fa33 --- /dev/null +++ b/src/common/graph/stub/Makefile @@ -0,0 +1,6 @@ +inc_path := $(shell pwd)/metadef/inc/external/ +out_path := $(shell pwd)/out/graph/lib64/stub/ +stub_path := $(shell pwd)/metadef/graph/stub/ + +mkdir_stub := $(shell mkdir -p $(out_path)) +graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) diff --git a/src/common/graph/stub/gen_stubapi.py b/src/common/graph/stub/gen_stubapi.py new file mode 100644 index 00000000..7263ff17 --- /dev/null +++ b/src/common/graph/stub/gen_stubapi.py @@ -0,0 +1,578 @@ +import os +import re +import sys +import logging + +logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s', + level=logging.INFO) + +""" + this attr is used for symbol table visible +""" +GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY' + +""" + generate stub func body by return type +""" +RETURN_STATEMENTS = { + 'graphStatus': ' std::cout << "[ERROR]: stub library libgraph or libge_compiler cannot be used for execution, please check your "\n ' + ' << "environment variables and compilation options to make sure you use the correct library."\n' + ' << std::endl;\n' + ' return ACL_ERROR_COMPILING_STUB_MODE;', + 'Status': ' return SUCCESS;', + 'Graph': ' return Graph();', + 'Graph&': ' return *this;', + 'Format': ' return Format();', + 'Format&': ' return *this;', + 'Shape': ' return Shape();', + 'Shape&': ' return *this;', + 'TensorDesc': ' return TensorDesc();', + 'TensorDesc&': ' return *this;', + 'Tensor': ' return Tensor();', + 'Tensor&': ' return *this;', + 'Operator': ' return Operator();', + 'Operator&': ' return *this;', + 'Ptr': ' return nullptr;', + 'std::string': ' return "";', + 'std::string&': ' return "";', + 'string': ' return "";', + 'int': ' return 0;', + 'DataType': ' return DT_FLOAT;', + 'InferenceContextPtr': ' return nullptr;', + 'SubgraphBuilder': ' return nullptr;', + 'OperatorImplPtr': ' return nullptr;', + 'OutHandler': ' return nullptr;', + 'std::vector': ' return {};', + 'std::vector': ' return {};', + 'std::map': ' return {};', + 'uint32_t': ' return 0;', + 'int64_t': ' return 0;', + 'uint64_t': ' return 0;', + 'size_t': ' return 0;', + 'float': ' return 0.0f;', + 'bool': ' return false;', +} + +""" + max code len per line in hua_wei software programming specifications +""" +max_code_len_per_line = 100 + +""" + white_list_for_debug, include_dir_key_words is to + determines which header files to generate cc files from + when DEBUG on +""" +white_list_for_debug = ["attr_value.h", "operator.h", "tensor.h", "graph.h", "operator_factory.h", "inference_context.h", + "ge_ir_build.h", "ge_api.h", "ascend_string.h", "gnode.h"] +include_dir_key_words = ["ge", "graph"] +DEBUG = True + + +def need_generate_func(func_line): + """ + :param func_line: + :return: + """ + if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \ + or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"): + return False + return True + + +def file_endswith_white_list_suffix(file): + """ + :param file: + :return: + """ + if DEBUG: + for suffix in white_list_for_debug: + if file.endswith(suffix): + return True + return False + else: + return True + + +""" + belows are patterns used for analyse .h file +""" +# pattern function +pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after +([a-zA-Z~_] # void int likely +.* +[)] #we find ) +(?!.*{) # we do not want the case int abc() const +.*) +(;.*) #we want to find ; and after for we will replace these later +\n$ +""", re.VERBOSE | re.MULTILINE | re.DOTALL) + +# pattern comment +pattern_comment = re.compile(r'^\s*//') +pattern_comment_2_start = re.compile(r'^\s*/[*]') +pattern_comment_2_end = re.compile(r'[*]/\s*$') +# pattern define +pattern_define = re.compile(r'^\s*#define') +pattern_define_return = re.compile(r'\\\s*$') +# blank line +pattern_blank_line = re.compile(r'^\s*$') +# virtual,explicit,friend,static +pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)') +# lead space +pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]') +# functions will have patterns such as func ( or func( +# but operator is an exception; the class name is preceded by an operator, and the above mode does not exist +# format like :"operator = ()" +pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]') +# template +pattern_template = re.compile(r'^\s*template') +pattern_template_end = re.compile(r'>\s*$') +# namespace +pattern_namespace = re.compile(r'namespace.*{') +# class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with +pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+ 0 and not friend_match: + line, func_name = self.handle_class_member_func(line, template_string) + # Normal functions + else: + line, func_name = self.handle_normal_func(line, template_string) + + need_generate = need_generate_func(line) + # func body + line += self.implement_function(line) + # comment + line = self.gen_comment(start_i) + line + # write to out file + self.write_func_content(line, func_name, need_generate) + # next loop + self.line_index += 1 + + logging.info('Added %s functions', len(self.func_list_exist)) + logging.info('Successfully converted,please see ' + self.output_file) + + def handle_func1(self, line): + """ + :param line: + :return: + """ + find1 = re.search('[(]', line) + if not find1: + self.line_index += 1 + return "continue", line, None + find2 = re.search('[)]', line) + start_i = self.line_index + space_match = pattern_leading_space.search(line) + # deal with + # int abc(int a, + # int b) + if find1 and (not find2): + self.line_index += 1 + line2 = self.input_content[self.line_index] + if space_match: + line2 = re.sub('^' + space_match.group(1), '', line2) + line += line2 + while self.line_index < len(self.input_content) and (not re.search('[)]', line2)): + self.line_index += 1 + line2 = self.input_content[self.line_index] + line2 = re.sub('^' + space_match.group(1), '', line2) + line += line2 + + match_start = pattern_start.search(self.input_content[self.line_index]) + match_end = pattern_end.search(self.input_content[self.line_index]) + if match_start: # like ) { or ) {} int the last line + if not match_end: + self.stack.append('normal_now') + ii = start_i + while ii <= self.line_index: + ii += 1 + self.line_index += 1 + return "continue", line, start_i + logging.info("line[%s]", line) + # ' int abc();'->'int abc()' + (line, match) = pattern_func.subn(r'\2\n', line) + logging.info("line[%s]", line) + # deal with case: + # 'int \n abc(int a, int b)' + if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]): + line = self.input_content[start_i - 1] + line + line = line.lstrip() + if not match: + self.line_index += 1 + return "continue", line, start_i + return "pass", line, start_i + + def handle_stack(self, match_start): + """ + :param match_start: + :return: + """ + line = self.input_content[self.line_index] + match_end = pattern_end.search(line) + if match_start: + self.stack.append('normal_now') + if match_end: + top_status = self.stack.pop() + if top_status == 'namespace_now': + self.output_fd.write(line + '\n') + elif top_status == 'class_now': + self.stack_class.pop() + self.stack_template.pop() + if match_start or match_end: + self.line_index += 1 + return "continue" + + if len(self.stack) > 0 and self.stack[-1] == 'normal_now': + self.line_index += 1 + return "continue" + return "pass" + + def handle_class(self, template_string, line, match_start, match_class): + """ + :param template_string: + :param line: + :param match_start: + :param match_class: + :return: + """ + if match_class: # we face a class + self.stack_template.append(template_string) + self.stack.append('class_now') + class_name = match_class.group(3) + + # class template specializations: class A > + if '<' in class_name: + k = line.index('<') + fit = 1 + for ii in range(k + 1, len(line)): + if line[ii] == '<': + fit += 1 + if line[ii] == '>': + fit -= 1 + if fit == 0: + break + class_name += line[k + 1:ii + 1] + logging.info('class_name[%s]', class_name) + self.stack_class.append(class_name) + while not match_start: + self.line_index += 1 + line = self.input_content[self.line_index] + match_start = pattern_start.search(line) + self.line_index += 1 + return "continue" + return "pass" + + def handle_template(self): + line = self.input_content[self.line_index] + match_template = pattern_template.search(line) + template_string = '' + if match_template: + match_template_end = pattern_template_end.search(line) + template_string = line + while not match_template_end: + self.line_index += 1 + line = self.input_content[self.line_index] + template_string += line + match_template_end = pattern_template_end.search(line) + self.line_index += 1 + return template_string + + def handle_namespace(self): + line = self.input_content[self.line_index] + match_namespace = pattern_namespace.search(line) + if match_namespace: # we face namespace + self.output_fd.write(line + '\n') + self.stack.append('namespace_now') + self.line_index += 1 + + def handle_normal_func(self, line, template_string): + template_line = '' + self.stack_template.append(template_string) + if self.stack_template[-1] != '': + template_line = re.sub(r'\s*template', 'template', self.stack_template[-1]) + # change '< class T = a, class U = A(3)>' to '' + template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) + template_line = re.sub(r'\s*=.*,', ',', template_line) + template_line = re.sub(r'\s*=.*', '', template_line) + line = re.sub(r'\s*=.*,', ',', line) + line = re.sub(r'\s*=.*\)', ')', line) + line = template_line + line + self.stack_template.pop() + func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() + logging.info("line[%s]", line) + logging.info("func_name[%s]", func_name) + return line, func_name + + def handle_class_member_func(self, line, template_string): + template_line = '' + x = '' + if template_string != '': + template_string = re.sub(r'\s*template', 'template', template_string) + template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string) + template_string = re.sub(r'\s*=.*,', ',', template_string) + template_string = re.sub(r'\s*=.*', '', template_string) + if self.stack_template[-1] != '': + if not (re.search(r'<\s*>', stack_template[-1])): + template_line = re.sub(r'^\s*template', 'template', stack_template[-1]) + if not (re.search(r'<.*>', self.stack_class[-1])): + # for x we get like template -> + x = re.sub(r'template\s*<', '<', template_line) # remove template -> + x = re.sub(r'\n', '', x) + x = re.sub(r'\s*=.*,', ',', x) + x = re.sub(r'\s*=.*\>', '>', x) + x = x.rstrip() # remove \n + x = re.sub(r'(class|typename)\s+|(|\s*class)', '', + x) # remove class,typename -> + x = re.sub(r'<\s+', '<', x) + x = re.sub(r'\s+>', '>', x) + x = re.sub(r'\s+,', ',', x) + x = re.sub(r',\s+', ', ', x) + line = re.sub(r'\s*=\s+0', '', line) + line = re.sub(r'\s*=\s+.*,', ',', line) + line = re.sub(r'\s*=\s+.*\)', ')', line) + logging.info("x[%s]\nline[%s]", x, line) + # if the function is long, void ABC::foo() + # breaks into two lines void ABC::\n foo() + temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1) + if len(temp_line) > max_code_len_per_line: + line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1) + else: + line = temp_line + logging.info("line[%s]", line) + # add template as the above if there is one + template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) + template_line = re.sub(r'\s*=.*,', ',', template_line) + template_line = re.sub(r'\s*=.*', '', template_line) + line = template_line + template_string + line + func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() + logging.info("line[%s]", line) + logging.info("func_name[%s]", func_name) + return line, func_name + + def write_func_content(self, content, func_name, need_generate): + if not (func_name in self.func_list_exist) and need_generate: + self.output_fd.write(content) + self.func_list_exist.append(func_name) + logging.info('add func:[%s]', func_name) + + def gen_comment(self, start_i): + comment_line = '' + # Function comments are on top of function declarations, copy them over + k = start_i - 1 # one line before this func start + if pattern_template.search(self.input_content[k]): + k -= 1 + if pattern_comment_2_end.search(self.input_content[k]): + comment_line = self.input_content[k].lstrip() + while not pattern_comment_2_start.search(self.input_content[k]): + k -= 1 + comment_line = self.input_content[k].lstrip() + comment_line + else: + for j in range(k, 0, -1): + c_line = self.input_content[j] + if pattern_comment.search(c_line): + c_line = re.sub(r'\s*//', '//', c_line) + comment_line = c_line + comment_line + else: + break + return comment_line + + @staticmethod + def implement_function(func): + function_def = '' + function_def += '{\n' + + all_items = func.split() + start = 0 + return_type = all_items[start] + if return_type == "const": + start += 1 + return_type = all_items[start] + if return_type.startswith(('std::map', 'std::set', 'std::vector')): + return_type = "std::map" + if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): + return_type = "Ptr" + if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): + return_type += "&" + if RETURN_STATEMENTS.__contains__(return_type): + function_def += RETURN_STATEMENTS[return_type] + else: + logging.warning("Unhandled return type[%s]", return_type) + + function_def += '\n' + function_def += '}\n' + function_def += '\n' + return function_def + + +def collect_header_files(path): + """ + :param path: + :return: + """ + header_files = [] + shared_includes_content = [] + for root, dirs, files in os.walk(path): + files.sort() + for file in files: + if file.find("git") >= 0: + continue + if not file.endswith('.h'): + continue + file_path = os.path.join(root, file) + file_path = file_path.replace('\\', '/') + header_files.append(file_path) + include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:]) + shared_includes_content.append(include_str) + # for acl error code + shared_includes_content.append('#include \n') + shared_includes_content.append('const int ACL_ERROR_COMPILING_STUB_MODE = 100039;\n') + return header_files, shared_includes_content + + +def generate_stub_file(inc_dir, out_cc_dir): + """ + :param inc_dir: + :param out_cc_dir: + :return: + """ + target_header_files, shared_includes_content = collect_header_files(inc_dir) + for header_file in target_header_files: + if not file_endswith_white_list_suffix(header_file): + continue + cc_file = re.sub('.h*$', '.cc', header_file) + h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content) + h_2_cc.h2cc() + + +def gen_code(inc_dir, out_cc_dir): + """ + :param inc_dir: + :param out_cc_dir: + :return: + """ + if not inc_dir.endswith('/'): + inc_dir += '/' + if not out_cc_dir.endswith('/'): + out_cc_dir += '/' + for include_dir_key_word in include_dir_key_words: + generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir) + + +if __name__ == '__main__': + inc_dir = sys.argv[1] + out_cc_dir = sys.argv[2] + gen_code(inc_dir, out_cc_dir) diff --git a/src/common/graph/tensor.cc b/src/common/graph/tensor.cc new file mode 100644 index 00000000..1f30c876 --- /dev/null +++ b/src/common/graph/tensor.cc @@ -0,0 +1,704 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "external/graph/tensor.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/ge_tensor.h" +#include "securec.h" +#include "utils/attr_utils.h" +#include "utils/tensor_adapter.h" +#include "utils/tensor_utils.h" +#include "utils/type_utils.h" + +namespace { +/// Extra 8 bytes store pointer of string +/// Extra 1 byte store '\0' +const int EXTRA_STORE_POINTER_FOR_STRING = 8; +const int EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL = 9; +const int64_t UNKNOWN_DIM_SIZE = -1; +} // namespace + +namespace ge { +// If not overflow return true +static bool Int64MulNotOverflow(int64_t a, int64_t b) { + if (a > 0) { + if (b > 0) { + if (a > (INT64_MAX / b)) { + return false; + } + } else { + if (b < (INT64_MIN / a)) { + return false; + } + } + } else { + if (b > 0) { + if (a < (INT64_MIN / b)) { + return false; + } + } else { + if ((a != 0) && (b < (INT64_MAX / a))) { + return false; + } + } + } + return true; +} + +class TensorDescImpl { + public: + TensorDescImpl() = default; + ~TensorDescImpl() = default; + TensorDescImpl(const Shape &shape, Format format, DataType dt) : shape_(shape), format_(format), data_type_(dt) {} + + Shape shape_; + std::vector> range_; + Format format_ = FORMAT_ND; + Format origin_format_ = FORMAT_ND; + DataType data_type_ = DT_FLOAT; + Shape origin_shape_; + int64_t size_ = 0; + int64_t real_dim_cnt_ = 0; + std::string name_; +}; + +class TensorImpl { + public: + TensorImpl() = default; + ~TensorImpl() = default; + + explicit TensorImpl(const TensorDesc &tensor_desc) : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)) {} + TensorImpl(const TensorDesc &tensor_desc, const std::vector &data) + : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), data) {} + TensorImpl(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) + : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), data, size) {} + TensorImpl(TensorDesc &&tensor_desc, std::vector &&data) + : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), std::move(data)) {} + + GeTensor ge_tensor; +}; + +class ShapeImpl { + public: + ShapeImpl() = default; + ~ShapeImpl() = default; + explicit ShapeImpl(const std::vector &dims) { + bool is_unknown_dim_num = false; + for (const auto &dim : dims) { + if (dim == UNKNOWN_DIM_NUM) { + is_unknown_dim_num = true; + break; + } + } + dims_ = is_unknown_dim_num ? std::vector({UNKNOWN_DIM_NUM}) : dims; + } + + std::vector dims_; +}; + +Shape::Shape() { impl_ = ComGraphMakeShared(); } + +Shape::Shape(const std::vector &dims) { impl_ = ComGraphMakeShared(dims); } + +size_t Shape::GetDimNum() const { + if (impl_ != nullptr) { + for (auto i : impl_->dims_) { + if (i == UNKNOWN_DIM_NUM) { + return 0; + } + } + return impl_->dims_.size(); + } + return 0; +} + +int64_t Shape::GetDim(size_t idx) const { + if (impl_ != nullptr) { + if (idx >= impl_->dims_.size()) { + return 0; + } + return impl_->dims_[idx]; + } + return 0; +} + +graphStatus Shape::SetDim(size_t idx, int64_t value) { + if (impl_ != nullptr) { + if (idx >= impl_->dims_.size()) { + return GRAPH_FAILED; + } + impl_->dims_[idx] = value; + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +std::vector Shape::GetDims() const { + vector dims; + if (impl_ != nullptr) { + return impl_->dims_; + } + return dims; +} + +int64_t Shape::GetShapeSize() const { + if (impl_ != nullptr) { + if (impl_->dims_.empty()) { + return 0; + } + int64_t size = 1; + for (auto i : impl_->dims_) { + if (i == UNKNOWN_DIM_NUM || i == UNKNOWN_DIM) { + return UNKNOWN_DIM_SIZE; + } + + if (!Int64MulNotOverflow(size, i)) { + GELOGE(GRAPH_FAILED, "mul overflow: %ld, %ld", size, i); + size = 0; + return size; + } + size *= i; + } + return size; + } + return 0; +} + +TensorDesc::TensorDesc() { + impl = ComGraphMakeShared(); // lint !e665 +} + +TensorDesc::TensorDesc(Shape shape, Format format, DataType dt) { + impl = ComGraphMakeShared(shape, format, dt); // lint !e665 + SetRealDimCnt(shape.GetDimNum()); +} + +TensorDesc::TensorDesc(const TensorDesc &desc) { + // Copy + impl = ComGraphMakeShared(); // lint !e665 + if (desc.impl != nullptr && impl != nullptr) { + *impl = *desc.impl; + } +} + +TensorDesc::TensorDesc(TensorDesc &&desc) { + // Move + impl = std::move(desc.impl); +} + +TensorDesc &TensorDesc::operator=(const TensorDesc &desc) { + // Copy + if (&desc != this) { + impl = ComGraphMakeShared(); + if (desc.impl != nullptr && impl != nullptr) { + *impl = *desc.impl; + } + } + return *this; +} + +TensorDesc &TensorDesc::operator=(TensorDesc &&desc) { + if (&desc != this) { + impl = std::move(desc.impl); + } + return *this; +} + +void TensorDesc::Update(const Shape &shape, Format format, DataType dt) { + if (impl != nullptr) { + impl->shape_ = shape; + impl->format_ = format; + impl->data_type_ = dt; + } +} + +Shape TensorDesc::GetShape() const { + if (impl != nullptr) { + return impl->shape_; + } + return Shape(); +} + +void TensorDesc::SetShape(const Shape &shape) { + if (impl != nullptr) { + impl->shape_ = shape; + } +} + +// set shape with -2, it stand for unknown shape +graphStatus TensorDesc::SetUnknownDimNumShape() { + if (impl != nullptr) { + impl->shape_ = Shape({UNKNOWN_DIM_NUM}); + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Set unknown shape failed,because no impl class!"); + return GRAPH_FAILED; +} + +// for unknown shape +graphStatus TensorDesc::SetShapeRange(const std::vector> &range) { + if (impl != nullptr) { + impl->range_ = range; + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "SetShapeRange failed!impl is nullptr!"); + return GRAPH_FAILED; +} +graphStatus TensorDesc::GetShapeRange(std::vector> &range) const { + if (impl != nullptr) { + range = impl->range_; + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "impl is nullptr!"); + return GRAPH_FAILED; +} + +Shape TensorDesc::GetOriginShape() const { + if (impl != nullptr) { + return impl->origin_shape_; + } + return Shape(); +} + +void TensorDesc::SetOriginShape(const Shape &origin_shape) { + if (impl != nullptr) { + impl->origin_shape_ = origin_shape; + } +} + +Format TensorDesc::GetFormat() const { + if (impl != nullptr) { + return impl->format_; + } + return FORMAT_RESERVED; +} + +void TensorDesc::SetFormat(Format format) { + if (impl != nullptr) { + impl->format_ = format; + } +} + +Format TensorDesc::GetOriginFormat() const { + if (impl != nullptr) { + return impl->origin_format_; + } + return FORMAT_RESERVED; +} + +void TensorDesc::SetOriginFormat(Format origin_format) { + if (impl != nullptr) { + impl->origin_format_ = origin_format; + } +} + +DataType TensorDesc::GetDataType() const { + if (impl != nullptr) { + return impl->data_type_; + } + return DT_UNDEFINED; +} + +void TensorDesc::SetDataType(DataType dt) { + if (impl != nullptr) { + impl->data_type_ = dt; + } +} + +void TensorDesc::SetSize(int64_t size) { + if (impl != nullptr) { + impl->size_ = size; + } +} + +int64_t TensorDesc::GetSize() const { + if (impl != nullptr) { + return impl->size_; + } + return 0; +} + +void TensorDesc::SetRealDimCnt(const int64_t real_dim_cnt) { + if (impl != nullptr) { + impl->real_dim_cnt_ = real_dim_cnt; + } +} + +int64_t TensorDesc::GetRealDimCnt() const { + if (impl != nullptr) { + return impl->real_dim_cnt_; + } + return 0; +} + +std::string TensorDesc::GetName() const { + if (impl != nullptr) { + return impl->name_; + } + return ""; +} + +void TensorDesc::SetName(const std::string &name) { + if (impl != nullptr) { + impl->name_ = name; + } +} + +Tensor::Tensor() { impl = ComGraphMakeShared(); } + +Tensor::Tensor(const TensorDesc &tensor_desc) { + impl = ComGraphMakeShared(tensor_desc); // lint !e665 +} + +Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector &data) { + uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); + DataType data_type = tensor_desc.GetDataType(); + uint32_t type_length; + bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); + if (!ret) { + GELOGW("datatype %d is not found.", data_type); + } + + auto data_size = data.size(); + if (ret && (shape_size || (data_size != type_length))) { + if (type_length != 0 && UINT64_MAX / type_length < shape_size) { + GELOGW("mul overflow: %lu, %u", shape_size, type_length); + } else { + if (shape_size * type_length != data_size) { + GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, + data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); + } + } + } + impl = ComGraphMakeShared(tensor_desc, data); // lint !e665 +} + +Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) { + uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); + DataType data_type = tensor_desc.GetDataType(); + uint32_t type_length; + bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); + if (!ret) { + GELOGW("datatype %d is not found.", data_type); + } + if (ret && (shape_size || (size != type_length))) { + if (type_length != 0 && UINT64_MAX / type_length < shape_size) { + GELOGW("mul overflow: %lu, %u", shape_size, type_length); + } else { + if (shape_size * type_length != size) { + GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, + size, TypeUtils::DataTypeToSerialString(data_type).c_str()); + } + } + } + + impl = ComGraphMakeShared(tensor_desc, data, size); // lint !e665 +} + +Tensor::Tensor(TensorDesc &&tensor_desc, std::vector &&data) { + uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); + DataType data_type = tensor_desc.GetDataType(); + uint32_t type_length; + bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); + if (!ret) { + GELOGW("datatype %d is not found.", data_type); + } + + auto data_size = data.size(); + if (ret && (shape_size || (data_size != type_length))) { + if (type_length != 0 && UINT64_MAX / type_length < shape_size) { + GELOGW("mul overflow: %lu, %u", shape_size, type_length); + } else { + if (shape_size * type_length != data_size) { + GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, + data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); + } + } + } + impl = ComGraphMakeShared(std::move(tensor_desc), std::move(data)); // lint !e665 +} + +TensorDesc Tensor::GetTensorDesc() const { + if (impl != nullptr) { + return TensorAdapter::GeTensorDesc2TensorDesc(impl->ge_tensor.MutableTensorDesc()); + } + return TensorDesc(); +} + +graphStatus Tensor::SetTensorDesc(const TensorDesc &tensor_desc) { + if (impl != nullptr) { + impl->ge_tensor.SetTensorDesc(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +const uint8_t *Tensor::GetData() const { + if (impl != nullptr) { + return impl->ge_tensor.GetData().data(); + } + return nullptr; +} + +uint8_t *Tensor::GetData() { + if (impl != nullptr) { + return impl->ge_tensor.MutableData().data(); + } + return nullptr; +} + +size_t Tensor::GetSize() const { + if (impl != nullptr) { + return impl->ge_tensor.GetData().size(); + } + return 0; +} + +graphStatus Tensor::SetData(std::vector &&data) { + if (impl != nullptr) { + (void)impl->ge_tensor.SetData(data); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +graphStatus Tensor::SetData(const std::vector &data) { + if (impl != nullptr) { + (void)impl->ge_tensor.SetData(data); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +graphStatus Tensor::SetData(const uint8_t *data, size_t size) { + if (impl != nullptr) { + (void)impl->ge_tensor.SetData(data, size); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +graphStatus Tensor::SetData(const std::string &data) { + if (impl != nullptr && (!data.empty())) { + /// Extra 8 bytes store pointer of string + /// Extra 1 byte store '\0' + size_t total_size = data.size() + EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL; + std::unique_ptr buff(new (std::nothrow) char[total_size]()); + if (buff == nullptr) { + GELOGE(GRAPH_FAILED, "allocate string raw data buff failed"); + return GRAPH_FAILED; + } + uint64_t *p = reinterpret_cast(buff.get()); + // Front 8 bytes store pointer of string + char *raw_data = buff.get() + EXTRA_STORE_POINTER_FOR_STRING; + p[0] = reinterpret_cast(raw_data); + int32_t memcpy_ret = memcpy_s(raw_data, total_size - EXTRA_STORE_POINTER_FOR_STRING, data.c_str(), data.size() + 1); + GE_CHK_BOOL_RET_STATUS(memcpy_ret == EOK, GRAPH_FAILED, "copy data failed"); + (void)impl->ge_tensor.SetData(reinterpret_cast(buff.get()), total_size); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +graphStatus Tensor::SetData(const std::vector &data) { + if (impl != nullptr) { + if (data.empty()) { + GELOGE(GRAPH_FAILED, "there is no data, please check the input variable"); + return GRAPH_FAILED; + } + size_t total_size = 0; + for (auto str : data) { + /// Extra 8 bytes store pointer of each string + /// Extra 1 byte store '\0' + total_size += (str.size() + EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL); + } + std::unique_ptr buff(new (std::nothrow) char[total_size]); + if (buff == nullptr) { + GELOGE(GRAPH_FAILED, "allocate string raw data buff failed"); + return GRAPH_FAILED; + } + uint64_t *p = reinterpret_cast(buff.get()); + // Front some bytes store pointer of each string + char *raw_data = buff.get() + data.size() * sizeof(uint64_t); + uint64_t ptr_size = data.size() * sizeof(uint64_t); + for (size_t i = 0; i < data.size(); ++i) { + p[i] = reinterpret_cast(raw_data); + if (total_size < ptr_size) { + GELOGE(GRAPH_FAILED, "Subtraction invalid, total_size: %zu, ptr_size: %lu", total_size, ptr_size); + return GRAPH_FAILED; + } + int32_t memcpy_ret = memcpy_s(raw_data, total_size - ptr_size, data[i].c_str(), data[i].size() + 1); + GE_CHK_BOOL_RET_STATUS(memcpy_ret == EOK, GRAPH_FAILED, "copy data failed"); + raw_data += (data[i].size() + 1); + ptr_size += (data[i].size() + 1); + } + + (void)impl->ge_tensor.SetData(reinterpret_cast(buff.get()), total_size); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +graphStatus Tensor::IsValid() { + uint64_t shape_size = GetTensorDesc().GetShape().GetShapeSize(); + DataType data_type = GetTensorDesc().GetDataType(); + uint32_t type_length; + bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); + if (!ret) { + GELOGW("datatype %d is not found.", data_type); + return GRAPH_SUCCESS; + } + + size_t data_size = GetSize(); + if (data_type != DT_STRING) { + if (shape_size || (data_size != type_length)) { + if (type_length != 0 && UINT64_MAX / type_length < shape_size) { + GELOGW("mul overflow: %lu, %u", shape_size, type_length); + } else { + if (shape_size * type_length != data_size) { + GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, + data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); + return GRAPH_FAILED; + } + } + } + } + + return GRAPH_SUCCESS; +} + +Tensor Tensor::Clone() const { + Tensor tensor; + if (impl != nullptr && tensor.impl != nullptr) { + tensor.impl->ge_tensor = impl->ge_tensor.Clone(); + } + return tensor; +} + +GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_desc) { + GeTensorDesc ge_tensor_desc(GeShape(tensor_desc.GetShape().GetDims()), tensor_desc.GetFormat(), + tensor_desc.GetDataType()); + ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); + ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); + ge_tensor_desc.SetName(tensor_desc.GetName()); + std::vector> shape_range; + auto status = tensor_desc.GetShapeRange(shape_range); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get shape range failed!"); + return ge_tensor_desc; + } + status = ge_tensor_desc.SetShapeRange(shape_range); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set shape range failed!"); + return ge_tensor_desc; + } + auto size = tensor_desc.GetSize(); + TensorUtils::SetSize(ge_tensor_desc, size); + + auto real_dim_cnt = static_cast(tensor_desc.GetRealDimCnt()); + TensorUtils::SetRealDimCnt(ge_tensor_desc, real_dim_cnt); + return ge_tensor_desc; +} + +TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_desc) { + TensorDesc tensor_desc(Shape(ge_tensor_desc.GetShape().GetDims()), ge_tensor_desc.GetFormat(), + ge_tensor_desc.GetDataType()); + tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); + tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); + tensor_desc.SetName(ge_tensor_desc.GetName()); + std::vector> shape_range; + auto status = ge_tensor_desc.GetShapeRange(shape_range); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get shape range failed!"); + return tensor_desc; + } + status = tensor_desc.SetShapeRange(shape_range); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set shape range failed!"); + return tensor_desc; + } + int64_t size = 0; + (void)TensorUtils::GetSize(ge_tensor_desc, size); + tensor_desc.SetSize(size); + + uint32_t real_dim_cnt = 0; + (void)TensorUtils::GetRealDimCnt(ge_tensor_desc, real_dim_cnt); + tensor_desc.SetRealDimCnt(real_dim_cnt); + return tensor_desc; +} + +GeTensorPtr TensorAdapter::Tensor2GeTensor(const Tensor &tensor) { + GeTensorPtr ge_tensor; + if (tensor.impl != nullptr) { + ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor.Clone()); // lint !e665 + } + return ge_tensor; +} + +Tensor TensorAdapter::GeTensor2Tensor(const ConstGeTensorPtr &ge_tensor) { + Tensor tensor; + if (ge_tensor != nullptr && tensor.impl != nullptr) { + tensor.impl->ge_tensor = ge_tensor->Clone(); + } + return tensor; +} + +ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) { + GeTensorPtr ge_tensor; + if (tensor.impl != nullptr) { + ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor); // lint !e665 + } + return ge_tensor; +} + +GeTensorPtr TensorAdapter::AsGeTensorPtr(Tensor &tensor) { + GeTensorPtr ge_tensor; + if (tensor.impl != nullptr) { + ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor); // lint !e665 + } + return ge_tensor; +} + +const GeTensor TensorAdapter::AsGeTensor(const Tensor &tensor) { + if (tensor.impl != nullptr) { + return tensor.impl->ge_tensor; + } + return GeTensor(); +} + +GeTensor TensorAdapter::AsGeTensor(Tensor &tensor) { + if (tensor.impl != nullptr) { + return tensor.impl->ge_tensor; + } + return GeTensor(); +} + +const Tensor TensorAdapter::AsTensor(const GeTensor &ge_tensor) { + Tensor tensor; + if (tensor.impl != nullptr) { + tensor.impl->ge_tensor = ge_tensor; + } + return tensor; +} + +Tensor TensorAdapter::AsTensor(GeTensor &ge_tensor) { + Tensor tensor; + if (tensor.impl != nullptr) { + tensor.impl->ge_tensor = ge_tensor; + } + return tensor; +} +} // namespace ge diff --git a/src/common/graph/utils/anchor_utils.cc b/src/common/graph/utils/anchor_utils.cc new file mode 100644 index 00000000..5a042283 --- /dev/null +++ b/src/common/graph/utils/anchor_utils.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/anchor_utils.h" +#include +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Format AnchorUtils::GetFormat(const DataAnchorPtr &data_anchor) { + if (data_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "The input data anchor is invalid."); + return FORMAT_RESERVED; + } + return data_anchor->format_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus AnchorUtils::SetFormat(const DataAnchorPtr &data_anchor, + Format data_format) { + if ((data_anchor == nullptr) || (data_format == FORMAT_RESERVED)) { + GELOGE(GRAPH_FAILED, "The input data anchor or input data format is invalid ."); + return GRAPH_FAILED; + } + data_anchor->format_ = data_format; + return GRAPH_SUCCESS; +} + +// Get anchor status +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorStatus AnchorUtils::GetStatus(const DataAnchorPtr &data_anchor) { + if (data_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "The input data anchor is invalid."); + return ANCHOR_RESERVED; + } + return data_anchor->status_; +} + +// Set anchor status +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus AnchorUtils::SetStatus(const DataAnchorPtr &data_anchor, + AnchorStatus anchor_status) { + if ((data_anchor == nullptr) || (anchor_status == ANCHOR_RESERVED)) { + GELOGE(GRAPH_FAILED, "The input data anchor or input data format is invalid ."); + return GRAPH_FAILED; + } + data_anchor->status_ = anchor_status; + return GRAPH_SUCCESS; +} + +bool AnchorUtils::HasControlEdge(const AnchorPtr &anchor) { + auto control_anchor = Anchor::DynamicAnchorCast(anchor); + if (control_anchor != nullptr) { + return (control_anchor->GetPeerAnchors().size() != 0); + } + + auto data_anchor = Anchor::DynamicAnchorCast(anchor); + if (data_anchor) { + for (const auto &peer : data_anchor->GetPeerAnchors()) { + auto peer_cast = Anchor::DynamicAnchorCast(peer); + if (peer_cast) { + return true; + } + } + return false; + } + GELOGE(GRAPH_FAILED, "the anchor is neither control anchor nor data anchor"); + return false; +} + +bool AnchorUtils::IsControlEdge(const AnchorPtr &src, const AnchorPtr &dst) { + GE_CHK_BOOL_EXEC(src != nullptr, return false, "src is null."); + GE_CHK_BOOL_RET_STATUS_NOLOG(src->IsLinkedWith(dst), false); + auto src_control_anchor = Anchor::DynamicAnchorCast(src); + auto dst_control_anchor = Anchor::DynamicAnchorCast(dst); + return (src_control_anchor || dst_control_anchor); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int AnchorUtils::GetIdx(const AnchorPtr &anchor) { + // Check if it can add edge between DataAnchor + auto data_anchor = Anchor::DynamicAnchorCast(anchor); + if (data_anchor != nullptr) { + return data_anchor->GetIdx(); + } + // Check if it can add edge between ControlAnchor + auto control_anchor = Anchor::DynamicAnchorCast(anchor); + if (control_anchor != nullptr) { + return control_anchor->GetIdx(); + } + return -1; +} +} // namespace ge diff --git a/src/common/graph/utils/ge_ir_utils.cc b/src/common/graph/utils/ge_ir_utils.cc new file mode 100644 index 00000000..f238c6e8 --- /dev/null +++ b/src/common/graph/utils/ge_ir_utils.cc @@ -0,0 +1,1178 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/utils/ge_ir_utils.h" +#include +#include "framework/common/debug/ge_log.h" + +namespace { +const char *const kControlAnchorIndex = ":-1"; +const char *const kNodeTypeForSubgraph = "subgraph"; +const char *const kPrefixForInputDesc = "input_desc_attr_"; +const char *const kPrefixForOutputDesc = "output_desc_attr_"; +const char *const kDumpGEGraph = "DUMP_GE_GRAPH"; +const int8_t kMaxRecursionDepth = 10; +const char *const kDumpGeGraph = std::getenv(kDumpGEGraph); +const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, nullptr, 10) : ge::OnnxUtils::NO_DUMP; +const int64_t kInputPrefixLength = 5; +const int64_t kOutputPrefixLength = 6; +using AttrDefPair = ::google::protobuf::MapPair; +} // namespace + +namespace ge { +// Part 1: from IR convert to ONNX Protobuf +static const std::map kGeDataTypeToOnnxMap = { + {DT_INT64, onnx::TensorProto_DataType_INT64}, {DT_UINT64, onnx::TensorProto_DataType_UINT64}, + {DT_FLOAT, onnx::TensorProto_DataType_FLOAT}, {DT_INT32, onnx::TensorProto_DataType_INT32}, + {DT_UINT32, onnx::TensorProto_DataType_UINT32}, {DT_INT8, onnx::TensorProto_DataType_INT8}, + {DT_UINT8, onnx::TensorProto_DataType_UINT8}, {DT_INT16, onnx::TensorProto_DataType_INT16}, + {DT_UINT16, onnx::TensorProto_DataType_UINT16}, {DT_FLOAT16, onnx::TensorProto_DataType_FLOAT16}, + {DT_DOUBLE, onnx::TensorProto_DataType_DOUBLE}, {DT_BOOL, onnx::TensorProto_DataType_BOOL}, +}; + +onnx::TensorProto_DataType OnnxUtils::EncodeDataType(DataType data_type) { + auto it = kGeDataTypeToOnnxMap.find(data_type); + if (it != kGeDataTypeToOnnxMap.end()) { + return it->second; + } else { + GELOGW("EncodeDataType: datatype not support %u", data_type); + return onnx::TensorProto_DataType_UNDEFINED; + } +} + +void OnnxUtils::AddAttrProtoFromAttribute(const std::pair &string_attr_value, + onnx::NodeProto *node_proto) { + if (node_proto == nullptr) { + GELOGE(FAILED, "Node proto is nullptr."); + return; + } + auto attr = node_proto->add_attribute(); + if (attr == nullptr) { + GELOGE(GRAPH_FAILED, "attr is nullptr."); + return; + } + auto attr_name = string_attr_value.first; + attr->set_name(attr_name); + auto attr_value = string_attr_value.second; + auto value_type = attr_value.GetValueType(); + switch (value_type) { + case GeAttrValue::VT_FLOAT: { + GeAttrValue::FLOAT data_f = 0; + (void)attr_value.GetValue(data_f); + attr->set_f(data_f); + attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); + break; + } + case GeAttrValue::VT_LIST_FLOAT: { + GeAttrValue::LIST_FLOAT data_fs = {}; + (void)attr_value.GetValue(data_fs); + attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); + for (auto &v : data_fs) { + attr->add_floats(v); + } + break; + } + case GeAttrValue::VT_INT: { + GeAttrValue::INT data_i = 0; + (void)attr_value.GetValue(data_i); + attr->set_type(onnx::AttributeProto_AttributeType_INT); + attr->set_i(data_i); + break; + } + case GeAttrValue::VT_LIST_INT: { + GeAttrValue::LIST_INT data_is = {}; + (void)attr_value.GetValue(data_is); + attr->set_type(onnx::AttributeProto_AttributeType_INTS); + for (auto &v : data_is) { + attr->add_ints(v); + } + break; + } + case GeAttrValue::VT_STRING: { + GeAttrValue::STR data_s; + (void)attr_value.GetValue(data_s); + attr->set_type(onnx::AttributeProto_AttributeType_STRING); + attr->set_s(data_s); + break; + } + case GeAttrValue::VT_LIST_STRING: { + GeAttrValue::LIST_STR data_ss = {}; + (void)attr_value.GetValue(data_ss); + attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); + for (auto &v : data_ss) { + attr->add_strings(v); + } + break; + } + default: + GELOGW("GeAttrValue ValueType: %u is not supported for now", value_type); + break; + } +} + +void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, + void *data) { + if (node_proto == nullptr) { + GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); + return; + } + auto attr = node_proto->add_attribute(); + if (attr == nullptr) { + GELOGE(GRAPH_FAILED, "attr is nullptr."); + return; + } + attr->set_name(name); + switch (type) { + case onnx::AttributeProto_AttributeType_FLOAT: + attr->set_f((*(static_cast(data)))); + attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); + break; + + case onnx::AttributeProto_AttributeType_FLOATS: + attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); + for (auto &v : (*(static_cast *>(data)))) { + attr->add_floats(v); + } + break; + + case onnx::AttributeProto_AttributeType_INT: + attr->set_type(onnx::AttributeProto_AttributeType_INT); + attr->set_i((*(static_cast(data)))); + break; + + case onnx::AttributeProto_AttributeType_INTS: + attr->set_type(onnx::AttributeProto_AttributeType_INTS); + for (auto &v : *(static_cast *>(data))) { + attr->add_ints(v); + } + break; + + case onnx::AttributeProto_AttributeType_STRING: + attr->set_type(onnx::AttributeProto_AttributeType_STRING); + attr->set_s((*(static_cast(data)))); + break; + + case onnx::AttributeProto_AttributeType_STRINGS: + attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); + for (auto &v : *(static_cast *>(data))) { + attr->add_strings(v); + } + break; + + default: + GELOGW("AttributeProto AttributeType: %u is not supported for now", type); + break; + } +} + +void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, + ::google::protobuf::RepeatedField<::google::protobuf::int64> data) { + if (node_proto == nullptr) { + GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); + return; + } + if (!data.empty()) { + auto attr = node_proto->add_attribute(); + if (attr == nullptr) { + GELOGE(GRAPH_FAILED, "attr is nullptr."); + return; + } + attr->set_name(name); + for (auto &v : data) { + attr->add_ints(v); + } + attr->set_type(type); + } +} + +void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, + ::google::protobuf::RepeatedField data) { + if (node_proto == nullptr) { + GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str()); + return; + } + if (!data.empty()) { + auto attr = node_proto->add_attribute(); + if (attr == nullptr) { + GELOGE(GRAPH_FAILED, "attr is nullptr."); + return; + } + attr->set_name(name); + for (auto &v : data) { + attr->add_ints(static_cast(v)); + } + attr->set_type(type); + } +} + +void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, + ::google::protobuf::RepeatedField data) { + if (node_proto == nullptr) { + GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); + return; + } + if (!data.empty()) { + auto attr = node_proto->add_attribute(); + if (attr == nullptr) { + GELOGE(GRAPH_FAILED, "attr is nullptr."); + return; + } + attr->set_name(name); + for (auto &v : data) { + attr->add_floats(v); + } + attr->set_type(type); + } +} + +void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, + ::google::protobuf::RepeatedPtrField<::std::string> data) { + if (node_proto == nullptr) { + GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str()); + return; + } + if (!data.empty()) { + auto attr = node_proto->add_attribute(); + if (attr == nullptr) { + GELOGE(GRAPH_FAILED, "attr is nullptr."); + return; + } + attr->set_name(name); + for (auto &v : data) { + attr->add_strings(v); + } + attr->set_type(type); + } +} + +void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc) { + if (node_proto == nullptr || op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "node_proto or op_desc is nullptr"); + return; + } + // Input describes + auto size_in = op_desc->GetAllInputsSize(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_nums", &size_in); + if (size_in > 0) { + for (uint32_t i = 0; i < size_in; i++) { + auto input_desc = op_desc->GetInputDescPtrDfault(i); + if (input_desc != nullptr) { + auto data_type = TypeUtils::DataTypeToSerialString(input_desc->GetDataType()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_dtype:" + std::to_string(i), + &data_type); + auto data_type_origin = TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "input_desc_origin_dtype:" + std::to_string(i), &data_type_origin); + auto dims = input_desc->GetShape().GetDims(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_desc_shape:" + std::to_string(i), + &dims); + auto dims_origin = input_desc->GetOriginShape().GetDims(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, + "input_desc_origin_shape:" + std::to_string(i), &dims_origin); + auto layout = TypeUtils::FormatToSerialString(input_desc->GetFormat()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_layout:" + std::to_string(i), + &layout); + auto layout_origin = TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "input_desc_origin_layout:" + std::to_string(i), &layout_origin); + auto tensor_descriptor = input_desc->tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor != nullptr) { + auto size = tensor_descriptor->size(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_size:" + std::to_string(i), + &size); + auto weight_size = tensor_descriptor->weight_size(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_weight_size:" + std::to_string(i), &weight_size); + auto reuse_input = tensor_descriptor->reuse_input(); + auto reuse_input_int = static_cast(reuse_input); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_reuse_input:" + std::to_string(i), &reuse_input_int); + auto output_tensor = tensor_descriptor->output_tensor(); + auto output_tensor_int = static_cast(output_tensor); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_output_tensor:" + std::to_string(i), &output_tensor_int); + auto device_type = tensor_descriptor->device_type(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "input_desc_device_type:" + std::to_string(i), &device_type); + auto input_tensor = tensor_descriptor->input_tensor(); + auto input_tensor_int = static_cast(input_tensor); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_input_tensor:" + std::to_string(i), &input_tensor_int); + auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); + auto data_offset = tensor_descriptor->data_offset(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_data_offset:" + std::to_string(i), &data_offset); + auto cmps_size = tensor_descriptor->cmps_size(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_cmps_size:" + std::to_string(i), + &cmps_size); + auto cmps_tab = tensor_descriptor->cmps_tab(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "input_desc_cmps_tab:" + std::to_string(i), &cmps_tab); + auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset); + const auto &tensor_desc_map = tensor_descriptor->attr(); + std::string suffix = ":" + std::to_string(i); + AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForInputDesc, suffix); + } else { + GELOGW("Tensor descriptor is nullptr"); + continue; + } + } else { + GELOGW("Input desc is nullptr"); + continue; + } + } + } + // Output describes + auto size_out = op_desc->GetOutputsSize(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_nums", &size_out); + if (size_out > 0) { + for (uint32_t i = 0; i < size_out; i++) { + auto output_desc = op_desc->GetOutputDescPtr(i); + if (output_desc != nullptr) { + auto data_type = TypeUtils::DataTypeToSerialString(output_desc->GetDataType()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_dtype:" + std::to_string(i), + &data_type); + auto origin_data_type = TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "output_desc_origin_dtype:" + std::to_string(i), &origin_data_type); + auto dims = output_desc->GetShape().GetDims(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_desc_shape:" + std::to_string(i), + &dims); + auto dims_origin = output_desc->GetOriginShape().GetDims(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, + "output_desc_origin_shape:" + std::to_string(i), &dims_origin); + auto layout = TypeUtils::FormatToSerialString(output_desc->GetFormat()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_layout:" + std::to_string(i), + &layout); + auto layout_origin = TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "output_desc_origin_layout:" + std::to_string(i), &layout_origin); + auto tensor_descriptor = output_desc->tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor != nullptr) { + auto size = tensor_descriptor->size(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_size:" + std::to_string(i), + &size); + auto weight_size = tensor_descriptor->weight_size(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "output_desc_weight_size:" + std::to_string(i), &weight_size); + auto device_type = tensor_descriptor->device_type(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "output_desc_device_type:" + std::to_string(i), &device_type); + auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); + const auto &tensor_desc_map = tensor_descriptor->attr(); + std::string suffix = ":" + std::to_string(i); + AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForOutputDesc, suffix); + } else { + GELOGW("Tensor descriptor is nullptr"); + continue; + } + } else { + GELOGW("Output desc is nullptr"); + continue; + } + } + } +} + +void OnnxUtils::AddAttrProtoForAttrsFromAttrMap( + const ::google::protobuf::Map &attr_map, onnx::NodeProto *node_proto, + const std::string &prefix, const std::string &suffix) { + for (const auto &item : attr_map) { + auto attr_name = item.first; + auto attr_def = item.second; + auto attr_type = attr_def.value_case(); + if (attr_type == ge::proto::AttrDef::kT) { + const auto &tensor_def = attr_def.t(); + const auto &tensor_desc = tensor_def.desc(); + auto data_type = ge::proto::DataType_Name(tensor_desc.dtype()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_dtype" + suffix, + &data_type); + auto dims = tensor_desc.shape().dim(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + "_desc_shape" + suffix, + dims); + auto layout = tensor_desc.layout(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_layout" + suffix, + &layout); + auto device_type = tensor_desc.device_type(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + prefix + attr_name + "_desc_device_type" + suffix, &device_type); + if (kDumpLevel == DUMP_ALL) { + auto data = tensor_def.data(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_data" + suffix, + &data); + } + } + if (attr_type == ge::proto::AttrDef::kS) { + if (kDumpLevel == DUMP_ALL) { + auto str_value = attr_def.s(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value); + } + } + if (attr_type == ge::proto::AttrDef::kI) { + auto int_value = attr_def.i(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); + } + if (attr_type == ge::proto::AttrDef::kF) { + auto float_value = attr_def.f(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value); + } + if (attr_type == ge::proto::AttrDef::kB) { + auto int_value = static_cast(attr_def.b()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); + } + if (attr_type == ge::proto::AttrDef::kList) { + const auto &list_value = attr_def.list(); + auto list_value_type = list_value.val_type(); + if (list_value_type == + ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) { + if (kDumpLevel == DUMP_ALL) { + const auto &strings = list_value.s(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings); + } + } + if (list_value_type == + ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) { + const auto &floats = list_value.f(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats); + } + if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) { + const auto &ints = list_value.i(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints); + } + if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) { + const auto &bools = list_value.b(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools); + } + } + } +} + +void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "node is nullptr"); + return; + } + // 1.Attributes added from node's methods + auto send_list = node->send_event_id_list_; + if (!send_list.empty()) { + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "send_event_id_list", &send_list); + } + auto recv_list = node->recv_event_id_list_; + if (!recv_list.empty()) { + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "recv_event_id_list", &recv_list); + } + auto op_desc = node->op_; + if (op_desc != nullptr) { + // for input_name_idx_ in opdesc + auto input_name_2_indexs = op_desc->GetAllInputName(); + ::google::protobuf::RepeatedPtrField<::std::string> input_names; + ::google::protobuf::RepeatedField<::google::protobuf::int64> input_indexes; + for (const auto &input_name_2_index : input_name_2_indexs) { + std::string input_name = input_name_2_index.first; + input_names.Add(std::move(input_name)); + input_indexes.Add(input_name_2_index.second); + } + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "_input_name_key", input_names); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "_input_name_value", input_indexes); + // 2.Attributes added from node's op_(message OpDef) + // Input and out describes + AddAttrProtoForOpInAndOutDesc(node_proto, op_desc); + // Others + auto op_def = op_desc->op_def_.GetProtoMsg(); + if (op_def != nullptr) { + auto id = op_def->id(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "id", &id); + auto stream_id = op_def->stream_id(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "stream_id", &stream_id); + const auto &input_name = op_def->input_name(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "input_name", input_name); + const auto &src_name = op_def->src_name(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "src_name", src_name); + const auto &src_index = op_def->src_index(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "src_index", src_index); + const auto &dst_name = op_def->dst_name(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "dst_name", dst_name); + const auto &dst_index = op_def->dst_index(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "dst_index", dst_index); + const auto &input_i = op_def->input_i(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_i", input_i); + const auto &output_i = op_def->output_i(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_i", output_i); + const auto &workspace = op_def->workspace(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace", workspace); + const auto &workspace_bytes = op_def->workspace_bytes(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes); + const auto &is_input_const = op_def->is_input_const(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const); + const auto &op_def_attr_map = op_def->attr(); + AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto); + } else { + GELOGE(FAILED, "Opdef is nullptr"); + return; + } + } else { + GELOGE(FAILED, "Opdesc is nullptr"); + return; + } +} + +bool OnnxUtils::EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto) { + if ((node == nullptr) || (node_proto == nullptr)) { + GELOGE(GRAPH_FAILED, "EncodeOpDesc: Input Para Node Invalid"); + return false; + } + + // 2.Encode map attrs_ to AttributeProto + for (auto &node_attr : node->attrs_) { + AddAttrProtoFromAttribute(node_attr, node_proto); + } + // 3.Encode ge::Node members to AttributeProto + AddAttrProtoFromNodeMembers(node, node_proto); + return true; +} + +void OnnxUtils::EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto) { + if ((node == nullptr) || (node_proto == nullptr)) { + GELOGE(GRAPH_FAILED, "EncodeNodeLinkForNetronVisual: Input Para Node Invalid"); + return; + } + const auto &node_name = node->GetName(); + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + if ((out_data_anchor != nullptr) && (!out_data_anchor->GetPeerInDataAnchors().empty())) { + node_proto->add_output(node_name + ":" + std::to_string(out_data_anchor->GetIdx())); + } + } + auto out_control_anchor = node->GetOutControlAnchor(); + if ((out_control_anchor != nullptr) && (!out_control_anchor->GetPeerInControlAnchors().empty())) { + node_proto->add_output(node_name + kControlAnchorIndex); + } +} + +bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) { + if ((node == nullptr) || (node_proto == nullptr)) { + GELOGE(GRAPH_FAILED, "EncodeNodeLink: Input Para Node Invalid"); + return false; + } + node_proto->clear_input(); + // 1. Add input by in data edge + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) { + node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + + std::to_string(peer_out_anchor->GetIdx())); + } else { + // Add "" input + node_proto->add_input(""); + } + } + + // 2. Add input by in control edge + auto in_control_anchor = node->GetInControlAnchor(); + if (in_control_anchor != nullptr) { + auto peer_out_anchors = in_control_anchor->GetPeerOutControlAnchors(); + for (const auto &peer_out_anchor : peer_out_anchors) { + if (peer_out_anchor->GetOwnerNode()) { + node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex); + } + } + } else { + GELOGE(FAILED, "Incontrol anchor is nullptr"); + return false; + } + + // 3. Add output for Netron visual support + EncodeNodeLinkForNetronVisual(node, node_proto); + return true; +} + +bool OnnxUtils::EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto) { + if ((node == nullptr) || (node_proto == nullptr)) { + GELOGE(GRAPH_FAILED, "EncodeNode: Input Para Node Invalid"); + return false; + } + // 1. Encode name and type + node_proto->set_name(node->GetName()); + /// Netron believes that some operators, such as the activation operator of softplus, only have one input, + /// while the link relation of control anchor may exist in ge, resulting in two inputs. Therefore, "ge:" prefix + /// is added to correctly display the link relation at the expense of some color features + node_proto->set_op_type("ge:" + node->GetType()); + + if (kDumpLevel != DUMP_WITH_OUT_DESC) { + // 2.for attr + if (!EncodeNodeDesc(node, node_proto)) { + GELOGE(GRAPH_FAILED, "Encode NodeDesc: %s failed", node->GetName().c_str()); + return false; + } + } + // 3.for link info + return EncodeNodeLink(node, node_proto); +} + +void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type) { + if ((node == nullptr) || (tensor_type == nullptr)) { + GELOGE(GRAPH_FAILED, "EncodeTypeProtoTensorType: Input Para Node or tensor_type Invalid"); + return; + } + const auto &op_desc = node->GetOpDesc(); + if (op_desc != nullptr) { + uint32_t size_out = static_cast(op_desc->GetOutputsSize()); + if (size_out > 0) { + for (uint32_t i = 0; i < size_out; i++) { + const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i); + if (ge_tensor != nullptr) { + auto ge_data_type = ge_tensor->GetDataType(); + auto onnx_data_type = EncodeDataType(ge_data_type); + tensor_type->set_elem_type(onnx_data_type); + onnx::TensorShapeProto *shape = tensor_type->mutable_shape(); + if (shape != nullptr) { + for (auto d : ge_tensor->GetShape().GetDims()) { + auto dim = shape->add_dim(); + dim->set_dim_value(d); + } + } else { + GELOGW("Shape is nullptr"); + continue; + } + } else { + GELOGW("Ge tensor is nullptr"); + continue; + } + } + } + } else { + GELOGW("OpDesc Is Empty, nodeName %s nodeType %s", node->GetName().c_str(), node->GetType().c_str()); + return; + } +} + +void OnnxUtils::EncodeValueInfo(const NodePtr &node, onnx::ValueInfoProto *value_info_proto) { + if ((node == nullptr) || (value_info_proto == nullptr)) { + GELOGE(GRAPH_FAILED, "EncodeValueInfo: Input Para Node or value_info_proto Invalid"); + return; + } + value_info_proto->set_name(node->GetName()); + onnx::TypeProto *t = value_info_proto->mutable_type(); + onnx::TypeProto_Tensor *tensor_type = t->mutable_tensor_type(); + EncodeTypeProtoTensorType(node, tensor_type); +} + +bool OnnxUtils::EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto) { + if ((graph == nullptr) || (graph_proto == nullptr)) { + GELOGE(GRAPH_FAILED, "EncodeGraph: Input para Invalid"); + return false; + } + graph_proto->set_name(graph->GetName()); + // 1. Add graph inputs + for (const auto &input : graph->GetInputNodes()) { + auto value_info_proto = graph_proto->add_input(); + EncodeValueInfo(input, value_info_proto); + } + // 2. Add graph outputs + for (const auto &output : graph->GetOutputNodes()) { + auto value_info_proto = graph_proto->add_output(); + EncodeValueInfo(output, value_info_proto); + } + // 3. Add nodes + for (const auto &node : graph->GetDirectNode()) { + if (!EncodeNode(node, graph_proto->add_node())) { + GELOGW("EncodeNode failed"); + continue; + } + } + return true; +} + +bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto) { + model_proto.set_model_version(model.GetVersion()); + model_proto.set_ir_version(onnx::IR_VERSION); + model_proto.set_producer_name(model.GetName()); + auto &graph = model.graph_; + auto compute_graph = GraphUtils::GetComputeGraph(graph); + if (compute_graph == nullptr) { + GELOGE(GRAPH_FAILED, "GetComputeGraph: return nullptr"); + return false; + } + auto graph_proto = model_proto.mutable_graph(); + if (graph_proto == nullptr) { + GELOGE(GRAPH_FAILED, "mutable_graph: %s return nullptr", compute_graph->GetName().c_str()); + return false; + } + if (!EncodeGraph(compute_graph, graph_proto)) { + GELOGE(GRAPH_FAILED, "EncodeGraph: %s fail", compute_graph->GetName().c_str()); + return false; + } + + // For subgraphs: a subgraph is represented by a node + for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) { + if (sub_compute_graph != nullptr) { + auto node_proto = graph_proto->add_node(); + if (node_proto == nullptr) { + GELOGW("Node proto is nullptr"); + continue; + } + node_proto->set_name(sub_compute_graph->GetName()); + node_proto->set_op_type(kNodeTypeForSubgraph); + auto attr = node_proto->add_attribute(); + attr->set_name("graph"); + attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); + auto sub_graph_proto = attr->mutable_g(); + if (sub_graph_proto == nullptr) { + GELOGW("Sub graph proto is nullptr"); + continue; + } + if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) { + GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str()); + continue; + } + } else { + GELOGW("Graph: %s subgraph is nullptr, skip EncodeGraph", compute_graph->GetName().c_str()); + continue; + } + } + return true; +} + +// Part 2: from ONNX Protobuf convert to IR +static std::map onnxDataTypeToGeMap = { + {onnx::TensorProto_DataType_INT64, DT_INT64}, {onnx::TensorProto_DataType_UINT64, DT_UINT64}, + {onnx::TensorProto_DataType_FLOAT, DT_FLOAT}, {onnx::TensorProto_DataType_INT32, DT_INT32}, + {onnx::TensorProto_DataType_UINT32, DT_UINT32}, {onnx::TensorProto_DataType_INT8, DT_INT8}, + {onnx::TensorProto_DataType_UINT8, DT_UINT8}, {onnx::TensorProto_DataType_INT16, DT_INT16}, + {onnx::TensorProto_DataType_UINT16, DT_UINT16}, {onnx::TensorProto_DataType_FLOAT16, DT_FLOAT16}, + {onnx::TensorProto_DataType_DOUBLE, DT_DOUBLE}, {onnx::TensorProto_DataType_BOOL, DT_BOOL}, +}; + +ge::DataType OnnxUtils::DecodeDataType(onnx::TensorProto_DataType data_type) { + auto it = onnxDataTypeToGeMap.find(data_type); + if (it != onnxDataTypeToGeMap.end()) { + return it->second; + } else { + GELOGW("DecodeDataType: datatype not support %u", data_type); + return ge::DT_UNDEFINED; + } +} + +bool OnnxUtils::ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index) { + auto sep = node_name_index.rfind(':'); + if (sep == std::string::npos) { + return false; + } + node_name = node_name_index.substr(0, sep); + auto index_str = node_name_index.substr(sep + 1); + index = static_cast(std::strtol(index_str.c_str(), nullptr, 10)); + return true; +} + +bool OnnxUtils::DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr) { + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp: node_ptr is nullptr"); + return false; + } + // Data edge + if (item.src_out_index >= 0) { + auto src_anchor = node_ptr->GetOutDataAnchor(item.src_out_index); + auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index); + if ((src_anchor == nullptr) || (dst_anchor == nullptr)) { + GELOGE(GRAPH_FAILED, "Get data anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, + item.dst_node_name.c_str(), item.dst_in_index); + return false; + } + if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Data Anchor: src_anchor->LinkTo(dst_anchor) failed"); + return false; + } + // Control edge + } else { + auto src_anchor = node_ptr->GetOutControlAnchor(); + auto dst_anchor = item.dst_node->GetInControlAnchor(); + if ((src_anchor == nullptr) || (dst_anchor == nullptr)) { + GELOGE(GRAPH_FAILED, "Get control anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, + item.dst_node_name.c_str(), item.dst_in_index); + return false; + } + if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Control Anchor: src_anchor->LinkTo(dst_anchor) failed"); + return false; + } + } + return true; +} + +bool OnnxUtils::DecodeNodeLink(const std::vector &node_proto_vector, + const std::map &node_map) { + for (const auto &node_proto : node_proto_vector) { + const auto &node_name = node_proto.name(); + auto dst_node = node_map.find(node_name); + if ((dst_node == node_map.end()) || (dst_node->second == nullptr)) { + GELOGE(GRAPH_FAILED, "destination node: %s find failed or is nullptr", node_name.c_str()); + return false; + } + int32_t dst_index = 0; + for (const auto &input : node_proto.input()) { + std::string input_node_name; + int32_t index = 0; + if (ParseNameIndex(input, input_node_name, index)) { + auto item = NodeLinkInfo{input_node_name, index, dst_node->second, dst_index, node_proto.name()}; + auto src_node = node_map.find(input_node_name); + if (src_node == node_map.end()) { + GELOGE(GRAPH_FAILED, "find src node: %s failed", input_node_name.c_str()); + return false; + } + auto node_ptr = src_node->second; + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "src node: %s is nullptr", input_node_name.c_str()); + return false; + } + if (!DecodeNodeLinkImp(item, node_ptr)) { + GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp node: %s failed", input_node_name.c_str()); + return false; + } + } + if (index >= 0) { + dst_index++; + } + } + } + return true; +} + +void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &strings) { + if (attr_proto.type() != onnx::AttributeProto_AttributeType_STRINGS) { + GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); + return; + } + for (int i = 0; i < attr_proto.strings_size(); i++) { + strings.push_back(attr_proto.strings(i)); + } +} + +void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value) { + if (attr_proto.type() != onnx::AttributeProto_AttributeType_STRING) { + GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); + return; + } + value = attr_proto.s(); +} + +void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &ints) { + if (attr_proto.type() != onnx::AttributeProto_AttributeType_INTS) { + GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); + return; + } + for (int i = 0; i < attr_proto.ints_size(); i++) { + ints.push_back(attr_proto.ints(i)); + } +} + +void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value) { + if (attr_proto.type() != onnx::AttributeProto_AttributeType_INT) { + GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); + return; + } + value = attr_proto.i(); +} + +void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, + const std::string &attr_name_for_input_desc, int32_t index, + OpDescPtr &op_desc) { + if (op_desc->MutableInputDesc(static_cast(index)) == nullptr) { + GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableInputDesc(static_cast(index)) is nullptr", + op_desc->GetName().c_str(), attr_name_for_input_desc.c_str()); + return; + } + if (attr_name_for_input_desc == "input_desc_dtype") { + auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); + op_desc->MutableInputDesc(static_cast(index))->SetDataType(data_type); + } else if (attr_name_for_input_desc == "input_desc_shape") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + GeShape ge_shape(ints); + op_desc->MutableInputDesc(static_cast(index))->SetShape(ge_shape); + } else if (attr_name_for_input_desc == "input_desc_layout") { + auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); + op_desc->MutableInputDesc(static_cast(index))->SetFormat(data_format); + } else if (attr_name_for_input_desc == "input_desc_origin_shape") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + GeShape ge_shape(ints); + op_desc->MutableInputDesc(static_cast(index))->SetOriginShape(ge_shape); + } else if (attr_name_for_input_desc == "input_desc_origin_layout") { + auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); + op_desc->MutableInputDesc(static_cast(index))->SetOriginFormat(data_format); + } else if (attr_name_for_input_desc == "input_desc_size") { + int64_t input_size = 0; + auto tensor_descriptor = op_desc->MutableInputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); + DecodeAttribute(attr_proto, input_size); + tensor_descriptor->set_size(input_size); + } else if (attr_name_for_input_desc == "input_desc_data_offset") { + auto tensor_descriptor = op_desc->MutableInputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); + int64_t offset = 0; + DecodeAttribute(attr_proto, offset); + tensor_descriptor->set_data_offset(offset); + } else { + return; + } +} + +void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, + const std::string &attr_name_for_output_desc, int32_t index, + OpDescPtr &op_desc) { + if (op_desc->MutableOutputDesc(static_cast(index)) == nullptr) { + GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableOutputDesc(static_cast(index)) is nullptr", + op_desc->GetName().c_str(), attr_name_for_output_desc.c_str()); + return; + } + if (attr_name_for_output_desc == "output_desc_dtype") { + auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); + op_desc->MutableOutputDesc(static_cast(index))->SetDataType(data_type); + } else if (attr_name_for_output_desc == "output_desc_shape") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + GeShape ge_shape(ints); + op_desc->MutableOutputDesc(static_cast(index))->SetShape(ge_shape); + } else if (attr_name_for_output_desc == "output_desc_layout") { + auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); + op_desc->MutableOutputDesc(static_cast(index))->SetFormat(data_format); + } else if (attr_name_for_output_desc == "output_desc_origin_shape") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + GeShape ge_shape(ints); + op_desc->MutableOutputDesc(static_cast(index))->SetOriginShape(ge_shape); + } else if (attr_name_for_output_desc == "output_desc_origin_layout") { + auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); + op_desc->MutableOutputDesc(static_cast(index))->SetOriginFormat(data_format); + } else if (attr_name_for_output_desc == "output_desc_size") { + int64_t output_size = 0; + auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); + DecodeAttribute(attr_proto, output_size); + tensor_descriptor->set_size(output_size); + } else if (attr_name_for_output_desc == "output_desc_data_offset") { + auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); + int64_t offset = 0; + DecodeAttribute(attr_proto, offset); + tensor_descriptor->set_data_offset(offset); + } else { + return; + } +} + +void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, + const std::string &attr_name_for_input_output_desc, int32_t index, + OpDescPtr &op_desc) { + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "op_desc is nullptr"); + return; + } + if (attr_name_for_input_output_desc.substr(0, kInputPrefixLength) == "input") { + DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); + } else if (attr_name_for_input_output_desc.substr(0, kOutputPrefixLength) == "output") { + DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); + } else { + return; + } +} + +void OnnxUtils::DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def) { + auto attr_map = op_def.mutable_attr(); + const auto &attr_name = attr_proto.name(); + ge::proto::AttrDef op_attr; + int64_t value = 0; + DecodeAttribute(attr_proto, value); + op_attr.set_i(value); + attr_map->insert(AttrDefPair(attr_name, op_attr)); +} + +void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) { + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr"); + return; + } + const auto &attr_name = attr_proto.name(); + std::string attr_name_for_input_output_desc; + int32_t index = 0; + if (!ParseNameIndex(attr_name, attr_name_for_input_output_desc, index)) { + if (attr_name == "id") { + op_desc->SetId(attr_proto.i()); + } else if (attr_name == "stream_id") { + op_desc->SetStreamId(attr_proto.i()); + } else if (attr_name == "src_name") { + std::vector strings; + DecodeAttribute(attr_proto, strings); + op_desc->SetSrcName(strings); + } else if (attr_name == "dst_name") { + std::vector strings; + DecodeAttribute(attr_proto, strings); + op_desc->SetDstName(strings); + } else if (attr_name == "src_index") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + op_desc->SetSrcIndex(ints); + } else if (attr_name == "dst_index") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + op_desc->SetDstIndex(ints); + } else if (attr_name == "fusion_scope") { + DecodeNodeAttributeForOpDef(attr_proto, *op_desc->op_def_.GetProtoMsg()); + } else if (attr_name == "input_i") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + op_desc->SetInputOffset(ints); + } else if (attr_name == "output_i") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + op_desc->SetOutputOffset(ints); + } else { + return; + } + // Update input and output desc + } else { + DecodeNodeAttributeForOpInAndOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); + } +} + +bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_desc) { + if (op_desc == nullptr || node_proto == nullptr) { + GELOGE(GRAPH_FAILED, " Op_desc is nullptr or node_proto is nullptr"); + return false; + } + // 1. Decode node_proto name and type + op_desc->SetName(node_proto->name()); + const auto &node_type_with_ge_prefix = node_proto->op_type(); + auto sep = node_type_with_ge_prefix.find(':'); + if (sep == std::string::npos) { + return false; + } + auto node_type = node_type_with_ge_prefix.substr(sep + 1); + op_desc->SetType(node_type); + // 2. Add empty input and output desc + for (const auto &attr : node_proto->attribute()) { + if (attr.name() == "input_desc_nums") { + auto size_in = attr.i(); + for (int64_t i = 0; i < size_in; i++) { + GeTensorDesc ge_tensor_desc; + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed."); + } + } + if (attr.name() == "output_desc_nums") { + auto size_out = attr.i(); + for (int64_t i = 0; i < size_out; i++) { + GeTensorDesc ge_tensor_desc; + GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed."); + } + } + } + // 3.Decode node_proto attributes + for (int i = 0; i < node_proto->attribute_size(); i++) { + DecodeNodeAttributeForOpDesc(node_proto->attribute(i), op_desc); + } + return true; +} + +bool OnnxUtils::DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph) { + if (recursion_depth > kMaxRecursionDepth) { + GELOGE(GRAPH_FAILED, "DecodeGraph: recursion depth is too large, abort"); + return false; + } + + graph = ComGraphMakeShared(graph_proto.name()); + GE_CHK_BOOL_EXEC(graph != nullptr, return false, "ComputeGraph make shared failed"); + /// 1. Decode all nodes first, node should include input + /// and output nodes and nodes which represent sub graphs + std::map node_map; + std::vector node_proto_vector; + for (const auto &node_proto : graph_proto.node()) { + // a. nodes represent sub graphs + if (node_proto.op_type() == kNodeTypeForSubgraph) { + ComputeGraphPtr compute_graph; + // in this case, node only have one attr, whose type is AttributeProto_AttributeType_GRAPH + const auto &node_attr = node_proto.attribute(0); + if ((node_attr.type() == onnx::AttributeProto_AttributeType_GRAPH) && + DecodeGraph(recursion_depth + 1, node_attr.g(), compute_graph)) { + (void)graph->AddSubGraph(compute_graph); + } else { + GELOGE(GRAPH_FAILED, "Decode sub graph %s failed with node type:%d", node_proto.name().c_str(), + node_attr.type()); + return false; + } + // b. direct nodes in graph + } else { + node_proto_vector.push_back(node_proto); + OpDescPtr op_desc = ComGraphMakeShared(); + // b.1 For node desc + if (!DecodeNodeDesc(&node_proto, op_desc)) { + GELOGE(GRAPH_FAILED, "Decode node desc %s failed ", node_proto.name().c_str()); + return false; + } + auto node = graph->AddNode(op_desc); + node_map.insert(std::make_pair(node_proto.name(), node)); + } + } + /// We get all nodes in graph here + /// b.2 For node link + if (!DecodeNodeLink(node_proto_vector, node_map)) { + GELOGE(GRAPH_FAILED, "Decode node link failed"); + return false; + } + + // 2. Add inputs nodes for graph + for (const auto &input : graph_proto.input()) { + const auto &input_node_name = input.name(); + auto input_node_item = node_map.find(input_node_name); + if (input_node_item == node_map.end()) { + GELOGE(GRAPH_FAILED, "cannot find graph's input node %s in node_", input_node_name.c_str()); + return false; + } + auto ret = graph->AddInputNode(input_node_item->second); + GE_CHK_BOOL_EXEC(ret != nullptr, continue, "Add inputnode failed"); + } + // 3. Add outputs nodes for graph + for (const auto &output : graph_proto.output()) { + const auto &output_node_name = output.name(); + auto output_node_item = node_map.find(output_node_name); + if (output_node_item == node_map.end()) { + GELOGE(GRAPH_FAILED, "cannot find graph's output node %s in node_", output_node_name.c_str()); + return false; + } + auto ret = graph->AddOutputNode(output_node_item->second); + if (ret == nullptr) { + GELOGW("Add outputnode failed,out put node is %s", output_node_name.c_str()); + continue; + } + } + return true; +} + +bool OnnxUtils::ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model) { + model.name_ = model_proto.producer_name(); + model.version_ = static_cast(model_proto.model_version()); + + auto &graph_proto = model_proto.graph(); + ComputeGraphPtr compute_graph; + // 0 means recursion depth, father call + if (!DecodeGraph(0, graph_proto, compute_graph)) { + GELOGE(GRAPH_FAILED, "Decode compute graph from graph_proto failed"); + return false; + } + model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph); + return true; +} +} // namespace ge diff --git a/src/common/graph/utils/ge_ir_utils.h b/src/common/graph/utils/ge_ir_utils.h new file mode 100644 index 00000000..9b16be18 --- /dev/null +++ b/src/common/graph/utils/ge_ir_utils.h @@ -0,0 +1,206 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ +#define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "proto/ge_ir.pb.h" +#include "proto/onnx.pb.h" + +namespace ge { +const int kOffsetToString = 2; + +/// +/// @ingroup ge_ir_utils +/// @brief RepeatedField->String +/// @param [in] const rpd_field RepeatedField +/// @return String +/// +template +const std::string ToString(const google::protobuf::RepeatedField &rpd_field) { + std::stringstream ss; + ss << "["; + for (const T &x : rpd_field) { + ss << x; + ss << ", "; + } + std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString); + str_ret += "]"; + return str_ret; +} + +/// +/// @ingroup ge_ir_utils +/// @brief RepeatedPtrField->String +/// @param [in] const rpd_field RepeatedPtrField +/// @return String +/// +template +const std::string ToString(const google::protobuf::RepeatedPtrField &rpd_ptr_field) { + std::stringstream ss; + ss << "["; + for (const T &x : rpd_ptr_field) { + ss << x; + ss << ", "; + } + std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString); + str_ret += "]"; + return str_ret; +} + +/// +/// @ingroup ge_ir_utils +/// @brief check, if not equal, log with tag +/// @param [in] const left_value, right_value reference, log_info_tag +/// @return bool +/// +template +bool IsEqual(const T &l_value, const T &r_value, const std::string &log_info_tag) { + if (l_value == r_value) { + return true; + } else { + GELOGE(GRAPH_FAILED, "Check failed with %s", log_info_tag.c_str()); + return false; + } +} + +class OnnxUtils { + public: + enum DumpLevel { NO_DUMP = 0, DUMP_ALL = 1, DUMP_WITH_OUT_DATA = 2, DUMP_WITH_OUT_DESC = 3, DUMP_LEVEL_END }; + + static bool ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto); + + static bool ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model); + + private: + // Part 1: from IR convert to ONNX Protobuf + static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, + const std::string &name, void *data); + + static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, + const std::string &name, ::google::protobuf::RepeatedField<::google::protobuf::int64> data); + + static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, + const std::string &name, ::google::protobuf::RepeatedField data); + + static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, + const std::string &name, ::google::protobuf::RepeatedField data); + + static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, + const std::string &name, ::google::protobuf::RepeatedPtrField<::std::string> data); + + static void AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto); + + static void AddAttrProtoFromAttribute(const std::pair &string_attr_value, + onnx::NodeProto *node_proto); + + static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc); + + static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map &attr_map, + onnx::NodeProto *node_proto, const std::string &prefix = "", + const std::string &suffix = ""); + + static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto); + + static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type); + + static void EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto); + + static bool EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto); + + static bool EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto); + + static bool EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto); + + static void EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type); + + static void EncodeValueInfo(const NodePtr &n, onnx::ValueInfoProto *v); + + static bool EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto); + + /// Part 2: from ONNX Protobuf convert to IR + /// Describes node's link relationships + struct NodeLinkInfo { + std::string src_node_name; + int32_t src_out_index; + NodePtr dst_node; + int32_t dst_in_index; + std::string dst_node_name; + }; + + // Parse node name and index + static bool ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index); + + static ge::DataType DecodeDataType(onnx::TensorProto_DataType data_type); + + static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &strings); + + static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &ints); + + static void DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value); + + static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value); + + static void DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, + const std::string &attr_name_for_output_desc, int32_t index, + OpDescPtr &op_desc); + + static void DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, + const std::string &attr_name_for_input_desc, int32_t index, + OpDescPtr &op_desc); + + static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, + const std::string &attr_name_for_input_output_desc, int32_t index, + OpDescPtr &op_desc); + + static void DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def); + + static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc); + + static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr); + + static bool DecodeNodeLink(const std::vector &node_proto_vector, + const std::map &node_map); + + static bool DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &node); + + static bool DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph); +}; +} // namespace ge + +#endif // COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ diff --git a/src/common/graph/utils/graph_utils.cc b/src/common/graph/utils/graph_utils.cc new file mode 100644 index 00000000..c741a316 --- /dev/null +++ b/src/common/graph/utils/graph_utils.cc @@ -0,0 +1,2767 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/graph_utils.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "./ge_context.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "proto/ge_ir.pb.h" +#include "utils/attr_utils.h" +#include "utils/ge_ir_utils.h" +#include "utils/node_utils.h" +#include "debug/ge_op_types.h" +#include "external/ge/ge_api_types.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" + +using google::protobuf::io::FileOutputStream; + +namespace ge { +enum DumpGraphLevel { + kDumpLevel1 = 1, + kDumpLevel2 = 2, + kDumpLevel3 = 3, + kDumpLevelOther, +}; + +namespace { +const int32_t kBaseOfIntegerValue = 10; +#ifdef FMK_SUPPORT_DUMP +const char *const kDumpGeGraph = "DUMP_GE_GRAPH"; +const int kDumpGraphIndexWidth = 5; +#endif +const char *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; +const char *const kDumpStrBuild = "Build"; +const char *const kDumpStrPartition = "partition"; +const char *const kDumpStrOptimizeSubgraph = "OptimizeSubGraph"; +const char *const kDumpStrSubgraphFunc = "sub_graph"; +const char *const kDumpStrAicpu = "Aicpu"; +}; // namespace + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, + const InDataAnchorPtr &dst) { + if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Add edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const AnchorPtr &src, + const AnchorPtr &dst) { + OutDataAnchorPtr src_data = Anchor::DynamicAnchorCast(src); + InDataAnchorPtr dst_data = Anchor::DynamicAnchorCast(dst); + OutControlAnchorPtr src_control = Anchor::DynamicAnchorCast(src); + InControlAnchorPtr dst_control = Anchor::DynamicAnchorCast(dst); + if ((src_data != nullptr) && (dst_data != nullptr) && (src_data->LinkTo(dst_data) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + if ((src_data != nullptr) && (dst_control != nullptr) && (src_data->LinkTo(dst_control) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + if ((src_control != nullptr) && (dst_control != nullptr) && (src_control->LinkTo(dst_control) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + if ((src_control != nullptr) && (dst_data != nullptr) && (src_control->LinkTo(dst_data) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Add edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, + const Format &src_format, + const InDataAnchorPtr &dst, + const Format &dst_format) { + if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { + auto ret = AnchorUtils::SetFormat(src, src_format); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set format failed, format is %d", static_cast(src_format)); + return ret; + } + ret = AnchorUtils::SetFormat(dst, dst_format); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set format failed,format is %d", static_cast(dst_format)); + return ret; + } + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Add edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutControlAnchorPtr &src, + const InControlAnchorPtr &dst) { + if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Add edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, + const InControlAnchorPtr &dst) { + if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Add edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutDataAnchorPtr &src, + const InDataAnchorPtr &dst) { + if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Remove edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const AnchorPtr &src, + const AnchorPtr &dst) { + if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Remove edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutControlAnchorPtr &src, + const InControlAnchorPtr &dst) { + if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Remove edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutDataAnchorPtr &src, + const InControlAnchorPtr &dst) { + if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Remove edge Failed."); + return GRAPH_FAILED; +} + +graphStatus GraphUtils::ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, + const InDataAnchorPtr &new_dst) { + if (RemoveEdge(src, dst) == GRAPH_SUCCESS && AddEdge(src, new_dst) == GRAPH_SUCCESS) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Replace edge dst Failed."); + return GRAPH_FAILED; +} + +graphStatus GraphUtils::ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst, + const InControlAnchorPtr &new_dst) { + if (RemoveEdge(src, dst) == GRAPH_SUCCESS && AddEdge(src, new_dst) == GRAPH_SUCCESS) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Replace edge dst Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertNodeBetweenDataAnchors( + const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, const NodePtr &new_node) { + GE_CHECK_NOTNULL(src); + GE_CHECK_NOTNULL(dst); + GE_CHECK_NOTNULL(new_node); + + InDataAnchorPtr node_in_anchor = new_node->GetInDataAnchor(0); + GE_CHK_BOOL_RET_STATUS(node_in_anchor != nullptr, GRAPH_FAILED, "this node has not inDataAnchor"); + OutDataAnchorPtr node_out_anchor = new_node->GetOutDataAnchor(0); + GE_CHK_BOOL_RET_STATUS(node_out_anchor != nullptr, GRAPH_FAILED, "this node has not outDataAnchor"); + GE_CHK_STATUS_RET(src->ReplacePeer(dst, node_in_anchor, node_out_anchor), "ReplacePeer Failed"); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node) { + GE_CHECK_NOTNULL(compute_graph); + if (remove_node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return GRAPH_FAILED; + } + + // Check if this node is belong to this compute graph, maybe a little slow + const auto &all_nodes_in_graph = compute_graph->GetDirectNode(); + if (std::find(all_nodes_in_graph.begin(), all_nodes_in_graph.end(), remove_node) == all_nodes_in_graph.end()) { + GELOGE(GRAPH_FAILED, "Can not find node %s in graph %s.", remove_node->GetName().c_str(), + compute_graph->GetName().c_str()); + return GRAPH_FAILED; + } + // Find all subgraph of this node + const auto &root_graph = GraphUtils::FindRootGraph(compute_graph); + std::vector subgraphs; + std::vector all_nodes; + std::deque candidates; + NodePtr remove_node_new = remove_node; + candidates.emplace_back(remove_node_new); + while (!candidates.empty()) { + const NodePtr node = candidates.front(); + all_nodes.emplace_back(node); + candidates.pop_front(); + + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + + const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); + for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { + auto subgraph = root_graph->GetSubgraph(*name_iter); + if (subgraph != nullptr) { + subgraphs.emplace_back(subgraph); + candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); + } + } + } + // Remove all subgraph + for (const auto &remove_graph : subgraphs) { + if (root_graph->RemoveSubGraph(remove_graph) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Remove subgraph failed, sub graph name is %s, compute graph is %s.", + remove_node->GetName().c_str(), compute_graph->GetName().c_str()); + return GRAPH_FAILED; + } + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node) { + GE_CHECK_NOTNULL(compute_graph); + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return GRAPH_FAILED; + } + + // If the node save as input node, delete it + (void)compute_graph->RemoveInputNode(node); + + // If the node save as output node, delete it + (void)compute_graph->RemoveOutputNode(node); + + // If the node has sub-graphs, delete them + auto ret = RemoveSubgraphRecursively(compute_graph, node); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Remove subgraph recursively failed."); + return GRAPH_FAILED; + } + + auto iter = find(compute_graph->nodes_.begin(), compute_graph->nodes_.end(), node); + if (iter != compute_graph->nodes_.end()) { + compute_graph->nodes_.erase(iter); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +/// Add two edges to the new node, respectively connecting the SRC and DST +/// associated with the original edge +/// A ---> B transfered to A ---> N ---> B +graphStatus InsertTransNode(ComputeGraph &compute_graph, const InDataAnchorPtr &in_data_anchor, + const std::vector &vec_op_desc) { + GE_CHECK_NOTNULL(in_data_anchor); + for (const auto &op_desc : vec_op_desc) { + GE_CHECK_NOTNULL(op_desc); + + auto ret = op_desc->AddInputDesc(GeTensorDesc()); + GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED, "Add input desc failed"); + ret = op_desc->AddOutputDesc(GeTensorDesc()); + GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED, "Add input desc failed"); + auto node_to_insert = compute_graph.AddNode(op_desc); + + GE_CHECK_NOTNULL(node_to_insert); + GE_CHECK_NOTNULL(in_data_anchor->GetPeerOutAnchor()); + + auto src = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); + if (!src) { + GELOGE(GRAPH_FAILED, "src nullptr error."); + return GRAPH_FAILED; + } + + auto src_out_index = in_data_anchor->GetPeerOutAnchor()->GetIdx(); + + auto dst = in_data_anchor->GetOwnerNode(); + if (!dst) { + GELOGE(GRAPH_FAILED, "dst nullptr error."); + return GRAPH_FAILED; + } + + auto dst_in_index = in_data_anchor->GetIdx(); + + auto in_data_anchor_src_format = AnchorUtils::GetFormat(in_data_anchor->GetPeerOutAnchor()); + auto in_data_anchor_dst_format = AnchorUtils::GetFormat(in_data_anchor); + + GE_CHECK_NOTNULL(src->GetOutDataAnchor(src_out_index)); + GE_CHECK_NOTNULL(dst->GetInDataAnchor(dst_in_index)); + + ret = GraphUtils::RemoveEdge(src->GetOutDataAnchor(src_out_index), dst->GetInDataAnchor(dst_in_index)); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Remove edge failed"); + return GRAPH_FAILED; + } + + GE_CHECK_NOTNULL(node_to_insert->GetInDataAnchor(0)); + GE_CHECK_NOTNULL(node_to_insert->GetOutDataAnchor(0)); + + ret = GraphUtils::AddEdge(src->GetOutDataAnchor(src_out_index), node_to_insert->GetInDataAnchor(0)); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add edge failed"); + return ret; + } + ret = GraphUtils::AddEdge(node_to_insert->GetOutDataAnchor(0), dst->GetInDataAnchor(dst_in_index)); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add edge failed"); + return ret; + } + + if (op_desc->HasAttr("input_format")) { + int64_t input_format = 0; + int64_t output_format = 0; + if (!AttrUtils::GetInt(op_desc, "input_format", input_format)) { + GELOGW("get attr input_format failed"); + continue; + } + if (!AttrUtils::GetInt(op_desc, "output_format", output_format)) { + GELOGW("get attr output_format failed"); + continue; + } + + GE_CHECK_NOTNULL(node_to_insert->GetInDataAnchor(0)->GetPeerOutAnchor()); + GE_CHK_BOOL_RET_STATUS(node_to_insert->GetOutDataAnchor(0)->GetPeerInDataAnchors().empty(), GRAPH_FAILED, + "Vistor is empty"); + GE_CHECK_NOTNULL(node_to_insert->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)); + + auto status = + AnchorUtils::SetFormat(node_to_insert->GetInDataAnchor(0)->GetPeerOutAnchor(), in_data_anchor_src_format); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set format failed,format is %d", static_cast(in_data_anchor_src_format)); + return status; + } + status = AnchorUtils::SetFormat(node_to_insert->GetInDataAnchor(0), static_cast(input_format)); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set format failed,format is %ld", input_format); + return status; + } + status = AnchorUtils::SetFormat(node_to_insert->GetOutDataAnchor(0), static_cast(output_format)); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set format failed,format is %ld", output_format); + return status; + } + status = AnchorUtils::SetFormat(node_to_insert->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0), + in_data_anchor_dst_format); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set format failed,format is %d", static_cast(in_data_anchor_dst_format)); + return status; + } + } + std::vector original_nodes; + GraphUtils::RecordOriginalNames(original_nodes, node_to_insert); + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTransNode( + ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, const std::vector &vec_op_desc) { + GE_CHECK_NOTNULL(compute_graph); + GE_CHECK_NOTNULL(in_data_anchor); + graphStatus ret = + ge::InsertTransNode(*compute_graph, in_data_anchor, vec_op_desc) == GRAPH_SUCCESS ? GRAPH_SUCCESS : GRAPH_FAILED; + return ret; +} + +/// +/// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst +/// @param [in] src +/// @param [in] dsts +/// @param [in] insert_node +/// @param [in] input_index +/// @param [in] output_index +/// @return graphStatus +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector &dsts, + const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { + GE_CHECK_NOTNULL(src); + GE_CHECK_NOTNULL(insert_node); + + NodePtr src_node = src->GetOwnerNode(); + if (src_node->GetOwnerComputeGraph() != insert_node->GetOwnerComputeGraph()) { + GELOGE(GRAPH_FAILED, "src:%s and insert_node:%s not exist in the same graph.", src_node->GetName().c_str(), + insert_node->GetName().c_str()); + return GRAPH_FAILED; + } + + if (AddEdge(src, insert_node->GetInDataAnchor(input_index)) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "AddEdge %s->%s failed.", src_node->GetName().c_str(), insert_node->GetName().c_str()); + return GRAPH_FAILED; + } + + OutControlAnchorPtr src_out_ctrl_anchor = src_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(src_out_ctrl_anchor); + + bool ctrl_edge_flag = true; + std::string type = NodeUtils::GetNodeType(src->GetOwnerNode()); + if ((type == SWITCH) || (type == REFSWITCH) || (type == SWITCHN)) { + ctrl_edge_flag = false; + } + + for (auto &dst : dsts) { + GE_CHECK_NOTNULL(dst); + NodePtr dst_node = dst->GetOwnerNode(); + GELOGI("Insert node %s between %s->%s.", insert_node->GetName().c_str(), src_node->GetName().c_str(), + dst_node->GetName().c_str()); + if (src_node->GetOwnerComputeGraph() != dst_node->GetOwnerComputeGraph()) { + GELOGE(GRAPH_FAILED, "src:%s and dst:%s not exist in the same graph.", src_node->GetName().c_str(), + dst_node->GetName().c_str()); + return GRAPH_FAILED; + } + + (void)RemoveEdge(src, dst); + if (AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), + dst_node->GetName().c_str(), insert_node->GetName().c_str(), dst_node->GetName().c_str()); + return GRAPH_FAILED; + } + + if (!ctrl_edge_flag) { + continue; + } + for (const InControlAnchorPtr &peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { + if ((RemoveEdge(src_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS) || + (AddEdge(insert_node->GetOutControlAnchor(), peer_in_ctrl_anchor) != GRAPH_SUCCESS)) { + GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), + peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str(), insert_node->GetName().c_str(), + peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); + return GRAPH_FAILED; + } + } + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveJustNode(ComputeGraph &compute_graph, + const NodePtr &node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should be not null."); + return GRAPH_FAILED; + } + auto iter = find(compute_graph.nodes_.begin(), compute_graph.nodes_.end(), node); + if (iter != compute_graph.nodes_.end()) { + compute_graph.nodes_.erase(iter); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveJustNode(ComputeGraphPtr compute_graph, + const NodePtr &node) { + GE_CHECK_NOTNULL(compute_graph); + GE_CHECK_NOTNULL(node); + graphStatus ret = (RemoveJustNode(*compute_graph, node) == GRAPH_SUCCESS ? GRAPH_SUCCESS : GRAPH_FAILED); + return ret; +} + +void GraphUtils::RecordOriginalNames(std::vector original_nodes, const ge::NodePtr &node) { + GE_CHK_BOOL_EXEC(node != nullptr, return, "node is null."); + std::vector original_names; + for (const auto &node_tmp : original_nodes) { + std::vector names_tmp; + ge::OpDescPtr opdesc_tmp = node_tmp->GetOpDesc(); + if (opdesc_tmp == nullptr) { + GELOGE(GRAPH_FAILED, "Node %s get opdesc is nullptr", node_tmp->GetName().c_str()); + continue; + } + auto ret = ge::AttrUtils::GetListStr(opdesc_tmp, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, names_tmp); + if (!ret) { + GELOGW("Get list str failed"); + continue; + } + if (names_tmp.size() != 0) { + original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); + } else { + original_names.push_back(opdesc_tmp->GetName()); + } + } + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names), + return, "Set original_op_names fail."); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::RecordOriginalNames(std::vector names_tmp, + const ge::NodePtr &node) { + GE_CHK_BOOL_EXEC(node != nullptr, return, "node is null."); + std::vector original_names; + if (names_tmp.size() != 0) { + original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); + } else { + std::string tmp; + original_names.push_back(tmp); + } + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names), + return, "Set original_op_names fail."); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::MatchDumpStr(const std::string &suffix) { + char *dump_level = std::getenv(kDumpGraphLevel); + int64_t dump_graph_level = + (dump_level != nullptr) ? std::strtol(dump_level, nullptr, kBaseOfIntegerValue) : kDumpLevel2; + + if (dump_graph_level == kDumpLevel1) { + return false; + } + + if (dump_graph_level == kDumpLevel2 && + ((suffix.find(kDumpStrPartition) != std::string::npos) || + (suffix.find(kDumpStrOptimizeSubgraph) != std::string::npos) || + (suffix.find(kDumpStrAicpu) != std::string::npos) || (suffix.find(kDumpStrSubgraphFunc) != std::string::npos))) { + return true; + } + + if (dump_graph_level == kDumpLevel3 && suffix.compare(kDumpStrBuild) != 0) { + return true; + } + + return false; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(const ge::ComputeGraphPtr &graph, + const std::string &suffix, + bool is_always_dump, + const std::string &user_graph_name) { +#ifdef FMK_SUPPORT_DUMP + char *dump_ge_graph = std::getenv(kDumpGeGraph); + GE_IF_BOOL_EXEC(dump_ge_graph == nullptr && !is_always_dump, return;); + + // dump the graph according to different graph level + if (GraphUtils::MatchDumpStr(suffix)) { + return; + } + + // file name + static std::atomic_long atomic_file_index(0); + auto file_index = atomic_file_index.fetch_add(1); + GELOGD("Start to dump om txt: %ld", file_index); + + thread_local long max_dump_file_num = 0; + if (max_dump_file_num == 0) { + string opt = "0"; + (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); + max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); + } + if (max_dump_file_num != 0 && file_index > max_dump_file_num) { + GELOGW("dump graph file cnt > maxDumpFileNum, maxDumpFileCnt=%ld.", max_dump_file_num); + return; + } + + std::stringstream stream_file_name; + stream_file_name << "ge_proto_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; + stream_file_name << "_" << suffix << ".txt"; + std::string proto_file = user_graph_name.empty() ? stream_file_name.str() : user_graph_name; + + // Create buffer + ge::Model model("", ""); + model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast(graph))); + Buffer buffer; + const int64_t kDumpLevel = + (dump_ge_graph != nullptr) ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) : ge::OnnxUtils::NO_DUMP; + model.Save(buffer, kDumpLevel != ge::OnnxUtils::DUMP_ALL); + + // Write file + ge::proto::ModelDef ge_proto; + if (buffer.GetData() != nullptr) { + std::string str(reinterpret_cast(buffer.GetData()), buffer.GetSize()); + if (!ge_proto.ParseFromString(str)) { + GELOGE(GRAPH_FAILED, "parse from string failed."); + return; + } + char real_path[PATH_MAX] = {0x00}; + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(proto_file.c_str()) >= PATH_MAX, return, "file path is too longer!"); + GE_IF_BOOL_EXEC(realpath(proto_file.c_str(), real_path) == nullptr, + GELOGI("file %s does not exist, it will be created.", proto_file.c_str())); + + GraphUtils::WriteProtoToTextFile(ge_proto, real_path); + } +#else + GELOGW("need to define FMK_SUPPORT_DUMP for dump graph."); +#endif +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(const char *file, + ge::ComputeGraph &compute_graph) { + ge::proto::ModelDef model_def; + // Get ModelDef object from file generated by DumpGEGraph() + if (!ReadProtoFromTextFile(file, &model_def)) { + GELOGE(GRAPH_FAILED, "Get ModelDef failed from file"); + return false; + } + ge::Model model; + // Get Model object from ModelDef by deserialize ModelDef + if (model.Load(model_def) == GRAPH_SUCCESS) { + GE_CHK_BOOL_EXEC(GraphUtils::GetComputeGraph(model.GetGraph()) != nullptr, return false, + "Get computer graph is nullptr"); + compute_graph = *(GraphUtils::GetComputeGraph(model.GetGraph())); + return true; + } else { + GELOGE(GRAPH_FAILED, "Get Model failed from ModelDef"); + return false; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(const char *file, + ge::ComputeGraphPtr &compute_graph) { + ge::proto::ModelDef model_def; + // Get ModelDef object from file generated by DumpGEGraph() + if (!ReadProtoFromTextFile(file, &model_def)) { + GELOGE(GRAPH_FAILED, "Get ModelDef failed from file"); + return false; + } + ge::Model model; + // Get Model object from ModelDef by deserialize ModelDef + if (model.Load(model_def) == GRAPH_SUCCESS) { + GE_CHK_BOOL_EXEC(GraphUtils::GetComputeGraph(model.GetGraph()) != nullptr, return false, + "Get computer graph is nullptr"); + compute_graph = GraphUtils::GetComputeGraph(model.GetGraph()); + for (const auto &node : compute_graph->GetDirectNode()) { + GELOGI("Node %s set owner graph", node->GetName().c_str()); + GE_CHECK_NOTNULL(node); + if (node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Node %s set owner graph failed", node->GetName().c_str()); + return false; + } + } + return true; + } else { + GELOGE(GRAPH_FAILED, "Get Model failed from ModelDef"); + return false; + } +} + +// Printing protocol messages in text format is useful for debugging and human editing of messages. +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToTextFile( + const google::protobuf::Message &proto, const char *real_path) { +#ifdef FMK_SUPPORT_DUMP + const int FILE_AUTHORITY = 0600; + int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, FILE_AUTHORITY); + if (fd < 0) { + GELOGE(GRAPH_FAILED, "fail to open the file: %s, %s", real_path, strerror(errno)); + return; + } + google::protobuf::io::FileOutputStream *output = new (std::nothrow) FileOutputStream(fd); + if (output == nullptr) { + GELOGE(GRAPH_FAILED, "Output is nullptr"); + if (close(fd) != 0) { + GELOGE(GRAPH_FAILED, "Close fileoutputstream failed"); + } + return; + } + bool ret = google::protobuf::TextFormat::Print(proto, output); + if (!ret) { + GELOGE(GRAPH_FAILED, "Fail to write the file: %s", real_path); + delete output; + output = nullptr; + GE_CHK_BOOL_EXEC(close(fd) == 0, return, "Close fileoutputstream failed"); + return; + } + delete output; + output = nullptr; + GE_CHK_BOOL_EXEC(close(fd) == 0, return, "Close fileoutputstream failed"); + + FILE *file = fopen(real_path, "rb"); + if (file == nullptr) { + return; + } + if (fseek(file, 0L, SEEK_END) == 0) { + long fileSize = ftell(file); + thread_local long max_dump_file_size = 0; + if (max_dump_file_size == 0) { + string opt = "0"; + // Can not check return value + (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt); + max_dump_file_size = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); + } + if (max_dump_file_size != 0 && fileSize != -1 && fileSize > max_dump_file_size) { + GELOGW("dump graph file size > maxDumpFileSize, maxDumpFileSize=%ld.", max_dump_file_size); + GE_IF_BOOL_EXEC(std::remove(real_path) != 0, GELOGW("remove %s failed", real_path)); + GE_CHK_BOOL_EXEC(fclose(file) == 0, return, "Fclose %s failed", real_path); + return; + } + } + GE_CHK_BOOL_EXEC(fclose(file) == 0, return, "Fclose fileoutputstream failed"); +#else + GELOGW("need to define FMK_SUPPORT_DUMP for dump graph."); +#endif +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::ReadProtoFromTextFile( + const char *file, google::protobuf::Message *proto) { + if (file == nullptr || proto == nullptr) { + GELOGE(GRAPH_FAILED, "incorrect parameter. file path or message is invalid"); + return false; + } + std::ifstream fs(file, std::ifstream::in); + if (!fs.is_open()) { + GELOGE(GRAPH_FAILED, "proto file '%s' open fail.", file); + return false; + } + google::protobuf::io::IstreamInputStream input(&fs); + bool ret = google::protobuf::TextFormat::Parse(&input, proto); + if (!ret) { + GELOGE(GRAPH_FAILED, "parse proto from text ret fail, please check your text file '%s'.", file); + } + fs.close(); + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, + const std::string &suffix) { +#ifdef FMK_SUPPORT_DUMP + char *dump_ge_graph = std::getenv(kDumpGeGraph); + int64_t dump_ge_graph_level = + (dump_ge_graph != nullptr) ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) : OnnxUtils::NO_DUMP; + if ((dump_ge_graph_level == OnnxUtils::NO_DUMP) || (dump_ge_graph_level >= OnnxUtils::DUMP_LEVEL_END)) { + GELOGD("Skip DumpGEGraphToOnnx with dump_ge_graph_level %ld.", dump_ge_graph_level); + return; + } + + // dump the graph according to different graph level + if (GraphUtils::MatchDumpStr(suffix)) { + return; + } + + // 1.Get ge::onnx::ModelProto from ge::Model + ge::Model model("GE", ""); + std::shared_ptr compute_graph_ptr = ComGraphMakeShared(compute_graph); + model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast(compute_graph_ptr))); + onnx::ModelProto model_proto; + if (!OnnxUtils::ConvertGeModelToModelProto(model, model_proto)) { + GELOGE(GRAPH_FAILED, "DumpGEGraphToOnnx failed."); + return; + } + + // 2.Set file name + static std::atomic_long atomic_file_index(0); + auto file_index = atomic_file_index.fetch_add(1); + GELOGD("Start to dump ge onnx file: %ld", file_index); + + thread_local long max_dump_file_num = 0; + if (max_dump_file_num == 0) { + string opt = "0"; + (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); + max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); + } + if (max_dump_file_num != 0 && file_index > max_dump_file_num) { + GELOGW("dump graph file cnt > maxDumpFileNum, maxDumpFileNum=%ld.", max_dump_file_num); + return; + } + + std::stringstream stream_file_name; + stream_file_name << "ge_onnx_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; + stream_file_name << "_graph_" << compute_graph.GetGraphID(); + stream_file_name << "_" << suffix << ".pbtxt"; + std::string proto_file = stream_file_name.str(); + if ((proto_file.length()) >= NAME_MAX) { + GELOGE(GRAPH_FAILED, "File name is too longer!"); + return; + } + std::unique_ptr real_path(new (std::nothrow) char[PATH_MAX]{0}); + if (real_path == nullptr) { + GELOGE(GRAPH_FAILED, "New real_path failed."); + return; + } + /// Returning nullptr means 3 case as follows: + /// a.path is PATH_MAX chars or more + /// b.the file does not exist + /// c.the path has no permissions + /// Distinguish between last the two cases in the function WriteProtoToTextFile call open() + if (realpath(proto_file.c_str(), real_path.get()) == nullptr) { + // For case a + if (errno == ENAMETOOLONG) { + GELOGE(GRAPH_FAILED, "Call realpath failed: path is PATH_MAX chars or more."); + return; + } + } + + // 3. Serialize to file in current path + GraphUtils::WriteProtoToTextFile(model_proto, real_path.get()); +#else + GELOGW("need to define FMK_SUPPORT_DUMP for dump graph."); +#endif +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraphFromOnnx(const char *file, + ge::ComputeGraph &compute_graph) { + if (file == nullptr) { + GELOGE(GRAPH_FAILED, "incorrect parameter. file path is invalid"); + return false; + } + onnx::ModelProto model_proto; + // 1. Get ModelDef object from file generated by DumpGEGraphToOnnx() + if (!ReadProtoFromTextFile(file, &model_proto)) { + GELOGE(GRAPH_FAILED, "Get ModelDef from file failed"); + return false; + } + // 2.Convert onnx::ModelProto To ge::Model + ge::Model model; + if (!OnnxUtils::ConvertModelProtoToGeModel(model_proto, model)) { + GELOGE(GRAPH_FAILED, "Convert ModelDef to Model failed"); + return false; + } + auto compute_graph_ptr = GraphUtils::GetComputeGraph(model.GetGraph()); + if (compute_graph_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "Get compute graph from Model failed"); + return false; + } + compute_graph = *(compute_graph_ptr); + return true; +} + +namespace { +using InNodesToOut = std::unordered_map>; + +inline std::string GetNodeNameByAnchor(const Anchor *anchor) { + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Anchor is nullptr"); + return "Null"; + } + auto node = anchor->GetOwnerNode(); + return node == nullptr ? "Null" : node->GetName(); +} + +graphStatus ReplaceOutDataAnchor(const OutDataAnchorPtr &new_anchor, const OutDataAnchorPtr &old_anchor, + InNodesToOut *in_nodes_to_out = nullptr) { + if (new_anchor == nullptr || old_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "new_anchor or old_anchor is nullptr"); + return GRAPH_PARAM_INVALID; + } + auto new_node = new_anchor->GetOwnerNode(); + for (const auto &peer_in_anchor : old_anchor->GetPeerInDataAnchors()) { + auto ret = peer_in_anchor->Unlink(old_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to unlink old anchor link from %s(%d) to %s(%d)", + GetNodeNameByAnchor(old_anchor.get()).c_str(), old_anchor->GetIdx(), + GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx()); + return GRAPH_FAILED; + } + ret = peer_in_anchor->LinkFrom(new_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to relink new anchors from %s(%d) to %s(%d)", + GetNodeNameByAnchor(new_anchor.get()).c_str(), new_anchor->GetIdx(), + GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx()); + return GRAPH_FAILED; + } + + if (in_nodes_to_out != nullptr) { + (*in_nodes_to_out)[new_node].insert(peer_in_anchor->GetOwnerNode()); + } + } + return GRAPH_SUCCESS; +} + +graphStatus RelinkDataIO(const NodePtr &node, const std::vector &io_map, InNodesToOut &in_nodes_to_out) { + GE_CHECK_NOTNULL(node); + auto in_data_anchors = node->GetAllInDataAnchors(); + auto out_data_anchors = node->GetAllOutDataAnchors(); + if (out_data_anchors.size() < io_map.size()) { + GELOGE(GRAPH_FAILED, "The io_map specified for node %s type %s is larger %zu than the actual size %zu", + node->GetName().c_str(), node->GetType().c_str(), io_map.size(), out_data_anchors.size()); + return GRAPH_PARAM_INVALID; + } + + for (size_t i = 0; i < out_data_anchors.size(); ++i) { + auto out_data_anchor = out_data_anchors.at(i); + if (out_data_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to relink for node %s type %s, the out data anchor at index %zu is null", + node->GetName().c_str(), node->GetType().c_str(), i); + return GRAPH_FAILED; + } + + int in_index = -1; + if (i < io_map.size()) { + in_index = io_map.at(i); + } + if (in_index < 0) { + out_data_anchor->UnlinkAll(); + continue; + } + + if (in_index >= static_cast(in_data_anchors.size())) { + GELOGE(GRAPH_PARAM_INVALID, "Failed to relink for node %s type %s, invalid index %d specified for input(%zu)", + node->GetName().c_str(), node->GetType().c_str(), in_index, in_data_anchors.size()); + return GRAPH_PARAM_INVALID; + } + auto in_anchor = in_data_anchors.at(in_index); + if (in_anchor == nullptr) { + GELOGW("Invalid in data anchors(null) found at node %s type %s index %d, ignore it.", node->GetName().c_str(), + node->GetType().c_str(), in_index); + continue; + } + auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + continue; + } + if (peer_out_anchor->Unlink(in_anchor) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, + "Failed relink node %s type %s, failed to unlink the data link" + " from %s(%d) to it at input-index %d", + node->GetName().c_str(), node->GetType().c_str(), GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), + peer_out_anchor->GetIdx(), in_index); + return GRAPH_FAILED; + } + auto ret = ReplaceOutDataAnchor(peer_out_anchor, out_data_anchor, &in_nodes_to_out); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to relink node %s type %s for relinking data anchors", node->GetName().c_str(), + node->GetType().c_str()); + return GRAPH_FAILED; + } + } + + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + in_anchor->UnlinkAll(); + } + return GRAPH_SUCCESS; +} + +InNodesToOut GetFullConnectIONodes(const NodePtr &node) { + InNodesToOut in_nodes_to_out; + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Node is nullptr"); + return in_nodes_to_out; + } + auto in_nodes_list = node->GetInNodes(); + auto out_nodes_list = node->GetOutNodes(); + auto out_nodes = std::unordered_set(out_nodes_list.begin(), out_nodes_list.end()); + + for (const auto &in_node : in_nodes_list) { + in_nodes_to_out.insert(std::make_pair(in_node, out_nodes)); + } + return in_nodes_to_out; +} + +graphStatus RelinkControlNodeIfNeed(const NodePtr &node, InNodesToOut &in_nodes_to_out, + InNodesToOut &connected_data_in_to_out) { + GE_CHECK_NOTNULL(node); + for (const auto &in_node_to_out : in_nodes_to_out) { + auto &in_node = in_node_to_out.first; + GE_CHECK_NOTNULL(in_node); + auto &connected_data_out = connected_data_in_to_out[in_node]; + for (const auto &out_node : in_node_to_out.second) { + GE_CHECK_NOTNULL(out_node); + if (connected_data_out.count(out_node) == 0) { + GE_CHECK_NOTNULL(in_node->GetOutControlAnchor()); + if (in_node->GetOutControlAnchor()->IsLinkedWith(out_node->GetInControlAnchor())) { + continue; + } + auto ret = GraphUtils::AddEdge(in_node->GetOutControlAnchor(), out_node->GetInControlAnchor()); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to add control edge from %s to %s when isolating node %s type %s", + in_node->GetName().c_str(), out_node->GetName().c_str(), node->GetName().c_str(), + node->GetType().c_str()); + return GRAPH_FAILED; + } + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus ReplaceOutDataAnchors(const Node::Vistor &new_outs, + const Node::Vistor &old_outs, const std::vector &outputs_map) { + auto new_out_size = new_outs.size(); + if (new_out_size < outputs_map.size()) { + GELOGE(GRAPH_PARAM_INVALID, + "Failed to replace out data anchors, the actual size %zu is less than the mapping size %zu", new_out_size, + outputs_map.size()); + return GRAPH_PARAM_INVALID; + } + for (size_t i = 0; i < new_out_size; ++i) { + auto &new_out_anchor = new_outs.at(i); + if (new_out_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to replace out data anchors, the out data anchor on new node is null, index %zu", i); + return GRAPH_FAILED; + } + if (i >= outputs_map.size()) { + continue; + } + auto old_index = outputs_map.at(i); + if (old_index < 0) { + continue; + } + + const OutDataAnchorPtr &old_out_anchor = old_outs.at(old_index); + if (old_out_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to replace out data anchors, the out data anchor on old node is null, index %d", + old_index); + return GRAPH_FAILED; + } + auto ret = ReplaceOutDataAnchor(new_out_anchor, old_out_anchor); + if (ret != GRAPH_SUCCESS) { + return ret; + } + } + + return GRAPH_SUCCESS; +} + +graphStatus ReplaceInDataAnchors(const Node::Vistor &new_ins, + const Node::Vistor &old_ins, const std::vector &inputs_map) { + auto new_in_size = new_ins.size(); + if (new_in_size < inputs_map.size()) { + GELOGE(GRAPH_FAILED, "Failed to replace in data anchors, the actual size %zu is less than the mapping size %zu", + new_in_size, inputs_map.size()); + return GRAPH_PARAM_INVALID; + } + + for (size_t i = 0; i < new_in_size; ++i) { + auto &new_in_anchor = new_ins.at(i); + if (new_in_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to replace in data anchors, the out data anchor on new node is null, index %zu", i); + return GRAPH_FAILED; + } + if (i >= inputs_map.size()) { + continue; + } + auto old_index = inputs_map.at(i); + if (old_index < 0) { + continue; + } + const InDataAnchorPtr &old_in_anchor = old_ins.at(old_index); + if (old_in_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to replace in data anchors, the out data anchor on old node is null, index %d", + old_index); + return GRAPH_FAILED; + } + + auto peer_out_anchor = old_in_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + GELOGW("Peer out anchor is nullptr"); + continue; + } + auto ret = peer_out_anchor->Unlink(old_in_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to unlink old anchors, unlink from %s(%d) to %s(%d)", + GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), peer_out_anchor->GetIdx(), + GetNodeNameByAnchor(old_in_anchor.get()).c_str(), old_in_anchor->GetIdx()); + return GRAPH_FAILED; + } + ret = peer_out_anchor->LinkTo(new_in_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to link new anchors, link from %s(%d) to %s(%d)", + GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), peer_out_anchor->GetIdx(), + GetNodeNameByAnchor(old_in_anchor.get()).c_str(), old_in_anchor->GetIdx()); + return GRAPH_FAILED; + } + } + return GRAPH_SUCCESS; +} + +graphStatus ReplaceControlAnchors(const NodePtr &new_node, const NodePtr &old_node) { + GE_CHECK_NOTNULL(new_node); + GE_CHECK_NOTNULL(new_node->GetInControlAnchor()); + GE_CHECK_NOTNULL(old_node); + GE_CHECK_NOTNULL(old_node->GetInControlAnchor()); + auto peer_out_anchors = old_node->GetInControlAnchor()->GetPeerAnchors(); + auto new_in_control_anchor = new_node->GetInControlAnchor(); + auto exists_out_anchors = new_in_control_anchor->GetPeerAnchors(); + auto exists_out_anchors_set = std::set(exists_out_anchors.begin(), exists_out_anchors.end()); + for (const auto &peer_out_anchor : peer_out_anchors) { + if (peer_out_anchor != nullptr) { + if (exists_out_anchors_set.count(peer_out_anchor) > 0) { + continue; + } + auto ret = GraphUtils::AddEdge(peer_out_anchor, new_in_control_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add edge failed"); + return GRAPH_FAILED; + } + } else { + GELOGW("peer outanchor is nullptr"); + continue; + } + } + auto old_out_control_anchor = old_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(old_out_control_anchor); + auto peer_in_anchors = old_out_control_anchor->GetPeerAnchors(); + auto new_out_control_anchor = new_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(new_out_control_anchor); + auto exists_in_anchors = new_out_control_anchor->GetPeerAnchors(); + auto exists_in_anchors_set = std::set(exists_in_anchors.begin(), exists_in_anchors.end()); + for (const auto &peer_in_anchor : peer_in_anchors) { + if (peer_in_anchor != nullptr) { + if (exists_in_anchors_set.count(peer_in_anchor) > 0) { + continue; + } + auto ret = GraphUtils::AddEdge(new_out_control_anchor, peer_in_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add edge failed"); + return GRAPH_FAILED; + } + } else { + GELOGW("Peer inanchor is nullptr"); + continue; + } + } + + return GRAPH_SUCCESS; +} +} // namespace + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::IsolateNode(const NodePtr &node, + const std::vector &io_map) { + if (node == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "Failed to isolate node(null)"); + return GRAPH_PARAM_INVALID; + } + + /// We must get full connections info before re-link data io, because the data + /// edges may be unlinked when relink data io + auto in_nodes_to_out = GetFullConnectIONodes(node); + + InNodesToOut data_in_to_out; + auto ret = RelinkDataIO(node, io_map, data_in_to_out); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to isolate node %s type %s when relink data IO", node->GetName().c_str(), + node->GetType().c_str()); + return ret; + } + + ret = RelinkControlNodeIfNeed(node, in_nodes_to_out, data_in_to_out); + if (ret != GRAPH_SUCCESS) { + return ret; + } + NodeUtils::UnlinkAll(*node); + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::IsolateNode(const NodePtr &node, const std::initializer_list &io_map) { + return IsolateNode(node, std::vector(io_map)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::IsolateNodeOneIO(const NodePtr &node) { + if (node == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "incorrect parameter. node is invalid"); + return GRAPH_PARAM_INVALID; + } + if (node->GetAllInDataAnchorsSize() != 1) { + return GRAPH_PARAM_INVALID; + } + if (node->GetAllOutDataAnchorsSize() != 1) { + return GRAPH_PARAM_INVALID; + } + return IsolateNode(node, {0}); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, const std::vector &inputs_map, + const std::vector &outputs_map) { + if ((new_node == nullptr) || (old_node == nullptr)) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_PARAM_INVALID; + } + auto ret = ReplaceNodeDataAnchors(new_node, old_node, inputs_map, outputs_map); + if (ret != GRAPH_SUCCESS) { + // The error log was printed in `ReplaceNodeDataAnchors` + return GRAPH_FAILED; + } + ret = ReplaceControlAnchors(new_node, old_node); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, + "Failed to replace control anchors when replace node from old node %s type %s to new node %s type %s", + old_node->GetName().c_str(), old_node->GetType().c_str(), new_node->GetName().c_str(), + new_node->GetType().c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::ReplaceNodeAnchors( + const NodePtr &new_node, const NodePtr &old_node, const std::initializer_list inputs_map, + const std::initializer_list outputs_map) { + return ReplaceNodeAnchors(new_node, old_node, std::vector(inputs_map), std::vector(outputs_map)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, + std::initializer_list inputs_map, std::initializer_list outputs_map) { + return ReplaceNodeDataAnchors(new_node, old_node, std::vector(inputs_map), std::vector(outputs_map)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, const std::vector &inputs_map, + const std::vector &outputs_map) { + if (new_node == nullptr || old_node == nullptr) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_PARAM_INVALID; + } + + auto ret = ReplaceOutDataAnchors(new_node->GetAllOutDataAnchors(), old_node->GetAllOutDataAnchors(), outputs_map); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, + "Failed to replace out data anchors when replace node from old node %s type %s to new node %s type %s", + old_node->GetName().c_str(), old_node->GetType().c_str(), new_node->GetName().c_str(), + new_node->GetType().c_str()); + return GRAPH_FAILED; + } + ret = ReplaceInDataAnchors(new_node->GetAllInDataAnchors(), old_node->GetAllInDataAnchors(), inputs_map); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, + "Failed to replace in data anchors when replace node from old node %s type %s to new node %s type %s", + old_node->GetName().c_str(), old_node->GetType().c_str(), new_node->GetName().c_str(), + new_node->GetType().c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInCtrlEdges(const NodePtr &src_node, + NodePtr &dst_node) { + if ((src_node == nullptr) || (dst_node == nullptr)) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_PARAM_INVALID; + } + auto src_ctrl_in_nodes = src_node->GetInControlNodes(); + if (src_ctrl_in_nodes.empty()) { + return GRAPH_SUCCESS; + } + + std::unordered_set exist_in_ctrl_nodes_set; + auto exist_in_ctrl_nodes = dst_node->GetInControlNodes(); + if (!exist_in_ctrl_nodes.empty()) { + exist_in_ctrl_nodes_set.insert(exist_in_ctrl_nodes.begin(), exist_in_ctrl_nodes.end()); + } + + auto dst_ctrl = dst_node->GetInControlAnchor(); + for (const auto &in_node : src_ctrl_in_nodes) { + if (exist_in_ctrl_nodes_set.count(in_node) > 0) { + continue; + } + auto ret = GraphUtils::AddEdge(in_node->GetOutControlAnchor(), dst_ctrl); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to add control edge from %s to %s when copy control dependencies from %s to %s", + in_node->GetName().c_str(), dst_node->GetName().c_str(), src_node->GetName().c_str(), + dst_node->GetName().c_str()); + return ret; + } + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveInCtrlEdges(const NodePtr &src_node, + NodePtr &dst_node) { + if (src_node == nullptr || dst_node == nullptr) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_FAILED; + } + auto ret = CopyInCtrlEdges(src_node, dst_node); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Copy in ctrl edges failed"); + return ret; + } + GE_CHECK_NOTNULL(src_node->GetInControlAnchor()); + src_node->GetInControlAnchor()->UnlinkAll(); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyOutCtrlEdges(const NodePtr &src_node, + NodePtr &dst_node) { + if (src_node == nullptr || dst_node == nullptr) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_FAILED; + } + auto out_ctrl_nodes = src_node->GetOutControlNodes(); + if (out_ctrl_nodes.empty()) { + return GRAPH_SUCCESS; + } + + std::unordered_set exists_out_ctrl_nodes_set; + for (const auto &node : dst_node->GetOutControlNodes()) { + exists_out_ctrl_nodes_set.insert(node.get()); + } + + auto dst_out_ctrl = dst_node->GetOutControlAnchor(); + for (const auto &node : out_ctrl_nodes) { + if (exists_out_ctrl_nodes_set.count(node.get()) > 0) { + continue; + } + auto ret = GraphUtils::AddEdge(dst_out_ctrl, node->GetInControlAnchor()); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to add control edge from %s to %s when copy control dependencies from %s to %s", + dst_node->GetName().c_str(), node->GetName().c_str(), src_node->GetName().c_str(), + dst_node->GetName().c_str()); + return ret; + } + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveOutCtrlEdges(NodePtr &src_node, + NodePtr &dst_node) { + if (src_node == nullptr || dst_node == nullptr) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_FAILED; + } + auto ret = CopyOutCtrlEdges(src_node, dst_node); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Copyout ctrl edges failed"); + return ret; + } + GE_CHECK_NOTNULL(src_node->GetOutControlAnchor()); + src_node->GetOutControlAnchor()->UnlinkAll(); + return GRAPH_SUCCESS; +} + +/// +/// Copy all in-data edges from `src_node` to `dst_node`. +/// @param src_node +/// @param dst_node +/// @return +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInDataEdges(const NodePtr &src_node, + NodePtr &dst_node) { + if ((src_node == nullptr) || (dst_node == nullptr)) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_PARAM_INVALID; + } + auto src_data_in_nodes = src_node->GetInDataNodes(); + if (src_data_in_nodes.empty()) { + return GRAPH_SUCCESS; + } + for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) { + auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); + auto ret = + GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx())); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s", + in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(), + src_node->GetName().c_str(), dst_node->GetName().c_str()); + return ret; + } + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AppendInputNode(const ComputeGraphPtr &graph, + const NodePtr &node) { + if (graph->AddInputNode(node) == nullptr) { + GELOGE(GRAPH_FAILED, "Copyout ctrl edges failed"); + return GRAPH_FAILED; + } + graph->SetInputSize(graph->GetInputSize() + 1); + graph->inputs_order_.emplace_back(node->GetName()); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::FindRootGraph(ComputeGraphPtr graph) { + ComputeGraphPtr result = nullptr; + while (graph != nullptr) { + result = std::move(graph); + graph = result->GetParentGraph(); + } + return result; +} + +/// +/// Make a copy of ComputeGraph. +/// @param graph: original graph. +/// @param prefix: node name prefix of new graph. +/// @param output_nodes: output nodes of new graph. +/// @return ComputeGraphPtr +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr +GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std::string &prefix, std::vector &input_nodes, + std::vector &output_nodes) { + GE_CHK_BOOL_EXEC(graph != nullptr, return nullptr, "Original graph is null"); + ComputeGraphPtr new_graph = ComGraphMakeShared(graph->GetName()); + GE_CHK_BOOL_EXEC(new_graph != nullptr, return nullptr, "Create new graph failed"); + + std::unordered_map all_new_nodes; + for (const auto &n : graph->GetDirectNode()) { + OpDescPtr op_desc = AttrUtils::CopyOpDesc(n->GetOpDesc()); + GE_CHK_BOOL_EXEC(op_desc != nullptr, return nullptr, "Create new node failed"); + + if (CopyTensorAttrs(op_desc, n) != GRAPH_SUCCESS) { + return nullptr; + } + + op_desc->SetName(prefix + n->GetName()); + NodePtr node = new_graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(node != nullptr, return nullptr, "Add node[%s] to graph failed", op_desc->GetName().c_str()); + all_new_nodes[node->GetName()] = node; + + if (node->GetType() == DATA) { + input_nodes.emplace_back(node); + } else if (node->GetType() == NETOUTPUT) { + output_nodes.emplace_back(node); + } + } + + for (const auto &n : graph->GetDirectNode()) { + if (RelinkGraphEdges(n, prefix, all_new_nodes) != GRAPH_SUCCESS) { + return nullptr; + } + } + + std::string session_graph_id; + if (AttrUtils::GetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) { + bool ret = AttrUtils::SetStr(*new_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); + if (!ret) { + GELOGE(GRAPH_FAILED, "Set attr ATTR_NAME_SESSION_GRAPH_ID failed."); + return nullptr; + } + } + return new_graph; +} + +/// +/// Copy tensor attribute to new node. +/// @param [in] dst_node: cloned node. +/// @param [in] src_node: original node. +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node) { + if (dst_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Input param dst node not valid"); + return GRAPH_FAILED; + } + if (src_node == nullptr || src_node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "Input param src node not valid"); + return GRAPH_FAILED; + } + + const auto &src_desc = src_node->GetOpDesc(); + dst_desc->CopyAttrsFrom(*src_desc); + + for (uint32_t i = 0; i < src_node->GetAllInDataAnchorsSize(); ++i) { + auto input_desc = dst_desc->MutableInputDesc(i); + if (input_desc == nullptr) { + continue; + } + input_desc->CopyAttrsFrom(src_desc->GetInputDesc(i)); + } + + for (uint32_t i = 0; i < src_node->GetAllOutDataAnchorsSize(); ++i) { + auto output_desc = dst_desc->MutableOutputDesc(i); + if (output_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Param dst node not valid"); + return GRAPH_FAILED; + } + output_desc->CopyAttrsFrom(src_desc->GetOutputDesc(i)); + } + + return GRAPH_SUCCESS; +} + +/// +/// Relink all edges for cloned ComputeGraph. +/// @param [in] node: original node. +/// @param [in] prefix: node name prefix of new node. +/// @param [in] all_nodes: all nodes in new graph. +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &prefix, + const std::unordered_map &all_nodes) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "Input node not valid"); + return GRAPH_FAILED; + } + + auto it = all_nodes.find(prefix + node->GetName()); + if (it == all_nodes.end()) { + GELOGE(GRAPH_FAILED, "node[%s] not found", node->GetName().c_str()); + return GRAPH_FAILED; + } + const auto &new_node = it->second; + + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + GE_CHK_BOOL_EXEC(in_anchor != nullptr, return GRAPH_FAILED, "In data anchor is null"); + const auto &out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + GELOGW("Peer out anchor is null: %s", node->GetName().c_str()); + continue; + } + GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); + + it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName()); + if (it == all_nodes.end()) { + GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); + return GRAPH_FAILED; + } + const auto &new_out_node = it->second; + + auto rslt = + GraphUtils::AddEdge(new_out_node->GetOutAnchor(out_anchor->GetIdx()), new_node->GetInAnchor(in_anchor->GetIdx())); + GE_CHK_BOOL_EXEC(rslt == GRAPH_SUCCESS, return GRAPH_FAILED, "link failed[%s to %s]", + new_out_node->GetName().c_str(), new_node->GetName().c_str()); + } + + if (node->GetInControlAnchor() != nullptr) { + for (const auto &out_anchor : node->GetInControlAnchor()->GetPeerAnchors()) { + GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "Peer out anchor is null: %s", node->GetName().c_str()); + GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); + + it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName()); + if (it == all_nodes.end()) { + GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); + return GRAPH_FAILED; + } + const auto &new_out_node = it->second; + + auto rslt = GraphUtils::AddEdge(new_out_node->GetOutAnchor(out_anchor->GetIdx()), new_node->GetInControlAnchor()); + GE_CHK_BOOL_EXEC(rslt == GRAPH_SUCCESS, return GRAPH_FAILED, "link failed[%s to %s]", + new_out_node->GetName().c_str(), new_node->GetName().c_str()); + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Get reference-mapping of all data_anchors in graph +/// @param [in] graph +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(graph); + for (const auto &node : graph->GetAllNodes()) { + // in_data_anchor + if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + + // out_data_anchor + if (HandleOutAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Find ref_mapping for out_data_anchors of node %s failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr GraphUtils::FindNodeFromAllNodes(ComputeGraphPtr &graph, + const std::string &name) { + auto root_graph = FindRootGraph(graph); + if (root_graph == nullptr) { + GE_LOGE("Failed find node %s, null root graph", name.c_str()); + return nullptr; + } + + for (const auto &node : root_graph->GetAllNodes()) { + if (node == nullptr) { + continue; + } + if (node->GetName() == name) { + return node; + } + } + + return nullptr; +} + +/// +/// Get reference-mapping for in_data_anchors of node +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + + if (NodeUtils::IsSubgraphOutput(node)) { + return HandleSubgraphOutput(node, symbol_to_anchors, anchor_to_symbol); + } + + if (NodeUtils::IsSubgraphInput(node)) { + return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); + } + + const std::string &type = node->GetType(); + if ((type == MERGE) || (type == STREAMMERGE)) { + return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); + } + + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + const std::string &symbol = cur_node_info.ToString(); + GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); + symbol_to_anchors[symbol] = {cur_node_info}; + anchor_to_symbol[symbol] = symbol; + } else { + NodeIndexIO exist_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); + if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Update symbol mapping failed."); + return GRAPH_FAILED; + } + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Get reference-mapping for out_data_anchors of node +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + NodeIndexIO cur_node_info(node, out_data_anchor->GetIdx(), kOut); + if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { + continue; + } + + int32_t reuse_in_index = -1; + if (IsRefFromInput(out_data_anchor, reuse_in_index)) { + NodeIndexIO exist_node_info(node, reuse_in_index, kIn); + if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Update symbol mapping failed."); + return GRAPH_FAILED; + } + } else { + const std::string &symbol = cur_node_info.ToString(); + GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); + symbol_to_anchors.emplace(std::make_pair(symbol, std::list{cur_node_info})); + anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Handle input of subgraph +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleSubgraphInput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + + // Data in subgraph + uint32_t index = 0; + if (!ge::AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index)) { + GE_LOGE("Get attr ATTR_NAME_PARENT_NODE_INDEX failed, node:%s.", node->GetName().c_str()); + return GRAPH_FAILED; + } + NodePtr parent_node = node->GetOwnerComputeGraph()->GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + InDataAnchorPtr parent_in_anchor = parent_node->GetInDataAnchor(index); + GE_CHECK_NOTNULL(parent_in_anchor); + OutDataAnchorPtr peer_out_anchor = parent_in_anchor->GetPeerOutAnchor(); + if (peer_out_anchor != nullptr) { + // Data has and only has one input + NodeIndexIO cur_node_info(node, 0, kIn); + NodeIndexIO exist_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); + if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Update symbol mapping failed."); + return GRAPH_FAILED; + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Handle input of Merge op +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + std::vector exist_node_infos; + std::vector cur_node_infos; + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + std::string next_name; + if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next_name) && !next_name.empty()) { + ComputeGraphPtr graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + ge::NodePtr next_node = graph->FindNode(next_name); + GE_CHECK_NOTNULL(next_node); + // NextIteration has and only has one output + peer_out_anchor = next_node->GetOutDataAnchor(0); + GE_CHECK_NOTNULL(peer_out_anchor); + cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); + cur_node_infos.emplace_back(NodeIndexIO(next_node, peer_out_anchor->GetIdx(), kOut)); + } + } else { + cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); + exist_node_infos.emplace_back(NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut)); + } + } + + size_t anchor_nums = 0; + NodeIndexIO max_node_index_io(nullptr, 0, kOut); + for (const auto &temp_node_info : exist_node_infos) { + auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); + if (iter1 != anchor_to_symbol.end()) { + const std::string &temp_symbol = iter1->second; + auto iter2 = symbol_to_anchors.find(temp_symbol); + if (iter2 != symbol_to_anchors.end()) { + if (iter2->second.size() > anchor_nums) { + max_node_index_io = temp_node_info; + anchor_nums = iter2->second.size(); + } + } + } + } + + std::string symbol; + for (const auto &temp_node_info : exist_node_infos) { + if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != + GRAPH_SUCCESS) || + symbol.empty()) { + GE_LOGE("Union symbol map anchor1:%s & anchor2:%s.", max_node_index_io.ToString().c_str(), + temp_node_info.ToString().c_str()); + return GRAPH_FAILED; + } + } + + auto iter = symbol_to_anchors.find(symbol); + if (iter != symbol_to_anchors.end()) { + for (const auto &temp_node_info : cur_node_infos) { + GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); + iter->second.emplace_back(temp_node_info); + anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Handle output of subgraph +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + ComputeGraphPtr owner_graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(owner_graph); + NodePtr parent_node = owner_graph->GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + + GeTensorDesc in_tensor = op_desc->GetInputDesc(in_data_anchor->GetIdx()); + uint32_t index = 0; + if (!ge::AttrUtils::GetInt(in_tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { + continue; + } + GE_CHECK_NOTNULL(parent_node->GetOutDataAnchor(index)); + // Union symbol of peer_out_anchor & parent_out_anchor + NodeIndexIO peer_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); + NodeIndexIO parent_node_info(parent_node, index, kOut); + std::string symbol; + if ((UnionSymbolMapping(peer_node_info, parent_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != + GRAPH_SUCCESS) || + symbol.empty()) { + GE_LOGE("Union symbol map anchor1:%s, anchor2:%s.", peer_node_info.ToString().c_str(), + parent_node_info.ToString().c_str()); + return GRAPH_FAILED; + } + + NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); + GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); + symbol_to_anchors[symbol].emplace_back(cur_node_info); + anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); + } + + return GRAPH_SUCCESS; +} + +/// +/// Union ref-mapping +/// @param [in] exist_node_info1 +/// @param [in] exist_node_info2 +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @param [out] symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol, std::string &symbol) { + const std::string &symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; + const std::string &symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; + if (symbol1 == symbol2) { + symbol = symbol1; + GELOGI("no need to union."); + return GRAPH_SUCCESS; + } + + auto iter1 = symbol_to_anchors.find(symbol1); + auto iter2 = symbol_to_anchors.find(symbol2); + if ((iter1 == symbol_to_anchors.end()) || (iter2 == symbol_to_anchors.end())) { + GE_LOGE("symbol %s or %s not exist.", symbol1.c_str(), symbol2.c_str()); + return GRAPH_FAILED; + } + + auto &max_iter = (iter1->second.size() > iter2->second.size() ? iter1 : iter2); + auto &min_iter = (iter1->second.size() > iter2->second.size() ? iter2 : iter1); + symbol = (iter1->second.size() > iter2->second.size() ? symbol1 : symbol2); + std::string min_symbol = (iter1->second.size() > iter2->second.size() ? symbol2 : symbol1); + for (auto &node_index_io : min_iter->second) { + GELOGD("Update anchor %s, symbol %s.", node_index_io.ToString().c_str(), symbol.c_str()); + max_iter->second.emplace_back(node_index_io); + auto iter = anchor_to_symbol.find(node_index_io.ToString()); + if (iter == anchor_to_symbol.end()) { + GE_LOGE("anchor %s not exist.", node_index_io.ToString().c_str()); + return GRAPH_FAILED; + } + if (iter->second != min_symbol) { + GELOGW("not expected symbol of anchor %s, expect %s but %s exactly.", iter->first.c_str(), min_symbol.c_str(), + iter->second.c_str()); + } + iter->second = symbol; + } + + GELOGI("Union symbol %s and %s succ.", symbol.c_str(), min_symbol.c_str()); + symbol_to_anchors.erase(min_iter); + return GRAPH_SUCCESS; +} + +/// +/// Update symbol mapping with a new reference pair +/// @param [in] cur_node_info +/// @param [in] exist_node_info +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + auto iter1 = anchor_to_symbol.find(exist_node_info.ToString()); + if (iter1 == anchor_to_symbol.end()) { + GE_LOGE("data_anchor %s is not visible before data_anchor %s, maybe TopoSorting is missing.", + exist_node_info.ToString().c_str(), cur_node_info.ToString().c_str()); + return GRAPH_FAILED; + } + + const std::string &symbol = iter1->second; + auto iter2 = symbol_to_anchors.find(symbol); + if (iter2 == symbol_to_anchors.end()) { + GE_LOGE("symbol %s not found.", symbol.c_str()); + return GRAPH_FAILED; + } + GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); + iter2->second.emplace_back(cur_node_info); + anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); + + return GRAPH_SUCCESS; +} + +/// +/// Check if out_data_anchor is reference of input +/// @param [in] out_data_anchor +/// @param [out] reuse_in_index +/// @return bool +/// +bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index) { + if (out_data_anchor == nullptr) { + GELOGW("out_data_anchor is NULL."); + return false; + } + int32_t output_index = out_data_anchor->GetIdx(); + + // pass-through op + NodePtr node = out_data_anchor->GetOwnerNode(); + const std::string &type = node->GetType(); + const std::set pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE}; + if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) { + reuse_in_index = output_index; + GELOGI("Pass-Through node name[%s] index[%u].", node->GetName().c_str(), reuse_in_index); + return true; + } + + // Merge op 0th output + if ((type == MERGE) && (output_index == 0)) { + reuse_in_index = 0; + GELOGI("Merge name[%s] output_index[0].", node->GetName().c_str()); + return true; + } + + // ref op + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + GELOGW("op_desc is NULL."); + return false; + } + bool is_ref = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_REFERENCE, is_ref); + if (is_ref) { + const string &output_name = op_desc->GetOutputNameByIndex(output_index); + for (const auto &input_name : op_desc->GetAllInputNames()) { + if (!input_name.empty() && (output_name == input_name)) { + reuse_in_index = op_desc->GetInputIndexByName(input_name); + GELOGI("Reference name[%s] output[%s][%d] ref to input[%s][%d].", op_desc->GetName().c_str(), + output_name.c_str(), output_index, input_name.c_str(), reuse_in_index); + return true; + } + } + } + + // reuse input + auto output_op_desc = op_desc->GetOutputDescPtr(output_index); + bool reuse_input = false; + if (output_op_desc != nullptr) { + if ((TensorUtils::GetReuseInput(*output_op_desc, reuse_input) == GRAPH_SUCCESS) && reuse_input) { + uint32_t reuse_input_index = 0; + if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { + reuse_in_index = static_cast(reuse_input_index); + GELOGI("ReuseInput name[%s] output[%d] reuse input[%d].", op_desc->GetName().c_str(), output_index, + reuse_in_index); + return true; + } + } + } + + return false; +} + +/// +/// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs +/// of the graph have UNKNOWN_SHAPE operators or not. +/// Note: This function will only look 'down' from the graph, not 'up'. For example, the following +/// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE +/// ROOT graph: A -----> B -----> C +/// K subgraph U +/// | +/// V +/// SUB graph: D --> E --> F +/// K K K +/// @param [in] graph +/// @return bool +/// +bool GraphUtils::IsUnknownShapeGraph(const ComputeGraphPtr &graph) { + if (graph == nullptr) { + GELOGW("Input graph is nullptr."); + return false; + } + for (const auto &node : graph->GetDirectNode()) { + bool is_unknown = false; + auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); + if (ret != GRAPH_SUCCESS) { + GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(), + node->GetType().c_str()); + continue; + } + if (is_unknown) { + GELOGD("Node %s, type %s is unknown shape in graph %s.", node->GetName().c_str(), node->GetType().c_str(), + graph->GetName().c_str()); + return true; + } + } + GELOGD("Graph %s does not have unknown shape node.", graph->GetName().c_str()); + return false; +} + +/// +/// @brief Add node to graph +/// @param [in] op_desc +/// @return ComputeGraphBuilder +/// +ComputeGraphBuilder &ComputeGraphBuilder::AddNode(const OpDescPtr &op_desc) { + nodes_.emplace_back(op_desc); + return *this; +} + +/// +/// @brief Add data-link among nodes in graph +/// @param [in] src_name +/// @param [in] out_anchor_ind +/// @param [in] dst_name +/// @param [in] in_anchor_ind +/// @return ComputeGraphBuilder +/// +ComputeGraphBuilder &ComputeGraphBuilder::AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, + const std::string &dst_name, uint32_t in_anchor_ind) { + data_links_.emplace_back( + std::make_pair(std::make_pair(src_name, out_anchor_ind), std::make_pair(dst_name, in_anchor_ind))); + return *this; +} + +/// +/// @brief Add ctrl-link among nodes in graph +/// @param [in] src_name +/// @param [in] dst_name +/// @return ComputeGraphBuilder +/// +ComputeGraphBuilder &ComputeGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { + ctrl_links_.emplace_back(std::make_pair(src_name, dst_name)); + return *this; +} + +/// +/// @brief Build nodes +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void ComputeGraphBuilder::BuildNodes(graphStatus &error_code, std::string &error_msg) { + if (owner_graph_ == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "graph is NULL."; + return; + } + + std::string node_name; + for (auto &op_desc : nodes_) { + if (op_desc == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "op_desc is NULL."; + return; + } + + node_name = op_desc->GetName(); + NodePtr node = owner_graph_->AddNode(op_desc); + if (node == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "Add node " + node_name + " failed."; + return; + } + + GELOGD("Add node name:%s, type:%s.", node_name.c_str(), op_desc->GetType().c_str()); + node_names_[node_name] = node; + } + + GELOGD("BuildNodes succ."); +} + +/// +/// @brief Build data-links +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void ComputeGraphBuilder::BuildDataLinks(graphStatus &error_code, std::string &error_msg) { + for (auto &pair : data_links_) { + std::string src_name = pair.first.first; + uint32_t out_ind = pair.first.second; + std::string dst_name = pair.second.first; + uint32_t in_ind = pair.second.second; + std::string log_msg = "Add data-edge "; + log_msg.append(src_name) + .append(":") + .append(std::to_string(out_ind)) + .append("->") + .append(dst_name) + .append(":") + .append(std::to_string(in_ind)); + + auto src_iter = node_names_.find(src_name); + auto dst_iter = node_names_.find(dst_name); + if ((src_iter == node_names_.end()) || (dst_iter == node_names_.end())) { + error_code = GRAPH_FAILED; + error_msg = log_msg + " failed: node not exist in graph."; + return; + } + + NodePtr src_node = node_names_[src_name]; + NodePtr dst_node = node_names_[dst_name]; + if ((src_node == nullptr) || (dst_node == nullptr)) { + error_code = GRAPH_FAILED; + error_msg = log_msg + " failed: node is NULL."; + return; + } + + if (GraphUtils::AddEdge(src_node->GetOutDataAnchor(out_ind), dst_node->GetInDataAnchor(in_ind)) != GRAPH_SUCCESS) { + error_code = GRAPH_FAILED; + error_msg = log_msg + " failed."; + return; + } + + GELOGD("%s succ.", log_msg.c_str()); + } + + GELOGD("BuildDataLinks succ."); +} + +/// +/// @brief Build ctrl-links +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void ComputeGraphBuilder::BuildCtrlLinks(graphStatus &error_code, std::string &error_msg) { + for (auto &pair : ctrl_links_) { + std::string src_name = pair.first; + std::string dst_name = pair.second; + std::string log_msg = "Add ctrl-edge "; + log_msg.append(src_name).append("->").append(dst_name); + + auto src_iter = node_names_.find(src_name); + auto dst_iter = node_names_.find(dst_name); + if ((src_iter == node_names_.end()) || (dst_iter == node_names_.end())) { + error_code = GRAPH_FAILED; + error_msg = log_msg + " failed: node not exist in graph."; + return; + } + + NodePtr src_node = node_names_[src_name]; + NodePtr dst_node = node_names_[dst_name]; + if ((src_node == nullptr) || (dst_node == nullptr)) { + error_code = GRAPH_FAILED; + error_msg = log_msg + " failed: node is NULL."; + return; + } + + if (GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()) != GRAPH_SUCCESS) { + error_code = GRAPH_FAILED; + error_msg = log_msg + " failed."; + return; + } + + GELOGD("%s succ.", log_msg.c_str()); + } + + GELOGD("BuildCtrlLinks succ."); +} + +/// @brief Get node with name +/// @param [in] name +/// @return NodePtr +/// +NodePtr ComputeGraphBuilder::GetNode(const std::string &name) { + auto iter = node_names_.find(name); + if (iter == node_names_.end()) { + GE_LOGE("node %s not exist.", name.c_str()); + return nullptr; + } + return iter->second; +} + +/// @brief Get all nodes +/// @return std::vector +/// +std::vector ComputeGraphBuilder::GetAllNodes() { + std::vector nodes; + for (const auto &iter : node_names_) { + nodes.emplace_back(iter.second); + } + return nodes; +} + +/// +/// @brief Add node to graph +/// @param [in] op_desc +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder &CompleteGraphBuilder::AddNode(const OpDescPtr &op_desc) { + ComputeGraphBuilder::AddNode(op_desc); + return *this; +} + +/// +/// @brief Add data-link among nodes in graph +/// @param [in] src_name +/// @param [in] out_anchor_ind +/// @param [in] dst_name +/// @param [in] in_anchor_ind +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder &CompleteGraphBuilder::AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, + const std::string &dst_name, uint32_t in_anchor_ind) { + ComputeGraphBuilder::AddDataLink(src_name, out_anchor_ind, dst_name, in_anchor_ind); + return *this; +} + +/// +/// @brief Add ctrl-link among nodes in graph +/// @param [in] src_name +/// @param [in] dst_name +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder &CompleteGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { + ComputeGraphBuilder::AddControlLink(src_name, dst_name); + return *this; +} + +/// +/// @brief Set index_th input anchor for graph +/// @param [in] index +/// @param [in] node_names +/// @param [in] anchor_inds +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder &CompleteGraphBuilder::SetInput(uint32_t index, const std::vector &node_names, + const std::vector &anchor_inds) { + graph_inputs_[index] = std::make_pair(node_names, anchor_inds); + return *this; +} + +/// +/// @brief Set index_th input of graph as useless +/// @param [in] index +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder &CompleteGraphBuilder::SetUselessInput(uint32_t index) { + graph_inputs_[index] = std::make_pair(std::vector(), std::vector()); + return *this; +} + +/// +/// @brief Add output anchor for graph +/// @param [in] owner_node_name +/// @param [in] anchor_ind +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder &CompleteGraphBuilder::AddOutput(const std::string &owner_node_name, uint32_t anchor_ind) { + graph_outputs_.emplace_back(std::make_pair(owner_node_name, anchor_ind)); + return *this; +} + +/// +/// @brief Add target for graph +/// @param [in] target_name +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder &CompleteGraphBuilder::AddTarget(const std::string &target_name) { + graph_targets_.emplace_back(target_name); + return *this; +} + +/// +/// @brief Set parent-node of graph +/// @param [in] parent_node +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder &CompleteGraphBuilder::SetParentNode(const NodePtr &parent_node) { + parent_node_ = parent_node; + return *this; +} + +/// +/// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node +/// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder &CompleteGraphBuilder::SetInputMapping(const std::map &input_mapping) { + for (auto &item : input_mapping) { + input_mapping_[item.first] = item.second; + } + return *this; +} + +/// +/// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind +/// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder &CompleteGraphBuilder::SetOutputMapping(const std::map &output_mapping) { + for (auto &item : output_mapping) { + output_mapping_[item.first] = item.second; + } + return *this; +} + +/// +/// @brief Build graph +/// @param [out] error_code +/// @param [out] error_msg +/// @return ComputeGraphPtr +/// +ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { + owner_graph_ = shared_ptr(new (std::nothrow) ComputeGraph(name_)); + if ((owner_graph_ == nullptr) || (parent_node_ == nullptr)) { + error_code = GRAPH_FAILED; + error_msg = "graph / parent_node is NULL."; + return nullptr; + } + + owner_graph_->SetParentNode(parent_node_); + owner_graph_->SetParentGraph(parent_node_->GetOwnerComputeGraph()); + + BuildNodes(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + BuildDataLinks(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + BuildCtrlLinks(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + AddDataNodes(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + AddRetValNodes(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + BuildGraphTargets(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + // ATTR_NAME_SESSION_GRAPH_ID + std::string graph_id; + if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { + error_code = GRAPH_FAILED; + error_msg = "Get attr session_graph_id failed."; + return nullptr; + } + if (!AttrUtils::SetStr(owner_graph_, ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { + error_code = GRAPH_FAILED; + error_msg = "Set attr session_graph_id failed."; + return nullptr; + } + + // refresh node name + for (const NodePtr &node : owner_graph_->GetDirectNode()) { + if ((node->GetOpDesc() == nullptr) || (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2)) { + continue; + } + node->GetOpDesc()->SetName(owner_graph_->GetName() + "/" + node->GetName()); + } + + return owner_graph_; +} + +/// +/// @brief Add data nodes +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void CompleteGraphBuilder::AddDataNodes(graphStatus &error_code, std::string &error_msg) { + for (auto &input : graph_inputs_) { + NodePtr data_node = AddDataNode(input.first, error_code, error_msg); + if (data_node == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNodes failed: add node Data:" + std::to_string(input.first) + +" failed."; + return; + } + + if (owner_graph_->AddInputNode(data_node) == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNodes failed: add input node Data:" + std::to_string(input.first) + +" failed."; + return; + } + + // useless input + std::vector input_names = input.second.first; + std::vector anchor_indes = input.second.second; + if (input_names.size() != anchor_indes.size()) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNodes failed: num of input_names and indexs not equal."; + return; + } + if (input_names.empty()) { + continue; + } + + size_t input_num = input_names.size(); + for (size_t i = 0; i < input_num; i++) { + std::string input_name = input_names[i]; + uint32_t ind = anchor_indes[i]; + auto iter = node_names_.find(input_name); + if (iter == node_names_.end()) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNodes failed: node " + input_name + " not exist in graph."; + return; + } + + NodePtr in_node = node_names_[input_name]; + if (in_node == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNodes failed: node " + input_name + " is NULL."; + return; + } + + if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), in_node->GetInDataAnchor(ind)) != GRAPH_SUCCESS) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNodes failed: add data-edge Data:" + std::to_string(input.first) + ":0->" + input_name + + ":" + std::to_string(ind) + " failed."; + return; + } + } + + GELOGD("AddDataNodes : Add %u input succ.", input.first); + } + + GELOGD("AddDataNodes succ."); +} + +/// +/// @brief Add data node +/// @param [in] index +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +NodePtr CompleteGraphBuilder::AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg) { + std::string data_name = "Data_" + std::to_string(index); + OpDescBuilder op_desc_builder(data_name, "Data"); + OpDescPtr op_desc = op_desc_builder.AddInput("x").AddOutput("y").Build(); + if (op_desc == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNode failed: create op_desc " + data_name + " failed."; + return nullptr; + } + + auto index_iter = input_mapping_.find(index); + if (index_iter != input_mapping_.end()) { + if (!ge::AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, index_iter->second)) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNode failed: set attr ATTR_NAME_PARENT_NODE_INDEX for " + data_name + " failed."; + return nullptr; + } + } + + NodePtr data_node = owner_graph_->AddNode(op_desc); + if (data_node == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNode failed: add node " + data_name + " failed."; + return nullptr; + } + node_names_[data_name] = data_node; + + return data_node; +} + +/// +/// @brief Add RetVal nodes +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string &error_msg) { + size_t output_num = graph_outputs_.size(); + for (size_t i = 0; i < output_num; i++) { + int32_t index = graph_outputs_[i].second; + auto out_iter = node_names_.find(graph_outputs_[i].first); + if (out_iter == node_names_.end()) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode failed: node " + graph_outputs_[i].first + " not exist in graph."; + return; + } + NodePtr node = out_iter->second; + if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode failed: node is NULL."; + return; + } + + std::string name = node->GetName() + "_RetVal_" + std::to_string(index); + OpDescPtr ret_val_desc = shared_ptr(new (std::nothrow) OpDesc(name, FRAMEWORKOP)); + if (ret_val_desc == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: op_desc is NULL."; + return; + } + ge::GeTensorDesc tensor = node->GetOpDesc()->GetOutputDesc(index); + if ((ret_val_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) || + (ret_val_desc->AddOutputDesc(tensor) != GRAPH_SUCCESS)) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: add input_desc / output_desc failed."; + return; + } + + if (!(ge::AttrUtils::SetStr(ret_val_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_RetVal") && + ge::AttrUtils::SetInt(ret_val_desc, RETVAL_ATTR_NAME_INDEX, i))) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: set FRAMEWORK_ORIGINAL_TYPE / RETVAL_ATTR_NAME_INDEX failed."; + return; + } + auto iter = output_mapping_.find(i); + if (iter != output_mapping_.end()) { + if (!ge::AttrUtils::SetInt(ret_val_desc, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: set attr PARENT_NODE_INDEX failed."; + return; + } + } + + NodePtr ret_val_node = owner_graph_->AddNode(ret_val_desc); + if (ret_val_node == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: add node failed."; + return; + } + + if (GraphUtils::AddEdge(node->GetOutDataAnchor(index), ret_val_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: add data-edge " + node->GetName() + ":" + std::to_string(index) + + "->" + ret_val_node->GetName() + ":0 failed."; + return; + } + } + + GELOGD("AddRetValNodes succ."); +} + +/// +/// @brief Build target-nodes for graph +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void CompleteGraphBuilder::BuildGraphTargets(graphStatus &error_code, std::string &error_msg) { + std::vector target_nodes; + for (const std::string &target_name : graph_targets_) { + auto target_iter = node_names_.find(target_name); + if ((target_iter == node_names_.end()) || (target_iter->second == nullptr)) { + error_code = GRAPH_FAILED; + error_msg = "BuildGraphTargets failed: target_node " + target_name + " not exist in graph."; + return; + } + target_nodes.emplace_back(target_iter->second); + } + owner_graph_->SetGraphTargetNodesInfo(target_nodes); + return; +} + +/// +/// @brief Add node to graph +/// @param [in] op_desc +/// @return PartialGraphBuilder +/// +PartialGraphBuilder &PartialGraphBuilder::AddNode(const OpDescPtr &op_desc) { + ComputeGraphBuilder::AddNode(op_desc); + return *this; +} + +/// +/// @brief Add data-link among nodes in graph +/// @param [in] src_name +/// @param [in] out_anchor_ind +/// @param [in] dst_name +/// @param [in] in_anchor_ind +/// @return PartialGraphBuilder +/// +PartialGraphBuilder &PartialGraphBuilder::AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, + const std::string &dst_name, uint32_t in_anchor_ind) { + ComputeGraphBuilder::AddDataLink(src_name, out_anchor_ind, dst_name, in_anchor_ind); + return *this; +} + +/// +/// @brief Add ctrl-link among nodes in graph +/// @param [in] src_name +/// @param [in] dst_name +/// @return PartialGraphBuilder +/// +PartialGraphBuilder &PartialGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { + ComputeGraphBuilder::AddControlLink(src_name, dst_name); + return *this; +} + +/// +/// @brief Set owner graph +/// @param [in] graph +/// @return PartialGraphBuilder +/// +PartialGraphBuilder &PartialGraphBuilder::SetOwnerGraph(const ComputeGraphPtr &graph) { + owner_graph_ = graph; + return *this; +} + +/// +/// @brief Add exist node +/// @param [in] node +/// @return PartialGraphBuilder +/// +PartialGraphBuilder &PartialGraphBuilder::AddExistNode(const NodePtr &node) { + exist_nodes_.emplace_back(node); + return *this; +} + +/// +/// @brief Build partial graph +/// @param [out] error_code +/// @param [out] error_msg +/// @return ComputeGraphPtr +/// +ComputeGraphPtr PartialGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { + if (owner_graph_ == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "graph is NULL."; + return nullptr; + } + + BuildNodes(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + BuildExistNodes(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + BuildDataLinks(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + BuildCtrlLinks(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + return owner_graph_; +} + +/// +/// @brief Build exist nodes +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void PartialGraphBuilder::BuildExistNodes(graphStatus &error_code, std::string &error_msg) { + std::string node_name; + for (auto &node : exist_nodes_) { + if (node == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "Build exist nodes failed: node is NULL."; + return; + } + + node_name = node->GetName(); + if (node->GetOwnerComputeGraph() != owner_graph_) { + error_code = GRAPH_FAILED; + error_msg = "Build exist nodes failed: node " + node_name + " not belongs to this graph."; + return; + } + + GELOGD("Add exist_node name:%s.", node_name.c_str()); + node_names_[node_name] = node; + } + + GELOGD("Build exist nodes succ."); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector &node_vec) { + std::vector stack_input; + std::map map_in_edge_num; + graphStatus ret = compute_graph->SortNodes(stack_input, map_in_edge_num); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Sort nodes failed."); + return GRAPH_FAILED; + } + const size_t non_user_input_index = stack_input.size() - compute_graph->inputs_order_.size() - 1; + std::sort(stack_input.begin(), stack_input.begin() + non_user_input_index, + [](const NodePtr &a, const NodePtr &b) -> bool { return (a->GetName() > b->GetName()); }); + + std::queue stack; + NodePtr cur_node = nullptr; + std::map name_node_map; + vector nodes_name; + while (!stack_input.empty() || !stack.empty()) { + if (!stack.empty()) { + cur_node = stack.front(); + stack.pop(); + } else { + cur_node = stack_input.back(); + stack_input.pop_back(); + } + node_vec.emplace_back(cur_node); + compute_graph->CollectBreadthOutNode(cur_node, map_in_edge_num, name_node_map); + for (const auto &iter : name_node_map) { + nodes_name.emplace_back(iter.first); + } + std::sort(nodes_name.begin(), nodes_name.end()); + for (const auto &iter : nodes_name) { + stack.push(name_node_map[iter]); + } + name_node_map.clear(); + nodes_name.clear(); + } + // If they are not equal, there is a closed loop + if (node_vec.size() != compute_graph->nodes_.size()) { + std::set itered_nodes_set; + for (auto &node : node_vec) { + itered_nodes_set.insert(node.get()); + } + GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", + compute_graph->nodes_.size(), node_vec.size()); + for (auto &node : compute_graph->nodes_) { + if (itered_nodes_set.count(node.get()) == 0) { + GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); + } + } + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +} // namespace ge diff --git a/src/common/graph/utils/mem_utils.h b/src/common/graph/utils/mem_utils.h new file mode 100644 index 00000000..7e8dd9fd --- /dev/null +++ b/src/common/graph/utils/mem_utils.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_UTILS_MEM_UTILS_H_ +#define COMMON_GRAPH_UTILS_MEM_UTILS_H_ + +#include +#include + +namespace ge { +template +static inline std::shared_ptr<_Tp> MakeShared(_Args &&... __args) { + typedef typename std::remove_const<_Tp>::type _Tp_nc; + std::shared_ptr<_Tp> ret(new (std::nothrow) _Tp_nc(std::forward<_Args>(__args)...)); + return ret; +} +} + +#endif // COMMON_GRAPH_UTILS_MEM_UTILS_H_ diff --git a/src/common/graph/utils/node_utils.cc b/src/common/graph/utils/node_utils.cc new file mode 100644 index 00000000..684e37ac --- /dev/null +++ b/src/common/graph/utils/node_utils.cc @@ -0,0 +1,1005 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/utils/node_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/graph_utils.h" +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/anchor.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/types.h" +#include "external/graph/operator.h" +#include "graph/ge_context.h" +#include "graph/runtime_inference_context.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/tensor_adapter.h" +#include "graph/utils/type_utils.h" + +namespace ge { +std::map> NodeUtils::map_send_info_{}; +std::map> NodeUtils::map_recv_info_{}; + +const std::set kConstOpTypes = {"Const", "Constant"}; + +const std::set kIfOpTypes = {"If", "_If", "StatelessIf"}; +const std::set kWhileOpTypes = {"While", "_While", "StatelessWhile"}; +const std::set kCaseOpTypes = {"Case"}; +const std::set kForOpTypes = {"For"}; + +bool OpShapeIsUnknown(const OpDescPtr &desc) { + for (const auto &ptr : desc->GetAllInputsDescPtr()) { + auto ge_shape = ptr->GetShape(); + for (const auto &dim : ge_shape.GetDims()) { + if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { + return true; + } + } + } + for (const auto &ptr : desc->GetAllOutputsDescPtr()) { + auto ge_shape = ptr->GetShape(); + for (const auto &dim : ge_shape.GetDims()) { + if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { + return true; + } + } + } + return false; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node, + const uint32_t &event_id) { + GE_CHECK_NOTNULL(node); + map_send_info_[node].push_back(event_id); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node, + const uint32_t &event_id) { + GE_CHECK_NOTNULL(node); + map_recv_info_[node].push_back(event_id); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector &vec_send) { + GE_CHECK_NOTNULL(node); + auto find = map_send_info_.find(node); + if (find == map_send_info_.end()) { + return GRAPH_FAILED; + } else { + vec_send = find->second; + return GRAPH_SUCCESS; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector &vec_recv) { + GE_CHECK_NOTNULL(node); + auto find = map_recv_info_.find(node); + if (find == map_recv_info_.end()) { + return GRAPH_FAILED; + } else { + vec_recv = find->second; + return GRAPH_SUCCESS; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() { + map_send_info_.clear(); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() { + map_recv_info_.clear(); + return GRAPH_SUCCESS; +} + +graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) { + GE_CHECK_NOTNULL(src); + NodePtr cur_ptr; + if (depth < 1) { + return GRAPH_FAILED; + } + for (int i = 0; i < depth; i++) { + if (src->GetOutDataNodes().size() != 1) { + return GRAPH_FAILED; + } + cur_ptr = src->GetOutDataNodes().at(0); + GE_CHECK_NOTNULL(cur_ptr); + } + dst = cur_ptr; + return GRAPH_SUCCESS; +} + +graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data, + InControlAnchorPtr &in_control) { + GE_CHECK_NOTNULL(node_ptr); + for (const auto &p : node_ptr->GetAllOutDataAnchors()) { + GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr"); + for (const auto &p_in : p->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr"); + out_data = p; + in_control = p_in; + return GRAPH_SUCCESS; + } + } + return GRAPH_FAILED; +} + +graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) { + GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED, + "node or in_data_anchor is nullptr"); + + bool find_flag = false; + uint32_t index = 0; + vector::iterator it = node_ptr->in_data_anchors_.end(); + for (const auto &tmp : node_ptr->in_data_anchors_) { + if (tmp == in_data_anchor) { + find_flag = true; + auto iter = node_ptr->in_data_anchors_.begin() + index; + if (iter != node_ptr->in_data_anchors_.end()) { + it = node_ptr->in_data_anchors_.erase(iter); + } + break; + } + index++; + } + for (; it != node_ptr->in_data_anchors_.end(); ++it) { + (*it)->SetIdx(index); + index++; + } + + if (!find_flag) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) { + GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr"); + GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed"); + return GRAPH_SUCCESS; +} + +graphStatus NodeUtils::SetAllAnchorStatus(Node &node) { + node.anchor_status_updated_ = true; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) { + GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr"); + return IsAnchorStatusSet(*node_ptr); +} + +bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; } + +graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) { + if ((origin_node == nullptr) || (new_node == nullptr)) { + return GRAPH_FAILED; + } + auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors(); + auto new_out_data_anchors = new_node->GetAllOutDataAnchors(); + if (origin_out_data_anchors.size() != new_out_data_anchors.size()) { + return GRAPH_FAILED; + } + + for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) { + for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) { + GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue, + "unlink peer_anchor failed"); + GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, + "linkto peer_anchor failed"); + } + + for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue, + "unlink peer_anchor failed"); + GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, + "linkto peer_anchor failed"); + } + } + + auto origin_out_control_anchor = origin_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(origin_out_control_anchor); + auto new_out_control_anchor = new_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(new_out_control_anchor); + for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, + "linkto peer_anchor failed"); + } + for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) { + GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, + "linkto peer_anchor failed"); + } + origin_out_control_anchor->UnlinkAll(); + + return GRAPH_SUCCESS; +} + +bool NodeUtils::IsConst(const Node &node) { + auto src_node_type = node.GetType(); + bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP)); + return is_const; +} + +void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) { + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "node is null"); + return; + } + UpdateIsInputConst(*node_ptr); +} + +/// +/// update is_input_const +/// @param node +/// @return void +/// +void NodeUtils::UpdateIsInputConst(Node &node) { + std::vector is_input_const; + size_t anchor_num = node.GetAllInDataAnchors().size(); + for (size_t i = 0; i < anchor_num; i++) { + auto in_anchor = node.GetInDataAnchor(static_cast(i)); + if (in_anchor == nullptr) { + is_input_const.push_back(false); + continue; + } + auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + is_input_const.push_back(false); + continue; + } + auto src_node = peer_out_anchor->GetOwnerNode(); + if (src_node == nullptr) { + is_input_const.push_back(false); + continue; + } + if (IsConst(*(src_node))) { + is_input_const.push_back(true); + } else { + is_input_const.push_back(false); + } + } + if (node.GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr"); + return; + } + node.GetOpDesc()->SetIsInputConst(is_input_const); +} + +void NodeUtils::UnlinkAll(const Node &node) { + for (const auto &anchor : node.GetAllOutAnchors()) { + anchor->UnlinkAll(); + } + for (const auto &anchor : node.GetAllInAnchors()) { + anchor->UnlinkAll(); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) { + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "Nodeptr is nullptr"); + return GRAPH_FAILED; + } + auto op_desc = node_ptr->GetOpDesc(); + if (op_desc == nullptr) { + return GRAPH_FAILED; + } + bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag(); + if (is_unknown_graph) { + return GRAPH_SUCCESS; + } + for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { + auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + auto out_dims = output_tensor->GetShape().GetDims(); + auto out_dtype = output_tensor->GetDataType(); + ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetShape().GetDims().size())); + output_tensor->SetOriginShape(output_tensor->GetShape()); + output_tensor->SetOriginDataType(output_tensor->GetDataType()); + + GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", + node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), + TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), + TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); + + for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { + if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); + continue; + } + auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx()); + if (peer_input_desc == nullptr) { + GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); + continue; + } + // check shape and dtype continuity. do not stop process + auto peer_input_dims = peer_input_desc->GetShape().GetDims(); + auto peer_input_dtype = peer_input_desc->GetDataType(); + if (out_dtype != peer_input_dtype) { + GELOGW( + "current node [%s] [%d]\'th out_dtype is [%s].peer input node [%s] [%d]\'th " + "input_dtype is [%s].The two dtype should be same! Please check graph and fix it", + node_ptr->GetName().c_str(), out_anchor->GetIdx(), TypeUtils::DataTypeToSerialString(out_dtype).c_str(), + peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), + TypeUtils::DataTypeToSerialString(peer_input_dtype).c_str()); + } else if ((!peer_input_dims.empty()) && (out_dims != peer_input_dims)) { + string out_shape_str, peer_in_shape_str; + out_shape_str += "["; + for (int64_t dim : out_dims) { + out_shape_str += std::to_string(dim) + " "; + } + out_shape_str += "]"; + peer_in_shape_str += "["; + for (int64_t dim : peer_input_dims) { + peer_in_shape_str += std::to_string(dim) + " "; + } + peer_in_shape_str += "]"; + + GELOGW( + "current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th " + "input_shape is [%s].The two shape should be same! Please check graph and fix it", + node_ptr->GetName().c_str(), out_anchor->GetIdx(), out_shape_str.c_str(), + peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), peer_in_shape_str.c_str()); + } + GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", + peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(), + output_tensor->GetDataType(), output_tensor->GetOriginDataType()); + peer_input_desc->SetOriginShape(output_tensor->GetOriginShape()); + peer_input_desc->SetShape(output_tensor->GetShape()); + peer_input_desc->SetDataType(output_tensor->GetDataType()); + peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType()); + std::vector> shape_range; + (void)output_tensor->GetShapeRange(shape_range); + peer_input_desc->SetShapeRange(shape_range); + ge::TensorUtils::SetRealDimCnt(*peer_input_desc, + static_cast(output_tensor->GetShape().GetDims().size())); + GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", + peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), + peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); + } + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node, + uint32_t num) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Input node is null"); + return GRAPH_FAILED; + } + + GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); + const auto &op_desc = node->GetOpDesc(); + for (size_t i = op_desc->GetInputsSize(); i < num; ++i) { + if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add input desc failed"); + return GRAPH_FAILED; + } + + auto anchor = ComGraphMakeShared(node, i); + if (anchor == nullptr) { + GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed."); + return GRAPH_FAILED; + } + node->in_data_anchors_.push_back(anchor); + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node, + uint32_t num) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Input node is null"); + return GRAPH_FAILED; + } + + const auto &op_desc = node->GetOpDesc(); + while (op_desc->GetInputsSize() > num) { + if (!OpDescUtils::ClearInputDesc(op_desc, num)) { + return GRAPH_FAILED; + } + } + + auto input_names = op_desc->GetAllInputName(); + (void)op_desc->UpdateInputName(input_names); + auto is_input_const = op_desc->GetIsInputConst(); + is_input_const.resize(num); + op_desc->SetIsInputConst(is_input_const); + + while (node->in_data_anchors_.size() > num) { + node->in_data_anchors_.pop_back(); + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendOutputAnchor(const NodePtr &node, + uint32_t num) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Input node is null"); + return GRAPH_FAILED; + } + + GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); + const OpDescPtr &op_desc = node->GetOpDesc(); + for (size_t i = op_desc->GetOutputsSize(); i < num; ++i) { + if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add output desc failed"); + return GRAPH_FAILED; + } + + auto anchor = ComGraphMakeShared(node, i); + if (anchor == nullptr) { + GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed."); + return GRAPH_FAILED; + } + node->out_data_anchors_.push_back(anchor); + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveOutputAnchor(const NodePtr &node, + uint32_t num) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Input node is null"); + return GRAPH_FAILED; + } + + const auto &op_desc = node->GetOpDesc(); + auto output_names = op_desc->GetAllOutputName(); + while (op_desc->GetOutputsSize() > num) { + if (!OpDescUtils::ClearOutputDesc(op_desc, num)) { + return GRAPH_FAILED; + } + } + (void)op_desc->UpdateOutputName(output_names); + + while (node->out_data_anchors_.size() > num) { + node->out_data_anchors_.pop_back(); + } + + return GRAPH_SUCCESS; +} + +bool NodeUtils::IsInNodesEmpty(const Node &node) { + for (const auto &in_anchor : node.in_data_anchors_) { + if (in_anchor != nullptr) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor != nullptr) { + if (out_anchor->GetOwnerNode() != nullptr) { + return false; + } + } + } + } + + if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) { + auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors(); + for (const auto &out_control_anchor : peer_out_control_anchors) { + if (out_control_anchor != nullptr) { + if (out_control_anchor->GetOwnerNode() != nullptr) { + return false; + } + } + } + } + + return true; +} +GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) { + auto desc = node.GetOpDesc(); + if (desc == nullptr) { + return GeTensorDesc(); + } + return desc->GetOutputDesc(index); +} +GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) { + auto desc = node.GetOpDesc(); + if (desc == nullptr) { + return GeTensorDesc(); + } + return desc->GetInputDesc(index); +} +graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) { + auto desc = node.GetOpDesc(); + if (desc == nullptr) { + return GRAPH_PARAM_INVALID; + } + auto output_desc = desc->MutableOutputDesc(index); + if (output_desc == nullptr) { + return GRAPH_PARAM_INVALID; + } + output_desc->SetShape(shape); + return GRAPH_SUCCESS; +} +graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) { + auto desc = node.GetOpDesc(); + if (desc == nullptr) { + return GRAPH_PARAM_INVALID; + } + auto input_desc = desc->MutableInputDesc(index); + if (input_desc == nullptr) { + return GRAPH_PARAM_INVALID; + } + input_desc->SetShape(shape); + return GRAPH_SUCCESS; +} + +graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { + auto desc = node.GetOpDesc(); + GE_CHECK_NOTNULL(desc); + // check self + is_unknow = OpShapeIsUnknown(desc); + if (is_unknow) { + return GRAPH_SUCCESS; + } + auto sub_graph_names = desc->GetSubgraphInstanceNames(); + if (sub_graph_names.empty()) { + return GRAPH_SUCCESS; + } else { + auto owner_graph = node.GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(owner_graph); + auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); + if (root_graph == nullptr) { + GE_LOGE("Node %s gets null root graph", node.GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + for (auto &sub_graph_name : sub_graph_names) { + auto sub_graph = root_graph->GetSubgraph(sub_graph_name); + GE_CHECK_NOTNULL(sub_graph); + for (const auto &node_ptr : sub_graph->GetDirectNode()) { + auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow); + if (status != GRAPH_SUCCESS) { + GE_LOGE("get node unknown shape status failed!"); + return status; + } + if (is_unknow) { + return GRAPH_SUCCESS; + } + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus NodeUtils::GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor) { + GE_CHECK_NOTNULL(node_ptr); + return NodeUtils::GetInputConstData(*node_ptr, dst_name, ge_tensor); +} + +graphStatus NodeUtils::GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor) { + // For inner compute graph + auto op_desc = node.GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto index = op_desc->GetInputIndexByName(dst_name); + auto in_data_anchor = node.GetInDataAnchor(index); + GE_CHECK_NOTNULL(in_data_anchor); + auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(out_data_anchor); + auto peer_node = out_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_node); + auto peer_op_desc = peer_node->GetOpDesc(); + GE_CHECK_NOTNULL(peer_op_desc); + auto peer_op_type = peer_op_desc->GetType(); + if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) { + if (!AttrUtils::MutableTensor(peer_node->GetOpDesc(), ATTR_NAME_WEIGHTS, ge_tensor)) { + GELOGW("get attr name %s failed.", ATTR_NAME_WEIGHTS.c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; + } else if (peer_op_type == DATA) { + auto parent_node = NodeUtils::GetParentInput(peer_node); + while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { + parent_node = NodeUtils::GetParentInput(parent_node); + } + if ((parent_node != nullptr) && ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { + if (!AttrUtils::MutableTensor(parent_node->GetOpDesc(), ATTR_NAME_WEIGHTS, ge_tensor)) { + GELOGW("get attr name %s failed.", ATTR_NAME_WEIGHTS.c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; + } + } + // Try get from runtime inference context + auto session_id = std::to_string(GetContext().SessionId()); + RuntimeInferenceContext *runtime_infer_ctx = nullptr; + if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { + GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); + auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), ge_tensor); + if (ret == GRAPH_SUCCESS) { + return GRAPH_SUCCESS; + } + } + GELOGW("node[%s]'s input[%s]'s peer node is not const", node.GetName().c_str(), dst_name.c_str()); + return GRAPH_FAILED; +} + +std::string NodeUtils::GetNodeType(const Node &node) { + if (node.GetType() != FRAMEWORKOP) { + return node.GetType(); + } + + std::string type; + (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); + return type; +} + +std::string NodeUtils::GetNodeType(const NodePtr &node) { return node == nullptr ? "" : GetNodeType(*node); } + +ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { + auto op_desc = node.GetOpDesc(); + if (op_desc == nullptr) { + return nullptr; + } + auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); + if (root_graph == nullptr) { + return nullptr; + } + return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); +} + +graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) { + if (subgraph == nullptr) { + GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index); + return GRAPH_PARAM_INVALID; + } + auto op_desc = node.GetOpDesc(); + if (op_desc == nullptr) { + return GRAPH_PARAM_INVALID; + } + auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); + if (root_graph == nullptr) { + GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName()); + if (ret != GRAPH_SUCCESS) { + GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index); + return ret; + } + subgraph->SetParentNode(node.shared_from_this()); + subgraph->SetParentGraph(node.GetOwnerComputeGraph()); + return root_graph->AddSubgraph(subgraph); +} + +/// +/// Check if node is input of subgraph +/// @param [in] node +/// @return bool +/// +bool NodeUtils::IsSubgraphInput(const NodePtr &node) { + if ((node == nullptr) || (node->GetOpDesc() == nullptr) || + (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) { + return false; + } + + auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); + if (parent_op_desc == nullptr) { + return false; + } + + // dynamic shape unknown graph false + // dynamic shape known graph with functional subgraph maybe true + if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { + if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) { + return false; + } else { + if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) { + return false; + } + } + } + + return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); +} + +/// +/// Check if node is output of subgraph +/// @param [in] node +/// @return bool +/// +bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { + if ((node == nullptr) || (node->GetOpDesc() == nullptr) || + (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) { + return false; + } + + auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); + if (parent_op_desc == nullptr) { + return false; + } + + if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { + if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) { + return false; + } else { + if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) { + return false; + } + } + } + + for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) { + if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) { + return true; + } + } + + return false; +} + +/// +/// @brief Get subgraph original input node. +/// @param [in] node +/// @return Node +/// +NodePtr NodeUtils::GetParentInput(const Node &node) { + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(node.GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + return nullptr; + } + + // Subgraph Data Node, check for constant input. + const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); + GE_CHECK_NOTNULL_EXEC(graph, return nullptr); + + const NodePtr &parent_node = graph->GetParentNode(); + GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr); + + const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index); + GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr); + + const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr); + + return peer_out_anchor->GetOwnerNode(); +} + +NodePtr NodeUtils::GetParentInput(const NodePtr &node) { return node == nullptr ? node : GetParentInput(*node); } + +/// +/// @brief Get is dynamic shape graph from node. +/// @param [in] node +/// @return bool +/// +bool NodeUtils::IsDynamicShape(const Node &node) { + const auto graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); + if (graph == nullptr) { + return false; + } + + bool is_dynamic_shape = false; + (void)AttrUtils::GetBool(graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); + return is_dynamic_shape; +} + +bool NodeUtils::IsDynamicShape(const NodePtr &node) { return node == nullptr ? false : IsDynamicShape(*node); } + +/// +/// @brief Check is varying_input for while node +/// @param [in] node: Data node for subgraph +/// @return bool +/// +bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) { + if (node == nullptr) { + return false; + } + if (node->GetType() != DATA) { + return false; // not input_node for subgraph + } + + const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode(); + if (parent_node == nullptr) { + return false; // root graph + } + + if (kWhileOpTypes.count(parent_node->GetType()) == 0) { + return false; // not input_node for while subgraph + } + + uint32_t index_i = 0; + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) { + GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str()); + return false; + } + bool varying_flag = true; + for (const auto &item : node->GetOutDataNodesAndAnchors()) { + if (item.first->GetType() != NETOUTPUT) { + continue; + } + OpDescPtr op_desc = item.first->GetOpDesc(); + uint32_t index_o = 0; + if ((op_desc == nullptr) || + !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) { + continue; // input for while-cond subgraph + } + if (index_i != index_o) { + continue; // varying input for while-body subgraph + } + varying_flag = false; + break; + } + return varying_flag; +} + +/// +/// @brief Get subgraph input is constant. +/// @param [in] node +/// @param [out] string +/// @return bool +/// +bool NodeUtils::GetConstOpType(const NodePtr &node, std::string &type) { + if (node == nullptr) { + return false; + } + + if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) { + type = node->GetType(); + return true; + } + + if (node->GetType() != DATA) { + return false; // not subgraph input node + } + + const auto &parent = GetParentInput(node); + return GetConstOpType(parent, type); +} + +/// +/// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. +/// @param [in] node +/// @return return GRAPH_SUCCESS if remove successfully, other for failed. +/// +Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto subgraph_names = op_desc->GetSubgraphInstanceNames(); + if (subgraph_names.empty()) { + return GRAPH_SUCCESS; + } else { + auto owner_graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(owner_graph); + auto root_graph = GraphUtils::FindRootGraph(owner_graph); + GE_CHECK_NOTNULL(root_graph); + + std::unordered_set subgraph_to_remove; + for (auto &subgraph_name : subgraph_names) { + std::deque queue; + queue.push_back(subgraph_name); + subgraph_to_remove.insert(subgraph_name); + op_desc->RemoveSubgraphInstanceName(subgraph_name); + while (!queue.empty()) { + auto graph_name = queue.front(); + queue.pop_front(); + + auto subgraph = root_graph->GetSubgraph(graph_name); + GE_CHECK_NOTNULL(subgraph); + for (const auto &sub_node : subgraph->GetDirectNode()) { + auto sub_op_desc = sub_node->GetOpDesc(); + GE_CHECK_NOTNULL(sub_op_desc); + auto sub_names = sub_op_desc->GetSubgraphInstanceNames(); + // Subgraph and all nodes in it will be removed later, + // no need to remove 'SubgraphInstanceName' in op desc here. + for (auto &name : sub_names) { + if (subgraph_to_remove.insert(name).second) { + queue.push_back(name); + } + } + } + } + } + // Remove subgraph from root_graph + for (const auto &name : subgraph_to_remove) { + GELOGI("Remove subgraph:%s.", name.c_str()); + root_graph->RemoveSubgraph(name); + } + } + + return GRAPH_SUCCESS; +} +/// +/// @brief Get subgraph input data node by index. +/// @param [in] node +/// @return Node +/// +vector NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) { + vector in_data_node_vec; + auto op_desc = node.GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec); + auto subgraph_names = op_desc->GetSubgraphInstanceNames(); + if (subgraph_names.empty()) { + GELOGW("Node %s is single node without sub graph.", node.GetName().c_str()); + return in_data_node_vec; + } + auto compute_graph = node.GetOwnerComputeGraph(); + for (const std::string &instance_name : subgraph_names) { + auto subgraph = compute_graph->GetSubgraph(instance_name); + for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { + int parent_index = -1; + if (NodeUtils::IsSubgraphInput(node_in_subgraph)) { + (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index); + if (parent_index == index) { + in_data_node_vec.emplace_back(node_in_subgraph); + } + } + } + } + return in_data_node_vec; +} +/// +/// @brief Get subgraph input data node by index. +/// @param [in] node +/// @return Node +/// +vector NodeUtils::GetSubgraphOutputNodes(const Node &node) { + vector out_data_node_vec; + auto op_desc = node.GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec); + auto subgraph_names = op_desc->GetSubgraphInstanceNames(); + if (subgraph_names.empty()) { + GELOGI("Node %s is single node without sub graph.", node.GetName().c_str()); + return out_data_node_vec; + } + auto compute_graph = node.GetOwnerComputeGraph(); + for (const std::string &instance_name : subgraph_names) { + auto subgraph = compute_graph->GetSubgraph(instance_name); + for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { + if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) { + out_data_node_vec.emplace_back(node_in_subgraph); + } + } + } + return out_data_node_vec; +} + +NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, const int index) { + if (node.GetInDataAnchor(index) == nullptr) { + return nullptr; + } + if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) { + return nullptr; + } + return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode(); +} + +vector> NodeUtils::GetOutDataNodesWithAnchorByIndex(const Node &node, const int index) { + vector> out_data_nodes; + auto out_data_anchor = node.GetOutDataAnchor(index); + if (out_data_anchor == nullptr) { + return out_data_nodes; + } + + for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + if (peer_in_anchor == nullptr) { + continue; + } + if (peer_in_anchor->GetOwnerNode() == nullptr) { + continue; + } + out_data_nodes.emplace_back(std::make_pair(peer_in_anchor, peer_in_anchor->GetOwnerNode())); + } + return out_data_nodes; +} + +ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) { return oprt.GetNode(); } +} // namespace ge diff --git a/src/common/graph/utils/op_desc_utils.cc b/src/common/graph/utils/op_desc_utils.cc new file mode 100644 index 00000000..17c80b2c --- /dev/null +++ b/src/common/graph/utils/op_desc_utils.cc @@ -0,0 +1,825 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/op_desc_utils.h" +#include +#include "debug/ge_attr_define.h" +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/anchor.h" +#include "graph/compute_graph.h" +#include "graph/ge_attr_value.h" +#include "utils/graph_utils.h" +#include "utils/node_utils.h" + +using std::vector; + +/*lint -e512 -e737 -e752*/ +namespace ge { +const char OP_DESC_QUANT_PARAMS[] = "quantize_factor"; +static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1; + +bool OpDescUtils::ClearInputDesc(const NodePtr &node) { + GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); + GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr"); + vector index_list; + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + if (in_anchor->GetPeerOutAnchor() == nullptr) { + index_list.push_back(in_anchor->GetIdx()); + } + } + std::sort(index_list.begin(), index_list.end()); + // Node's in anchor index need shrink + for (size_t i = 0; i < index_list.size(); ++i) { + auto iter = node->GetOpDesc()->inputs_desc_.begin() + index_list[i]; + if (iter < node->GetOpDesc()->inputs_desc_.end()) { + (void)node->GetOpDesc()->inputs_desc_.erase(iter); + } else { + GELOGW("inputs_desc_ iterator out of range."); + } + } + + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearInputDesc(OpDescPtr op_desc, + const uint32_t index) { + GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr"); + GE_CHK_BOOL_EXEC(index < op_desc->inputs_desc_.size(), return false, "index %u is invalid.", index); + + auto iter = op_desc->inputs_desc_.begin() + index; + if (iter < op_desc->inputs_desc_.end()) { + (void)op_desc->inputs_desc_.erase(iter); + } else { + GELOGW("inputs_desc_ iterator out of range."); + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::HasQuantizeFactorParams(const OpDescPtr &op_desc) { + GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return false, "op_desc is nullptr"); + return op_desc->HasAttr(OP_DESC_QUANT_PARAMS); +} + +bool OpDescUtils::ClearOutputDesc(const NodePtr &node) { + GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); + GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr"); + vector index_list; + for (const auto &out_anchor : node->GetAllOutDataAnchors()) { + if (out_anchor->GetPeerInDataAnchors().empty()) { + index_list.push_back(out_anchor->GetIdx()); + } + } + std::sort(index_list.begin(), index_list.end()); + // Node's out anchor index need shrink + for (size_t i = 0; i < index_list.size(); ++i) { + auto iter = node->GetOpDesc()->outputs_desc_.begin() + index_list[i]; + if (iter < node->GetOpDesc()->outputs_desc_.end()) { + (void)node->GetOpDesc()->outputs_desc_.erase(iter); + } else { + GELOGW("outputs_desc_ iterator out of range."); + } + } + + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearOutputDesc(const OpDescPtr &op_desc, + uint32_t index) { + GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr"); + GE_CHK_BOOL_EXEC(index < op_desc->outputs_desc_.size(), return false, "index %u is invalid.", index); + + auto iter = op_desc->outputs_desc_.begin() + index; + if (iter < op_desc->outputs_desc_.end()) { + (void)op_desc->outputs_desc_.erase(iter); + } else { + GELOGW("outputs_desc_ iterator out of range."); + } + return true; +} + +bool OpDescUtils::HasQuantizeFactorParams(const OpDesc &op_desc) { return op_desc.HasAttr(OP_DESC_QUANT_PARAMS); } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDescUtils::GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant) { + GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); + GeAttrValue attr_value; + GE_CHK_BOOL_EXEC_INFO(op_desc->GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED, + "GetQuantizeFactorParams failed"); + return attr_value.GetValue(quant); +} + +graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant) { + GeAttrValue attr_value; + GE_CHK_BOOL_EXEC_INFO(op_desc.GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED, + "GetQuantizeFactorParams failed"); + return attr_value.GetValue(quant); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) { + GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); + return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom(quant)); // lint !e732 +} + +graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) { + return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom(quant)); // lint !e732 +} + +GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) { + GeTensorPtr weight = nullptr; + if (!AttrUtils::MutableTensor(&op_desc, ATTR_NAME_WEIGHTS, weight)) { + GELOGW("MutableTensor error"); + } + + return weight; +} + +GE_FUNC_HOST_VISIBILITY GeTensorPtr OpDescUtils::MutableWeights(OpDescPtr op_desc) { + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "op_desc is null"); + return nullptr; + } + return MutableWeights(*op_desc); +} + +graphStatus OpDescUtils::SetWeights(OpDesc &op_desc, const GeTensorPtr weight) { + if (weight == nullptr) { + GELOGE(GRAPH_FAILED, "weight is null"); + return GRAPH_FAILED; + } + return AttrUtils::SetTensor(&op_desc, ATTR_NAME_WEIGHTS, weight) ? GRAPH_SUCCESS : GRAPH_FAILED; +} + +graphStatus OpDescUtils::SetWeights(OpDescPtr op_desc, const GeTensorPtr weight) { + GE_CHECK_NOTNULL(op_desc); + GE_CHECK_NOTNULL(weight); + return SetWeights(*op_desc, weight); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetWeights(const ge::Node &node) { + auto weights = MutableWeights(node); + vector ret(weights.size()); + std::copy(weights.begin(), weights.end(), ret.begin()); + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetWeights( + const ge::ConstNodePtr &node) { + if (node == nullptr) { + return vector(); + } + return GetWeights(*node); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetConstInputNode( + const ge::Node &node) { + vector ret; + auto in_anchors = node.GetAllInDataAnchors(); + for (const auto &in_anchor : in_anchors) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + // normally out_anchor could be null, this is ok + GELOGD("node %s' peer_out_anchor is null", node.GetName().c_str()); + continue; + } + auto in_node = out_anchor->GetOwnerNode(); + while (true) { + if (in_node == nullptr) { + break; + } + if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { + ret.push_back(in_node); + break; + } else if (in_node->GetType() == DATA) { + if (NodeUtils::IsWhileVaryingInput(in_node)) { + break; + } + in_node = NodeUtils::GetParentInput(in_node); + } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) { + bool is_constant = false; + (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant); + if (!is_constant) { + break; + } + // Enter node has and only has one input + if (in_node->GetInDataNodes().size() != 1) { + GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(), + in_node->GetInDataNodes().size()); + break; + } + in_node = in_node->GetInDataNodes().at(0); + } else { + break; + } + } + } + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetInputData( + const vector &input_nodes) { + vector ret; + + for (const auto &input_node : input_nodes) { + auto temp_weight = MutableWeights(input_node->GetOpDesc()); + if (temp_weight == nullptr) { + GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str()); + return vector(); + } + ret.push_back(temp_weight); + } + + return ret; +} +size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) { + if (NodeUtils::IsAnchorStatusSet(node)) { + size_t input_num = 0; + for (const auto &anchor : node.GetAllInDataAnchors()) { + if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { + input_num++; + continue; + } + } + return input_num; // lint !e712 + } else { + GE_IF_BOOL_EXEC( + node.GetInDataNodes().size() < GetConstInputs(node).size(), + GELOGE(GRAPH_FAILED, "%zu is smaller than %zu", node.GetInDataNodes().size(), GetConstInputs(node).size()); + return 0); + return node.GetInDataNodes().size() - GetConstInputs(node).size(); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDescUtils::GetNonConstInputsSize(const ge::ConstNodePtr node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Node is nullptr"); + return 0; + } + return GetNonConstInputsSize(*node); +} + +GeTensorDesc OpDescUtils::GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const) { + GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GeTensorDesc(), "node.GetOpDesc() is nullptr!"); + size_t i = 0; + if (NodeUtils::IsAnchorStatusSet(node)) { + for (const auto &anchor : node.GetAllInDataAnchors()) { + if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { + if (index_non_const == i) { + return node.GetOpDesc()->GetInputDesc(static_cast(anchor->GetIdx())); + } + ++i; + } + } + } else { + for (const auto &anchor : node.GetAllInDataAnchors()) { + auto peer_anchor = anchor->GetPeerOutAnchor(); + if (peer_anchor == nullptr) { + continue; + } + auto owner_node = peer_anchor->GetOwnerNode(); + if (owner_node == nullptr) { + continue; + } + if (owner_node->GetType() == CONSTANT) { + continue; + } + if (index_non_const == i) { + return node.GetOpDesc()->GetInputDesc(anchor->GetIdx()); + } + ++i; + } + } + return GeTensorDesc(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc +OpDescUtils::GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const) { + CHECK_FALSE_EXEC(node != nullptr, return GeTensorDesc()); + return GetNonConstInputTensorDesc(*node, index_non_const); +} + +bool OpDescUtils::GetNonConstInputIndex(const ge::Node &node, const size_t index_non_const, size_t &index) { + bool ret = false; + size_t i = 0; + if (NodeUtils::IsAnchorStatusSet(node)) { + for (const auto &anchor : node.GetAllInDataAnchors()) { + if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { + if (index_non_const == i) { + index = static_cast(anchor->GetIdx()); + ret = true; + } + ++i; + } + } + } else { + for (const auto &anchor : node.GetAllInDataAnchors()) { + auto peer_anchor = anchor->GetPeerOutAnchor(); + if (peer_anchor == nullptr) { + continue; + } + auto owner_node = peer_anchor->GetOwnerNode(); + if (owner_node == nullptr) { + continue; + } + if (owner_node->GetType() == CONSTANT) { + continue; + } + if (index_non_const == i) { + index = static_cast(anchor->GetIdx()); + ret = true; + } + ++i; + } + } + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::GetNonConstInputIndex(const ge::ConstNodePtr &node, + size_t index_non_const, + size_t &index) { + CHECK_FALSE_EXEC(node != nullptr, return false); + return GetNonConstInputIndex(*node, index_non_const, index); +} + +bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) { + bool ret = false; + if (index < node.GetAllInDataAnchors().size()) { + if (NodeUtils::IsAnchorStatusSet(node)) { + ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast(index))) == ANCHOR_DATA); // lint !e712 + } else { + for (const auto &anchor : node.GetAllInDataAnchors()) { + if (anchor->GetIdx() != static_cast(index)) { + continue; + } + auto peer_anchor = anchor->GetPeerOutAnchor(); + if (peer_anchor == nullptr) { + break; + } + auto owner_node = peer_anchor->GetOwnerNode(); + if (owner_node == nullptr) { + break; + } + ret = (owner_node->GetType() != CONSTANT); + } + } + } + + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::IsNonConstInput(const ge::ConstNodePtr &node, + size_t index) { + CHECK_FALSE_EXEC(node != nullptr, return false); + return IsNonConstInput(*node, index); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetConstInputs( + const ge::ConstNodePtr &node) { + if (node == nullptr) { + return vector(); + } + return GetConstInputs(*node); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetNonConstTensorDesc( + const ge::ConstNodePtr &node) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + return vector(); + } + vector ret; + if (NodeUtils::IsAnchorStatusSet(*node)) { + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) { + ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); + } + } + } else { + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr || out_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { + continue; + } + if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) { + ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); + } + } + } + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetConstInputs(const ge::Node &node) { + vector ret; + auto in_anchors = node.GetAllInDataAnchors(); + for (const auto &in_anchor : in_anchors) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) continue; + + auto in_node = out_anchor->GetOwnerNode(); + if (in_node->GetType() == CONSTANT) { + ret.push_back(in_node); + } else if (in_node->GetType() == SWITCH && node.GetType() == MATMUL) { + // const --> switch --> matmul + auto switch_input = GetConstInputs(*in_node); + if (switch_input.size() > 0) { + ret.insert(ret.end(), switch_input.begin(), switch_input.end()); + } + } else if (in_node->GetType() == DATA) { + auto parent = NodeUtils::GetParentInput(in_node); + if ((parent != nullptr) && (parent->GetType() == CONSTANT)) { + ret.push_back(parent); + } + } + } + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::MutableWeights(const ge::Node &node) { + vector ret; + auto op_desc = node.GetOpDesc(); + GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!"); + // Place holder operator, try to get the weight from parent node + // when parent node is const operator + if (node.GetType() == PLACEHOLDER) { + std::string parent_op; + (void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op); + // This if judgment is necessary because the current subgraph optimization is multithreaded + // and the parent node of the PLD operation should be a stable type, such as const + if (parent_op == CONSTANT || parent_op == CONSTANTOP) { + NodePtr parent_node = nullptr; + parent_node = op_desc->TryGetExtAttr("parentNode", parent_node); + if (parent_node != nullptr) { + op_desc = parent_node->GetOpDesc(); + GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str()); + } + } + } + // Const operator, take the weight directly + if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) { + auto weight = MutableWeights(op_desc); + if (weight == nullptr) { + GELOGI("const op has no weight, op name:%s", node.GetName().c_str()); + return ret; + } + ret.push_back(weight); + return ret; + } + + if (node.GetType() == DATA) { + auto parent = NodeUtils::GetParentInput(node); + if ((parent != nullptr) && NodeUtils::IsConst(*parent)) { + auto weight = MutableWeights(parent->GetOpDesc()); + if (weight == nullptr) { + GELOGI("const op has no weight, op name:%s", parent->GetName().c_str()); + return ret; + } + ret.push_back(weight); + } + return ret; + } + + // Other operators, get weights from connected constop + auto input_nodes = GetConstInputs(node); + for (const auto &input_node : input_nodes) { + auto temp_weight = MutableWeights(input_node->GetOpDesc()); + if (temp_weight == nullptr) { + GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str()); + return vector(); + } + ret.push_back(temp_weight); + } + + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::MutableWeights(const ge::NodePtr node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Node is nullptr"); + return vector(); + } + return MutableWeights(*node); +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDescUtils::SetWeights(ge::Node &node, const vector &weights) { + GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GRAPH_PARAM_INVALID, "node.GetOpDesc is nullptr!"); + if (node.GetOpDesc()->GetType() == CONSTANT) { + if (weights.size() == CONST_OP_NORMAL_WEIGHT_SIZE) { + return SetWeights(node.GetOpDesc(), weights[0]); + } + GELOGI("const op weight size %zu should be 1", weights.size()); + return GRAPH_PARAM_INVALID; + } + + auto input_nodes = GetConstInputs(node); + if (weights.size() < input_nodes.size()) { + GELOGE(GRAPH_FAILED, "weights count can't be less than const input count"); + return GRAPH_PARAM_INVALID; + } + + ge::GeAttrValue::NAMED_ATTRS named_attrs; + (void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights); + vector copy_weights; + (void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights); + + for (size_t i = 0; i < input_nodes.size(); ++i) { + if (input_nodes[i]->GetOpDesc() != nullptr) { + SetWeights(input_nodes[i]->GetOpDesc(), copy_weights[i]); + } + } + + // If set more weights than constop, need to add constop + for (size_t i = input_nodes.size(); i < copy_weights.size(); ++i) { + // Use org weight before SetWeights Overwrite + auto const_opdesc = CreateConstOp(copy_weights[i]); + GE_CHECK_NOTNULL(const_opdesc); + + auto owner_graph = node.GetOwnerComputeGraph(); + if (owner_graph == nullptr) { + GELOGE(GRAPH_FAILED, "node's graph is empty, name: %s", node.GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + auto const_node = owner_graph->AddNodeFront(const_opdesc); + GE_CHK_BOOL_EXEC(node.AddLinkFrom(const_node) == GRAPH_SUCCESS, return GRAPH_FAILED, "graph add link failedï¼"); + std::vector original_nodes; + ge::GraphUtils::RecordOriginalNames(original_nodes, const_node); + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDescUtils::SetWeights(ge::Node &node, const map &weights_map) { + GE_CHECK_NOTNULL(node.GetOpDesc()); + // 1. node is const + if (node.GetOpDesc()->GetType() == CONSTANT) { + if (weights_map.size() == CONST_OP_NORMAL_WEIGHT_SIZE) { + return SetWeights(node.GetOpDesc(), weights_map.begin()->second); + } + GELOGE(GRAPH_PARAM_INVALID, "const op %s weight size %zu should be 1", node.GetName().c_str(), weights_map.size()); + return GRAPH_PARAM_INVALID; + } + // 2. node is not const + for (const auto &pair : weights_map) { + auto in_data_anchor = node.GetInDataAnchor(pair.first); + if (in_data_anchor != nullptr && in_data_anchor->GetPeerOutAnchor() != nullptr) { + // a. update const input node + auto out_anchor = in_data_anchor->GetPeerOutAnchor(); + auto peer_node = out_anchor->GetOwnerNode(); + if (peer_node == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "op %s [%d]'s input node is null", node.GetName().c_str(), pair.first); + return GRAPH_PARAM_INVALID; + } + if (peer_node->GetType() != CONSTANT) { + GELOGE(GRAPH_PARAM_INVALID, " op %s [%d]'s input node should be const, but is %s type:%s ", + node.GetName().c_str(), pair.first, peer_node->GetName().c_str(), peer_node->GetType().c_str()); + } + SetWeights(peer_node->GetOpDesc(), pair.second); + } else { + // b. create new const input node + auto const_opdesc = CreateConstOp(pair.second); + GE_CHECK_NOTNULL(const_opdesc); + auto owner_graph = node.GetOwnerComputeGraph(); + if (owner_graph == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", node.GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + auto const_node = owner_graph->AddNodeFront(const_opdesc); + if (node.AddLinkFrom(static_cast(pair.first), const_node) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "op %s add const to input index[%d] failed", node.GetName().c_str(), pair.first); + return GRAPH_FAILED; + } + } + } + NodeUtils::UpdateIsInputConst(node); + return GRAPH_SUCCESS; +} + +OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr) { + GE_CHK_BOOL_EXEC(tensor_ptr != nullptr, return nullptr, "tensor_ptr is nullptr!"); + shared_ptr const_opdesc = ComGraphMakeShared(); + if (const_opdesc == nullptr) { + GELOGE(GRAPH_FAILED, "failed to make_shared "); + return nullptr; + } + + CHECK_FALSE_EXEC(SetWeights(const_opdesc, tensor_ptr) == ge::GRAPH_SUCCESS, return nullptr); + + const_opdesc->SetType(CONSTANT); + + thread_local int64_t const_count = 0; + const_opdesc->SetName("dynamic_const_" + std::to_string(GetTid()) + "_" + std::to_string(const_count)); + GELOGI("add const op: %s", const_opdesc->GetName().c_str()); + ++const_count; + + (void)const_opdesc->AddOutputDesc(tensor_ptr->GetTensorDesc()); + + GELOGI("after add const op: %s", const_opdesc->GetName().c_str()); + + return const_opdesc; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDescUtils::AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr) { + GE_CHECK_NOTNULL(in_anchor); + GE_CHECK_NOTNULL(tensor_ptr); + auto const_opdesc = CreateConstOp(tensor_ptr); + GE_CHECK_NOTNULL(const_opdesc); + auto in_node = in_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(in_node); + auto owner_graph = in_node->GetOwnerComputeGraph(); + if (owner_graph == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", in_node->GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + auto const_node = in_node->GetOwnerComputeGraph()->AddNodeFront(const_opdesc); + GE_CHECK_NOTNULL(const_node); + if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), in_anchor) != GRAPH_SUCCESS) { + GELOGE(GRAPH_PARAM_INVALID, "Addedge const to node failed."); + return GRAPH_PARAM_INVALID; + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDescUtils::SetWeights(ge::NodePtr node, const vector &weights) { + GE_CHECK_NOTNULL(node); + return SetWeights(*node, weights); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWeights(const ge::NodePtr node) { + GE_CHECK_NOTNULL(node); + auto const_ops = GetConstInputs(node); + auto graph = node->GetOwnerComputeGraph(); + if (graph == nullptr) { + GELOGE(GRAPH_FAILED, "Graph is nullptr"); + return GRAPH_PARAM_INVALID; + } + for (const auto &const_op : const_ops) { + GE_CHK_STATUS_RET(GraphUtils::IsolateNode(const_op, {}), "Isolate removed node: %s, type: %s failed", + const_op->GetName().c_str(), const_op->GetType().c_str()); + GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, const_op), + "Remove node: %s, type: %s without relink failed", const_op->GetName().c_str(), + const_op->GetType().c_str()); + } + return GRAPH_SUCCESS; +} + +/// +/// @brief Add input +/// @param [in] name +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) { + inputs_.emplace_back(std::make_pair(name, GeTensorDesc())); + return *this; +} + +/// +/// @brief Add input +/// @param [in] name +/// @param [in] tensor +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name, + const GeTensorDesc &tensor) { + inputs_.emplace_back(std::make_pair(name, tensor)); + return *this; +} + +/// +/// @brief Add dynamic input +/// @param [in] name +/// @param [in] num +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, + uint32_t num) { + for (uint32_t i = 0; i < num; i++) { + inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); + } + return *this; +} + +/// +/// @brief Add dynamic input +/// @param [in] name +/// @param [in] num +/// @param [in] tensor +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput( + const std::string &name, uint32_t num, const GeTensorDesc &tensor) { + for (uint32_t i = 0; i < num; i++) { + inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); + } + return *this; +} + +/// +/// @brief Add output +/// @param [in] name +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) { + outputs_.emplace_back(std::make_pair(name, GeTensorDesc())); + return *this; +} + +/// +/// @brief Add output +/// @param [in] name +/// @param [in] tensor +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name, + const GeTensorDesc &tensor) { + outputs_.emplace_back(std::make_pair(name, tensor)); + return *this; +} + +/// +/// @brief Add dynamic output +/// @param [in] name +/// @param [in] num +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, + uint32_t num) { + for (uint32_t i = 0; i < num; i++) { + outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); + } + return *this; +} + +/// +/// @brief Add dynamic output +/// @param [in] name +/// @param [in] num +/// @param [in] tensor +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput( + const std::string &name, uint32_t num, const GeTensorDesc &tensor) { + for (uint32_t i = 0; i < num; i++) { + outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); + } + return *this; +} + +/// +/// @brief Build op_desc +/// @return OpDescPtr +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() { + OpDescPtr op_desc = shared_ptr(new (std::nothrow) OpDesc(name_, type_)); + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "OpDesc is nullptr"); + return nullptr; + } + + for (auto &input : inputs_) { + if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add input_desc failed."); + return nullptr; + } + } + + for (auto &output : outputs_) { + if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add output_desc failed."); + return nullptr; + } + } + + return op_desc; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgraphInstanceName( + const std::string &subgraph_name, const std::string &subgraph_instance_name, OpDescPtr &op_desc) { + const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); + auto iter = subgraph_names_to_index.find(subgraph_name); + if (iter == subgraph_names_to_index.end()) { + GELOGE(GRAPH_PARAM_INVALID, + "Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists", + subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), + subgraph_name.c_str()); + return GRAPH_PARAM_INVALID; + } + + return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name); +} +} // namespace ge +/*lint +e512 +e737 +e752*/ diff --git a/src/common/graph/utils/string_utils.h b/src/common/graph/utils/string_utils.h new file mode 100644 index 00000000..a9700469 --- /dev/null +++ b/src/common/graph/utils/string_utils.h @@ -0,0 +1,68 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_UTILS_STRING_UTILS_H_ +#define COMMON_GRAPH_UTILS_STRING_UTILS_H_ + +#include +#include +#include +#include +#include +#include "securec.h" + +namespace ge { +class StringUtils { + public: + static std::string &Ltrim(std::string &s) { + (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); + return s; + } + + static std::string &Rtrim(std::string &s) { + (void)s.erase(std::find_if(s.rbegin(), s.rend(), [](int c) { return !std::isspace(c); }).base(), s.end()); + return s; + } + + /// @ingroup domi_common + /// @brief trim space + static std::string &Trim(std::string &s) { return Ltrim(Rtrim(s)); } + + // split string + static std::vector Split(const std::string &str, char delim) { + std::vector elems; + + if (str.empty()) { + elems.emplace_back(""); + return elems; + } + + std::stringstream ss(str); + std::string item; + + while (getline(ss, item, delim)) { + elems.push_back(item); + } + auto str_size = str.size(); + if (str_size > 0 && str[str_size - 1] == delim) { + elems.emplace_back(""); + } + + return elems; + } +}; +} // namespace ge +#endif // COMMON_GRAPH_UTILS_STRING_UTILS_H_ diff --git a/src/common/graph/utils/tensor_utils.cc b/src/common/graph/utils/tensor_utils.cc new file mode 100644 index 00000000..26ac8cc8 --- /dev/null +++ b/src/common/graph/utils/tensor_utils.cc @@ -0,0 +1,401 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/utils/tensor_utils.h" +#include + +#include "debug/ge_log.h" +#include "framework/common/debug/ge_log.h" +#include "common/util/error_manager/error_manager.h" +#include "graph/ge_tensor.h" +#include "graph/types.h" +#include "graph/utils/type_utils.h" + +namespace ge { +namespace { +// When nc1hwc0 dim size = 5, calc element count directly. +const uint32_t kNc1hwc0CalcByDimsSize = 5; + +// Unknown shape element num +const int64_t kElementCntUnknownShape = -1; + +// Unknown shape mem size +const int64_t kMemSizeUnknownShape = -1; + +// Nchw and nhwc dim size must be 4 +const uint32_t kDimSize4d = 4; + +// C1HWNCoC0 dim size must be 6 +const uint32_t kDimSizeC1hwncoc0 = 6; + +// Cube size is 16 +const uint32_t kTheCubeSize = 16; + +// Default c0 size equals cube size. +const uint32_t kC0SizeDefault = kTheCubeSize; + +// Size equals int8 cube size is 32 +const uint32_t kC0SizeInt8 = 32; + +// NCHW dim N index +const int32_t kNchwDimIdxN = 0; +// NCHW dim C index +const int32_t kNchwDimIdxC = 1; +// NCHW dim H index +const int32_t kNchwDimIdxH = 2; +// NCHW dim W index +const int32_t kNchwDimIdxW = 3; + +const int kDataMemAlignSize = 32; +const int kNum2 = 2; +} // namespace + +/// +/// Check if a * b overflow. +/// @param a multiplier +/// @param b Multiplicand +/// @return true: overflow +/// false: not overflow +/// +static bool CheckMultiplyOverflowInt64(const int64_t &a, const int64_t &b) { + if (a > 0) { + if (b > 0) { + if (a > (INT64_MAX / b)) { + return true; + } + } else { + if (b < (INT64_MIN / a)) { + return true; + } + } + } else { + if (b > 0) { + if (a < (INT64_MIN / b)) { + return true; + } + } else { + if ((a != 0) && (b < (INT64_MAX / a))) { + return true; + } + } + } + return false; +} + +/// +/// Calculate element num by dims directly. +/// @param dims dim info +/// @param element_cnt element count +/// @return GRAPH_SUCCESS:success +/// other:failed +/// +static graphStatus CalcElementCntByDims(const std::vector &dims, int64_t &element_cnt) { + element_cnt = 1; + for (int64_t dim : dims) { + if (CheckMultiplyOverflowInt64(element_cnt, dim)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19013", {"function", "var1", "var2"}, + {"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)}); + GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim); + return GRAPH_FAILED; + } + element_cnt *= dim; + } + return GRAPH_SUCCESS; +} + +/// +/// Calculate fixed dims element num. +/// @param dims dim info +/// @param fixed_dim_size fixed dim size +/// @param element_cnt element count +/// @return GRAPH_SUCCESS:success +/// other:failed +/// +static graphStatus CalcElementCntOfFixedDims(const std::vector &dims, Format format, uint32_t fixed_dim_size, + int64_t &element_cnt) { + if (dims.size() != fixed_dim_size) { + GELOGW("Format %d(%s) need dim size=%u but %zu, calc as ND.", format, + TypeUtils::FormatToSerialString(format).c_str(), fixed_dim_size, dims.size()); + } + return CalcElementCntByDims(dims, element_cnt); +} + +/// +/// Get dim c0 size by type +/// @param data_type data type +/// @return c0 size +/// +static uint32_t GetDimC0(DataType &data_type) { + bool is_int8_size = (data_type == DT_INT8) || (data_type == DT_UINT8) || (data_type == DT_DUAL_SUB_UINT8) || + (data_type == DT_DUAL_SUB_INT8) || (data_type == DT_BOOL) || (data_type == DT_QINT8); + return is_int8_size ? kC0SizeInt8 : kC0SizeDefault; +} + +/// +/// Calculate nc1hwc0 element num. +/// @param dims dim info +/// @param data_type data type +/// @param element_cnt element count +/// @return GRAPH_SUCCESS:success +/// other:failed +/// +static graphStatus CalcElementCntOfNc1hwc0(const std::vector &dims, DataType data_type, int64_t &element_cnt) { + // When nc1hwc0 dims size = 5, no need split dim c + if (dims.size() == kNc1hwc0CalcByDimsSize) { + return CalcElementCntByDims(dims, element_cnt); + } else if (dims.size() != kDimSize4d) { + GELOGE(GRAPH_FAILED, "CalcElementCntOfNc1hwc0 failed as dims.size=%zu is not %u or %u.", dims.size(), kDimSize4d, + kNc1hwc0CalcByDimsSize); + return GRAPH_FAILED; + } + + auto c0 = static_cast(GetDimC0(data_type)); + // Nc1hwc0 dims is according to nchw, dim c index is 1. + auto c1 = static_cast(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0)); + // Store dims is split c to c1 and c0. + std::vector store_dims = {dims[kNchwDimIdxN], c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0}; + return CalcElementCntByDims(store_dims, element_cnt); +} + +/// +/// Calculate FractalZ element num. +/// @param dims dim info +/// @param data_type data type +/// @param element_cnt element count +/// @return GRAPH_SUCCESS:success +/// other:failed +/// +static graphStatus CalcElementCntOfFractalZ(const std::vector &dims, DataType data_type, + int64_t &element_cnt) { + static char *parser_priority = std::getenv("PARSER_PRIORITY"); + if (parser_priority != nullptr && string(parser_priority) == "cce") { + if (dims.size() != kDimSize4d) { + GELOGE(GRAPH_FAILED, "CalcElementCntOfFractalZ failed as dims.size=%zu is not %u.", dims.size(), kDimSize4d); + return GRAPH_FAILED; + } + auto c0 = static_cast(GetDimC0(data_type)); + // FractalZ dims is according to nchw, dim c index is 1. + auto c1 = static_cast(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0)); + + // Spread NC1HWC0 as a two dimension array, n as column dimension, + // C1HWC0 as row dimension + std::vector r_count_vec = {c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0}; + + int64_t r_count = 1; + graphStatus graph_status = CalcElementCntByDims(r_count_vec, r_count); + if (graph_status != GRAPH_SUCCESS) { + GELOGE(graph_status, "Calc [%ld, %ld, %ld, %ld] element count failed.", c1, dims[kNchwDimIdxH], + dims[kNchwDimIdxW], c0); + return graph_status; + } + + // Cube count in n + auto nc_cnt = static_cast(std::ceil(dims[kNchwDimIdxN] * 1.0 / kTheCubeSize)); + + // Cube count in vertical direction(C1HWC0) + int64_t vc_cnt = r_count / c0; + // Element count in each cube + int64_t cube_elem_cnt = c0 * kTheCubeSize; + + if (CheckMultiplyOverflowInt64(nc_cnt, vc_cnt)) { + GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", nc_cnt, vc_cnt); + return GRAPH_FAILED; + } + // Read data times needed by cube + int64_t c_cnt = nc_cnt * vc_cnt; + + if (CheckMultiplyOverflowInt64(c_cnt, cube_elem_cnt)) { + GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", c_cnt, cube_elem_cnt); + return GRAPH_FAILED; + } + // Element count after fractal arrangement + element_cnt = c_cnt * cube_elem_cnt; + return GRAPH_SUCCESS; + } else { + return CalcElementCntByDims(dims, element_cnt); + } +} + +/// +/// Calculate tensor element num. +/// @param dims dim info +/// @param format tensor format +/// @param data_type data type +/// @param element_cnt element count +/// @return GRAPH_SUCCESS:success +/// other:failed +/// +static graphStatus CalcTensorElementCnt(const std::vector &dims, Format format, DataType data_type, + int64_t &element_cnt) { + const string format_str = TypeUtils::FormatToSerialString(format); + // Check dims + for (size_t i = 0; i < dims.size(); ++i) { + int64_t dim = dims[i]; + if (dim < 0) { + GELOGI("It's unknown shape, as dims[%zu]=%ld negative, format=%d(%s).", i, dim, format, format_str.c_str()); + element_cnt = kElementCntUnknownShape; + return GRAPH_SUCCESS; + } else if (dim == 0) { + GELOGI("No need calc element count, as dims[%zu]=%ld, format=%d(%s).", i, dim, format, format_str.c_str()); + element_cnt = 0; + return GRAPH_SUCCESS; + } + } + + graphStatus graph_status; + switch (format) { + case FORMAT_ND: + case FORMAT_MD: + graph_status = CalcElementCntByDims(dims, element_cnt); + break; + case FORMAT_NCHW: + case FORMAT_HWCN: + case FORMAT_NHWC: + case FORMAT_CHWN: + graph_status = CalcElementCntOfFixedDims(dims, format, kDimSize4d, element_cnt); + break; + case FORMAT_C1HWNCoC0: + graph_status = CalcElementCntOfFixedDims(dims, format, kDimSizeC1hwncoc0, element_cnt); + break; + case FORMAT_NC1HWC0: + graph_status = CalcElementCntOfNc1hwc0(dims, data_type, element_cnt); + break; + case FORMAT_FRACTAL_Z: + graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt); + break; + case FORMAT_FRACTAL_NZ: + case FORMAT_FRACTAL_ZZ: + case FORMAT_NDHWC: + case FORMAT_NCDHW: + case FORMAT_DHWCN: + case FORMAT_DHWNC: + case FORMAT_FRACTAL_Z_3D: + case FORMAT_FRACTAL_Z_3D_TRANSPOSE: + case FORMAT_NDC1HWC0: + case FORMAT_FRACTAL_Z_C04: + case FORMAT_FRACTAL_ZN_LSTM: + case FORMAT_NC1HWC0_C04: + graph_status = CalcElementCntByDims(dims, element_cnt); + break; + default: + GELOGE(GRAPH_FAILED, "unsupported format, format=%d(%s).", format, format_str.c_str()); + graph_status = GRAPH_FAILED; + break; + } + + const string type_str = TypeUtils::DataTypeToSerialString(data_type); + if (graph_status == GRAPH_SUCCESS) { + GELOGD( + "CalcTensorElementCnt end, format=%d(%s)," + " data_type=%d(%s), element_cnt=%ld.", + format, format_str.c_str(), data_type, type_str.c_str(), element_cnt); + } else { + GELOGE(GRAPH_FAILED, "CalcTensorElementCnt failed, format=%d(%s), data_type=%d(%s).", format, format_str.c_str(), + data_type, type_str.c_str()); + } + return graph_status; +} + +/// +/// Calculate tensor mem size. +/// @param shape tensor shape +/// @param format tensor format +/// @param data_type tensor data type +/// @param mem_size -1 means unknown shape,other means mem size +/// @return GRAPH_SUCCESS:success, other:failed +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTensorMemSize(const GeShape &shape, + Format format, + DataType data_type, + int64_t &mem_size) { + const string format_str = TypeUtils::FormatToSerialString(format); + const string type_str = TypeUtils::DataTypeToSerialString(data_type); + uint32_t type_size = 0; + bool result = TypeUtils::GetDataTypeLength(data_type, type_size); + if (!result) { + GELOGE(GRAPH_FAILED, "GetDataTypeLength failed, data_type=%d(%s).", data_type, type_str.c_str()); + return GRAPH_FAILED; + } + + std::vector dims = shape.GetDims(); + int64_t element_cnt = 0; + graphStatus status = CalcTensorElementCnt(dims, format, data_type, element_cnt); + if (status != GRAPH_SUCCESS) { + GELOGE(status, "CalcTensorElementCnt failed, status=%u format=%d(%s) data_type=%d(%s).", status, format, + format_str.c_str(), data_type, type_str.c_str()); + return status; + } + // Support unknown shape + if (element_cnt < 0) { + mem_size = kMemSizeUnknownShape; + GELOGD( + "element_cnt is unknown. " + "format=%d(%s), data_type=%d(%s), mem_size=%ld", + format, format_str.c_str(), data_type, type_str.c_str(), mem_size); + return GRAPH_SUCCESS; + } + auto type_size_int64 = static_cast(type_size); + if (CheckMultiplyOverflowInt64(element_cnt, type_size_int64)) { + GELOGE(GRAPH_FAILED, "CalcTensorMemSize overflow, when multiplying %ld and %ld, format=%d(%s), data_type=%d(%s).", + element_cnt, type_size_int64, format, format_str.c_str(), data_type, type_str.c_str()); + return GRAPH_FAILED; + } + mem_size = element_cnt * type_size_int64; + + GELOGD( + "CalcTensorMemSize end, " + "format=%d(%s), data_type=%d(%s), mem_size=%ld", + format, format_str.c_str(), data_type, type_str.c_str(), mem_size); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { + graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp); + if (graph_status != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + // 64-byte alignment, if size is 0, align to 32 bytes + if (size_temp > (INT64_MAX - kNum2 * kDataMemAlignSize)) { + GELOGW("The updated mem size %ld is bigger than INT64_MAX", size_temp); + } else { + size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize; + } + return GRAPH_SUCCESS; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { + GeShape output_shape = desc_temp.GetShape(); + Format format = desc_temp.GetFormat(); + DataType data_type = desc_temp.GetDataType(); + int64_t output_mem_size = 0; + graphStatus graph_status = CalcTensorMemSize(output_shape, format, data_type, output_mem_size); + if (graph_status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "CalcTensorMemSize failed!"); + return GRAPH_FAILED; + } + + if (output_mem_size < 0) { + GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %ld]", + output_mem_size, INT64_MAX); + return GRAPH_FAILED; + } + + size_temp = output_mem_size; + return GRAPH_SUCCESS; +} +} // namespace ge diff --git a/src/common/graph/utils/tuning_utils.cc b/src/common/graph/utils/tuning_utils.cc new file mode 100644 index 00000000..0f07a197 --- /dev/null +++ b/src/common/graph/utils/tuning_utils.cc @@ -0,0 +1,684 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/tuning_utils.h" +#include "../debug/ge_util.h" +#include "../debug/ge_op_types.h" + +namespace ge { +const std::string peer_node_name_attr = "_peerNodeName"; +const std::string parent_node_name_attr = "_parentNodeName"; +const std::string alias_name_attr = "_aliasName"; +const std::string parent_node_attr = "parentNode"; +const std::string parent_node_anchor_index_attr = "_parentNodeAnchorIndex"; +const std::string tuning_subgraph_prefix = "/aicore_subgraph_"; +const std::string non_tuning_subgraph_prefix = "/subgraph_"; +const std::set kPartitionOpTypes = {PLACEHOLDER, END}; +const std::set kExeTypes = {DATA, NETOUTPUT}; +NodeNametoNodeNameMap TuningUtils::data_2_netoutput_; +NodetoNodeNameMap TuningUtils::data_node_2_netoutput_; +NodetoNodeMap TuningUtils::data_node_2_netoutput_node_; +NodeSet TuningUtils::netoutput_nodes_; +NodeSet TuningUtils::merged_graph_nodes_; +SubgraphCreateOutNode TuningUtils::create_output_; +std::mutex TuningUtils::mutex_; + +std::string TuningUtils::PrintCheckLog() { + std::stringstream ss; + ss << "d2n:{"; + for (const auto &pair : data_2_netoutput_) { + ss << "data:" << pair.first << "-" + << "netoutput:" << pair.second; + ss << " | "; + } + ss << "}"; + ss << "netoutputs:{"; + for (const auto &node : netoutput_nodes_) { + ss << "netoutput:" << node->GetName(); + ss << " | "; + } + ss << "}"; + return ss.str(); +} + +std::string TuningUtils::GetNodeNameByAnchor(const Anchor *anchor) { + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Anchor is nullptr"); + return "Null"; + } + auto node = anchor->GetOwnerNode(); + return node == nullptr ? "Null" : node->GetName(); +} + +// part 1 +graphStatus TuningUtils::ConvertGraphToFile(std::vector tuning_subgraphs, + std::vector non_tuning_subgraphs, bool exe_flag, + const std::string &path, const std::string &user_path) { + int64_t i = 0; + int64_t j = 0; + std::lock_guard lock(mutex_); + for (auto &subgraph : tuning_subgraphs) { + create_output_.emplace(subgraph, nullptr); + auto help_info = HelpInfo{i, exe_flag, true, path, user_path}; + if (MakeExeGraph(subgraph, help_info) != SUCCESS) { + GELOGE(GRAPH_FAILED, "TUU:subgraph %zu generate exe graph failed", i); + return GRAPH_FAILED; + } + i++; + } + + for (auto &subgraph : non_tuning_subgraphs) { + create_output_.emplace(subgraph, nullptr); + auto help_info = HelpInfo{j, true, false, path, user_path}; + if (MakeExeGraph(subgraph, help_info) != SUCCESS) { + GELOGE(GRAPH_FAILED, "TUU:non tuning_subgraph %zu generate exe graph failed", j); + return GRAPH_FAILED; + } + j++; + } + create_output_.clear(); + return SUCCESS; +} + +// +---------------+ +// | pld pld | +// | \ / | +// | relu relu | +// | \ / | +// | add | +// | | | +// | end | +// +---------------+ +// | +// | +// V +// +---------------+ +// | data data | +// | \ / | +// | relu relu | +// | \ / | +// | add | +// | | | +// | netoutput | +// +---------------+ +graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info) { + GE_CHECK_NOTNULL(exe_graph); + // if not make exe, just dump and return + if (!help_info.exe_flag) { + DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path); + GELOGI("TUU:just return, dump original sub_graph[%s]index[%d]", exe_graph->GetName().c_str(), help_info.index); + return SUCCESS; + } + // modify sub graph + for (NodePtr &node : exe_graph->GetDirectNode()) { + // 1.handle pld + if (node->GetType() == PLACEHOLDER) { + if (HandlePld(node) != SUCCESS) { + GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), + exe_graph->GetName().c_str()); + return FAILED; + } + } + // 2.handle end + if (node->GetType() == END) { + if (HandleEnd(node) != SUCCESS) { + GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), + exe_graph->GetName().c_str()); + return FAILED; + } + } + } + graphStatus ret = exe_graph->TopologicalSorting(); + if (ret != SUCCESS) { + GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", exe_graph->GetName().c_str(), ret); + return ret; + } + // dump subgraphs which modified by us + if (help_info.user_path.empty()) { + DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path); + } else { + GraphUtils::DumpGEGraph(exe_graph, "", true, help_info.user_path); + } + return SUCCESS; +} + +void TuningUtils::DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path) { + if (!path.empty()) { + if (is_tuning_graph) { + GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt"); + } else { + GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt"); + } + } else { + path = "./"; + if (is_tuning_graph) { + GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt"); + } else { + GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt"); + } + } +} + +graphStatus TuningUtils::CreateDataNode(NodePtr &node, NodePtr &data_node) { + auto graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + auto data_op_desc = ComGraphMakeShared(node->GetName(), DATA); + GE_CHECK_NOTNULL(data_op_desc); + auto pld_op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(pld_op_desc); + auto output_desc = pld_op_desc->GetOutputDesc(0); // only one output for pld and data + // data inputdesc & outputdesc set as same + if (data_op_desc->AddInputDesc(output_desc) != SUCCESS) { + GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str()); + return FAILED; + } + if (data_op_desc->AddOutputDesc(output_desc) != SUCCESS) { + GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str()); + return FAILED; + } + data_node = graph->AddNode(data_op_desc); + GE_CHECK_NOTNULL(data_node); + if (data_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed"); + return FAILED; + } + return SUCCESS; +} + +graphStatus TuningUtils::AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node) { + auto op_desc = data_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + auto pld_desc = pld->GetOpDesc(); + GE_CHECK_NOTNULL(pld_desc); + // inherit + // a. set `end's input node type` as attr + std::string parent_op_type; + if (!AttrUtils::GetStr(pld_desc, "parentOpType", parent_op_type)) { + GELOGE(FAILED, "TUU:pld %s get parentOpType failed", pld_desc->GetName().c_str()); + return FAILED; + } + (void)AttrUtils::SetStr(op_desc, "parentOpType", parent_op_type); + // b. set `end's input node name` as attr + std::string parent_op_name; + if (!AttrUtils::GetStr(pld_desc, parent_node_name_attr, parent_op_name)) { + GELOGE(FAILED, "TUU:pld %s get _parentNodeName failed", pld_desc->GetName().c_str()); + return FAILED; + } + (void)AttrUtils::SetStr(op_desc, parent_node_name_attr, parent_op_name); + // c. set `end's input node's out anchor index` as attr + int parent_node_anchor_index; + if (!AttrUtils::GetInt(pld_desc, "anchorIndex", parent_node_anchor_index)) { + GELOGE(FAILED, "TUU:pld %s get anchorIndex failed", pld_desc->GetName().c_str()); + return FAILED; + } + (void)AttrUtils::SetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index); + GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(), + data_node->GetName().c_str(), data_node->GetType().c_str()); + // d. set `end node name` as attr + std::string peer_end_name; + if (!AttrUtils::GetStr(pld_desc, peer_node_name_attr, peer_end_name)) { + GELOGE(FAILED, "TUU:pld %s get _peerNodeName failed", pld_desc->GetName().c_str()); + return FAILED; + } + (void)AttrUtils::SetStr(op_desc, peer_node_name_attr, peer_end_name); + GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(), + data_node->GetName().c_str(), data_node->GetType().c_str()); + return SUCCESS; +} + +graphStatus TuningUtils::ChangePld2Data(NodePtr &node, NodePtr &data_node) { + auto type_pld = node->GetType(); + auto type_data = data_node->GetType(); + if (type_pld != PLACEHOLDER || type_data != DATA) { + GELOGE(FAILED, "TUU:Failed to change node %s from type %s to type %s", node->GetName().c_str(), type_pld.c_str(), + type_data.c_str()); + return FAILED; + } + auto graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + std::vector output_map(node->GetAllOutDataAnchorsSize()); + for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) { + output_map[i] = static_cast(i); + } + + auto ret = GraphUtils::ReplaceNodeAnchors(data_node, node, {}, output_map); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:Failed to replace node %s by node %s error node %u", node->GetName().c_str(), + data_node->GetName().c_str(), ret); + return FAILED; + } + + NodeUtils::UnlinkAll(*node); + + ret = GraphUtils::RemoveNodeWithoutRelink(graph, node); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str()); + return FAILED; + } + + GELOGD("TUU:Remove node %s(%s) by the ChangePld2Data process, replace it with node %s(%s)", node->GetName().c_str(), + node->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str()); + return ret; +} + +graphStatus TuningUtils::HandlePld(NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + NodePtr data_node = nullptr; + + // 1. create data node + if (CreateDataNode(node, data_node) != SUCCESS) { + GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); + return FAILED; + } + // 2. add necessary info to data_node for recovery whole graph + if (AddAttrToDataNodeForMergeGraph(node, data_node) != SUCCESS) { + GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); + return FAILED; + } + // 3. replace pld node by data node created before + if (ChangePld2Data(node, data_node) != SUCCESS) { + GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); + return FAILED; + } + GELOGD("TUU:pld[%s] handle success", node->GetName().c_str()); + return SUCCESS; +} + +graphStatus TuningUtils::CreateNetOutput(NodePtr &node, NodePtr &out_node) { + GE_CHECK_NOTNULL(node); + auto graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + auto search = create_output_.find(graph); + if (search == create_output_.end()) { + GELOGE(FAILED, "TUU:node %s's owner sub graph %s not exist in create_output map", node->GetName().c_str(), + graph->GetName().c_str()); + return FAILED; + } + if (search->second != nullptr) { + out_node = search->second; + GELOGD("TUU:sub graph %s has created output node, just return", graph->GetName().c_str()); + return SUCCESS; + } + auto out_op_desc = ComGraphMakeShared(node->GetName(), NETOUTPUT); + GE_CHECK_NOTNULL(out_op_desc); + out_node = graph->AddNode(out_op_desc); + GE_CHECK_NOTNULL(out_node); + if (out_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed"); + return FAILED; + } + create_output_[graph] = out_node; + return SUCCESS; +} + +graphStatus TuningUtils::AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node) { + GE_CHECK_NOTNULL(end); + GE_CHECK_NOTNULL(out_node); + auto op_desc = out_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + std::vector alias_names = {}; + (void)AttrUtils::GetListStr(op_desc, alias_name_attr, alias_names); + alias_names.push_back(end->GetName()); + (void)AttrUtils::SetListStr(op_desc, alias_name_attr, alias_names); + return SUCCESS; +} + +graphStatus TuningUtils::LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) { + GE_CHECK_NOTNULL(end_node); + GE_CHECK_NOTNULL(out_node); + // get end in node is control node or normal node + AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr) + ? Anchor::DynamicAnchorCast(end_node->GetInControlAnchor()) + : Anchor::DynamicAnchorCast(end_node->GetInDataAnchor(0)); + auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1 + if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:remove end input edge from from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", + GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), + GetNodeNameByAnchor(end_in_anchor.get()).c_str(), end_in_anchor->GetIdx(), end_node->GetName().c_str(), + end_node->GetOwnerComputeGraph()->GetName().c_str()); + return FAILED; + } + // add edge between `end in node` and `out_node` + if (src_anchor->IsTypeOf()) { + std::shared_ptr anchor = + ComGraphMakeShared(out_node, out_node->GetAllInDataAnchors().size()); + GE_CHECK_NOTNULL(anchor); + out_node->in_data_anchors_.push_back(anchor); + if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", + GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), + GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(), + end_node->GetOwnerComputeGraph()->GetName().c_str()); + return FAILED; + } + auto end_op_desc = end_node->GetOpDesc(); + GE_CHECK_NOTNULL(end_op_desc); + auto out_node_op_desc = out_node->GetOpDesc(); + GE_CHECK_NOTNULL(out_node_op_desc); + // end node always has one input + if (out_node_op_desc->AddInputDesc(end_op_desc->GetInputDesc(0)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:node %s add input desc failed.", out_node_op_desc->GetName().c_str()); + return FAILED; + } + } else if (src_anchor->IsTypeOf()) { + auto anchor = out_node->GetInControlAnchor(); + if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", + GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), + GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(), + end_node->GetOwnerComputeGraph()->GetName().c_str()); + return FAILED; + } + } else { + GELOGE(FAILED, "TUU: node_name:%s, graph_name:%s handled failed", end_node->GetName().c_str(), + end_node->GetOwnerComputeGraph()->GetName().c_str()); + return FAILED; + } + + return SUCCESS; +} + +graphStatus TuningUtils::ChangeEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) { + GE_CHECK_NOTNULL(end_node); + GE_CHECK_NOTNULL(out_node); + auto type_end = end_node->GetType(); + auto type_out = out_node->GetType(); + if (type_end != END || type_out != NETOUTPUT) { + GELOGE(FAILED, "TUU:Failed to change end_node %s from type %s to type %s", end_node->GetName().c_str(), + type_end.c_str(), type_out.c_str()); + return FAILED; + } + // link all `end nodes's in node` to this out_node + if (LinkEnd2NetOutput(end_node, out_node) != SUCCESS) { + GELOGE(FAILED, "TUU:end_node [%s] LinkEnd2NetOutput failed.", end_node->GetName().c_str()); + return FAILED; + } + // remove `end node` + NodeUtils::UnlinkAll(*end_node); + auto graph = end_node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + if (GraphUtils::RemoveNodeWithoutRelink(graph, end_node) != SUCCESS) { + GELOGE(FAILED, "TUU:end node [%s] RemoveNodeWithoutRelink failed.", end_node->GetName().c_str()); + return FAILED; + } + return SUCCESS; +} + +graphStatus TuningUtils::HandleEnd(NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + NodePtr out_node = nullptr; + + // 1. create net_output node , add only one NetOutput node to one subgraph + if (CreateNetOutput(node, out_node) != SUCCESS) { + GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); + return FAILED; + } + // 2. add necessary info to out_node for recovery whole graph + if (AddAttrToNetOutputForMergeGraph(node, out_node) != SUCCESS) { + GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); + return FAILED; + } + // 3. replace all end nodes by one output node created before + if (ChangeEnd2NetOutput(node, out_node) != SUCCESS) { + GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); + return FAILED; + } + GELOGD("TUU:end[%s] handle success", node->GetName().c_str()); + return SUCCESS; +} + +// part 2 +graphStatus TuningUtils::ConvertFileToGraph(const map &options, ge::Graph &graph) { + // 1. get all subgraph object + std::vector graphs; + // options format like {index:"subgraph_path"} + for (const auto &pair : options) { + ComputeGraphPtr compute_graph = ComGraphMakeShared(std::to_string(pair.first)); + if (!ge::GraphUtils::LoadGEGraph(pair.second.c_str(), *compute_graph)) { + GELOGE(FAILED, "TUU:load graph from file failed"); + } + graphs.push_back(compute_graph); + } + // 2. merge graph + ComputeGraphPtr merged_graph = ComGraphMakeShared("whole_graph_after_tune"); + GE_CHECK_NOTNULL(merged_graph); + if (MergeAllSubGraph(graphs, merged_graph) != SUCCESS) { + GELOGE(FAILED, "TUU:MergeGraph failed"); + return FAILED; + } + // 3. set parent graph + for (const auto &node : merged_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + if (node->SetOwnerComputeGraph(merged_graph) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:node %s set owner graph failed", node->GetName().c_str()); + return FAILED; + } + } + graph = GraphUtils::CreateGraphFromComputeGraph(merged_graph); + return SUCCESS; +} + +// +----------------------------------+ +// | const const | +// | \ / | +// | netoutput(end,end) | +// +----------------------------------+ +// + +// +----------------------------------+ +// | data(pld) data(pld) | +// | \ / | +// | relu relu | +// | \ / | +// | \ / | +// | add | +// | | | +// | netoutput(end) | +// +----------------------------------+ +// + +// +----------------------------------+ +// | data(pld) | +// | / | +// | netoutput | +// +----------------------------------+ +// | +// | +// V +// +----------------------------------+ +// | const const | +// | \ / | +// | relu relu | +// | \ / | +// | \ / | +// | add | +// | | | +// | netoutput | +// +----------------------------------+ +graphStatus TuningUtils::MergeAllSubGraph(std::vector &subgraphs, + ComputeGraphPtr &output_merged_compute_graph) { + GE_CHECK_NOTNULL(output_merged_compute_graph); + // 1. handle all subgraphs + for (auto &subgraph : subgraphs) { + Status ret_status = MergeSubGraph(subgraph); + if (ret_status != SUCCESS) { + GELOGE(ret_status, "TUU:subgraph %s merge failed", subgraph->GetName().c_str()); + return ret_status; + } + } + + for (const auto &node : merged_graph_nodes_) { + (void)output_merged_compute_graph->AddNode(node); + GELOGD("TUU:graph %s add node %s success", output_merged_compute_graph->GetName().c_str(), node->GetName().c_str()); + } + + // 2. remove data and output node added by us + if (RemoveDataNetoutputEdge(output_merged_compute_graph) != SUCCESS) { + GELOGE(FAILED, "TUU:Failed to merge graph %s", output_merged_compute_graph->GetName().c_str()); + return FAILED; + } + graphStatus ret = output_merged_compute_graph->TopologicalSorting(); + if (ret != SUCCESS) { + GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", output_merged_compute_graph->GetName().c_str(), ret); + return ret; + } + GELOGD("TUU:Print-%s", PrintCheckLog().c_str()); + GELOGI("TUU:output_merged_compute_graph %s success", output_merged_compute_graph->GetName().c_str()); + return SUCCESS; +} + +graphStatus TuningUtils::MergeSubGraph(ComputeGraphPtr &subgraph) { + for (auto &node : subgraph->GetDirectNode()) { + if (kPartitionOpTypes.count(node->GetType()) > 0) { + GELOGE(FAILED, "TUU:subgraph passed in should not contain nodes of end or pld type"); + return FAILED; + } + // handle data converted from pld node + if (node->GetType() == DATA) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + std::string peer_out_name; + bool has_valid_str = (AttrUtils::GetStr(op_desc, peer_node_name_attr, peer_out_name)) && (!peer_out_name.empty()); + if (has_valid_str) { + std::lock_guard lock(mutex_); + data_2_netoutput_.emplace(op_desc->GetName(), peer_out_name); + data_node_2_netoutput_.emplace(node, peer_out_name); + continue; + } + } + // handle netoutput converted from end node + if (node->GetType() == NETOUTPUT) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + std::vector out_alias_name; + bool has_valid_str = + (AttrUtils::GetListStr(op_desc, alias_name_attr, out_alias_name)) && (!out_alias_name.empty()); + if (has_valid_str) { + std::lock_guard lock(mutex_); + netoutput_nodes_.insert(node); + } + } + { + std::lock_guard lock(mutex_); + merged_graph_nodes_.emplace(node); + } + GELOGD("TUU:subgraph %s add node %s success", subgraph->GetName().c_str(), node->GetName().c_str()); + } + GELOGI("TUU:merge subgraph %s success", subgraph->GetName().c_str()); + return SUCCESS; +} + +graphStatus TuningUtils::RemoveDataNetoutputEdge(ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + // 1. traverse + for (auto &pair : data_node_2_netoutput_) { + auto data_node = pair.first; + GE_CHECK_NOTNULL(data_node); + auto netoutput_name = pair.second; + auto netoutput_node = graph->FindNode(netoutput_name); + GE_CHECK_NOTNULL(netoutput_node); + data_node_2_netoutput_node_.emplace(data_node, netoutput_node); + // 2. get `data out anchor` and `net output in anchor` and `net output in node's out anchor` + AnchorPtr data_out_anchor = (data_node->GetOutDataAnchor(0)->GetFirstPeerAnchor() == nullptr) + ? Anchor::DynamicAnchorCast(data_node->GetOutControlAnchor()) + : Anchor::DynamicAnchorCast(data_node->GetOutDataAnchor(0)); + AnchorPtr net_output_in_anchor = nullptr; + AnchorPtr src_out_anchor = nullptr; + if (GetInAndOutAnchorPair(data_node, netoutput_node, net_output_in_anchor, src_out_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:get out node:%s 's in anchor related with data node:%s failed", + netoutput_node->GetName().c_str(), data_node->GetName().c_str()); + return FAILED; + } + // 3. relink + if (GraphUtils::RemoveEdge(src_out_anchor, net_output_in_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", + GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(), + GetNodeNameByAnchor(net_output_in_anchor.get()).c_str(), net_output_in_anchor->GetIdx(), + data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); + return FAILED; + } + GE_CHECK_NOTNULL(data_out_anchor); + for (const auto &peer_in_anchor : data_out_anchor->GetPeerAnchors()) { + if (GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", + GetNodeNameByAnchor(data_out_anchor.get()).c_str(), data_out_anchor->GetIdx(), + GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), + data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); + return FAILED; + } + if (GraphUtils::AddEdge(src_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", + GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(), + GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), + data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); + return FAILED; + } + } + } + // 4. remove out nodes added by us + for (auto &node : netoutput_nodes_) { + NodeUtils::UnlinkAll(*node); + if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str()); + return FAILED; + } + GELOGD("TUU:Remove node %s by the RemoveDataNetoutputEdge process success", node->GetName().c_str()); + } + return SUCCESS; +} + +graphStatus TuningUtils::GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor, + AnchorPtr &src_out_anchor) { + // 1. get `data parent node name`, i.e. `netoutput input node name` + std::string netoutput_input_name; + auto op_desc = data_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (!AttrUtils::GetStr(op_desc, parent_node_name_attr, netoutput_input_name)) { + GELOGE(FAILED, "TUU:Failed to get parent node attr from node %s", op_desc->GetName().c_str()); + return FAILED; + } + // 2. find index + int parent_node_anchor_index; + if (!AttrUtils::GetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index)) { + GELOGE(FAILED, "TUU:Failed to get parent node anchor index attr from node %s", op_desc->GetName().c_str()); + return FAILED; + } + // 3.find in data or ctrl anchor by 1&2 step + for (auto &in_anchor : out_node->GetAllInAnchors()) { + GE_CHECK_NOTNULL(in_anchor); + for (auto &src_anchor : in_anchor->GetPeerAnchors()) { // get all peer anchors for ctrl + GE_CHECK_NOTNULL(src_anchor); + auto src_node = src_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + if (src_node->GetName() == netoutput_input_name && src_anchor->GetIdx() == parent_node_anchor_index) { + dest_in_anchor = in_anchor; + src_out_anchor = src_anchor; + GELOGD("TUU:get out node:%s 's in anchor(%d) src_node:%s 's out anchor(%d) related with data node:%s", + out_node->GetName().c_str(), dest_in_anchor->GetIdx(), netoutput_input_name.c_str(), + parent_node_anchor_index, data_node->GetName().c_str()); + break; + } + } + } + GE_CHECK_NOTNULL(dest_in_anchor); + GE_CHECK_NOTNULL(src_out_anchor); + return SUCCESS; +} + +} // namespace ge \ No newline at end of file diff --git a/src/common/graph/utils/type_utils.cc b/src/common/graph/utils/type_utils.cc new file mode 100644 index 00000000..2efc530e --- /dev/null +++ b/src/common/graph/utils/type_utils.cc @@ -0,0 +1,448 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/utils/type_utils.h" +#include "debug/ge_util.h" + +using domi::domiTensorFormat_t; + +namespace ge { +static const std::map kFormatToStringMap = { + {FORMAT_NCHW, "NCHW"}, + {FORMAT_NHWC, "NHWC"}, + {FORMAT_ND, "ND"}, + {FORMAT_NC1HWC0, "NC1HWC0"}, + {FORMAT_FRACTAL_Z, "FRACTAL_Z"}, + {FORMAT_NC1C0HWPAD, "NC1C0HWPAD"}, + {FORMAT_NHWC1C0, "NHWC1C0"}, + {FORMAT_FSR_NCHW, "FSR_NCHW"}, + {FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"}, + {FORMAT_C1HWNC0, "C1HWNC0"}, + {FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, + {FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, + {FORMAT_NC1HWC0_C04, "NC1HWC0_C04"}, + {FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, + {FORMAT_CHWN, "CHWN"}, + {FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, + {FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"}, + {FORMAT_BN_WEIGHT, "BN_WEIGHT"}, + {FORMAT_FILTER_HWCK, "FILTER_HWCK"}, + {FORMAT_HWCN, "HWCN"}, + {FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, + {FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, + {FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, + {FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, + {FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, + {FORMAT_MD, "MD"}, + {FORMAT_NDHWC, "NDHWC"}, + {FORMAT_NCDHW, "NCDHW"}, + {FORMAT_DHWCN, "DHWCN"}, + {FORMAT_DHWNC, "DHWNC"}, + {FORMAT_NDC1HWC0, "NDC1HWC0"}, + {FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, + {FORMAT_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"}, + {FORMAT_C1HWNCoC0, "C1HWNCoC0"}, + {FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, + {FORMAT_CN, "CN"}, + {FORMAT_NC, "NC"}, + {FORMAT_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"}, + {FORMAT_FRACTAL_Z_G, "FRACTAL_Z_G"}, + {FORMAT_RESERVED, "FORMAT_RESERVED"}, + {FORMAT_ALL, "ALL"}}; + +static const std::map kDomiFormatToGeFormat = { + {domi::DOMI_TENSOR_NCHW, FORMAT_NCHW}, + {domi::DOMI_TENSOR_NHWC, FORMAT_NHWC}, + {domi::DOMI_TENSOR_ND, FORMAT_ND}, + {domi::DOMI_TENSOR_NC1HWC0, FORMAT_NC1HWC0}, + {domi::DOMI_TENSOR_FRACTAL_Z, FORMAT_FRACTAL_Z}, + {domi::DOMI_TENSOR_NC1C0HWPAD, FORMAT_NC1C0HWPAD}, + {domi::DOMI_TENSOR_NHWC1C0, FORMAT_NHWC1C0}, + {domi::DOMI_TENSOR_FSR_NCHW, FORMAT_FSR_NCHW}, + {domi::DOMI_TENSOR_FRACTAL_DECONV, FORMAT_FRACTAL_DECONV}, + {domi::DOMI_TENSOR_BN_WEIGHT, FORMAT_BN_WEIGHT}, + {domi::DOMI_TENSOR_CHWN, FORMAT_CHWN}, + {domi::DOMI_TENSOR_FILTER_HWCK, FORMAT_FILTER_HWCK}, + {domi::DOMI_TENSOR_NDHWC, FORMAT_NDHWC}, + {domi::DOMI_TENSOR_NCDHW, FORMAT_NCDHW}, + {domi::DOMI_TENSOR_DHWCN, FORMAT_DHWCN}, + {domi::DOMI_TENSOR_DHWNC, FORMAT_DHWNC}, + {domi::DOMI_TENSOR_RESERVED, FORMAT_RESERVED}}; + +static const std::unordered_set kInternalFormat = {"NC1HWC0", + "FRACTAL_Z", + "NC1C0HWPAD", + "NHWC1C0", + "FRACTAL_DECONV", + "C1HWNC0", + "FRACTAL_DECONV_TRANSPOSE", + "FRACTAL_DECONV_SP_STRIDE_TRANS", + "NC1HWC0_C04", + "FRACTAL_Z_C04", + "FRACTAL_DECONV_SP_STRIDE8_TRANS", + "NC1KHKWHWC0", + "C1HWNCoC0", + "FRACTAL_ZZ", + "FRACTAL_NZ", + "NDC1HWC0", + "FORMAT_FRACTAL_Z_3D", + "FORMAT_FRACTAL_Z_3D_TRANSPOSE", + "FORMAT_FRACTAL_ZN_LSTM", + "FORMAT_FRACTAL_Z_G"}; + +static const std::map kDataFormatMap = { + {"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}}; + +static const std::map kStringToFormatMap = { + {"NCHW", FORMAT_NCHW}, + {"NHWC", FORMAT_NHWC}, + {"ND", FORMAT_ND}, + {"NC1HWC0", FORMAT_NC1HWC0}, + {"FRACTAL_Z", FORMAT_FRACTAL_Z}, + {"NC1C0HWPAD", FORMAT_NC1C0HWPAD}, + {"NHWC1C0", FORMAT_NHWC1C0}, + {"FSR_NCHW", FORMAT_FSR_NCHW}, + {"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV}, + {"C1HWNC0", FORMAT_C1HWNC0}, + {"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE}, + {"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, + {"NC1HWC0_C04", FORMAT_NC1HWC0_C04}, + {"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04}, + {"CHWN", FORMAT_CHWN}, + {"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, + {"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0}, + {"BN_WEIGHT", FORMAT_BN_WEIGHT}, + {"FILTER_HWCK", FORMAT_FILTER_HWCK}, + {"HWCN", FORMAT_HWCN}, + {"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, + {"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS}, + {"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE}, + {"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, + {"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS}, + {"MD", FORMAT_MD}, + {"C1HWNCoC0", FORMAT_C1HWNCoC0}, + {"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, + {"NDHWC", FORMAT_NDHWC}, + {"NCDHW", FORMAT_NCDHW}, + {"DHWCN", FORMAT_DHWCN}, + {"DHWNC", FORMAT_DHWNC}, + {"NDC1HWC0", FORMAT_NDC1HWC0}, + {"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D}, + {"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, + {"CN", FORMAT_CN}, + {"NC", FORMAT_NC}, + {"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, + {"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G}, + {"FORMAT_RESERVED", FORMAT_RESERVED}, + {"ALL", FORMAT_ALL}, + {"NULL", FORMAT_NULL}}; + +static const std::map kDataTypeToStringMap = { + {DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. + {DT_FLOAT, "DT_FLOAT"}, // float type + {DT_FLOAT16, "DT_FLOAT16"}, // fp16 type + {DT_INT8, "DT_INT8"}, // int8 type + {DT_INT16, "DT_INT16"}, // int16 type + {DT_UINT16, "DT_UINT16"}, // uint16 type + {DT_UINT8, "DT_UINT8"}, // uint8 type + {DT_INT32, "DT_INT32"}, // uint32 type + {DT_INT64, "DT_INT64"}, // int64 type + {DT_UINT32, "DT_UINT32"}, // unsigned int32 + {DT_UINT64, "DT_UINT64"}, // unsigned int64 + {DT_BOOL, "DT_BOOL"}, // bool type + {DT_DOUBLE, "DT_DOUBLE"}, // double type + {DT_DUAL, "DT_DUAL"}, // dual output type + {DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type + {DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type + {DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type + {DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type + {DT_QINT8, "DT_QINT8"}, // qint8 type + {DT_QINT16, "DT_QINT16"}, // qint16 type + {DT_QINT32, "DT_QINT32"}, // qint32 type + {DT_QUINT8, "DT_QUINT8"}, // quint8 type + {DT_QUINT16, "DT_QUINT16"}, // quint16 type + {DT_RESOURCE, "DT_RESOURCE"}, // resource type + {DT_STRING_REF, "DT_STRING_REF"}, // string ref type + {DT_STRING, "DT_STRING"}, // string type +}; + +static const std::map kStringTodataTypeMap = { + {"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set. + {"DT_FLOAT", DT_FLOAT}, // float type + { + "DT_FLOAT16", + DT_FLOAT16, + }, // fp16 type + {"DT_INT8", DT_INT8}, // int8 type + {"DT_INT16", DT_INT16}, // int16 type + {"DT_UINT16", DT_UINT16}, // uint16 type + {"DT_UINT8", DT_UINT8}, // uint8 type + {"DT_INT32", DT_INT32}, // uint32 type + {"DT_INT64", DT_INT64}, // int64 type + {"DT_UINT32", DT_UINT32}, // unsigned int32 + {"DT_UINT64", DT_UINT64}, // unsigned int64 + {"DT_BOOL", DT_BOOL}, // bool type + {"DT_DOUBLE", DT_DOUBLE}, // double type + {"DT_DUAL", DT_DUAL}, // dual output type + {"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type + {"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type + {"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type + {"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type + {"DT_QINT8", DT_QINT8}, // qint8 type + {"DT_QINT16", DT_QINT16}, // qint16 type + {"DT_QINT32", DT_QINT32}, // qint32 type + {"DT_QUINT8", DT_QUINT8}, // quint8 type + {"DT_QUINT16", DT_QUINT16}, // quint16 type + {"DT_RESOURCE", DT_RESOURCE}, // resource type + {"DT_STRING_REF", DT_STRING_REF}, // string ref type + {"DT_STRING", DT_STRING}, // string type +}; + +static const std::map kDataTypeToLength = { + {DT_BOOL, sizeof(bool)}, + {DT_INT64, sizeof(int64_t)}, + {DT_UINT64, sizeof(int64_t)}, + {DT_FLOAT, sizeof(float)}, + {DT_INT32, sizeof(int32_t)}, + {DT_UINT32, sizeof(int32_t)}, + {DT_INT8, sizeof(char)}, + {DT_UINT8, sizeof(char)}, + {DT_INT16, sizeof(int16_t)}, + {DT_UINT16, sizeof(int16_t)}, + {DT_FLOAT16, sizeof(int16_t)}, + {DT_DOUBLE, sizeof(double)}, + {DT_DUAL, sizeof(float) + sizeof(int8_t)}, + {DT_DUAL_SUB_INT8, sizeof(int8_t)}, + {DT_DUAL_SUB_UINT8, sizeof(uint8_t)}, + {DT_COMPLEX64, sizeof(int64_t)}, + {DT_COMPLEX128, sizeof(int64_t) * 2}, + {DT_QINT8, sizeof(int8_t)}, + {DT_QINT16, sizeof(int16_t)}, + {DT_QINT32, sizeof(int32_t)}, + {DT_QUINT8, sizeof(uint8_t)}, + {DT_QUINT16, sizeof(uint16_t)}, + {DT_STRING_REF, sizeof(uint64_t) * 2}, + {DT_STRING, sizeof(uint64_t)}, + {DT_RESOURCE, sizeof(uint64_t)}, +}; + +static const std::map kFmkTypeToString = { + {domi::CAFFE, "caffe"}, {domi::MINDSPORE, "mindspore"}, {domi::TENSORFLOW, "tensorflow"}, + {domi::ANDROID_NN, "android_nn"}, {domi::ONNX, "onnx"}, {domi::FRAMEWORK_RESERVED, "framework_reserved"}, +}; + +static const std::map kImplyTypeToString = { + {domi::ImplyType::BUILDIN, "buildin"}, {domi::ImplyType::TVM, "tvm"}, {domi::ImplyType::CUSTOM, "custom"}, + {domi::ImplyType::AI_CPU, "ai_cpu"}, {domi::ImplyType::CCE, "cce"}, {domi::ImplyType::GELOCAL, "gelocal"}, + {domi::ImplyType::HCCL, "hccl"}, {domi::ImplyType::INVALID, "invalid"}}; + +std::string TypeUtils::ImplyTypeToSerialString(domi::ImplyType imply_type) { + auto it = kImplyTypeToString.find(imply_type); + if (it != kImplyTypeToString.end()) { + return it->second; + } else { + GELOGE(GRAPH_FAILED, "ImplyTypeToSerialString: imply_type not support %u", imply_type); + return "UNDEFINED"; + } +} + +bool TypeUtils::IsDataTypeValid(DataType dt) { + uint32_t num = static_cast(dt); + GE_CHK_BOOL_EXEC((num <= DT_UNDEFINED), return false, "The DataType is invalid"); + return true; +} + +std::string TypeUtils::DataTypeToSerialString(DataType data_type) { + auto it = kDataTypeToStringMap.find(data_type); + if (it != kDataTypeToStringMap.end()) { + return it->second; + } else { + GELOGE(GRAPH_FAILED, "DataTypeToSerialString: datatype not support %u", data_type); + return "UNDEFINED"; + } +} + +DataType TypeUtils::SerialStringToDataType(const std::string &str) { + auto it = kStringTodataTypeMap.find(str); + if (it != kStringTodataTypeMap.end()) { + return it->second; + } else { + GELOGE(GRAPH_FAILED, "SerialStringToDataType: datatype not support %s", str.c_str()); + return DT_UNDEFINED; + } +} + +bool TypeUtils::IsFormatValid(Format format) { + uint32_t num = static_cast(format); + GE_CHK_BOOL_EXEC((num <= FORMAT_RESERVED), return false, "The Format is invalid"); + return true; +} + +bool TypeUtils::IsInternalFormat(Format format) { + std::string serial_format = FormatToSerialString(format); + auto iter = kInternalFormat.find(serial_format); + bool result = (iter == kInternalFormat.end()) ? false : true; + return result; +} + +std::string TypeUtils::FormatToSerialString(Format format) { + auto it = kFormatToStringMap.find(format); + if (it != kFormatToStringMap.end()) { + return it->second; + } else { + GELOGE(GRAPH_FAILED, "Format not support %u", format); + return "RESERVED"; + } +} +Format TypeUtils::SerialStringToFormat(const std::string &str) { + auto it = kStringToFormatMap.find(str); + if (it != kStringToFormatMap.end()) { + return it->second; + } else { + GELOGE(GRAPH_FAILED, "Format not support %s", str.c_str()); + return FORMAT_RESERVED; + } +} + +Format TypeUtils::DataFormatToFormat(const std::string &str) { + auto it = kDataFormatMap.find(str); + if (it != kDataFormatMap.end()) { + return it->second; + } else { + GELOGE(GRAPH_FAILED, "Format not support %s", str.c_str()); + return FORMAT_RESERVED; + } +} + +Format TypeUtils::DomiFormatToFormat(domi::domiTensorFormat_t domi_format) { + auto it = kDomiFormatToGeFormat.find(domi_format); + if (it != kDomiFormatToGeFormat.end()) { + return it->second; + } + GELOGE(GRAPH_FAILED, "do not find domi Format %d from map", domi_format); + return FORMAT_RESERVED; +} + +std::string TypeUtils::FmkTypeToSerialString(domi::FrameworkType fmk_type) { + auto it = kFmkTypeToString.find(fmk_type); + if (it != kFmkTypeToString.end()) { + return it->second; + } else { + GELOGW("Framework type not support %d.", fmk_type); + return ""; + } +} + +static inline void CopyDataFromBuffer(vector &data, const Buffer &buffer) { + data.clear(); + if (buffer.GetData() != nullptr && buffer.GetSize() != 0) { + data.assign(buffer.GetData(), buffer.GetData() + buffer.GetSize()); + } +} + +graphStatus Usr2DefQuantizeFactor(const UsrQuantizeFactor &usr, QuantizeFactor &def) { + def.scale_mode = uint32_t(usr.scale_mode); + def.set_scale_value(usr.scale_value.data(), usr.scale_value.size()); + def.scale_offset = usr.scale_offset; + def.set_offset_data_value(usr.offset_data_value.data(), usr.offset_data_value.size()); + def.offset_data_offset = usr.offset_data_offset; + def.set_offset_weight_value(usr.offset_weight_value.data(), usr.offset_weight_value.size()); + def.offset_weight_offset = usr.offset_weight_offset; + def.set_offset_pad_value(usr.offset_pad_value.data(), usr.offset_pad_value.size()); + def.offset_pad_offset = usr.offset_pad_offset; + return GRAPH_SUCCESS; +} +graphStatus Def2UsrQuantizeFactor(const QuantizeFactor &def, UsrQuantizeFactor &usr) { + usr.scale_mode = UsrQuantizeScaleMode(def.scale_mode); + CopyDataFromBuffer(usr.scale_value, def.scale_value); + usr.scale_offset = def.scale_offset; + CopyDataFromBuffer(usr.offset_data_value, def.offset_data_value); + usr.offset_data_offset = def.offset_data_offset; + CopyDataFromBuffer(usr.offset_weight_value, def.offset_weight_value); + usr.offset_weight_offset = def.offset_weight_offset; + CopyDataFromBuffer(usr.offset_pad_value, def.offset_pad_value); + usr.offset_pad_offset = def.offset_pad_offset; + return GRAPH_SUCCESS; +} +graphStatus Usr2DefUsrQuantizeCalcFactor(const UsrQuantizeCalcFactor &usr, QuantizeCalcFactor &def) { + def.set_offsetw(usr.offsetw.data(), usr.offsetw.size()); + def.offsetw_offset = usr.offsetw_offset; + def.set_offsetd(usr.offsetd.data(), usr.offsetd.size()); + def.offsetd_offset = usr.offsetd_offset; + def.set_scalereq(usr.scalereq.data(), usr.scalereq.size()); + def.scaledreq_offset = usr.scaledreq_offset; + def.set_offsetdnext(usr.offsetdnext.data(), usr.offsetdnext.size()); + def.offsetdnext_offset = usr.offsetdnext_offset; + return GRAPH_SUCCESS; +} +graphStatus Def2UsrQuantizeCalcFactor(const QuantizeCalcFactor &def, UsrQuantizeCalcFactor &usr) { + CopyDataFromBuffer(usr.offsetw, def.offsetw); + usr.offsetw_offset = def.offsetw_offset; + CopyDataFromBuffer(usr.offsetd, def.offsetd); + usr.offsetd_offset = def.offsetd_offset; + CopyDataFromBuffer(usr.scalereq, def.scalereq); + usr.scaledreq_offset = def.scaledreq_offset; + CopyDataFromBuffer(usr.offsetdnext, def.offsetdnext); + usr.offsetdnext_offset = def.offsetdnext_offset; + return GRAPH_SUCCESS; +} +graphStatus TypeUtils::Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def) { + def.quantize_algo = uint32_t(usr.quantize_algo); + def.scale_type = uint32_t(usr.scale_type); + GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.quantize_param, def.quantize_param), + "Usr2DefQuantizeFactor quantize_param failed"); + GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.dequantize_param, def.dequantize_param), + "Usr2DefQuantizeFactor dequantize_param failed"); + GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.requantize_param, def.requantize_param), + "Usr2DefQuantizeFactor requantize_param failed"); + GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefUsrQuantizeCalcFactor(usr.quantizecalc_param, def.quantizecalc_param), + "Usr2DefQuantizeFactor quantizecalc_param failed"); + return GRAPH_SUCCESS; +} +graphStatus TypeUtils::Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr) { + usr.quantize_algo = UsrQuantizeAlgorithm(def.quantize_algo); + usr.scale_type = UsrQuantizeScaleType(def.scale_type); + GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.quantize_param, usr.quantize_param), + "Def2UsrQuantizeFactor quantize_param failed"); + GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.dequantize_param, usr.dequantize_param), + "Def2UsrQuantizeFactor dequantize_param failed"); + GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.requantize_param, usr.requantize_param), + "Def2UsrQuantizeFactor requantize_param failed"); + GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeCalcFactor(def.quantizecalc_param, usr.quantizecalc_param), + "Def2UsrQuantizeCalcFactor quantizecalc_param failed"); + return GRAPH_SUCCESS; +} +bool TypeUtils::GetDataTypeLength(ge::DataType data_type, uint32_t &length) { + auto it = kDataTypeToLength.find(data_type); + if (it != kDataTypeToLength.end()) { + length = it->second; + return true; + } else { + GELOGE(GRAPH_FAILED, "data_type not support %d", data_type); + return false; + } +} +bool TypeUtils::CheckUint64MulOverflow(uint64_t a, uint32_t b) { + // Not overflow + if (a == 0) { + return false; + } + if ((ULLONG_MAX / a) >= b) { + return false; + } + return true; +} +} // namespace ge diff --git a/src/ge/CMakeLists.txt b/src/ge/CMakeLists.txt new file mode 100755 index 00000000..3f4f1a8b --- /dev/null +++ b/src/ge/CMakeLists.txt @@ -0,0 +1,380 @@ +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# libge_compiler.so & libge_runner.so +# will later be integrated into libgraph_runner.so, works for both training and inference +# compiling proto files generates some warnings, use no-unused-variable to suppress them +set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") +file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "../proto/fusion_model.proto" + "../proto/optimizer_priority.proto" + ) +file(GLOB PROTO_CLIENT_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "../proto/ge_api.proto" + ) +file(GLOB PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "../proto/om.proto" + "../proto/task.proto" + "../proto/insert_op.proto" + "../proto/ge_ir.proto" + "../proto/fwk_adapter.proto" + "../proto/op_mapping_info.proto" + "../proto/dump_task.proto" + ) +ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +ge_protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) +ge_protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) +# include directories +include_directories(${CMAKE_CURRENT_LIST_DIR}) +include_directories(${GE_SOURCE_DIR}) +include_directories(${GE_SOURCE_DIR}/src) +include_directories(${GE_SOURCE_DIR}/src/ge/analyzer) +include_directories(${GE_SOURCE_DIR}/inc) +include_directories(${GE_SOURCE_DIR}/inc/common/util) +include_directories(${GE_SOURCE_DIR}/inc/external) +include_directories(${GE_SOURCE_DIR}/inc/external/graph) +include_directories(${GE_SOURCE_DIR}/inc/framework) +include_directories(${GE_SOURCE_DIR}/inc/framework/common) +include_directories(${GE_SOURCE_DIR}/inc/graph) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/toolchain) +include_directories(${CMAKE_BINARY_DIR}) +include_directories(${CMAKE_BINARY_DIR}/proto/ge) + +######### libge_runner.so ############# +# need to remove dependencies on pb files later +file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "analyzer/analyzer.cc" + "client/ge_prof.cc" + "client/ge_api.cc" + "common/dump/dump_manager.cc" + "common/dump/dump_properties.cc" + "common/dump/dump_op.cc" + "common/formats/format_transfers/*.cc" + "common/formats/formats.cc" + "common/formats/utils/formats_trans_utils.cc" + "common/fp16_t.cc" + "common/ge/op_tiling_manager.cc" + "common/ge/plugin_manager.cc" + "common/helper/model_cache_helper.cc" + "common/profiling/profiling_manager.cc" + "engine_manager/dnnengine_manager.cc" + "executor/ge_executor.cc" + "ge_local_engine/engine/host_cpu_engine.cc" + "generator/ge_generator.cc" + "generator/generator_api.cc" + "graph/build/*.cc" + "graph/common/*.cc" + "graph/execute/graph_execute.cc" + "graph/label/*.cc" + "graph/load/graph_loader.cc" + "graph/load/new_model_manager/*.cc" + "graph/load/new_model_manager/task_info/end_graph_task_info.cc" + "graph/load/new_model_manager/task_info/event_record_task_info.cc" + "graph/load/new_model_manager/task_info/event_wait_task_info.cc" + "graph/load/new_model_manager/task_info/fusion_start_task_info.cc" + "graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" + "graph/load/new_model_manager/task_info/hccl_task_info.cc" + "graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" + "graph/load/new_model_manager/task_info/kernel_task_info.cc" + "graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" + "graph/load/new_model_manager/task_info/label_set_task_info.cc" + "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" + "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" + "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" + "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" + "graph/load/new_model_manager/task_info/stream_active_task_info.cc" + "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" + "graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" + "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" + "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" + "graph/load/new_model_manager/task_info/task_info.cc" + "graph/manager/graph_context.cc" + "graph/manager/graph_manager.cc" + "graph/manager/graph_manager_utils.cc" + "graph/manager/graph_mem_allocator.cc" + "graph/manager/graph_caching_allocator.cc" + "graph/manager/graph_var_manager.cc" + "graph/manager/model_manager/event_manager.cc" + "graph/manager/rdma_pool_allocator.cc" + "graph/manager/trans_var_data_utils.cc" + "graph/manager/util/debug.cc" + "graph/manager/util/hcom_util.cc" + "graph/manager/util/rt_context_util.cc" + "graph/manager/util/variable_accelerate_ctrl.cc" + "graph/manager/util/debug.cc" + "graph/manager/util/hcom_util.cc" + "graph/manager/util/rt_context_util.cc" + "graph/manager/util/variable_accelerate_ctrl.cc" + "graph/optimize/graph_optimize.cc" + "graph/optimize/mem_rw_conflict_optimize.cc" + "graph/optimize/optimizer/allreduce_fusion_pass.cc" + "graph/optimize/summary_optimize.cc" + "graph/partition/dynamic_shape_partition.cc" + "graph/partition/engine_place.cc" + "graph/partition/graph_partition.cc" + "graph/passes/*.cc" + "graph/preprocess/graph_preprocess.cc" + "graph/preprocess/insert_op/ge_aipp_op.cc" + "graph/preprocess/insert_op/util_insert_aipp_op.cc" + "graph/preprocess/multi_batch_copy_graph.cc" + "graph/preprocess/multi_batch_options.cc" + "host_kernels/add_kernel.cc" + "host_kernels/broadcast_args_kernel.cc" + "host_kernels/broadcast_gradient_args_kernel.cc" + "host_kernels/cast_kernel.cc" + "host_kernels/concat_offset_kernel.cc" + "host_kernels/concat_v2_kernel.cc" + "host_kernels/dynamic_stitch_kernel.cc" + "host_kernels/empty_kernel.cc" + "host_kernels/expanddims_kernel.cc" + "host_kernels/fill_kernel.cc" + "host_kernels/floordiv_kernel.cc" + "host_kernels/floormod_kernel.cc" + "host_kernels/gather_v2_kernel.cc" + "host_kernels/greater_kernel.cc" + "host_kernels/identity_kernel.cc" + "host_kernels/kernel_utils.cc" + "host_kernels/maximum_kernel.cc" + "host_kernels/mul_kernel.cc" + "host_kernels/pack_kernel.cc" + "host_kernels/permute_kernel.cc" + "host_kernels/range_kernel.cc" + "host_kernels/rank_kernel.cc" + "host_kernels/reduce_prod_kernel.cc" + "host_kernels/reshape_kernel.cc" + "host_kernels/rsqrt_kernel.cc" + "host_kernels/shape_kernel.cc" + "host_kernels/shape_n_kernel.cc" + "host_kernels/size_kernel.cc" + "host_kernels/slice_d_kernel.cc" + "host_kernels/slice_kernel.cc" + "host_kernels/squeeze_kernel.cc" + "host_kernels/ssd_prior_box_kernel.cc" + "host_kernels/strided_slice_kernel.cc" + "host_kernels/sub_kernel.cc" + "host_kernels/transdata_kernel.cc" + "host_kernels/transpose_kernel.cc" + "host_kernels/unpack_kernel.cc" + "host_kernels/unsqueeze_kernel.cc" + "hybrid/common/npu_memory_allocator.cc" + "hybrid/common/tensor_value.cc" + "hybrid/executor/*.cc" + "hybrid/executor/worker/*.cc" + "hybrid/hybrid_davinci_model.cc" + "hybrid/model/*.cc" + "hybrid/node_executor/aicore/*.cc" + "hybrid/node_executor/aicpu/aicpu_ext_info.cc" + "hybrid/node_executor/aicpu/aicpu_node_executor.cc" + "hybrid/node_executor/compiledsubgraph/known_node_executor.cc" + "hybrid/node_executor/controlop/control_op_executor.cc" + "hybrid/node_executor/ge_local/ge_local_node_executor.cc" + "hybrid/node_executor/hccl/hccl_node_executor.cc" + "hybrid/node_executor/hostcpu/ge_local_node_executor.cc" + "hybrid/node_executor/host_cpu/host_cpu_node_executor.cc" + "hybrid/node_executor/host_cpu/kernel_factory.cc" + "hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc" + "hybrid/node_executor/host_cpu/kernel/variable_kernel.cc" + "hybrid/node_executor/host_cpu/kernel/assign_kernel.cc" + "hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc" + "hybrid/node_executor/node_executor.cc" + "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" + "hybrid/node_executor/rts/rts_node_executor.cc" + "hybrid/node_executor/task_context.cc" + "init/gelib.cc" + "model/ge_model.cc" + "model/ge_root_model.cc" + "omm/csa_interact.cc" + "opskernel_manager/ops_kernel_manager.cc" + "session/inner_session.cc" + "session/session_manager.cc" + "single_op/*.cc" + "single_op/task/*.cc" + ) + + +######### libge_runner.so ############# +add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS} ${PROTO_HEADER_HDRS}) +target_compile_definitions(ge_runner PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + DAVINCI_SUPPORT_PROFILING + REUSE_MEMORY=1 + DAVINCI_CLOUD) +target_link_libraries(ge_runner + graph + ge_common + ge_memory + ${PROTOBUF_LIBRARY} + ${register} + ${c_sec} + ${slog} + ${mmpa} + ${hccl} + ${msprof} + ${runtime} + ${resouce} + ${ascend_hal} + ${adump_server} + ${msprofiler} + rt + dl) + +######### libge_compiler.so ############# +# need to remove dependencies on pb files later +file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "analyzer/analyzer.cc" + "common/dump/dump_properties.cc" + "common/dump/dump_manager.cc" + "common/dump/dump_op.cc" + "common/dump/dump_server.cc" + "common/formats/format_transfers/*.cc" + "common/formats/formats.cc" + "common/formats/utils/formats_trans_utils.cc" + "common/fp16_t.cc" + "common/ge/op_tiling_manager.cc" + "common/ge/plugin_manager.cc" + "common/helper/model_cache_helper.cc" + "common/profiling/profiling_manager.cc" + "engine_manager/dnnengine_manager.cc" + "ge_local_engine/engine/host_cpu_engine.cc" + "generator/ge_generator.cc" + "generator/generator_api.cc" + "graph/build/*.cc" + "graph/common/*.cc" + "graph/execute/graph_execute.cc" + "graph/label/*.cc" + "graph/load/graph_loader.cc" + "graph/load/new_model_manager/*.cc" + "graph/load/new_model_manager/task_info/end_graph_task_info.cc" + "graph/load/new_model_manager/task_info/event_record_task_info.cc" + "graph/load/new_model_manager/task_info/event_wait_task_info.cc" + "graph/load/new_model_manager/task_info/fusion_start_task_info.cc" + "graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" + "graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" + "graph/load/new_model_manager/task_info/kernel_task_info.cc" + "graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" + "graph/load/new_model_manager/task_info/label_set_task_info.cc" + "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" + "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" + "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" + "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" + "graph/load/new_model_manager/task_info/stream_active_task_info.cc" + "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" + "graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" + "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" + "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" + "graph/load/new_model_manager/task_info/task_info.cc" + "graph/manager/graph_caching_allocator.cc" + "graph/manager/graph_context.cc" + "graph/manager/graph_manager.cc" + "graph/manager/graph_manager_utils.cc" + "graph/manager/graph_mem_allocator.cc" + "graph/manager/trans_var_data_utils.cc" + "graph/manager/graph_var_manager.cc" + "graph/manager/model_manager/event_manager.cc" + "graph/manager/rdma_pool_allocator.cc" + "graph/manager/util/debug.cc" + "graph/manager/util/rt_context_util.cc" + "graph/manager/util/variable_accelerate_ctrl.cc" + "graph/optimize/graph_optimize.cc" + "graph/optimize/mem_rw_conflict_optimize.cc" + "graph/optimize/summary_optimize.cc" + "graph/partition/dynamic_shape_partition.cc" + "graph/partition/engine_place.cc" + "graph/partition/graph_partition.cc" + "graph/passes/*.cc" + "graph/preprocess/graph_preprocess.cc" + "graph/preprocess/insert_op/ge_aipp_op.cc" + "graph/preprocess/insert_op/util_insert_aipp_op.cc" + "graph/preprocess/multi_batch_copy_graph.cc" + "graph/preprocess/multi_batch_options.cc" + "host_kernels/add_kernel.cc" + "host_kernels/broadcast_args_kernel.cc" + "host_kernels/broadcast_gradient_args_kernel.cc" + "host_kernels/cast_kernel.cc" + "host_kernels/concat_offset_kernel.cc" + "host_kernels/concat_v2_kernel.cc" + "host_kernels/dynamic_stitch_kernel.cc" + "host_kernels/empty_kernel.cc" + "host_kernels/expanddims_kernel.cc" + "host_kernels/fill_kernel.cc" + "host_kernels/floordiv_kernel.cc" + "host_kernels/floormod_kernel.cc" + "host_kernels/gather_v2_kernel.cc" + "host_kernels/greater_kernel.cc" + "host_kernels/identity_kernel.cc" + "host_kernels/kernel_utils.cc" + "host_kernels/maximum_kernel.cc" + "host_kernels/mul_kernel.cc" + "host_kernels/pack_kernel.cc" + "host_kernels/permute_kernel.cc" + "host_kernels/range_kernel.cc" + "host_kernels/rank_kernel.cc" + "host_kernels/reduce_prod_kernel.cc" + "host_kernels/reshape_kernel.cc" + "host_kernels/rsqrt_kernel.cc" + "host_kernels/shape_kernel.cc" + "host_kernels/shape_n_kernel.cc" + "host_kernels/size_kernel.cc" + "host_kernels/slice_d_kernel.cc" + "host_kernels/slice_kernel.cc" + "host_kernels/squeeze_kernel.cc" + "host_kernels/ssd_prior_box_kernel.cc" + "host_kernels/strided_slice_kernel.cc" + "host_kernels/sub_kernel.cc" + "host_kernels/transdata_kernel.cc" + "host_kernels/transpose_kernel.cc" + "host_kernels/unpack_kernel.cc" + "host_kernels/unsqueeze_kernel.cc" + "hybrid/hybrid_davinci_model_stub.cc" + "hybrid/node_executor/aicpu/aicpu_ext_info.cc" + "init/gelib.cc" + "ir_build/atc_ir_common.cc" + "ir_build/ge_ir_build.cc" + "model/ge_model.cc" + "model/ge_root_model.cc" + "omm/csa_interact.cc" + "opskernel_manager/ops_kernel_manager.cc" + "session/inner_session.cc" + "session/session_manager.cc" + "single_op/*.cc" + "single_op/task/*.cc" + ) + +add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) +target_compile_definitions(ge_compiler PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + REUSE_MEMORY=1 + FMK_HOST_INFER + FMK_SUPPORT_DUMP + COMPILE_OMG_PACKAGE + REUSE_MEMORY=1) +target_link_libraries(ge_compiler + graph + ge_common + ge_memory + ${PROTOBUF_LIBRARY} + ${register} + ${c_sec} + ${slog} + ${mmpa} + ${msprof} + ${runtime} + ${resouce} + ${error_manager} + rt + dl) diff --git a/ge/analyzer/analyzer.cc b/src/ge/analyzer/analyzer.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/analyzer/analyzer.cc rename to src/ge/analyzer/analyzer.cc diff --git a/ge/analyzer/analyzer.h b/src/ge/analyzer/analyzer.h old mode 100755 new mode 100644 similarity index 100% rename from ge/analyzer/analyzer.h rename to src/ge/analyzer/analyzer.h diff --git a/src/ge/client/CMakeLists.txt b/src/ge/client/CMakeLists.txt new file mode 100755 index 00000000..b568e3f6 --- /dev/null +++ b/src/ge/client/CMakeLists.txt @@ -0,0 +1,74 @@ +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# libge_client.so +# add all proto files, generate corresponding .h and .cc files +set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") +file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "../../proto/ge_api.proto" + ) + +file(GLOB PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "../../proto/ge_ir.proto" + "../../proto/task.proto" + "../../proto/om.proto" + "../../proto/insert_op.proto" + ) + +file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "ge_api.cc" + "ge_prof.cc" + ) + +ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +ge_protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) + +# include directories +include_directories(${CMAKE_CURRENT_LIST_DIR}) +include_directories(${GE_SOURCE_DIR}/src/ge) +include_directories(${GE_SOURCE_DIR}/src) +include_directories(${GE_SOURCE_DIR}/inc) +include_directories(${GE_SOURCE_DIR}/inc/external) +include_directories(${GE_SOURCE_DIR}/inc/external/graph) +include_directories(${GE_SOURCE_DIR}/inc/common) +include_directories(${GE_SOURCE_DIR}/inc/framework) +include_directories(${GE_SOURCE_DIR}/inc/graph) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) +include_directories(${CMAKE_BINARY_DIR}) +include_directories(${CMAKE_BINARY_DIR}/proto/ge) + +############ libge_client.so ################ +add_library(ge_client SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) +target_compile_definitions(ge_client PRIVATE + Werror + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + REUSE_MEMORY=1 + PLATFORM_CLOUD) +target_link_libraries(ge_client + graph + ge_compiler + ge_common + ${PROTOBUF_LIBRARY} + ${register} + ${c_sec} + ${slog} + ${mmpa} + ${runtime} + ${msprof} + ${msprofiler} + ${ascend_hal} + rt + dl) diff --git a/ge/client/ge_api.cc b/src/ge/client/ge_api.cc similarity index 100% rename from ge/client/ge_api.cc rename to src/ge/client/ge_api.cc diff --git a/ge/client/ge_prof.cc b/src/ge/client/ge_prof.cc similarity index 100% rename from ge/client/ge_prof.cc rename to src/ge/client/ge_prof.cc diff --git a/ge/client/module.mk b/src/ge/client/module.mk similarity index 100% rename from ge/client/module.mk rename to src/ge/client/module.mk diff --git a/src/ge/common/CMakeLists.txt b/src/ge/common/CMakeLists.txt new file mode 100755 index 00000000..f6c75f87 --- /dev/null +++ b/src/ge/common/CMakeLists.txt @@ -0,0 +1,103 @@ +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# libge_common.so +file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "../../proto/om.proto" + "../../proto/ge_ir.proto" + "../../proto/task.proto" + "../../proto/insert_op.proto" + ) + +file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "../model/ge_model.cc" + "auth/file_saver.cc" + "context/ctx.cc" + "cust_aicpu_kernel_store.cc" + "debug/memory_dumper.cc" + "dump/dump_properties.cc" + "fmk_error_codes.cc" + "formats/format_transfers/datatype_transfer.cc" + "formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" + "formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" + "formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" + "formats/format_transfers/format_transfer_fractal_nz.cc" + "formats/format_transfers/format_transfer_fractal_z.cc" + "formats/format_transfers/format_transfer_fractal_zz.cc" + "formats/format_transfers/format_transfer_fracz_hwcn.cc" + "formats/format_transfers/format_transfer_fracz_nchw.cc" + "formats/format_transfers/format_transfer_fracz_nhwc.cc" + "formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" + "formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" + "formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" + "formats/format_transfers/format_transfer_nchw_fz_c04.cc" + "formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" + "formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" + "formats/format_transfers/format_transfer_transpose.cc" + "formats/formats.cc" + "formats/utils/formats_trans_utils.cc" + "fp16_t.cc" + "ge/datatype_util.cc" + "ge/tbe_plugin_manager.cc" + "ge_format_util.cc" + "helper/model_helper.cc" + "helper/om_file_helper.cc" + "kernel_store.cc" + "math/fp16_math.cc" + "model_parser/base.cc" + "model_saver.cc" + "op/attr_value_util.cc" + "op/ge_op_utils.cc" + "properties_manager.cc" + "tbe_kernel_store.cc" + "thread_pool.cc" + "types.cc" + "util.cc" + ) + +ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) + +# include directories +include_directories(${CMAKE_CURRENT_LIST_DIR}) +include_directories(${CMAKE_CURRENT_LIST_DIR}/op) +include_directories(${GE_SOURCE_DIR}/src/ge) +include_directories(${GE_SOURCE_DIR}/inc) +include_directories(${GE_SOURCE_DIR}/inc/common/util) +include_directories(${GE_SOURCE_DIR}/inc/external) +include_directories(${GE_SOURCE_DIR}/inc/external/graph) +include_directories(${GE_SOURCE_DIR}/inc/framework) +include_directories(${GE_SOURCE_DIR}/inc/graph) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) +include_directories(${CMAKE_BINARY_DIR}) +include_directories(${CMAKE_BINARY_DIR}/proto/ge) + +############ libge_common.so ################ +add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) +target_compile_definitions(ge_common PUBLIC + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + HOST_VISIBILITY + OS_CENTOS) +target_link_libraries(ge_common + graph + ${PROTOBUF_LIBRARY} + ${register} + ${c_sec} + ${slog} + ${mmpa} + ${resource} + ${error_manager} + rt + dl) diff --git a/ge/common/auth/file_saver.cc b/src/ge/common/auth/file_saver.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/auth/file_saver.cc rename to src/ge/common/auth/file_saver.cc diff --git a/ge/common/auth/file_saver.h b/src/ge/common/auth/file_saver.h similarity index 100% rename from ge/common/auth/file_saver.h rename to src/ge/common/auth/file_saver.h diff --git a/ge/common/base64.h b/src/ge/common/base64.h similarity index 100% rename from ge/common/base64.h rename to src/ge/common/base64.h diff --git a/ge/common/context/ctx.cc b/src/ge/common/context/ctx.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/context/ctx.cc rename to src/ge/common/context/ctx.cc diff --git a/src/ge/common/convert/pb2json.cc b/src/ge/common/convert/pb2json.cc new file mode 100644 index 00000000..0a5d24ee --- /dev/null +++ b/src/ge/common/convert/pb2json.cc @@ -0,0 +1,248 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// File: pb2json.h +// Description: This imply file for protobuf message and json interconversion + +#include "common/convert/pb2json.h" +#include +#include +#include "securec.h" +#include "framework/common/fmk_types.h" +#include "framework/common/debug/ge_log.h" + +using std::set; +using std::string; + +namespace ge { +namespace { +const int kSignificantDigits = 10; +} +// JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message, + const set &black_fields, Json &json, + bool enum2str) { + auto descriptor = message.GetDescriptor(); + auto reflection = message.GetReflection(); + if (descriptor == nullptr || reflection == nullptr) { + return; + } + + auto count = descriptor->field_count(); + + for (auto i = 0; i < count; ++i) { + const auto field = descriptor->field(i); + if (field == nullptr) { + return; + } + + // Do not display weight data + if (black_fields.find(field->name()) != black_fields.end()) { + continue; + } + + if (field->is_repeated()) { + if (reflection->FieldSize(message, field) > 0) { + RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str); + } + continue; + } + + if (!reflection->HasField(message, field)) { + continue; + } + + OneField2Json(message, field, reflection, black_fields, json, enum2str); + } +} + +void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, + const ProtobufReflection *reflection, const set &black_fields, Json &json, + bool enum2str) { + switch (field->type()) { + case ProtobufFieldDescriptor::TYPE_MESSAGE: { + const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); + if (0 != tmp_message.ByteSize()) { + Message2Json(tmp_message, black_fields, json[field->name()], enum2str); + } + break; + } + + case ProtobufFieldDescriptor::TYPE_BOOL: + json[field->name()] = reflection->GetBool(message, field); + break; + + case ProtobufFieldDescriptor::TYPE_ENUM: { + auto *enum_value_desc = reflection->GetEnum(message, field); + Enum2Json(enum_value_desc, field, enum2str, json); + break; + } + + case ProtobufFieldDescriptor::TYPE_INT32: + case ProtobufFieldDescriptor::TYPE_SINT32: + case ProtobufFieldDescriptor::TYPE_SFIXED32: + json[field->name()] = reflection->GetInt32(message, field); + break; + + case ProtobufFieldDescriptor::TYPE_UINT32: + case ProtobufFieldDescriptor::TYPE_FIXED32: + json[field->name()] = reflection->GetUInt32(message, field); + break; + + case ProtobufFieldDescriptor::TYPE_INT64: + case ProtobufFieldDescriptor::TYPE_SINT64: + case ProtobufFieldDescriptor::TYPE_SFIXED64: + json[field->name()] = reflection->GetInt64(message, field); + break; + + case ProtobufFieldDescriptor::TYPE_UINT64: + case ProtobufFieldDescriptor::TYPE_FIXED64: + json[field->name()] = reflection->GetUInt64(message, field); + break; + + case ProtobufFieldDescriptor::TYPE_FLOAT: + char str[kSignificantDigits]; + if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1) { + json[field->name()] = str; + } else { + json[field->name()] = reflection->GetFloat(message, field); + } + + break; + + case ProtobufFieldDescriptor::TYPE_STRING: + json[field->name()] = reflection->GetString(message, field); + break; + + case ProtobufFieldDescriptor::TYPE_BYTES: { + string field_name = field->name(); + string type_bytes = reflection->GetString(message, field); + json[field_name] = TypeBytes2String(field_name, type_bytes); + break; + } + + default: + break; + } +} + +string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { + if (field_name != "offset") { + return type_bytes; + } + string result = ""; + for (char temp_value : type_bytes) { + uint8_t *value = 0; + value = reinterpret_cast(&temp_value); + char str[kSignificantDigits]; + if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1) { + GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str()); + continue; + } + result += str; + } + return result; +} + +void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, + const ProtobufReflection *reflection, const set &black_fields, Json &json, + bool enum2str) { + if ((field == nullptr) || (reflection == nullptr)) { + Message2Json(message, black_fields, json, enum2str); + return; + } + + for (auto i = 0; i < reflection->FieldSize(message, field); ++i) { + Json tmp_json; + switch (field->type()) { + case ProtobufFieldDescriptor::TYPE_MESSAGE: { + const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i); + if (0 != tmp_message.ByteSize()) { + Message2Json(tmp_message, black_fields, tmp_json, enum2str); + } + } break; + + case ProtobufFieldDescriptor::TYPE_BOOL: + tmp_json = reflection->GetRepeatedBool(message, field, i); + break; + + case ProtobufFieldDescriptor::TYPE_ENUM: { + auto *enum_value_desc = reflection->GetRepeatedEnum(message, field, i); + RepeatedEnum2Json(enum_value_desc, enum2str, tmp_json); + } break; + + case ProtobufFieldDescriptor::TYPE_INT32: + case ProtobufFieldDescriptor::TYPE_SINT32: + case ProtobufFieldDescriptor::TYPE_SFIXED32: + tmp_json = reflection->GetRepeatedInt32(message, field, i); + break; + + case ProtobufFieldDescriptor::TYPE_UINT32: + case ProtobufFieldDescriptor::TYPE_FIXED32: + tmp_json = reflection->GetRepeatedUInt32(message, field, i); + break; + + case ProtobufFieldDescriptor::TYPE_INT64: + case ProtobufFieldDescriptor::TYPE_SINT64: + case ProtobufFieldDescriptor::TYPE_SFIXED64: + tmp_json = reflection->GetRepeatedInt64(message, field, i); + break; + + case ProtobufFieldDescriptor::TYPE_UINT64: + case ProtobufFieldDescriptor::TYPE_FIXED64: + tmp_json = reflection->GetRepeatedUInt64(message, field, i); + break; + + case ProtobufFieldDescriptor::TYPE_FLOAT: + tmp_json = reflection->GetRepeatedFloat(message, field, i); + break; + + case ProtobufFieldDescriptor::TYPE_STRING: + case ProtobufFieldDescriptor::TYPE_BYTES: + tmp_json = reflection->GetRepeatedString(message, field, i); + break; + + default: + break; + } + json += tmp_json; + } +} + +void Pb2Json::Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, + bool enum2str, Json &json) { + if (enum_value_desc != nullptr) { + if (field == nullptr) { + return; + } + if (enum2str) { + json[field->name()] = enum_value_desc->name(); + } else { + json[field->name()] = enum_value_desc->number(); + } + } +} + +void Pb2Json::RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json) { + if (enum_value_desc != nullptr) { + if (enum2str) { + json = enum_value_desc->name(); + } else { + json = enum_value_desc->number(); + } + } +} +} // namespace ge diff --git a/src/ge/common/convert/pb2json.h b/src/ge/common/convert/pb2json.h new file mode 100644 index 00000000..88ded50e --- /dev/null +++ b/src/ge/common/convert/pb2json.h @@ -0,0 +1,68 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// File: pb2json.h +// Description: This header file for protobuf message and json interconversion + +#ifndef GE_COMMON_CONVERT_PB2JSON_H_ +#define GE_COMMON_CONVERT_PB2JSON_H_ +#include +#include +#include +#include +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "nlohmann/json.hpp" + +namespace ge { +using Json = nlohmann::json; +using ProtobufMsg = ::google::protobuf::Message; +using ProtobufReflection = ::google::protobuf::Reflection; +using ProtobufFieldDescriptor = ::google::protobuf::FieldDescriptor; +using ProtobufDescriptor = ::google::protobuf::Descriptor; +using ProtobufEnumValueDescriptor = ::google::protobuf::EnumValueDescriptor; + +class Pb2Json { + public: + /** + * @ingroup domi_omg + * @brief Transfer protobuf object to JSON object + * @param [out] json Converted JSON object + * @return void success + * @author + */ + static void Message2Json(const ProtobufMsg &message, const std::set &black_fields, Json &json, + bool enum2str = false); + + protected: + static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, + const ProtobufReflection *reflection, const std::set &black_fields, + Json &json, bool enum2str); + + static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, + bool enum2str, Json &json); + + static void RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json); + + static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, + const ProtobufReflection *reflection, const std::set &black_fields, Json &json, + bool enum2str); + + static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes); +}; +} // namespace ge + +#endif // GE_COMMON_CONVERT_PB2JSON_H_ diff --git a/ge/common/cust_aicpu_kernel_store.cc b/src/ge/common/cust_aicpu_kernel_store.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/cust_aicpu_kernel_store.cc rename to src/ge/common/cust_aicpu_kernel_store.cc diff --git a/ge/common/cust_aicpu_kernel_store.h b/src/ge/common/cust_aicpu_kernel_store.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/cust_aicpu_kernel_store.h rename to src/ge/common/cust_aicpu_kernel_store.h diff --git a/ge/common/debug/memory_dumper.cc b/src/ge/common/debug/memory_dumper.cc similarity index 100% rename from ge/common/debug/memory_dumper.cc rename to src/ge/common/debug/memory_dumper.cc diff --git a/ge/common/debug/memory_dumper.h b/src/ge/common/debug/memory_dumper.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/debug/memory_dumper.h rename to src/ge/common/debug/memory_dumper.h diff --git a/ge/common/dump/dump_manager.cc b/src/ge/common/dump/dump_manager.cc similarity index 100% rename from ge/common/dump/dump_manager.cc rename to src/ge/common/dump/dump_manager.cc diff --git a/ge/common/dump/dump_manager.h b/src/ge/common/dump/dump_manager.h similarity index 100% rename from ge/common/dump/dump_manager.h rename to src/ge/common/dump/dump_manager.h diff --git a/ge/common/dump/dump_op.cc b/src/ge/common/dump/dump_op.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/dump/dump_op.cc rename to src/ge/common/dump/dump_op.cc diff --git a/ge/common/dump/dump_op.h b/src/ge/common/dump/dump_op.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/dump/dump_op.h rename to src/ge/common/dump/dump_op.h diff --git a/ge/common/dump/dump_properties.cc b/src/ge/common/dump/dump_properties.cc similarity index 100% rename from ge/common/dump/dump_properties.cc rename to src/ge/common/dump/dump_properties.cc diff --git a/ge/common/dump/dump_properties.h b/src/ge/common/dump/dump_properties.h similarity index 100% rename from ge/common/dump/dump_properties.h rename to src/ge/common/dump/dump_properties.h diff --git a/ge/common/dump/dump_server.cc b/src/ge/common/dump/dump_server.cc similarity index 100% rename from ge/common/dump/dump_server.cc rename to src/ge/common/dump/dump_server.cc diff --git a/ge/common/fmk_error_codes.cc b/src/ge/common/fmk_error_codes.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/fmk_error_codes.cc rename to src/ge/common/fmk_error_codes.cc diff --git a/ge/common/formats/format_transfers/datatype_transfer.cc b/src/ge/common/formats/format_transfers/datatype_transfer.cc similarity index 100% rename from ge/common/formats/format_transfers/datatype_transfer.cc rename to src/ge/common/formats/format_transfers/datatype_transfer.cc diff --git a/ge/common/formats/format_transfers/datatype_transfer.h b/src/ge/common/formats/format_transfers/datatype_transfer.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/datatype_transfer.h rename to src/ge/common/formats/format_transfers/datatype_transfer.h diff --git a/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc b/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc rename to src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc diff --git a/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h b/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h rename to src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h diff --git a/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc b/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc rename to src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc diff --git a/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h b/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h rename to src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h diff --git a/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc b/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc rename to src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc diff --git a/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h b/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h rename to src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc b/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_fractal_nz.cc rename to src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_nz.h b/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_fractal_nz.h rename to src/ge/common/formats/format_transfers/format_transfer_fractal_nz.h diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_fractal_z.cc rename to src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.h b/src/ge/common/formats/format_transfers/format_transfer_fractal_z.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_fractal_z.h rename to src/ge/common/formats/format_transfers/format_transfer_fractal_z.h diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc b/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_fractal_zz.cc rename to src/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_zz.h b/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_fractal_zz.h rename to src/ge/common/formats/format_transfers/format_transfer_fractal_zz.h diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc b/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc rename to src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.h b/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.h similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_fracz_hwcn.h rename to src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.h diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc b/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc rename to src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_nchw.h b/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_fracz_nchw.h rename to src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.h diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc b/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc rename to src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.h b/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_fracz_nhwc.h rename to src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.h diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc b/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc rename to src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h b/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h rename to src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h diff --git a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc rename to src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc diff --git a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h rename to src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h diff --git a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc rename to src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc diff --git a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h rename to src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h diff --git a/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc b/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc rename to src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc diff --git a/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h b/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h rename to src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h diff --git a/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc b/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc rename to src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc diff --git a/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h b/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h rename to src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h diff --git a/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc b/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc rename to src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc diff --git a/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h b/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h rename to src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h diff --git a/ge/common/formats/format_transfers/format_transfer_transpose.cc b/src/ge/common/formats/format_transfers/format_transfer_transpose.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_transpose.cc rename to src/ge/common/formats/format_transfers/format_transfer_transpose.cc diff --git a/ge/common/formats/format_transfers/format_transfer_transpose.h b/src/ge/common/formats/format_transfers/format_transfer_transpose.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/format_transfers/format_transfer_transpose.h rename to src/ge/common/formats/format_transfers/format_transfer_transpose.h diff --git a/ge/common/formats/formats.cc b/src/ge/common/formats/formats.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/formats.cc rename to src/ge/common/formats/formats.cc diff --git a/ge/common/formats/formats.h b/src/ge/common/formats/formats.h similarity index 100% rename from ge/common/formats/formats.h rename to src/ge/common/formats/formats.h diff --git a/ge/common/formats/utils/formats_definitions.h b/src/ge/common/formats/utils/formats_definitions.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/utils/formats_definitions.h rename to src/ge/common/formats/utils/formats_definitions.h diff --git a/ge/common/formats/utils/formats_trans_utils.cc b/src/ge/common/formats/utils/formats_trans_utils.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/utils/formats_trans_utils.cc rename to src/ge/common/formats/utils/formats_trans_utils.cc diff --git a/ge/common/formats/utils/formats_trans_utils.h b/src/ge/common/formats/utils/formats_trans_utils.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/formats/utils/formats_trans_utils.h rename to src/ge/common/formats/utils/formats_trans_utils.h diff --git a/ge/common/fp16_t.cc b/src/ge/common/fp16_t.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/fp16_t.cc rename to src/ge/common/fp16_t.cc diff --git a/ge/common/fp16_t.h b/src/ge/common/fp16_t.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/fp16_t.h rename to src/ge/common/fp16_t.h diff --git a/ge/common/ge/datatype_util.cc b/src/ge/common/ge/datatype_util.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/ge/datatype_util.cc rename to src/ge/common/ge/datatype_util.cc diff --git a/ge/common/ge/datatype_util.h b/src/ge/common/ge/datatype_util.h similarity index 100% rename from ge/common/ge/datatype_util.h rename to src/ge/common/ge/datatype_util.h diff --git a/ge/common/ge/ge_util.h b/src/ge/common/ge/ge_util.h similarity index 100% rename from ge/common/ge/ge_util.h rename to src/ge/common/ge/ge_util.h diff --git a/ge/common/ge/op_tiling_manager.cc b/src/ge/common/ge/op_tiling_manager.cc similarity index 100% rename from ge/common/ge/op_tiling_manager.cc rename to src/ge/common/ge/op_tiling_manager.cc diff --git a/ge/common/ge/op_tiling_manager.h b/src/ge/common/ge/op_tiling_manager.h similarity index 100% rename from ge/common/ge/op_tiling_manager.h rename to src/ge/common/ge/op_tiling_manager.h diff --git a/ge/common/ge/plugin_manager.cc b/src/ge/common/ge/plugin_manager.cc similarity index 100% rename from ge/common/ge/plugin_manager.cc rename to src/ge/common/ge/plugin_manager.cc diff --git a/ge/common/ge/plugin_manager.h b/src/ge/common/ge/plugin_manager.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/ge/plugin_manager.h rename to src/ge/common/ge/plugin_manager.h diff --git a/ge/common/ge/tbe_plugin_manager.cc b/src/ge/common/ge/tbe_plugin_manager.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/ge/tbe_plugin_manager.cc rename to src/ge/common/ge/tbe_plugin_manager.cc diff --git a/ge/common/ge/tbe_plugin_manager.h b/src/ge/common/ge/tbe_plugin_manager.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/ge/tbe_plugin_manager.h rename to src/ge/common/ge/tbe_plugin_manager.h diff --git a/ge/common/ge_common.mk b/src/ge/common/ge_common.mk old mode 100755 new mode 100644 similarity index 100% rename from ge/common/ge_common.mk rename to src/ge/common/ge_common.mk diff --git a/ge/common/ge_format_util.cc b/src/ge/common/ge_format_util.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/ge_format_util.cc rename to src/ge/common/ge_format_util.cc diff --git a/ge/common/helper/model_cache_helper.cc b/src/ge/common/helper/model_cache_helper.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/helper/model_cache_helper.cc rename to src/ge/common/helper/model_cache_helper.cc diff --git a/ge/common/helper/model_cache_helper.h b/src/ge/common/helper/model_cache_helper.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/helper/model_cache_helper.h rename to src/ge/common/helper/model_cache_helper.h diff --git a/ge/common/helper/model_helper.cc b/src/ge/common/helper/model_helper.cc similarity index 100% rename from ge/common/helper/model_helper.cc rename to src/ge/common/helper/model_helper.cc diff --git a/ge/common/helper/om_file_helper.cc b/src/ge/common/helper/om_file_helper.cc similarity index 100% rename from ge/common/helper/om_file_helper.cc rename to src/ge/common/helper/om_file_helper.cc diff --git a/ge/common/kernel_store.cc b/src/ge/common/kernel_store.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/kernel_store.cc rename to src/ge/common/kernel_store.cc diff --git a/ge/common/kernel_store.h b/src/ge/common/kernel_store.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/kernel_store.h rename to src/ge/common/kernel_store.h diff --git a/ge/common/math/fp16_math.cc b/src/ge/common/math/fp16_math.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/math/fp16_math.cc rename to src/ge/common/math/fp16_math.cc diff --git a/ge/common/math/fp16_math.h b/src/ge/common/math/fp16_math.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/math/fp16_math.h rename to src/ge/common/math/fp16_math.h diff --git a/ge/common/math/math_util.h b/src/ge/common/math/math_util.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/math/math_util.h rename to src/ge/common/math/math_util.h diff --git a/ge/common/math_util.h b/src/ge/common/math_util.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/math_util.h rename to src/ge/common/math_util.h diff --git a/ge/common/model_parser/base.cc b/src/ge/common/model_parser/base.cc similarity index 100% rename from ge/common/model_parser/base.cc rename to src/ge/common/model_parser/base.cc diff --git a/ge/common/model_parser/base.h b/src/ge/common/model_parser/base.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/model_parser/base.h rename to src/ge/common/model_parser/base.h diff --git a/ge/common/model_saver.cc b/src/ge/common/model_saver.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/model_saver.cc rename to src/ge/common/model_saver.cc diff --git a/ge/common/model_saver.h b/src/ge/common/model_saver.h similarity index 100% rename from ge/common/model_saver.h rename to src/ge/common/model_saver.h diff --git a/ge/common/module.mk b/src/ge/common/module.mk old mode 100755 new mode 100644 similarity index 100% rename from ge/common/module.mk rename to src/ge/common/module.mk diff --git a/ge/common/op/attr_value_util.cc b/src/ge/common/op/attr_value_util.cc similarity index 100% rename from ge/common/op/attr_value_util.cc rename to src/ge/common/op/attr_value_util.cc diff --git a/ge/common/op/ge_op_utils.cc b/src/ge/common/op/ge_op_utils.cc similarity index 100% rename from ge/common/op/ge_op_utils.cc rename to src/ge/common/op/ge_op_utils.cc diff --git a/ge/common/profiling/profiling_manager.cc b/src/ge/common/profiling/profiling_manager.cc similarity index 100% rename from ge/common/profiling/profiling_manager.cc rename to src/ge/common/profiling/profiling_manager.cc diff --git a/ge/common/profiling/profiling_manager.h b/src/ge/common/profiling/profiling_manager.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/profiling/profiling_manager.h rename to src/ge/common/profiling/profiling_manager.h diff --git a/ge/common/properties_manager.cc b/src/ge/common/properties_manager.cc similarity index 100% rename from ge/common/properties_manager.cc rename to src/ge/common/properties_manager.cc diff --git a/ge/common/properties_manager.h b/src/ge/common/properties_manager.h similarity index 100% rename from ge/common/properties_manager.h rename to src/ge/common/properties_manager.h diff --git a/ge/common/singleton.h b/src/ge/common/singleton.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/singleton.h rename to src/ge/common/singleton.h diff --git a/ge/common/tbe_kernel_store.cc b/src/ge/common/tbe_kernel_store.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/tbe_kernel_store.cc rename to src/ge/common/tbe_kernel_store.cc diff --git a/ge/common/tbe_kernel_store.h b/src/ge/common/tbe_kernel_store.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/tbe_kernel_store.h rename to src/ge/common/tbe_kernel_store.h diff --git a/ge/common/thread_pool.cc b/src/ge/common/thread_pool.cc similarity index 100% rename from ge/common/thread_pool.cc rename to src/ge/common/thread_pool.cc diff --git a/ge/common/thread_pool.h b/src/ge/common/thread_pool.h old mode 100755 new mode 100644 similarity index 100% rename from ge/common/thread_pool.h rename to src/ge/common/thread_pool.h diff --git a/ge/common/types.cc b/src/ge/common/types.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/common/types.cc rename to src/ge/common/types.cc diff --git a/ge/common/util.cc b/src/ge/common/util.cc similarity index 100% rename from ge/common/util.cc rename to src/ge/common/util.cc diff --git a/ge/engine_manager/dnnengine_manager.cc b/src/ge/engine_manager/dnnengine_manager.cc similarity index 100% rename from ge/engine_manager/dnnengine_manager.cc rename to src/ge/engine_manager/dnnengine_manager.cc diff --git a/ge/engine_manager/dnnengine_manager.h b/src/ge/engine_manager/dnnengine_manager.h old mode 100755 new mode 100644 similarity index 100% rename from ge/engine_manager/dnnengine_manager.h rename to src/ge/engine_manager/dnnengine_manager.h diff --git a/ge/engine_manager/engine_conf.json b/src/ge/engine_manager/engine_conf.json similarity index 100% rename from ge/engine_manager/engine_conf.json rename to src/ge/engine_manager/engine_conf.json diff --git a/src/ge/executor/CMakeLists.txt b/src/ge/executor/CMakeLists.txt new file mode 100755 index 00000000..b68507bd --- /dev/null +++ b/src/ge/executor/CMakeLists.txt @@ -0,0 +1,126 @@ +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# libge_executor.so +# add all proto files, generate corresponding .h and .cc files +# add src files +file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "../../proto/task.proto" + "../../proto/om.proto" + "../../proto/insert_op.proto" + "../../proto/op_mapping_info.proto" + "../../proto/ge_ir.proto" + "../../proto/dump_task.proto" + ) + +file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "ge_executor.cc" + "../common/dump/dump_properties.cc" + "../common/dump/dump_manager.cc" + "../common/dump/dump_op.cc" + "../common/ge/op_tiling_manager.cc" + "../common/ge/plugin_manager.cc" + "../common/profiling/profiling_manager.cc" + "../graph/execute/graph_execute.cc" + "../graph/load/graph_loader.cc" + "../graph/load/new_model_manager/aipp_utils.cc" + "../graph/load/new_model_manager/cpu_queue_schedule.cc" + "../graph/load/new_model_manager/data_dumper.cc" + "../graph/load/new_model_manager/data_inputer.cc" + "../graph/load/new_model_manager/davinci_model.cc" + "../graph/load/new_model_manager/davinci_model_parser.cc" + "../graph/load/new_model_manager/model_manager.cc" + "../graph/load/new_model_manager/model_utils.cc" + "../graph/load/new_model_manager/task_info/end_graph_task_info.cc" + "../graph/load/new_model_manager/task_info/event_record_task_info.cc" + "../graph/load/new_model_manager/task_info/event_wait_task_info.cc" + "../graph/load/new_model_manager/task_info/fusion_start_task_info.cc" + "../graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" + "../graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" + "../graph/load/new_model_manager/task_info/kernel_task_info.cc" + "../graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" + "../graph/load/new_model_manager/task_info/label_set_task_info.cc" + "../graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" + "../graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" + "../graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" + "../graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" + "../graph/load/new_model_manager/task_info/stream_active_task_info.cc" + "../graph/load/new_model_manager/task_info/stream_switch_task_info.cc" + "../graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" + "../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" + "../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" + "../graph/load/new_model_manager/task_info/task_info.cc" + "../graph/load/new_model_manager/tbe_handle_store.cc" + "../graph/load/new_model_manager/zero_copy_offset.cc" + "../graph/load/new_model_manager/zero_copy_task.cc" + "../graph/manager/graph_caching_allocator.cc" + "../graph/manager/graph_manager_utils.cc" + "../graph/manager/graph_mem_allocator.cc" + "../graph/manager/graph_var_manager.cc" + "../graph/manager/rdma_pool_allocator.cc" + "../graph/manager/trans_var_data_utils.cc" + "../graph/manager/util/debug.cc" + "../hybrid/hybrid_davinci_model_stub.cc" + "../hybrid/node_executor/aicpu/aicpu_ext_info.cc" + "../model/ge_model.cc" + "../model/ge_root_model.cc" + "../omm/csa_interact.cc" + "../single_op/single_op.cc" + "../single_op/single_op_manager.cc" + "../single_op/single_op_model.cc" + "../single_op/stream_resource.cc" + "../single_op/task/aicpu_task_builder.cc" + "../single_op/task/build_task_utils.cc" + "../single_op/task/op_task.cc" + "../single_op/task/tbe_task_builder.cc" + ) + +ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) + +# include directories +include_directories(${CMAKE_CURRENT_LIST_DIR}) +include_directories(${GE_SOURCE_DIR}/src/ge) +include_directories(${GE_SOURCE_DIR}/inc/external) +include_directories(${GE_SOURCE_DIR}/inc/external/graph) +include_directories(${GE_SOURCE_DIR}/inc/framework) +include_directories(${GE_SOURCE_DIR}/inc) +include_directories(${GE_SOURCE_DIR}/inc/graph) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) +include_directories(${CMAKE_BINARY_DIR}) +include_directories(${CMAKE_BINARY_DIR}/proto/ge) + +######## libge_executor.so ######## +add_library(ge_executor SHARED ${SRC_LIST} ${PROTO_HDRS}) +target_compile_definitions(ge_executor PRIVATE + Werror + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + DAVINCI_SUPPORT_PROFILING + FMK_HOST_INFER) +target_link_libraries(ge_executor + ge_common + graph + ${PROTOBUF_LIBRARY} + ${register} + ${c_sec} + ${runtime} + ${slog} + ${mmpa} + ${msprof} + ${error_manager} + ${ascend_hal} + rt + dl) + diff --git a/ge/executor/ge_executor.cc b/src/ge/executor/ge_executor.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/executor/ge_executor.cc rename to src/ge/executor/ge_executor.cc diff --git a/ge/executor/module.mk b/src/ge/executor/module.mk old mode 100755 new mode 100644 similarity index 100% rename from ge/executor/module.mk rename to src/ge/executor/module.mk diff --git a/ge/ge_inference.mk b/src/ge/ge_inference.mk old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_inference.mk rename to src/ge/ge_inference.mk diff --git a/src/ge/ge_local_engine/CMakeLists.txt b/src/ge/ge_local_engine/CMakeLists.txt new file mode 100755 index 00000000..bcbc3e4c --- /dev/null +++ b/src/ge/ge_local_engine/CMakeLists.txt @@ -0,0 +1,52 @@ +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# libge_local_engine.so +# add all proto files, generate corresponding .h and .cc files +file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "../../proto/task.proto" + ) + +file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "engine/ge_local_engine.cc" + "ops_kernel_store/*.cc" + "ops_kernel_store/op/*.cc" + ) + +ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) + +# include directories +include_directories(${CMAKE_CURRENT_LIST_DIR}) +include_directories(${GE_SOURCE_DIR}/src/ge) +include_directories(${GE_SOURCE_DIR}/inc) +include_directories(${GE_SOURCE_DIR}/inc/external) +include_directories(${GE_SOURCE_DIR}/inc/external/graph) +include_directories(${GE_SOURCE_DIR}/inc/framework) +include_directories(${GE_SOURCE_DIR}/inc/graph) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) +include_directories(${CMAKE_BINARY_DIR}) +include_directories(${CMAKE_BINARY_DIR}/proto/ge) + +######### libge_local_engine.so ############# +add_library(ge_local_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) +target_compile_definitions(ge_local_engine PRIVATE Werror COMPILE_OMG_PACKAGE) +target_link_libraries(ge_local_engine + graph + ${PROTOBUF_LIBRARY} + ${register} + ${c_sec} + ${slog} + ${runtime}) diff --git a/ge/ge_local_engine/common/constant/constant.h b/src/ge/ge_local_engine/common/constant/constant.h similarity index 100% rename from ge/ge_local_engine/common/constant/constant.h rename to src/ge/ge_local_engine/common/constant/constant.h diff --git a/ge/ge_local_engine/engine/ge_local_engine.cc b/src/ge/ge_local_engine/engine/ge_local_engine.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_local_engine/engine/ge_local_engine.cc rename to src/ge/ge_local_engine/engine/ge_local_engine.cc diff --git a/ge/ge_local_engine/engine/ge_local_engine.h b/src/ge/ge_local_engine/engine/ge_local_engine.h similarity index 100% rename from ge/ge_local_engine/engine/ge_local_engine.h rename to src/ge/ge_local_engine/engine/ge_local_engine.h diff --git a/ge/ge_local_engine/engine/host_cpu_engine.cc b/src/ge/ge_local_engine/engine/host_cpu_engine.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_local_engine/engine/host_cpu_engine.cc rename to src/ge/ge_local_engine/engine/host_cpu_engine.cc diff --git a/ge/ge_local_engine/engine/host_cpu_engine.h b/src/ge/ge_local_engine/engine/host_cpu_engine.h similarity index 100% rename from ge/ge_local_engine/engine/host_cpu_engine.h rename to src/ge/ge_local_engine/engine/host_cpu_engine.h diff --git a/ge/ge_local_engine/module.mk b/src/ge/ge_local_engine/module.mk old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_local_engine/module.mk rename to src/ge/ge_local_engine/module.mk diff --git a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc b/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc rename to src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc diff --git a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h b/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h rename to src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h diff --git a/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc b/src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc rename to src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc diff --git a/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.h b/src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.h similarity index 100% rename from ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.h rename to src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.h diff --git a/ge/ge_local_engine/ops_kernel_store/op/no_op.cc b/src/ge/ge_local_engine/ops_kernel_store/op/no_op.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_local_engine/ops_kernel_store/op/no_op.cc rename to src/ge/ge_local_engine/ops_kernel_store/op/no_op.cc diff --git a/ge/ge_local_engine/ops_kernel_store/op/no_op.h b/src/ge/ge_local_engine/ops_kernel_store/op/no_op.h similarity index 100% rename from ge/ge_local_engine/ops_kernel_store/op/no_op.h rename to src/ge/ge_local_engine/ops_kernel_store/op/no_op.h diff --git a/ge/ge_local_engine/ops_kernel_store/op/op.cc b/src/ge/ge_local_engine/ops_kernel_store/op/op.cc similarity index 100% rename from ge/ge_local_engine/ops_kernel_store/op/op.cc rename to src/ge/ge_local_engine/ops_kernel_store/op/op.cc diff --git a/ge/ge_local_engine/ops_kernel_store/op/op.h b/src/ge/ge_local_engine/ops_kernel_store/op/op.h similarity index 100% rename from ge/ge_local_engine/ops_kernel_store/op/op.h rename to src/ge/ge_local_engine/ops_kernel_store/op/op.h diff --git a/ge/ge_local_engine/ops_kernel_store/op/op_factory.cc b/src/ge/ge_local_engine/ops_kernel_store/op/op_factory.cc similarity index 100% rename from ge/ge_local_engine/ops_kernel_store/op/op_factory.cc rename to src/ge/ge_local_engine/ops_kernel_store/op/op_factory.cc diff --git a/ge/ge_local_engine/ops_kernel_store/op/op_factory.h b/src/ge/ge_local_engine/ops_kernel_store/op/op_factory.h similarity index 100% rename from ge/ge_local_engine/ops_kernel_store/op/op_factory.h rename to src/ge/ge_local_engine/ops_kernel_store/op/op_factory.h diff --git a/ge/ge_runner.mk b/src/ge/ge_runner.mk similarity index 100% rename from ge/ge_runner.mk rename to src/ge/ge_runner.mk diff --git a/src/ge/ge_runtime/CMakeLists.txt b/src/ge/ge_runtime/CMakeLists.txt new file mode 100755 index 00000000..aa4e3470 --- /dev/null +++ b/src/ge/ge_runtime/CMakeLists.txt @@ -0,0 +1,51 @@ +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# libge_runtime.so +# include directories +include_directories(${CMAKE_CURRENT_LIST_DIR}) +include_directories(${GE_SOURCE_DIR}/src/ge) +include_directories(${GE_SOURCE_DIR}/src) +include_directories(${GE_SOURCE_DIR}/inc) +include_directories(${GE_SOURCE_DIR}/inc/graph) +include_directories(${GE_SOURCE_DIR}/inc/external) +include_directories(${GE_SOURCE_DIR}/inc/framework) +include_directories(${GE_SOURCE_DIR}/inc/framework/common) +include_directories(${GE_SOURCE_DIR}/inc/framework/ge_runtime) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) +include_directories(${CMAKE_BINARY_DIR}) +include_directories(${CMAKE_BINARY_DIR}/proto/ge) + +######### libge_runtime.so ############# +file(GLOB_RECURSE GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "model_runner.cc" + "runtime_model.cc" + "output.cc" + "task/*.cc" + ) + +add_library(ge_runtime SHARED ${GE_SRC_LIST}) +target_compile_definitions(ge_runtime PUBLIC + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + Werror) +target_link_libraries(ge_runtime + graph + ${slog} + ${runtime} + ${c_sec} + rt + dl + ) diff --git a/ge/ge_runtime/model_context.h b/src/ge/ge_runtime/model_context.h old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/model_context.h rename to src/ge/ge_runtime/model_context.h diff --git a/ge/ge_runtime/model_runner.cc b/src/ge/ge_runtime/model_runner.cc similarity index 100% rename from ge/ge_runtime/model_runner.cc rename to src/ge/ge_runtime/model_runner.cc diff --git a/ge/ge_runtime/output.cc b/src/ge/ge_runtime/output.cc similarity index 100% rename from ge/ge_runtime/output.cc rename to src/ge/ge_runtime/output.cc diff --git a/ge/ge_runtime/output.h b/src/ge/ge_runtime/output.h old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/output.h rename to src/ge/ge_runtime/output.h diff --git a/ge/ge_runtime/runtime_model.cc b/src/ge/ge_runtime/runtime_model.cc similarity index 100% rename from ge/ge_runtime/runtime_model.cc rename to src/ge/ge_runtime/runtime_model.cc diff --git a/ge/ge_runtime/runtime_model.h b/src/ge/ge_runtime/runtime_model.h similarity index 100% rename from ge/ge_runtime/runtime_model.h rename to src/ge/ge_runtime/runtime_model.h diff --git a/ge/ge_runtime/task/aicpu_task.cc b/src/ge/ge_runtime/task/aicpu_task.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/task/aicpu_task.cc rename to src/ge/ge_runtime/task/aicpu_task.cc diff --git a/ge/ge_runtime/task/aicpu_task.h b/src/ge/ge_runtime/task/aicpu_task.h old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/task/aicpu_task.h rename to src/ge/ge_runtime/task/aicpu_task.h diff --git a/ge/ge_runtime/task/cce_task.cc b/src/ge/ge_runtime/task/cce_task.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/task/cce_task.cc rename to src/ge/ge_runtime/task/cce_task.cc diff --git a/ge/ge_runtime/task/cce_task.h b/src/ge/ge_runtime/task/cce_task.h old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/task/cce_task.h rename to src/ge/ge_runtime/task/cce_task.h diff --git a/ge/ge_runtime/task/event_record_task.cc b/src/ge/ge_runtime/task/event_record_task.cc similarity index 100% rename from ge/ge_runtime/task/event_record_task.cc rename to src/ge/ge_runtime/task/event_record_task.cc diff --git a/ge/ge_runtime/task/event_record_task.h b/src/ge/ge_runtime/task/event_record_task.h old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/task/event_record_task.h rename to src/ge/ge_runtime/task/event_record_task.h diff --git a/ge/ge_runtime/task/event_wait_task.cc b/src/ge/ge_runtime/task/event_wait_task.cc similarity index 100% rename from ge/ge_runtime/task/event_wait_task.cc rename to src/ge/ge_runtime/task/event_wait_task.cc diff --git a/ge/ge_runtime/task/event_wait_task.h b/src/ge/ge_runtime/task/event_wait_task.h old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/task/event_wait_task.h rename to src/ge/ge_runtime/task/event_wait_task.h diff --git a/ge/ge_runtime/task/hccl_task.cc b/src/ge/ge_runtime/task/hccl_task.cc similarity index 100% rename from ge/ge_runtime/task/hccl_task.cc rename to src/ge/ge_runtime/task/hccl_task.cc diff --git a/ge/ge_runtime/task/hccl_task.h b/src/ge/ge_runtime/task/hccl_task.h old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/task/hccl_task.h rename to src/ge/ge_runtime/task/hccl_task.h diff --git a/ge/ge_runtime/task/label_goto_task.cc b/src/ge/ge_runtime/task/label_goto_task.cc similarity index 100% rename from ge/ge_runtime/task/label_goto_task.cc rename to src/ge/ge_runtime/task/label_goto_task.cc diff --git a/ge/ge_runtime/task/label_goto_task.h b/src/ge/ge_runtime/task/label_goto_task.h similarity index 100% rename from ge/ge_runtime/task/label_goto_task.h rename to src/ge/ge_runtime/task/label_goto_task.h diff --git a/ge/ge_runtime/task/label_set_task.cc b/src/ge/ge_runtime/task/label_set_task.cc similarity index 100% rename from ge/ge_runtime/task/label_set_task.cc rename to src/ge/ge_runtime/task/label_set_task.cc diff --git a/ge/ge_runtime/task/label_set_task.h b/src/ge/ge_runtime/task/label_set_task.h similarity index 100% rename from ge/ge_runtime/task/label_set_task.h rename to src/ge/ge_runtime/task/label_set_task.h diff --git a/ge/ge_runtime/task/label_switch_task.cc b/src/ge/ge_runtime/task/label_switch_task.cc similarity index 100% rename from ge/ge_runtime/task/label_switch_task.cc rename to src/ge/ge_runtime/task/label_switch_task.cc diff --git a/ge/ge_runtime/task/label_switch_task.h b/src/ge/ge_runtime/task/label_switch_task.h similarity index 100% rename from ge/ge_runtime/task/label_switch_task.h rename to src/ge/ge_runtime/task/label_switch_task.h diff --git a/ge/ge_runtime/task/memcpy_async_task.cc b/src/ge/ge_runtime/task/memcpy_async_task.cc similarity index 100% rename from ge/ge_runtime/task/memcpy_async_task.cc rename to src/ge/ge_runtime/task/memcpy_async_task.cc diff --git a/ge/ge_runtime/task/memcpy_async_task.h b/src/ge/ge_runtime/task/memcpy_async_task.h old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/task/memcpy_async_task.h rename to src/ge/ge_runtime/task/memcpy_async_task.h diff --git a/ge/ge_runtime/task/profiler_task.cc b/src/ge/ge_runtime/task/profiler_task.cc similarity index 100% rename from ge/ge_runtime/task/profiler_task.cc rename to src/ge/ge_runtime/task/profiler_task.cc diff --git a/ge/ge_runtime/task/profiler_task.h b/src/ge/ge_runtime/task/profiler_task.h old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/task/profiler_task.h rename to src/ge/ge_runtime/task/profiler_task.h diff --git a/ge/ge_runtime/task/stream_active_task.cc b/src/ge/ge_runtime/task/stream_active_task.cc similarity index 100% rename from ge/ge_runtime/task/stream_active_task.cc rename to src/ge/ge_runtime/task/stream_active_task.cc diff --git a/ge/ge_runtime/task/stream_active_task.h b/src/ge/ge_runtime/task/stream_active_task.h old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/task/stream_active_task.h rename to src/ge/ge_runtime/task/stream_active_task.h diff --git a/ge/ge_runtime/task/stream_switch_task.cc b/src/ge/ge_runtime/task/stream_switch_task.cc similarity index 100% rename from ge/ge_runtime/task/stream_switch_task.cc rename to src/ge/ge_runtime/task/stream_switch_task.cc diff --git a/ge/ge_runtime/task/stream_switch_task.h b/src/ge/ge_runtime/task/stream_switch_task.h old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/task/stream_switch_task.h rename to src/ge/ge_runtime/task/stream_switch_task.h diff --git a/ge/ge_runtime/task/task.h b/src/ge/ge_runtime/task/task.h old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/task/task.h rename to src/ge/ge_runtime/task/task.h diff --git a/ge/ge_runtime/task/task_factory.h b/src/ge/ge_runtime/task/task_factory.h similarity index 100% rename from ge/ge_runtime/task/task_factory.h rename to src/ge/ge_runtime/task/task_factory.h diff --git a/ge/ge_runtime/task/tbe_task.cc b/src/ge/ge_runtime/task/tbe_task.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/task/tbe_task.cc rename to src/ge/ge_runtime/task/tbe_task.cc diff --git a/ge/ge_runtime/task/tbe_task.h b/src/ge/ge_runtime/task/tbe_task.h old mode 100755 new mode 100644 similarity index 100% rename from ge/ge_runtime/task/tbe_task.h rename to src/ge/ge_runtime/task/tbe_task.h diff --git a/ge/generator/ge_generator.cc b/src/ge/generator/ge_generator.cc similarity index 100% rename from ge/generator/ge_generator.cc rename to src/ge/generator/ge_generator.cc diff --git a/ge/generator/generator_api.cc b/src/ge/generator/generator_api.cc similarity index 100% rename from ge/generator/generator_api.cc rename to src/ge/generator/generator_api.cc diff --git a/ge/graph/build/graph_builder.cc b/src/ge/graph/build/graph_builder.cc similarity index 100% rename from ge/graph/build/graph_builder.cc rename to src/ge/graph/build/graph_builder.cc diff --git a/ge/graph/build/graph_builder.h b/src/ge/graph/build/graph_builder.h similarity index 100% rename from ge/graph/build/graph_builder.h rename to src/ge/graph/build/graph_builder.h diff --git a/ge/graph/build/label_allocator.cc b/src/ge/graph/build/label_allocator.cc similarity index 100% rename from ge/graph/build/label_allocator.cc rename to src/ge/graph/build/label_allocator.cc diff --git a/ge/graph/build/label_allocator.h b/src/ge/graph/build/label_allocator.h similarity index 100% rename from ge/graph/build/label_allocator.h rename to src/ge/graph/build/label_allocator.h diff --git a/ge/graph/build/logical_stream_allocator.cc b/src/ge/graph/build/logical_stream_allocator.cc similarity index 100% rename from ge/graph/build/logical_stream_allocator.cc rename to src/ge/graph/build/logical_stream_allocator.cc diff --git a/ge/graph/build/logical_stream_allocator.h b/src/ge/graph/build/logical_stream_allocator.h similarity index 100% rename from ge/graph/build/logical_stream_allocator.h rename to src/ge/graph/build/logical_stream_allocator.h diff --git a/src/ge/graph/build/memory/CMakeLists.txt b/src/ge/graph/build/memory/CMakeLists.txt new file mode 100644 index 00000000..ea87b906 --- /dev/null +++ b/src/ge/graph/build/memory/CMakeLists.txt @@ -0,0 +1,51 @@ +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# libge_memosy.a +file(GLOB_RECURSE SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "memory_assigner.cc" + "graph_mem_assigner.cc" + "binary_block_mem_assigner.cc" + "block_mem_assigner.cc" + "hybrid_mem_assigner.cc" + "max_block_mem_assigner.cc" + "var_mem_assign_util.cc" + ) + +# include directories +include_directories(${CMAKE_CURRENT_LIST_DIR}) +include_directories(${GE_SOURCE_DIR}/src) +include_directories(${GE_SOURCE_DIR}/src/ge) +include_directories(${GE_SOURCE_DIR}/inc) +include_directories(${GE_SOURCE_DIR}/inc/external) +include_directories(${GE_SOURCE_DIR}/inc/external/graph) +include_directories(${GE_SOURCE_DIR}/inc/framework) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) +include_directories(${CMAKE_BINARY_DIR}) +include_directories(${CMAKE_BINARY_DIR}/proto/ge) + +######### libge_memory.a ############# +add_library(ge_memory STATIC ${SRC_LIST}) +target_compile_definitions(ge_memory PRIVATE + Werror + DAVINCI_CLOUD) +target_link_libraries(ge_memory + graph + ge_common + ${PROTOBUF_LIBRARY} + ${c_sec} + ${slog} + rt + dl) diff --git a/ge/graph/build/memory/binary_block_mem_assigner.cc b/src/ge/graph/build/memory/binary_block_mem_assigner.cc similarity index 100% rename from ge/graph/build/memory/binary_block_mem_assigner.cc rename to src/ge/graph/build/memory/binary_block_mem_assigner.cc diff --git a/ge/graph/build/memory/binary_block_mem_assigner.h b/src/ge/graph/build/memory/binary_block_mem_assigner.h similarity index 100% rename from ge/graph/build/memory/binary_block_mem_assigner.h rename to src/ge/graph/build/memory/binary_block_mem_assigner.h diff --git a/ge/graph/build/memory/block_mem_assigner.cc b/src/ge/graph/build/memory/block_mem_assigner.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/build/memory/block_mem_assigner.cc rename to src/ge/graph/build/memory/block_mem_assigner.cc diff --git a/ge/graph/build/memory/block_mem_assigner.h b/src/ge/graph/build/memory/block_mem_assigner.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/build/memory/block_mem_assigner.h rename to src/ge/graph/build/memory/block_mem_assigner.h diff --git a/ge/graph/build/memory/graph_mem_assigner.cc b/src/ge/graph/build/memory/graph_mem_assigner.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/build/memory/graph_mem_assigner.cc rename to src/ge/graph/build/memory/graph_mem_assigner.cc diff --git a/ge/graph/build/memory/graph_mem_assigner.h b/src/ge/graph/build/memory/graph_mem_assigner.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/build/memory/graph_mem_assigner.h rename to src/ge/graph/build/memory/graph_mem_assigner.h diff --git a/ge/graph/build/memory/hybrid_mem_assigner.cc b/src/ge/graph/build/memory/hybrid_mem_assigner.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/build/memory/hybrid_mem_assigner.cc rename to src/ge/graph/build/memory/hybrid_mem_assigner.cc diff --git a/ge/graph/build/memory/hybrid_mem_assigner.h b/src/ge/graph/build/memory/hybrid_mem_assigner.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/build/memory/hybrid_mem_assigner.h rename to src/ge/graph/build/memory/hybrid_mem_assigner.h diff --git a/ge/graph/build/memory/max_block_mem_assigner.cc b/src/ge/graph/build/memory/max_block_mem_assigner.cc similarity index 100% rename from ge/graph/build/memory/max_block_mem_assigner.cc rename to src/ge/graph/build/memory/max_block_mem_assigner.cc diff --git a/ge/graph/build/memory/max_block_mem_assigner.h b/src/ge/graph/build/memory/max_block_mem_assigner.h similarity index 100% rename from ge/graph/build/memory/max_block_mem_assigner.h rename to src/ge/graph/build/memory/max_block_mem_assigner.h diff --git a/ge/graph/build/memory/mem_assigner.h b/src/ge/graph/build/memory/mem_assigner.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/build/memory/mem_assigner.h rename to src/ge/graph/build/memory/mem_assigner.h diff --git a/ge/graph/build/memory/memory_assigner.cc b/src/ge/graph/build/memory/memory_assigner.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/build/memory/memory_assigner.cc rename to src/ge/graph/build/memory/memory_assigner.cc diff --git a/ge/graph/build/memory/module.mk b/src/ge/graph/build/memory/module.mk old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/build/memory/module.mk rename to src/ge/graph/build/memory/module.mk diff --git a/ge/graph/build/memory/var_mem_assign_util.cc b/src/ge/graph/build/memory/var_mem_assign_util.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/build/memory/var_mem_assign_util.cc rename to src/ge/graph/build/memory/var_mem_assign_util.cc diff --git a/ge/graph/build/memory/var_mem_assign_util.h b/src/ge/graph/build/memory/var_mem_assign_util.h similarity index 100% rename from ge/graph/build/memory/var_mem_assign_util.h rename to src/ge/graph/build/memory/var_mem_assign_util.h diff --git a/ge/graph/build/model_builder.cc b/src/ge/graph/build/model_builder.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/build/model_builder.cc rename to src/ge/graph/build/model_builder.cc diff --git a/ge/graph/build/model_builder.h b/src/ge/graph/build/model_builder.h similarity index 100% rename from ge/graph/build/model_builder.h rename to src/ge/graph/build/model_builder.h diff --git a/ge/graph/build/run_context.cc b/src/ge/graph/build/run_context.cc similarity index 100% rename from ge/graph/build/run_context.cc rename to src/ge/graph/build/run_context.cc diff --git a/ge/graph/build/run_context.h b/src/ge/graph/build/run_context.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/build/run_context.h rename to src/ge/graph/build/run_context.h diff --git a/ge/graph/build/stream_allocator.cc b/src/ge/graph/build/stream_allocator.cc similarity index 100% rename from ge/graph/build/stream_allocator.cc rename to src/ge/graph/build/stream_allocator.cc diff --git a/ge/graph/build/stream_allocator.h b/src/ge/graph/build/stream_allocator.h similarity index 100% rename from ge/graph/build/stream_allocator.h rename to src/ge/graph/build/stream_allocator.h diff --git a/ge/graph/build/stream_graph_optimizer.cc b/src/ge/graph/build/stream_graph_optimizer.cc similarity index 100% rename from ge/graph/build/stream_graph_optimizer.cc rename to src/ge/graph/build/stream_graph_optimizer.cc diff --git a/ge/graph/build/stream_graph_optimizer.h b/src/ge/graph/build/stream_graph_optimizer.h similarity index 100% rename from ge/graph/build/stream_graph_optimizer.h rename to src/ge/graph/build/stream_graph_optimizer.h diff --git a/ge/graph/build/task_generator.cc b/src/ge/graph/build/task_generator.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/build/task_generator.cc rename to src/ge/graph/build/task_generator.cc diff --git a/ge/graph/build/task_generator.h b/src/ge/graph/build/task_generator.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/build/task_generator.h rename to src/ge/graph/build/task_generator.h diff --git a/ge/graph/common/bcast.cc b/src/ge/graph/common/bcast.cc similarity index 100% rename from ge/graph/common/bcast.cc rename to src/ge/graph/common/bcast.cc diff --git a/ge/graph/common/bcast.h b/src/ge/graph/common/bcast.h similarity index 100% rename from ge/graph/common/bcast.h rename to src/ge/graph/common/bcast.h diff --git a/ge/graph/common/ge_call_wrapper.h b/src/ge/graph/common/ge_call_wrapper.h similarity index 100% rename from ge/graph/common/ge_call_wrapper.h rename to src/ge/graph/common/ge_call_wrapper.h diff --git a/ge/graph/common/local_context.cc b/src/ge/graph/common/local_context.cc similarity index 100% rename from ge/graph/common/local_context.cc rename to src/ge/graph/common/local_context.cc diff --git a/ge/graph/common/local_context.h b/src/ge/graph/common/local_context.h similarity index 100% rename from ge/graph/common/local_context.h rename to src/ge/graph/common/local_context.h diff --git a/ge/graph/common/omg_util.cc b/src/ge/graph/common/omg_util.cc similarity index 100% rename from ge/graph/common/omg_util.cc rename to src/ge/graph/common/omg_util.cc diff --git a/ge/graph/common/omg_util.h b/src/ge/graph/common/omg_util.h similarity index 100% rename from ge/graph/common/omg_util.h rename to src/ge/graph/common/omg_util.h diff --git a/ge/graph/common/transop_util.cc b/src/ge/graph/common/transop_util.cc similarity index 100% rename from ge/graph/common/transop_util.cc rename to src/ge/graph/common/transop_util.cc diff --git a/ge/graph/common/transop_util.h b/src/ge/graph/common/transop_util.h similarity index 100% rename from ge/graph/common/transop_util.h rename to src/ge/graph/common/transop_util.h diff --git a/ge/graph/execute/graph_execute.cc b/src/ge/graph/execute/graph_execute.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/execute/graph_execute.cc rename to src/ge/graph/execute/graph_execute.cc diff --git a/ge/graph/execute/graph_execute.h b/src/ge/graph/execute/graph_execute.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/execute/graph_execute.h rename to src/ge/graph/execute/graph_execute.h diff --git a/ge/graph/label/case_label_maker.cc b/src/ge/graph/label/case_label_maker.cc similarity index 100% rename from ge/graph/label/case_label_maker.cc rename to src/ge/graph/label/case_label_maker.cc diff --git a/ge/graph/label/case_label_maker.h b/src/ge/graph/label/case_label_maker.h similarity index 100% rename from ge/graph/label/case_label_maker.h rename to src/ge/graph/label/case_label_maker.h diff --git a/ge/graph/label/if_label_maker.cc b/src/ge/graph/label/if_label_maker.cc similarity index 100% rename from ge/graph/label/if_label_maker.cc rename to src/ge/graph/label/if_label_maker.cc diff --git a/ge/graph/label/if_label_maker.h b/src/ge/graph/label/if_label_maker.h similarity index 100% rename from ge/graph/label/if_label_maker.h rename to src/ge/graph/label/if_label_maker.h diff --git a/ge/graph/label/label_maker.cc b/src/ge/graph/label/label_maker.cc similarity index 100% rename from ge/graph/label/label_maker.cc rename to src/ge/graph/label/label_maker.cc diff --git a/ge/graph/label/label_maker.h b/src/ge/graph/label/label_maker.h similarity index 100% rename from ge/graph/label/label_maker.h rename to src/ge/graph/label/label_maker.h diff --git a/ge/graph/label/label_maker_factory.h b/src/ge/graph/label/label_maker_factory.h similarity index 100% rename from ge/graph/label/label_maker_factory.h rename to src/ge/graph/label/label_maker_factory.h diff --git a/ge/graph/label/partitioned_call_label_maker.cc b/src/ge/graph/label/partitioned_call_label_maker.cc similarity index 100% rename from ge/graph/label/partitioned_call_label_maker.cc rename to src/ge/graph/label/partitioned_call_label_maker.cc diff --git a/ge/graph/label/partitioned_call_label_maker.h b/src/ge/graph/label/partitioned_call_label_maker.h similarity index 100% rename from ge/graph/label/partitioned_call_label_maker.h rename to src/ge/graph/label/partitioned_call_label_maker.h diff --git a/ge/graph/label/while_label_maker.cc b/src/ge/graph/label/while_label_maker.cc similarity index 100% rename from ge/graph/label/while_label_maker.cc rename to src/ge/graph/label/while_label_maker.cc diff --git a/ge/graph/label/while_label_maker.h b/src/ge/graph/label/while_label_maker.h similarity index 100% rename from ge/graph/label/while_label_maker.h rename to src/ge/graph/label/while_label_maker.h diff --git a/ge/graph/load/graph_loader.cc b/src/ge/graph/load/graph_loader.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/graph_loader.cc rename to src/ge/graph/load/graph_loader.cc diff --git a/ge/graph/load/graph_loader.h b/src/ge/graph/load/graph_loader.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/graph_loader.h rename to src/ge/graph/load/graph_loader.h diff --git a/ge/graph/load/new_model_manager/aipp_utils.cc b/src/ge/graph/load/new_model_manager/aipp_utils.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/aipp_utils.cc rename to src/ge/graph/load/new_model_manager/aipp_utils.cc diff --git a/ge/graph/load/new_model_manager/aipp_utils.h b/src/ge/graph/load/new_model_manager/aipp_utils.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/aipp_utils.h rename to src/ge/graph/load/new_model_manager/aipp_utils.h diff --git a/ge/graph/load/new_model_manager/cpu_queue_schedule.cc b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc similarity index 100% rename from ge/graph/load/new_model_manager/cpu_queue_schedule.cc rename to src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc diff --git a/ge/graph/load/new_model_manager/cpu_queue_schedule.h b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.h similarity index 100% rename from ge/graph/load/new_model_manager/cpu_queue_schedule.h rename to src/ge/graph/load/new_model_manager/cpu_queue_schedule.h diff --git a/ge/graph/load/new_model_manager/data_dumper.cc b/src/ge/graph/load/new_model_manager/data_dumper.cc similarity index 100% rename from ge/graph/load/new_model_manager/data_dumper.cc rename to src/ge/graph/load/new_model_manager/data_dumper.cc diff --git a/ge/graph/load/new_model_manager/data_dumper.h b/src/ge/graph/load/new_model_manager/data_dumper.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/data_dumper.h rename to src/ge/graph/load/new_model_manager/data_dumper.h diff --git a/ge/graph/load/new_model_manager/data_inputer.cc b/src/ge/graph/load/new_model_manager/data_inputer.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/data_inputer.cc rename to src/ge/graph/load/new_model_manager/data_inputer.cc diff --git a/ge/graph/load/new_model_manager/data_inputer.h b/src/ge/graph/load/new_model_manager/data_inputer.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/data_inputer.h rename to src/ge/graph/load/new_model_manager/data_inputer.h diff --git a/ge/graph/load/new_model_manager/davinci_model.cc b/src/ge/graph/load/new_model_manager/davinci_model.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/davinci_model.cc rename to src/ge/graph/load/new_model_manager/davinci_model.cc diff --git a/ge/graph/load/new_model_manager/davinci_model.h b/src/ge/graph/load/new_model_manager/davinci_model.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/davinci_model.h rename to src/ge/graph/load/new_model_manager/davinci_model.h diff --git a/ge/graph/load/new_model_manager/davinci_model_parser.cc b/src/ge/graph/load/new_model_manager/davinci_model_parser.cc similarity index 100% rename from ge/graph/load/new_model_manager/davinci_model_parser.cc rename to src/ge/graph/load/new_model_manager/davinci_model_parser.cc diff --git a/ge/graph/load/new_model_manager/davinci_model_parser.h b/src/ge/graph/load/new_model_manager/davinci_model_parser.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/davinci_model_parser.h rename to src/ge/graph/load/new_model_manager/davinci_model_parser.h diff --git a/ge/graph/load/new_model_manager/model_manager.cc b/src/ge/graph/load/new_model_manager/model_manager.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/model_manager.cc rename to src/ge/graph/load/new_model_manager/model_manager.cc diff --git a/ge/graph/load/new_model_manager/model_manager.h b/src/ge/graph/load/new_model_manager/model_manager.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/model_manager.h rename to src/ge/graph/load/new_model_manager/model_manager.h diff --git a/ge/graph/load/new_model_manager/model_utils.cc b/src/ge/graph/load/new_model_manager/model_utils.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/model_utils.cc rename to src/ge/graph/load/new_model_manager/model_utils.cc diff --git a/ge/graph/load/new_model_manager/model_utils.h b/src/ge/graph/load/new_model_manager/model_utils.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/model_utils.h rename to src/ge/graph/load/new_model_manager/model_utils.h diff --git a/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc similarity index 100% rename from ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h similarity index 100% rename from ge/graph/load/new_model_manager/task_info/end_graph_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/event_record_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/event_record_task_info.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/event_record_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/event_record_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/event_record_task_info.h b/src/ge/graph/load/new_model_manager/task_info/event_record_task_info.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/event_record_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/event_record_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/event_wait_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/event_wait_task_info.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/event_wait_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/event_wait_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/event_wait_task_info.h b/src/ge/graph/load/new_model_manager/task_info/event_wait_task_info.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/event_wait_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/event_wait_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/fusion_start_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.h b/src/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/fusion_start_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.h b/src/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc similarity index 100% rename from ge/graph/load/new_model_manager/task_info/hccl_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/hccl_task_info.h b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h similarity index 100% rename from ge/graph/load/new_model_manager/task_info/hccl_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc similarity index 100% rename from ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h similarity index 100% rename from ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/kernel_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/kernel_task_info.h b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h similarity index 100% rename from ge/graph/load/new_model_manager/task_info/kernel_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h b/src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/label_set_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/label_set_task_info.cc similarity index 100% rename from ge/graph/load/new_model_manager/task_info/label_set_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/label_set_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/label_set_task_info.h b/src/ge/graph/load/new_model_manager/task_info/label_set_task_info.h similarity index 100% rename from ge/graph/load/new_model_manager/task_info/label_set_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/label_set_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc similarity index 100% rename from ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h similarity index 100% rename from ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h similarity index 100% rename from ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.h b/src/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/stream_active_task_info.h b/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/stream_active_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc similarity index 100% rename from ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc similarity index 100% rename from ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc rename to src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc diff --git a/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h similarity index 100% rename from ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h rename to src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h diff --git a/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc similarity index 100% rename from ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc rename to src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc diff --git a/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h similarity index 100% rename from ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h rename to src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h diff --git a/ge/graph/load/new_model_manager/task_info/task_info.cc b/src/ge/graph/load/new_model_manager/task_info/task_info.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/task_info/task_info.cc rename to src/ge/graph/load/new_model_manager/task_info/task_info.cc diff --git a/ge/graph/load/new_model_manager/task_info/task_info.h b/src/ge/graph/load/new_model_manager/task_info/task_info.h similarity index 100% rename from ge/graph/load/new_model_manager/task_info/task_info.h rename to src/ge/graph/load/new_model_manager/task_info/task_info.h diff --git a/ge/graph/load/new_model_manager/task_info/task_info_factory.h b/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h similarity index 100% rename from ge/graph/load/new_model_manager/task_info/task_info_factory.h rename to src/ge/graph/load/new_model_manager/task_info/task_info_factory.h diff --git a/ge/graph/load/new_model_manager/tbe_handle_store.cc b/src/ge/graph/load/new_model_manager/tbe_handle_store.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/tbe_handle_store.cc rename to src/ge/graph/load/new_model_manager/tbe_handle_store.cc diff --git a/ge/graph/load/new_model_manager/tbe_handle_store.h b/src/ge/graph/load/new_model_manager/tbe_handle_store.h similarity index 100% rename from ge/graph/load/new_model_manager/tbe_handle_store.h rename to src/ge/graph/load/new_model_manager/tbe_handle_store.h diff --git a/ge/graph/load/new_model_manager/zero_copy_offset.cc b/src/ge/graph/load/new_model_manager/zero_copy_offset.cc similarity index 100% rename from ge/graph/load/new_model_manager/zero_copy_offset.cc rename to src/ge/graph/load/new_model_manager/zero_copy_offset.cc diff --git a/ge/graph/load/new_model_manager/zero_copy_offset.h b/src/ge/graph/load/new_model_manager/zero_copy_offset.h similarity index 100% rename from ge/graph/load/new_model_manager/zero_copy_offset.h rename to src/ge/graph/load/new_model_manager/zero_copy_offset.h diff --git a/ge/graph/load/new_model_manager/zero_copy_task.cc b/src/ge/graph/load/new_model_manager/zero_copy_task.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/load/new_model_manager/zero_copy_task.cc rename to src/ge/graph/load/new_model_manager/zero_copy_task.cc diff --git a/ge/graph/load/new_model_manager/zero_copy_task.h b/src/ge/graph/load/new_model_manager/zero_copy_task.h similarity index 100% rename from ge/graph/load/new_model_manager/zero_copy_task.h rename to src/ge/graph/load/new_model_manager/zero_copy_task.h diff --git a/ge/graph/manager/block_memory.h b/src/ge/graph/manager/block_memory.h similarity index 100% rename from ge/graph/manager/block_memory.h rename to src/ge/graph/manager/block_memory.h diff --git a/ge/graph/manager/graph_caching_allocator.cc b/src/ge/graph/manager/graph_caching_allocator.cc similarity index 100% rename from ge/graph/manager/graph_caching_allocator.cc rename to src/ge/graph/manager/graph_caching_allocator.cc diff --git a/ge/graph/manager/graph_caching_allocator.h b/src/ge/graph/manager/graph_caching_allocator.h similarity index 100% rename from ge/graph/manager/graph_caching_allocator.h rename to src/ge/graph/manager/graph_caching_allocator.h diff --git a/ge/graph/manager/graph_context.cc b/src/ge/graph/manager/graph_context.cc similarity index 100% rename from ge/graph/manager/graph_context.cc rename to src/ge/graph/manager/graph_context.cc diff --git a/ge/graph/manager/graph_context.h b/src/ge/graph/manager/graph_context.h similarity index 100% rename from ge/graph/manager/graph_context.h rename to src/ge/graph/manager/graph_context.h diff --git a/ge/graph/manager/graph_manager.cc b/src/ge/graph/manager/graph_manager.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/manager/graph_manager.cc rename to src/ge/graph/manager/graph_manager.cc diff --git a/ge/graph/manager/graph_manager.h b/src/ge/graph/manager/graph_manager.h similarity index 100% rename from ge/graph/manager/graph_manager.h rename to src/ge/graph/manager/graph_manager.h diff --git a/ge/graph/manager/graph_manager_utils.cc b/src/ge/graph/manager/graph_manager_utils.cc similarity index 100% rename from ge/graph/manager/graph_manager_utils.cc rename to src/ge/graph/manager/graph_manager_utils.cc diff --git a/ge/graph/manager/graph_manager_utils.h b/src/ge/graph/manager/graph_manager_utils.h similarity index 100% rename from ge/graph/manager/graph_manager_utils.h rename to src/ge/graph/manager/graph_manager_utils.h diff --git a/ge/graph/manager/graph_mem_allocator.cc b/src/ge/graph/manager/graph_mem_allocator.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/manager/graph_mem_allocator.cc rename to src/ge/graph/manager/graph_mem_allocator.cc diff --git a/ge/graph/manager/graph_mem_allocator.h b/src/ge/graph/manager/graph_mem_allocator.h similarity index 100% rename from ge/graph/manager/graph_mem_allocator.h rename to src/ge/graph/manager/graph_mem_allocator.h diff --git a/ge/graph/manager/graph_var_manager.cc b/src/ge/graph/manager/graph_var_manager.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/manager/graph_var_manager.cc rename to src/ge/graph/manager/graph_var_manager.cc diff --git a/ge/graph/manager/graph_var_manager.h b/src/ge/graph/manager/graph_var_manager.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/manager/graph_var_manager.h rename to src/ge/graph/manager/graph_var_manager.h diff --git a/ge/graph/manager/host_mem_manager.cc b/src/ge/graph/manager/host_mem_manager.cc similarity index 100% rename from ge/graph/manager/host_mem_manager.cc rename to src/ge/graph/manager/host_mem_manager.cc diff --git a/ge/graph/manager/host_mem_manager.h b/src/ge/graph/manager/host_mem_manager.h similarity index 100% rename from ge/graph/manager/host_mem_manager.h rename to src/ge/graph/manager/host_mem_manager.h diff --git a/ge/graph/manager/memory_api.cc b/src/ge/graph/manager/memory_api.cc similarity index 100% rename from ge/graph/manager/memory_api.cc rename to src/ge/graph/manager/memory_api.cc diff --git a/ge/graph/manager/model_manager/event_manager.cc b/src/ge/graph/manager/model_manager/event_manager.cc similarity index 100% rename from ge/graph/manager/model_manager/event_manager.cc rename to src/ge/graph/manager/model_manager/event_manager.cc diff --git a/ge/graph/manager/model_manager/event_manager.h b/src/ge/graph/manager/model_manager/event_manager.h similarity index 100% rename from ge/graph/manager/model_manager/event_manager.h rename to src/ge/graph/manager/model_manager/event_manager.h diff --git a/ge/graph/manager/rdma_pool_allocator.cc b/src/ge/graph/manager/rdma_pool_allocator.cc similarity index 100% rename from ge/graph/manager/rdma_pool_allocator.cc rename to src/ge/graph/manager/rdma_pool_allocator.cc diff --git a/ge/graph/manager/rdma_pool_allocator.h b/src/ge/graph/manager/rdma_pool_allocator.h similarity index 100% rename from ge/graph/manager/rdma_pool_allocator.h rename to src/ge/graph/manager/rdma_pool_allocator.h diff --git a/ge/graph/manager/trans_var_data_utils.cc b/src/ge/graph/manager/trans_var_data_utils.cc similarity index 100% rename from ge/graph/manager/trans_var_data_utils.cc rename to src/ge/graph/manager/trans_var_data_utils.cc diff --git a/ge/graph/manager/trans_var_data_utils.h b/src/ge/graph/manager/trans_var_data_utils.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/manager/trans_var_data_utils.h rename to src/ge/graph/manager/trans_var_data_utils.h diff --git a/ge/graph/manager/util/debug.cc b/src/ge/graph/manager/util/debug.cc similarity index 100% rename from ge/graph/manager/util/debug.cc rename to src/ge/graph/manager/util/debug.cc diff --git a/ge/graph/manager/util/debug.h b/src/ge/graph/manager/util/debug.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/manager/util/debug.h rename to src/ge/graph/manager/util/debug.h diff --git a/ge/graph/manager/util/hcom_util.cc b/src/ge/graph/manager/util/hcom_util.cc similarity index 100% rename from ge/graph/manager/util/hcom_util.cc rename to src/ge/graph/manager/util/hcom_util.cc diff --git a/ge/graph/manager/util/hcom_util.h b/src/ge/graph/manager/util/hcom_util.h similarity index 100% rename from ge/graph/manager/util/hcom_util.h rename to src/ge/graph/manager/util/hcom_util.h diff --git a/ge/graph/manager/util/rt_context_util.cc b/src/ge/graph/manager/util/rt_context_util.cc similarity index 100% rename from ge/graph/manager/util/rt_context_util.cc rename to src/ge/graph/manager/util/rt_context_util.cc diff --git a/ge/graph/manager/util/rt_context_util.h b/src/ge/graph/manager/util/rt_context_util.h similarity index 100% rename from ge/graph/manager/util/rt_context_util.h rename to src/ge/graph/manager/util/rt_context_util.h diff --git a/ge/graph/manager/util/variable_accelerate_ctrl.cc b/src/ge/graph/manager/util/variable_accelerate_ctrl.cc similarity index 100% rename from ge/graph/manager/util/variable_accelerate_ctrl.cc rename to src/ge/graph/manager/util/variable_accelerate_ctrl.cc diff --git a/ge/graph/manager/util/variable_accelerate_ctrl.h b/src/ge/graph/manager/util/variable_accelerate_ctrl.h similarity index 100% rename from ge/graph/manager/util/variable_accelerate_ctrl.h rename to src/ge/graph/manager/util/variable_accelerate_ctrl.h diff --git a/ge/graph/optimize/common/params.h b/src/ge/graph/optimize/common/params.h similarity index 100% rename from ge/graph/optimize/common/params.h rename to src/ge/graph/optimize/common/params.h diff --git a/ge/graph/optimize/graph_optimize.cc b/src/ge/graph/optimize/graph_optimize.cc similarity index 100% rename from ge/graph/optimize/graph_optimize.cc rename to src/ge/graph/optimize/graph_optimize.cc diff --git a/ge/graph/optimize/graph_optimize.h b/src/ge/graph/optimize/graph_optimize.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/optimize/graph_optimize.h rename to src/ge/graph/optimize/graph_optimize.h diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/src/ge/graph/optimize/mem_rw_conflict_optimize.cc similarity index 100% rename from ge/graph/optimize/mem_rw_conflict_optimize.cc rename to src/ge/graph/optimize/mem_rw_conflict_optimize.cc diff --git a/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc b/src/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc similarity index 100% rename from ge/graph/optimize/optimizer/allreduce_fusion_pass.cc rename to src/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc diff --git a/ge/graph/optimize/optimizer/allreduce_fusion_pass.h b/src/ge/graph/optimize/optimizer/allreduce_fusion_pass.h similarity index 100% rename from ge/graph/optimize/optimizer/allreduce_fusion_pass.h rename to src/ge/graph/optimize/optimizer/allreduce_fusion_pass.h diff --git a/ge/graph/optimize/summary_optimize.cc b/src/ge/graph/optimize/summary_optimize.cc similarity index 100% rename from ge/graph/optimize/summary_optimize.cc rename to src/ge/graph/optimize/summary_optimize.cc diff --git a/ge/graph/partition/dynamic_shape_partition.cc b/src/ge/graph/partition/dynamic_shape_partition.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/partition/dynamic_shape_partition.cc rename to src/ge/graph/partition/dynamic_shape_partition.cc diff --git a/ge/graph/partition/dynamic_shape_partition.h b/src/ge/graph/partition/dynamic_shape_partition.h similarity index 100% rename from ge/graph/partition/dynamic_shape_partition.h rename to src/ge/graph/partition/dynamic_shape_partition.h diff --git a/ge/graph/partition/engine_place.cc b/src/ge/graph/partition/engine_place.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/partition/engine_place.cc rename to src/ge/graph/partition/engine_place.cc diff --git a/ge/graph/partition/engine_place.h b/src/ge/graph/partition/engine_place.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/partition/engine_place.h rename to src/ge/graph/partition/engine_place.h diff --git a/ge/graph/partition/graph_partition.cc b/src/ge/graph/partition/graph_partition.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/partition/graph_partition.cc rename to src/ge/graph/partition/graph_partition.cc diff --git a/ge/graph/partition/graph_partition.h b/src/ge/graph/partition/graph_partition.h similarity index 100% rename from ge/graph/partition/graph_partition.h rename to src/ge/graph/partition/graph_partition.h diff --git a/ge/graph/passes/addn_pass.cc b/src/ge/graph/passes/addn_pass.cc similarity index 100% rename from ge/graph/passes/addn_pass.cc rename to src/ge/graph/passes/addn_pass.cc diff --git a/ge/graph/passes/addn_pass.h b/src/ge/graph/passes/addn_pass.h similarity index 100% rename from ge/graph/passes/addn_pass.h rename to src/ge/graph/passes/addn_pass.h diff --git a/ge/graph/passes/aicpu_constant_folding_pass.cc b/src/ge/graph/passes/aicpu_constant_folding_pass.cc similarity index 100% rename from ge/graph/passes/aicpu_constant_folding_pass.cc rename to src/ge/graph/passes/aicpu_constant_folding_pass.cc diff --git a/ge/graph/passes/aicpu_constant_folding_pass.h b/src/ge/graph/passes/aicpu_constant_folding_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/aicpu_constant_folding_pass.h rename to src/ge/graph/passes/aicpu_constant_folding_pass.h diff --git a/ge/graph/passes/assert_pass.cc b/src/ge/graph/passes/assert_pass.cc similarity index 100% rename from ge/graph/passes/assert_pass.cc rename to src/ge/graph/passes/assert_pass.cc diff --git a/ge/graph/passes/assert_pass.h b/src/ge/graph/passes/assert_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/assert_pass.h rename to src/ge/graph/passes/assert_pass.h diff --git a/ge/graph/passes/assign_pass.cc b/src/ge/graph/passes/assign_pass.cc similarity index 100% rename from ge/graph/passes/assign_pass.cc rename to src/ge/graph/passes/assign_pass.cc diff --git a/ge/graph/passes/assign_pass.h b/src/ge/graph/passes/assign_pass.h similarity index 100% rename from ge/graph/passes/assign_pass.h rename to src/ge/graph/passes/assign_pass.h diff --git a/ge/graph/passes/atomic_addr_clean_pass.cc b/src/ge/graph/passes/atomic_addr_clean_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/atomic_addr_clean_pass.cc rename to src/ge/graph/passes/atomic_addr_clean_pass.cc diff --git a/ge/graph/passes/atomic_addr_clean_pass.h b/src/ge/graph/passes/atomic_addr_clean_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/atomic_addr_clean_pass.h rename to src/ge/graph/passes/atomic_addr_clean_pass.h diff --git a/ge/graph/passes/attach_stream_label_pass.cc b/src/ge/graph/passes/attach_stream_label_pass.cc similarity index 100% rename from ge/graph/passes/attach_stream_label_pass.cc rename to src/ge/graph/passes/attach_stream_label_pass.cc diff --git a/ge/graph/passes/attach_stream_label_pass.h b/src/ge/graph/passes/attach_stream_label_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/attach_stream_label_pass.h rename to src/ge/graph/passes/attach_stream_label_pass.h diff --git a/ge/graph/passes/base_pass.cc b/src/ge/graph/passes/base_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/base_pass.cc rename to src/ge/graph/passes/base_pass.cc diff --git a/ge/graph/passes/base_pass.h b/src/ge/graph/passes/base_pass.h similarity index 100% rename from ge/graph/passes/base_pass.h rename to src/ge/graph/passes/base_pass.h diff --git a/ge/graph/passes/bitcast_pass.cc b/src/ge/graph/passes/bitcast_pass.cc similarity index 100% rename from ge/graph/passes/bitcast_pass.cc rename to src/ge/graph/passes/bitcast_pass.cc diff --git a/ge/graph/passes/bitcast_pass.h b/src/ge/graph/passes/bitcast_pass.h similarity index 100% rename from ge/graph/passes/bitcast_pass.h rename to src/ge/graph/passes/bitcast_pass.h diff --git a/ge/graph/passes/cast_remove_pass.cc b/src/ge/graph/passes/cast_remove_pass.cc similarity index 100% rename from ge/graph/passes/cast_remove_pass.cc rename to src/ge/graph/passes/cast_remove_pass.cc diff --git a/ge/graph/passes/cast_remove_pass.h b/src/ge/graph/passes/cast_remove_pass.h similarity index 100% rename from ge/graph/passes/cast_remove_pass.h rename to src/ge/graph/passes/cast_remove_pass.h diff --git a/ge/graph/passes/cast_translate_pass.cc b/src/ge/graph/passes/cast_translate_pass.cc similarity index 100% rename from ge/graph/passes/cast_translate_pass.cc rename to src/ge/graph/passes/cast_translate_pass.cc diff --git a/ge/graph/passes/cast_translate_pass.h b/src/ge/graph/passes/cast_translate_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/cast_translate_pass.h rename to src/ge/graph/passes/cast_translate_pass.h diff --git a/ge/graph/passes/common_subexpression_elimination_pass.cc b/src/ge/graph/passes/common_subexpression_elimination_pass.cc similarity index 100% rename from ge/graph/passes/common_subexpression_elimination_pass.cc rename to src/ge/graph/passes/common_subexpression_elimination_pass.cc diff --git a/ge/graph/passes/common_subexpression_elimination_pass.h b/src/ge/graph/passes/common_subexpression_elimination_pass.h similarity index 100% rename from ge/graph/passes/common_subexpression_elimination_pass.h rename to src/ge/graph/passes/common_subexpression_elimination_pass.h diff --git a/ge/graph/passes/compile_nodes_pass.cc b/src/ge/graph/passes/compile_nodes_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/compile_nodes_pass.cc rename to src/ge/graph/passes/compile_nodes_pass.cc diff --git a/ge/graph/passes/compile_nodes_pass.h b/src/ge/graph/passes/compile_nodes_pass.h similarity index 100% rename from ge/graph/passes/compile_nodes_pass.h rename to src/ge/graph/passes/compile_nodes_pass.h diff --git a/ge/graph/passes/cond_pass.cc b/src/ge/graph/passes/cond_pass.cc similarity index 100% rename from ge/graph/passes/cond_pass.cc rename to src/ge/graph/passes/cond_pass.cc diff --git a/ge/graph/passes/cond_pass.h b/src/ge/graph/passes/cond_pass.h similarity index 100% rename from ge/graph/passes/cond_pass.h rename to src/ge/graph/passes/cond_pass.h diff --git a/ge/graph/passes/cond_remove_pass.cc b/src/ge/graph/passes/cond_remove_pass.cc similarity index 100% rename from ge/graph/passes/cond_remove_pass.cc rename to src/ge/graph/passes/cond_remove_pass.cc diff --git a/ge/graph/passes/cond_remove_pass.h b/src/ge/graph/passes/cond_remove_pass.h similarity index 100% rename from ge/graph/passes/cond_remove_pass.h rename to src/ge/graph/passes/cond_remove_pass.h diff --git a/ge/graph/passes/constant_folding_pass.cc b/src/ge/graph/passes/constant_folding_pass.cc similarity index 100% rename from ge/graph/passes/constant_folding_pass.cc rename to src/ge/graph/passes/constant_folding_pass.cc diff --git a/ge/graph/passes/constant_folding_pass.h b/src/ge/graph/passes/constant_folding_pass.h similarity index 100% rename from ge/graph/passes/constant_folding_pass.h rename to src/ge/graph/passes/constant_folding_pass.h diff --git a/ge/graph/passes/constant_fuse_same_pass.cc b/src/ge/graph/passes/constant_fuse_same_pass.cc similarity index 100% rename from ge/graph/passes/constant_fuse_same_pass.cc rename to src/ge/graph/passes/constant_fuse_same_pass.cc diff --git a/ge/graph/passes/constant_fuse_same_pass.h b/src/ge/graph/passes/constant_fuse_same_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/constant_fuse_same_pass.h rename to src/ge/graph/passes/constant_fuse_same_pass.h diff --git a/ge/graph/passes/control_trigger_pass.cc b/src/ge/graph/passes/control_trigger_pass.cc similarity index 100% rename from ge/graph/passes/control_trigger_pass.cc rename to src/ge/graph/passes/control_trigger_pass.cc diff --git a/ge/graph/passes/control_trigger_pass.h b/src/ge/graph/passes/control_trigger_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/control_trigger_pass.h rename to src/ge/graph/passes/control_trigger_pass.h diff --git a/ge/graph/passes/ctrl_edge_transfer_pass.cc b/src/ge/graph/passes/ctrl_edge_transfer_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/ctrl_edge_transfer_pass.cc rename to src/ge/graph/passes/ctrl_edge_transfer_pass.cc diff --git a/ge/graph/passes/ctrl_edge_transfer_pass.h b/src/ge/graph/passes/ctrl_edge_transfer_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/ctrl_edge_transfer_pass.h rename to src/ge/graph/passes/ctrl_edge_transfer_pass.h diff --git a/ge/graph/passes/data_pass.cc b/src/ge/graph/passes/data_pass.cc similarity index 100% rename from ge/graph/passes/data_pass.cc rename to src/ge/graph/passes/data_pass.cc diff --git a/ge/graph/passes/data_pass.h b/src/ge/graph/passes/data_pass.h similarity index 100% rename from ge/graph/passes/data_pass.h rename to src/ge/graph/passes/data_pass.h diff --git a/ge/graph/passes/dimension_adjust_pass.cc b/src/ge/graph/passes/dimension_adjust_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/dimension_adjust_pass.cc rename to src/ge/graph/passes/dimension_adjust_pass.cc diff --git a/ge/graph/passes/dimension_adjust_pass.h b/src/ge/graph/passes/dimension_adjust_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/dimension_adjust_pass.h rename to src/ge/graph/passes/dimension_adjust_pass.h diff --git a/ge/graph/passes/dimension_compute_pass.cc b/src/ge/graph/passes/dimension_compute_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/dimension_compute_pass.cc rename to src/ge/graph/passes/dimension_compute_pass.cc diff --git a/ge/graph/passes/dimension_compute_pass.h b/src/ge/graph/passes/dimension_compute_pass.h similarity index 100% rename from ge/graph/passes/dimension_compute_pass.h rename to src/ge/graph/passes/dimension_compute_pass.h diff --git a/ge/graph/passes/dropout_pass.cc b/src/ge/graph/passes/dropout_pass.cc similarity index 100% rename from ge/graph/passes/dropout_pass.cc rename to src/ge/graph/passes/dropout_pass.cc diff --git a/ge/graph/passes/dropout_pass.h b/src/ge/graph/passes/dropout_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/dropout_pass.h rename to src/ge/graph/passes/dropout_pass.h diff --git a/ge/graph/passes/end_of_sequence_add_control_pass.cc b/src/ge/graph/passes/end_of_sequence_add_control_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/end_of_sequence_add_control_pass.cc rename to src/ge/graph/passes/end_of_sequence_add_control_pass.cc diff --git a/ge/graph/passes/end_of_sequence_add_control_pass.h b/src/ge/graph/passes/end_of_sequence_add_control_pass.h similarity index 100% rename from ge/graph/passes/end_of_sequence_add_control_pass.h rename to src/ge/graph/passes/end_of_sequence_add_control_pass.h diff --git a/ge/graph/passes/enter_pass.cc b/src/ge/graph/passes/enter_pass.cc similarity index 100% rename from ge/graph/passes/enter_pass.cc rename to src/ge/graph/passes/enter_pass.cc diff --git a/ge/graph/passes/enter_pass.h b/src/ge/graph/passes/enter_pass.h similarity index 100% rename from ge/graph/passes/enter_pass.h rename to src/ge/graph/passes/enter_pass.h diff --git a/ge/graph/passes/flow_ctrl_pass.cc b/src/ge/graph/passes/flow_ctrl_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/flow_ctrl_pass.cc rename to src/ge/graph/passes/flow_ctrl_pass.cc diff --git a/ge/graph/passes/flow_ctrl_pass.h b/src/ge/graph/passes/flow_ctrl_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/flow_ctrl_pass.h rename to src/ge/graph/passes/flow_ctrl_pass.h diff --git a/ge/graph/passes/folding_pass.cc b/src/ge/graph/passes/folding_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/folding_pass.cc rename to src/ge/graph/passes/folding_pass.cc diff --git a/ge/graph/passes/folding_pass.h b/src/ge/graph/passes/folding_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/folding_pass.h rename to src/ge/graph/passes/folding_pass.h diff --git a/ge/graph/passes/for_pass.cc b/src/ge/graph/passes/for_pass.cc similarity index 100% rename from ge/graph/passes/for_pass.cc rename to src/ge/graph/passes/for_pass.cc diff --git a/ge/graph/passes/for_pass.h b/src/ge/graph/passes/for_pass.h similarity index 100% rename from ge/graph/passes/for_pass.h rename to src/ge/graph/passes/for_pass.h diff --git a/ge/graph/passes/get_original_format_pass.cc b/src/ge/graph/passes/get_original_format_pass.cc similarity index 100% rename from ge/graph/passes/get_original_format_pass.cc rename to src/ge/graph/passes/get_original_format_pass.cc diff --git a/ge/graph/passes/get_original_format_pass.h b/src/ge/graph/passes/get_original_format_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/get_original_format_pass.h rename to src/ge/graph/passes/get_original_format_pass.h diff --git a/ge/graph/passes/global_step_insert_pass.cc b/src/ge/graph/passes/global_step_insert_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/global_step_insert_pass.cc rename to src/ge/graph/passes/global_step_insert_pass.cc diff --git a/ge/graph/passes/global_step_insert_pass.h b/src/ge/graph/passes/global_step_insert_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/global_step_insert_pass.h rename to src/ge/graph/passes/global_step_insert_pass.h diff --git a/ge/graph/passes/guarantee_const_pass.cc b/src/ge/graph/passes/guarantee_const_pass.cc similarity index 100% rename from ge/graph/passes/guarantee_const_pass.cc rename to src/ge/graph/passes/guarantee_const_pass.cc diff --git a/ge/graph/passes/guarantee_const_pass.h b/src/ge/graph/passes/guarantee_const_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/guarantee_const_pass.h rename to src/ge/graph/passes/guarantee_const_pass.h diff --git a/ge/graph/passes/hccl_group_pass.cc b/src/ge/graph/passes/hccl_group_pass.cc similarity index 100% rename from ge/graph/passes/hccl_group_pass.cc rename to src/ge/graph/passes/hccl_group_pass.cc diff --git a/ge/graph/passes/hccl_group_pass.h b/src/ge/graph/passes/hccl_group_pass.h similarity index 100% rename from ge/graph/passes/hccl_group_pass.h rename to src/ge/graph/passes/hccl_group_pass.h diff --git a/ge/graph/passes/hccl_memcpy_pass.cc b/src/ge/graph/passes/hccl_memcpy_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/hccl_memcpy_pass.cc rename to src/ge/graph/passes/hccl_memcpy_pass.cc diff --git a/ge/graph/passes/hccl_memcpy_pass.h b/src/ge/graph/passes/hccl_memcpy_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/hccl_memcpy_pass.h rename to src/ge/graph/passes/hccl_memcpy_pass.h diff --git a/ge/graph/passes/identity_pass.cc b/src/ge/graph/passes/identity_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/identity_pass.cc rename to src/ge/graph/passes/identity_pass.cc diff --git a/ge/graph/passes/identity_pass.h b/src/ge/graph/passes/identity_pass.h similarity index 100% rename from ge/graph/passes/identity_pass.h rename to src/ge/graph/passes/identity_pass.h diff --git a/ge/graph/passes/infershape_pass.cc b/src/ge/graph/passes/infershape_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/infershape_pass.cc rename to src/ge/graph/passes/infershape_pass.cc diff --git a/ge/graph/passes/infershape_pass.h b/src/ge/graph/passes/infershape_pass.h similarity index 100% rename from ge/graph/passes/infershape_pass.h rename to src/ge/graph/passes/infershape_pass.h diff --git a/ge/graph/passes/input_output_connection_identify_pass.cc b/src/ge/graph/passes/input_output_connection_identify_pass.cc similarity index 100% rename from ge/graph/passes/input_output_connection_identify_pass.cc rename to src/ge/graph/passes/input_output_connection_identify_pass.cc diff --git a/ge/graph/passes/input_output_connection_identify_pass.h b/src/ge/graph/passes/input_output_connection_identify_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/input_output_connection_identify_pass.h rename to src/ge/graph/passes/input_output_connection_identify_pass.h diff --git a/ge/graph/passes/isolated_op_remove_pass.cc b/src/ge/graph/passes/isolated_op_remove_pass.cc similarity index 100% rename from ge/graph/passes/isolated_op_remove_pass.cc rename to src/ge/graph/passes/isolated_op_remove_pass.cc diff --git a/ge/graph/passes/isolated_op_remove_pass.h b/src/ge/graph/passes/isolated_op_remove_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/isolated_op_remove_pass.h rename to src/ge/graph/passes/isolated_op_remove_pass.h diff --git a/ge/graph/passes/iterator_op_pass.cc b/src/ge/graph/passes/iterator_op_pass.cc similarity index 100% rename from ge/graph/passes/iterator_op_pass.cc rename to src/ge/graph/passes/iterator_op_pass.cc diff --git a/ge/graph/passes/iterator_op_pass.h b/src/ge/graph/passes/iterator_op_pass.h similarity index 100% rename from ge/graph/passes/iterator_op_pass.h rename to src/ge/graph/passes/iterator_op_pass.h diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/src/ge/graph/passes/link_gen_mask_nodes_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/link_gen_mask_nodes_pass.cc rename to src/ge/graph/passes/link_gen_mask_nodes_pass.cc diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.h b/src/ge/graph/passes/link_gen_mask_nodes_pass.h similarity index 100% rename from ge/graph/passes/link_gen_mask_nodes_pass.h rename to src/ge/graph/passes/link_gen_mask_nodes_pass.h diff --git a/ge/graph/passes/mark_agnostic_pass.cc b/src/ge/graph/passes/mark_agnostic_pass.cc similarity index 100% rename from ge/graph/passes/mark_agnostic_pass.cc rename to src/ge/graph/passes/mark_agnostic_pass.cc diff --git a/ge/graph/passes/mark_agnostic_pass.h b/src/ge/graph/passes/mark_agnostic_pass.h similarity index 100% rename from ge/graph/passes/mark_agnostic_pass.h rename to src/ge/graph/passes/mark_agnostic_pass.h diff --git a/ge/graph/passes/mark_graph_unknown_status_pass.cc b/src/ge/graph/passes/mark_graph_unknown_status_pass.cc similarity index 100% rename from ge/graph/passes/mark_graph_unknown_status_pass.cc rename to src/ge/graph/passes/mark_graph_unknown_status_pass.cc diff --git a/ge/graph/passes/mark_graph_unknown_status_pass.h b/src/ge/graph/passes/mark_graph_unknown_status_pass.h similarity index 100% rename from ge/graph/passes/mark_graph_unknown_status_pass.h rename to src/ge/graph/passes/mark_graph_unknown_status_pass.h diff --git a/ge/graph/passes/mark_same_addr_pass.cc b/src/ge/graph/passes/mark_same_addr_pass.cc similarity index 100% rename from ge/graph/passes/mark_same_addr_pass.cc rename to src/ge/graph/passes/mark_same_addr_pass.cc diff --git a/ge/graph/passes/mark_same_addr_pass.h b/src/ge/graph/passes/mark_same_addr_pass.h similarity index 100% rename from ge/graph/passes/mark_same_addr_pass.h rename to src/ge/graph/passes/mark_same_addr_pass.h diff --git a/ge/graph/passes/memcpy_addr_async_pass.cc b/src/ge/graph/passes/memcpy_addr_async_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/memcpy_addr_async_pass.cc rename to src/ge/graph/passes/memcpy_addr_async_pass.cc diff --git a/ge/graph/passes/memcpy_addr_async_pass.h b/src/ge/graph/passes/memcpy_addr_async_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/memcpy_addr_async_pass.h rename to src/ge/graph/passes/memcpy_addr_async_pass.h diff --git a/ge/graph/passes/merge_pass.cc b/src/ge/graph/passes/merge_pass.cc similarity index 100% rename from ge/graph/passes/merge_pass.cc rename to src/ge/graph/passes/merge_pass.cc diff --git a/ge/graph/passes/merge_pass.h b/src/ge/graph/passes/merge_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/merge_pass.h rename to src/ge/graph/passes/merge_pass.h diff --git a/ge/graph/passes/merge_to_stream_merge_pass.cc b/src/ge/graph/passes/merge_to_stream_merge_pass.cc similarity index 100% rename from ge/graph/passes/merge_to_stream_merge_pass.cc rename to src/ge/graph/passes/merge_to_stream_merge_pass.cc diff --git a/ge/graph/passes/merge_to_stream_merge_pass.h b/src/ge/graph/passes/merge_to_stream_merge_pass.h similarity index 100% rename from ge/graph/passes/merge_to_stream_merge_pass.h rename to src/ge/graph/passes/merge_to_stream_merge_pass.h diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/src/ge/graph/passes/multi_batch_clone_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/multi_batch_clone_pass.cc rename to src/ge/graph/passes/multi_batch_clone_pass.cc diff --git a/ge/graph/passes/multi_batch_clone_pass.h b/src/ge/graph/passes/multi_batch_clone_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/multi_batch_clone_pass.h rename to src/ge/graph/passes/multi_batch_clone_pass.h diff --git a/ge/graph/passes/multi_batch_pass.cc b/src/ge/graph/passes/multi_batch_pass.cc similarity index 100% rename from ge/graph/passes/multi_batch_pass.cc rename to src/ge/graph/passes/multi_batch_pass.cc diff --git a/ge/graph/passes/multi_batch_pass.h b/src/ge/graph/passes/multi_batch_pass.h similarity index 100% rename from ge/graph/passes/multi_batch_pass.h rename to src/ge/graph/passes/multi_batch_pass.h diff --git a/ge/graph/passes/net_output_pass.cc b/src/ge/graph/passes/net_output_pass.cc similarity index 100% rename from ge/graph/passes/net_output_pass.cc rename to src/ge/graph/passes/net_output_pass.cc diff --git a/ge/graph/passes/net_output_pass.h b/src/ge/graph/passes/net_output_pass.h similarity index 100% rename from ge/graph/passes/net_output_pass.h rename to src/ge/graph/passes/net_output_pass.h diff --git a/ge/graph/passes/next_iteration_pass.cc b/src/ge/graph/passes/next_iteration_pass.cc similarity index 100% rename from ge/graph/passes/next_iteration_pass.cc rename to src/ge/graph/passes/next_iteration_pass.cc diff --git a/ge/graph/passes/next_iteration_pass.h b/src/ge/graph/passes/next_iteration_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/next_iteration_pass.h rename to src/ge/graph/passes/next_iteration_pass.h diff --git a/ge/graph/passes/no_use_reshape_remove_pass.cc b/src/ge/graph/passes/no_use_reshape_remove_pass.cc similarity index 100% rename from ge/graph/passes/no_use_reshape_remove_pass.cc rename to src/ge/graph/passes/no_use_reshape_remove_pass.cc diff --git a/ge/graph/passes/no_use_reshape_remove_pass.h b/src/ge/graph/passes/no_use_reshape_remove_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/no_use_reshape_remove_pass.h rename to src/ge/graph/passes/no_use_reshape_remove_pass.h diff --git a/ge/graph/passes/parallel_concat_start_op_pass.cc b/src/ge/graph/passes/parallel_concat_start_op_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/parallel_concat_start_op_pass.cc rename to src/ge/graph/passes/parallel_concat_start_op_pass.cc diff --git a/ge/graph/passes/parallel_concat_start_op_pass.h b/src/ge/graph/passes/parallel_concat_start_op_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/parallel_concat_start_op_pass.h rename to src/ge/graph/passes/parallel_concat_start_op_pass.h diff --git a/ge/graph/passes/pass_manager.cc b/src/ge/graph/passes/pass_manager.cc similarity index 100% rename from ge/graph/passes/pass_manager.cc rename to src/ge/graph/passes/pass_manager.cc diff --git a/ge/graph/passes/pass_utils.cc b/src/ge/graph/passes/pass_utils.cc similarity index 100% rename from ge/graph/passes/pass_utils.cc rename to src/ge/graph/passes/pass_utils.cc diff --git a/ge/graph/passes/pass_utils.h b/src/ge/graph/passes/pass_utils.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/pass_utils.h rename to src/ge/graph/passes/pass_utils.h diff --git a/ge/graph/passes/permute_pass.cc b/src/ge/graph/passes/permute_pass.cc similarity index 100% rename from ge/graph/passes/permute_pass.cc rename to src/ge/graph/passes/permute_pass.cc diff --git a/ge/graph/passes/permute_pass.h b/src/ge/graph/passes/permute_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/permute_pass.h rename to src/ge/graph/passes/permute_pass.h diff --git a/ge/graph/passes/placeholder_with_default_pass.cc b/src/ge/graph/passes/placeholder_with_default_pass.cc similarity index 100% rename from ge/graph/passes/placeholder_with_default_pass.cc rename to src/ge/graph/passes/placeholder_with_default_pass.cc diff --git a/ge/graph/passes/placeholder_with_default_pass.h b/src/ge/graph/passes/placeholder_with_default_pass.h similarity index 100% rename from ge/graph/passes/placeholder_with_default_pass.h rename to src/ge/graph/passes/placeholder_with_default_pass.h diff --git a/ge/graph/passes/prevent_gradient_pass.cc b/src/ge/graph/passes/prevent_gradient_pass.cc similarity index 100% rename from ge/graph/passes/prevent_gradient_pass.cc rename to src/ge/graph/passes/prevent_gradient_pass.cc diff --git a/ge/graph/passes/prevent_gradient_pass.h b/src/ge/graph/passes/prevent_gradient_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/prevent_gradient_pass.h rename to src/ge/graph/passes/prevent_gradient_pass.h diff --git a/ge/graph/passes/print_op_pass.cc b/src/ge/graph/passes/print_op_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/print_op_pass.cc rename to src/ge/graph/passes/print_op_pass.cc diff --git a/ge/graph/passes/print_op_pass.h b/src/ge/graph/passes/print_op_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/print_op_pass.h rename to src/ge/graph/passes/print_op_pass.h diff --git a/ge/graph/passes/prune_pass.cc b/src/ge/graph/passes/prune_pass.cc similarity index 100% rename from ge/graph/passes/prune_pass.cc rename to src/ge/graph/passes/prune_pass.cc diff --git a/ge/graph/passes/prune_pass.h b/src/ge/graph/passes/prune_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/prune_pass.h rename to src/ge/graph/passes/prune_pass.h diff --git a/ge/graph/passes/ref_identity_delete_op_pass.cc b/src/ge/graph/passes/ref_identity_delete_op_pass.cc similarity index 100% rename from ge/graph/passes/ref_identity_delete_op_pass.cc rename to src/ge/graph/passes/ref_identity_delete_op_pass.cc diff --git a/ge/graph/passes/ref_identity_delete_op_pass.h b/src/ge/graph/passes/ref_identity_delete_op_pass.h similarity index 100% rename from ge/graph/passes/ref_identity_delete_op_pass.h rename to src/ge/graph/passes/ref_identity_delete_op_pass.h diff --git a/ge/graph/passes/remove_nodes_pass.cc b/src/ge/graph/passes/remove_nodes_pass.cc similarity index 100% rename from ge/graph/passes/remove_nodes_pass.cc rename to src/ge/graph/passes/remove_nodes_pass.cc diff --git a/ge/graph/passes/remove_nodes_pass.h b/src/ge/graph/passes/remove_nodes_pass.h similarity index 100% rename from ge/graph/passes/remove_nodes_pass.h rename to src/ge/graph/passes/remove_nodes_pass.h diff --git a/ge/graph/passes/replace_transshape_pass.cc b/src/ge/graph/passes/replace_transshape_pass.cc similarity index 100% rename from ge/graph/passes/replace_transshape_pass.cc rename to src/ge/graph/passes/replace_transshape_pass.cc diff --git a/ge/graph/passes/replace_transshape_pass.h b/src/ge/graph/passes/replace_transshape_pass.h similarity index 100% rename from ge/graph/passes/replace_transshape_pass.h rename to src/ge/graph/passes/replace_transshape_pass.h diff --git a/ge/graph/passes/replace_with_empty_const_pass.cc b/src/ge/graph/passes/replace_with_empty_const_pass.cc similarity index 100% rename from ge/graph/passes/replace_with_empty_const_pass.cc rename to src/ge/graph/passes/replace_with_empty_const_pass.cc diff --git a/ge/graph/passes/replace_with_empty_const_pass.h b/src/ge/graph/passes/replace_with_empty_const_pass.h similarity index 100% rename from ge/graph/passes/replace_with_empty_const_pass.h rename to src/ge/graph/passes/replace_with_empty_const_pass.h diff --git a/ge/graph/passes/reshape_recovery_pass.cc b/src/ge/graph/passes/reshape_recovery_pass.cc similarity index 100% rename from ge/graph/passes/reshape_recovery_pass.cc rename to src/ge/graph/passes/reshape_recovery_pass.cc diff --git a/ge/graph/passes/reshape_recovery_pass.h b/src/ge/graph/passes/reshape_recovery_pass.h similarity index 100% rename from ge/graph/passes/reshape_recovery_pass.h rename to src/ge/graph/passes/reshape_recovery_pass.h diff --git a/ge/graph/passes/reshape_remove_pass.cc b/src/ge/graph/passes/reshape_remove_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/reshape_remove_pass.cc rename to src/ge/graph/passes/reshape_remove_pass.cc diff --git a/ge/graph/passes/reshape_remove_pass.h b/src/ge/graph/passes/reshape_remove_pass.h similarity index 100% rename from ge/graph/passes/reshape_remove_pass.h rename to src/ge/graph/passes/reshape_remove_pass.h diff --git a/ge/graph/passes/resource_pair_add_control_pass.cc b/src/ge/graph/passes/resource_pair_add_control_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/resource_pair_add_control_pass.cc rename to src/ge/graph/passes/resource_pair_add_control_pass.cc diff --git a/ge/graph/passes/resource_pair_add_control_pass.h b/src/ge/graph/passes/resource_pair_add_control_pass.h similarity index 100% rename from ge/graph/passes/resource_pair_add_control_pass.h rename to src/ge/graph/passes/resource_pair_add_control_pass.h diff --git a/ge/graph/passes/resource_pair_remove_control_pass.cc b/src/ge/graph/passes/resource_pair_remove_control_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/resource_pair_remove_control_pass.cc rename to src/ge/graph/passes/resource_pair_remove_control_pass.cc diff --git a/ge/graph/passes/resource_pair_remove_control_pass.h b/src/ge/graph/passes/resource_pair_remove_control_pass.h similarity index 100% rename from ge/graph/passes/resource_pair_remove_control_pass.h rename to src/ge/graph/passes/resource_pair_remove_control_pass.h diff --git a/ge/graph/passes/same_transdata_breadth_fusion_pass.cc b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc similarity index 100% rename from ge/graph/passes/same_transdata_breadth_fusion_pass.cc rename to src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc diff --git a/ge/graph/passes/same_transdata_breadth_fusion_pass.h b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/same_transdata_breadth_fusion_pass.h rename to src/ge/graph/passes/same_transdata_breadth_fusion_pass.h diff --git a/ge/graph/passes/save_pass.cc b/src/ge/graph/passes/save_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/save_pass.cc rename to src/ge/graph/passes/save_pass.cc diff --git a/ge/graph/passes/save_pass.h b/src/ge/graph/passes/save_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/save_pass.h rename to src/ge/graph/passes/save_pass.h diff --git a/ge/graph/passes/set_input_output_offset_pass.cc b/src/ge/graph/passes/set_input_output_offset_pass.cc similarity index 100% rename from ge/graph/passes/set_input_output_offset_pass.cc rename to src/ge/graph/passes/set_input_output_offset_pass.cc diff --git a/ge/graph/passes/set_input_output_offset_pass.h b/src/ge/graph/passes/set_input_output_offset_pass.h similarity index 100% rename from ge/graph/passes/set_input_output_offset_pass.h rename to src/ge/graph/passes/set_input_output_offset_pass.h diff --git a/ge/graph/passes/shape_operate_op_remove_pass.cc b/src/ge/graph/passes/shape_operate_op_remove_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/shape_operate_op_remove_pass.cc rename to src/ge/graph/passes/shape_operate_op_remove_pass.cc diff --git a/ge/graph/passes/shape_operate_op_remove_pass.h b/src/ge/graph/passes/shape_operate_op_remove_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/shape_operate_op_remove_pass.h rename to src/ge/graph/passes/shape_operate_op_remove_pass.h diff --git a/ge/graph/passes/snapshot_pass.cc b/src/ge/graph/passes/snapshot_pass.cc similarity index 100% rename from ge/graph/passes/snapshot_pass.cc rename to src/ge/graph/passes/snapshot_pass.cc diff --git a/ge/graph/passes/snapshot_pass.h b/src/ge/graph/passes/snapshot_pass.h similarity index 100% rename from ge/graph/passes/snapshot_pass.h rename to src/ge/graph/passes/snapshot_pass.h diff --git a/ge/graph/passes/stop_gradient_pass.cc b/src/ge/graph/passes/stop_gradient_pass.cc similarity index 100% rename from ge/graph/passes/stop_gradient_pass.cc rename to src/ge/graph/passes/stop_gradient_pass.cc diff --git a/ge/graph/passes/stop_gradient_pass.h b/src/ge/graph/passes/stop_gradient_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/stop_gradient_pass.h rename to src/ge/graph/passes/stop_gradient_pass.h diff --git a/ge/graph/passes/subexpression_migration_pass.cc b/src/ge/graph/passes/subexpression_migration_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/subexpression_migration_pass.cc rename to src/ge/graph/passes/subexpression_migration_pass.cc diff --git a/ge/graph/passes/subexpression_migration_pass.h b/src/ge/graph/passes/subexpression_migration_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/subexpression_migration_pass.h rename to src/ge/graph/passes/subexpression_migration_pass.h diff --git a/ge/graph/passes/subgraph_pass.cc b/src/ge/graph/passes/subgraph_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/subgraph_pass.cc rename to src/ge/graph/passes/subgraph_pass.cc diff --git a/ge/graph/passes/subgraph_pass.h b/src/ge/graph/passes/subgraph_pass.h similarity index 100% rename from ge/graph/passes/subgraph_pass.h rename to src/ge/graph/passes/subgraph_pass.h diff --git a/ge/graph/passes/switch_data_edges_bypass.cc b/src/ge/graph/passes/switch_data_edges_bypass.cc similarity index 100% rename from ge/graph/passes/switch_data_edges_bypass.cc rename to src/ge/graph/passes/switch_data_edges_bypass.cc diff --git a/ge/graph/passes/switch_data_edges_bypass.h b/src/ge/graph/passes/switch_data_edges_bypass.h similarity index 100% rename from ge/graph/passes/switch_data_edges_bypass.h rename to src/ge/graph/passes/switch_data_edges_bypass.h diff --git a/ge/graph/passes/switch_dead_branch_elimination.cc b/src/ge/graph/passes/switch_dead_branch_elimination.cc similarity index 100% rename from ge/graph/passes/switch_dead_branch_elimination.cc rename to src/ge/graph/passes/switch_dead_branch_elimination.cc diff --git a/ge/graph/passes/switch_dead_branch_elimination.h b/src/ge/graph/passes/switch_dead_branch_elimination.h similarity index 100% rename from ge/graph/passes/switch_dead_branch_elimination.h rename to src/ge/graph/passes/switch_dead_branch_elimination.h diff --git a/ge/graph/passes/switch_logic_remove_pass.cc b/src/ge/graph/passes/switch_logic_remove_pass.cc similarity index 100% rename from ge/graph/passes/switch_logic_remove_pass.cc rename to src/ge/graph/passes/switch_logic_remove_pass.cc diff --git a/ge/graph/passes/switch_logic_remove_pass.h b/src/ge/graph/passes/switch_logic_remove_pass.h similarity index 100% rename from ge/graph/passes/switch_logic_remove_pass.h rename to src/ge/graph/passes/switch_logic_remove_pass.h diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/src/ge/graph/passes/switch_to_stream_switch_pass.cc similarity index 100% rename from ge/graph/passes/switch_to_stream_switch_pass.cc rename to src/ge/graph/passes/switch_to_stream_switch_pass.cc diff --git a/ge/graph/passes/switch_to_stream_switch_pass.h b/src/ge/graph/passes/switch_to_stream_switch_pass.h similarity index 100% rename from ge/graph/passes/switch_to_stream_switch_pass.h rename to src/ge/graph/passes/switch_to_stream_switch_pass.h diff --git a/ge/graph/passes/transop_breadth_fusion_pass.cc b/src/ge/graph/passes/transop_breadth_fusion_pass.cc similarity index 100% rename from ge/graph/passes/transop_breadth_fusion_pass.cc rename to src/ge/graph/passes/transop_breadth_fusion_pass.cc diff --git a/ge/graph/passes/transop_breadth_fusion_pass.h b/src/ge/graph/passes/transop_breadth_fusion_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/transop_breadth_fusion_pass.h rename to src/ge/graph/passes/transop_breadth_fusion_pass.h diff --git a/ge/graph/passes/transop_depth_fusion_pass.cc b/src/ge/graph/passes/transop_depth_fusion_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/transop_depth_fusion_pass.cc rename to src/ge/graph/passes/transop_depth_fusion_pass.cc diff --git a/ge/graph/passes/transop_depth_fusion_pass.h b/src/ge/graph/passes/transop_depth_fusion_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/transop_depth_fusion_pass.h rename to src/ge/graph/passes/transop_depth_fusion_pass.h diff --git a/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc b/src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc similarity index 100% rename from ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc rename to src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc diff --git a/ge/graph/passes/transop_nearby_allreduce_fusion_pass.h b/src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/transop_nearby_allreduce_fusion_pass.h rename to src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.h diff --git a/ge/graph/passes/transop_symmetry_elimination_pass.cc b/src/ge/graph/passes/transop_symmetry_elimination_pass.cc similarity index 100% rename from ge/graph/passes/transop_symmetry_elimination_pass.cc rename to src/ge/graph/passes/transop_symmetry_elimination_pass.cc diff --git a/ge/graph/passes/transop_symmetry_elimination_pass.h b/src/ge/graph/passes/transop_symmetry_elimination_pass.h similarity index 100% rename from ge/graph/passes/transop_symmetry_elimination_pass.h rename to src/ge/graph/passes/transop_symmetry_elimination_pass.h diff --git a/ge/graph/passes/transop_without_reshape_fusion_pass.cc b/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc similarity index 100% rename from ge/graph/passes/transop_without_reshape_fusion_pass.cc rename to src/ge/graph/passes/transop_without_reshape_fusion_pass.cc diff --git a/ge/graph/passes/transop_without_reshape_fusion_pass.h b/src/ge/graph/passes/transop_without_reshape_fusion_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/transop_without_reshape_fusion_pass.h rename to src/ge/graph/passes/transop_without_reshape_fusion_pass.h diff --git a/ge/graph/passes/transpose_transdata_pass.cc b/src/ge/graph/passes/transpose_transdata_pass.cc similarity index 100% rename from ge/graph/passes/transpose_transdata_pass.cc rename to src/ge/graph/passes/transpose_transdata_pass.cc diff --git a/ge/graph/passes/transpose_transdata_pass.h b/src/ge/graph/passes/transpose_transdata_pass.h similarity index 100% rename from ge/graph/passes/transpose_transdata_pass.h rename to src/ge/graph/passes/transpose_transdata_pass.h diff --git a/ge/graph/passes/unused_args_clean_pass.cc b/src/ge/graph/passes/unused_args_clean_pass.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/unused_args_clean_pass.cc rename to src/ge/graph/passes/unused_args_clean_pass.cc diff --git a/ge/graph/passes/unused_args_clean_pass.h b/src/ge/graph/passes/unused_args_clean_pass.h similarity index 100% rename from ge/graph/passes/unused_args_clean_pass.h rename to src/ge/graph/passes/unused_args_clean_pass.h diff --git a/ge/graph/passes/unused_const_pass.cc b/src/ge/graph/passes/unused_const_pass.cc similarity index 100% rename from ge/graph/passes/unused_const_pass.cc rename to src/ge/graph/passes/unused_const_pass.cc diff --git a/ge/graph/passes/unused_const_pass.h b/src/ge/graph/passes/unused_const_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/unused_const_pass.h rename to src/ge/graph/passes/unused_const_pass.h diff --git a/ge/graph/passes/unused_op_remove_pass.cc b/src/ge/graph/passes/unused_op_remove_pass.cc similarity index 100% rename from ge/graph/passes/unused_op_remove_pass.cc rename to src/ge/graph/passes/unused_op_remove_pass.cc diff --git a/ge/graph/passes/unused_op_remove_pass.h b/src/ge/graph/passes/unused_op_remove_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/unused_op_remove_pass.h rename to src/ge/graph/passes/unused_op_remove_pass.h diff --git a/ge/graph/passes/var_is_initialized_op_pass.cc b/src/ge/graph/passes/var_is_initialized_op_pass.cc similarity index 100% rename from ge/graph/passes/var_is_initialized_op_pass.cc rename to src/ge/graph/passes/var_is_initialized_op_pass.cc diff --git a/ge/graph/passes/var_is_initialized_op_pass.h b/src/ge/graph/passes/var_is_initialized_op_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/var_is_initialized_op_pass.h rename to src/ge/graph/passes/var_is_initialized_op_pass.h diff --git a/ge/graph/passes/variable_format_pass.cc b/src/ge/graph/passes/variable_format_pass.cc similarity index 100% rename from ge/graph/passes/variable_format_pass.cc rename to src/ge/graph/passes/variable_format_pass.cc diff --git a/ge/graph/passes/variable_format_pass.h b/src/ge/graph/passes/variable_format_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/variable_format_pass.h rename to src/ge/graph/passes/variable_format_pass.h diff --git a/ge/graph/passes/variable_op_pass.cc b/src/ge/graph/passes/variable_op_pass.cc similarity index 100% rename from ge/graph/passes/variable_op_pass.cc rename to src/ge/graph/passes/variable_op_pass.cc diff --git a/ge/graph/passes/variable_op_pass.h b/src/ge/graph/passes/variable_op_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/variable_op_pass.h rename to src/ge/graph/passes/variable_op_pass.h diff --git a/ge/graph/passes/variable_prepare_op_pass.cc b/src/ge/graph/passes/variable_prepare_op_pass.cc similarity index 100% rename from ge/graph/passes/variable_prepare_op_pass.cc rename to src/ge/graph/passes/variable_prepare_op_pass.cc diff --git a/ge/graph/passes/variable_prepare_op_pass.h b/src/ge/graph/passes/variable_prepare_op_pass.h similarity index 100% rename from ge/graph/passes/variable_prepare_op_pass.h rename to src/ge/graph/passes/variable_prepare_op_pass.h diff --git a/ge/graph/passes/variable_ref_delete_op_pass.cc b/src/ge/graph/passes/variable_ref_delete_op_pass.cc similarity index 100% rename from ge/graph/passes/variable_ref_delete_op_pass.cc rename to src/ge/graph/passes/variable_ref_delete_op_pass.cc diff --git a/ge/graph/passes/variable_ref_delete_op_pass.h b/src/ge/graph/passes/variable_ref_delete_op_pass.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/passes/variable_ref_delete_op_pass.h rename to src/ge/graph/passes/variable_ref_delete_op_pass.h diff --git a/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc b/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc similarity index 100% rename from ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc rename to src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc diff --git a/ge/graph/passes/variable_ref_useless_control_out_delete_pass.h b/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.h similarity index 100% rename from ge/graph/passes/variable_ref_useless_control_out_delete_pass.h rename to src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.h diff --git a/ge/graph/preprocess/graph_preprocess.cc b/src/ge/graph/preprocess/graph_preprocess.cc similarity index 100% rename from ge/graph/preprocess/graph_preprocess.cc rename to src/ge/graph/preprocess/graph_preprocess.cc diff --git a/ge/graph/preprocess/graph_preprocess.h b/src/ge/graph/preprocess/graph_preprocess.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/preprocess/graph_preprocess.h rename to src/ge/graph/preprocess/graph_preprocess.h diff --git a/ge/graph/preprocess/insert_op/base_insert_op.h b/src/ge/graph/preprocess/insert_op/base_insert_op.h similarity index 100% rename from ge/graph/preprocess/insert_op/base_insert_op.h rename to src/ge/graph/preprocess/insert_op/base_insert_op.h diff --git a/ge/graph/preprocess/insert_op/ge_aipp_op.cc b/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/preprocess/insert_op/ge_aipp_op.cc rename to src/ge/graph/preprocess/insert_op/ge_aipp_op.cc diff --git a/ge/graph/preprocess/insert_op/ge_aipp_op.h b/src/ge/graph/preprocess/insert_op/ge_aipp_op.h old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/preprocess/insert_op/ge_aipp_op.h rename to src/ge/graph/preprocess/insert_op/ge_aipp_op.h diff --git a/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/graph/preprocess/insert_op/util_insert_aipp_op.cc rename to src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc diff --git a/ge/graph/preprocess/insert_op/util_insert_aipp_op.h b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.h similarity index 100% rename from ge/graph/preprocess/insert_op/util_insert_aipp_op.h rename to src/ge/graph/preprocess/insert_op/util_insert_aipp_op.h diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/src/ge/graph/preprocess/multi_batch_copy_graph.cc similarity index 100% rename from ge/graph/preprocess/multi_batch_copy_graph.cc rename to src/ge/graph/preprocess/multi_batch_copy_graph.cc diff --git a/ge/graph/preprocess/multi_batch_copy_graph.h b/src/ge/graph/preprocess/multi_batch_copy_graph.h similarity index 100% rename from ge/graph/preprocess/multi_batch_copy_graph.h rename to src/ge/graph/preprocess/multi_batch_copy_graph.h diff --git a/ge/graph/preprocess/multi_batch_options.cc b/src/ge/graph/preprocess/multi_batch_options.cc similarity index 100% rename from ge/graph/preprocess/multi_batch_options.cc rename to src/ge/graph/preprocess/multi_batch_options.cc diff --git a/ge/graph/preprocess/multi_batch_options.h b/src/ge/graph/preprocess/multi_batch_options.h similarity index 100% rename from ge/graph/preprocess/multi_batch_options.h rename to src/ge/graph/preprocess/multi_batch_options.h diff --git a/ge/host_cpu_engine/common/constant/constant.h b/src/ge/host_cpu_engine/common/constant/constant.h similarity index 100% rename from ge/host_cpu_engine/common/constant/constant.h rename to src/ge/host_cpu_engine/common/constant/constant.h diff --git a/ge/host_cpu_engine/engine/host_cpu_engine.cc b/src/ge/host_cpu_engine/engine/host_cpu_engine.cc similarity index 100% rename from ge/host_cpu_engine/engine/host_cpu_engine.cc rename to src/ge/host_cpu_engine/engine/host_cpu_engine.cc diff --git a/ge/host_cpu_engine/engine/host_cpu_engine.h b/src/ge/host_cpu_engine/engine/host_cpu_engine.h similarity index 100% rename from ge/host_cpu_engine/engine/host_cpu_engine.h rename to src/ge/host_cpu_engine/engine/host_cpu_engine.h diff --git a/ge/host_cpu_engine/module.mk b/src/ge/host_cpu_engine/module.mk similarity index 100% rename from ge/host_cpu_engine/module.mk rename to src/ge/host_cpu_engine/module.mk diff --git a/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc b/src/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc similarity index 100% rename from ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc rename to src/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc diff --git a/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h b/src/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h similarity index 100% rename from ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h rename to src/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h diff --git a/ge/host_cpu_engine/ops_kernel_store/op/host_op.cc b/src/ge/host_cpu_engine/ops_kernel_store/op/host_op.cc similarity index 100% rename from ge/host_cpu_engine/ops_kernel_store/op/host_op.cc rename to src/ge/host_cpu_engine/ops_kernel_store/op/host_op.cc diff --git a/ge/host_cpu_engine/ops_kernel_store/op/host_op.h b/src/ge/host_cpu_engine/ops_kernel_store/op/host_op.h similarity index 100% rename from ge/host_cpu_engine/ops_kernel_store/op/host_op.h rename to src/ge/host_cpu_engine/ops_kernel_store/op/host_op.h diff --git a/ge/host_cpu_engine/ops_kernel_store/op/op.h b/src/ge/host_cpu_engine/ops_kernel_store/op/op.h similarity index 100% rename from ge/host_cpu_engine/ops_kernel_store/op/op.h rename to src/ge/host_cpu_engine/ops_kernel_store/op/op.h diff --git a/ge/host_cpu_engine/ops_kernel_store/op/op_factory.cc b/src/ge/host_cpu_engine/ops_kernel_store/op/op_factory.cc similarity index 100% rename from ge/host_cpu_engine/ops_kernel_store/op/op_factory.cc rename to src/ge/host_cpu_engine/ops_kernel_store/op/op_factory.cc diff --git a/ge/host_cpu_engine/ops_kernel_store/op/op_factory.h b/src/ge/host_cpu_engine/ops_kernel_store/op/op_factory.h similarity index 100% rename from ge/host_cpu_engine/ops_kernel_store/op/op_factory.h rename to src/ge/host_cpu_engine/ops_kernel_store/op/op_factory.h diff --git a/src/ge/host_cpu_engine/proto/task.proto b/src/ge/host_cpu_engine/proto/task.proto new file mode 120000 index 00000000..36ae4847 --- /dev/null +++ b/src/ge/host_cpu_engine/proto/task.proto @@ -0,0 +1 @@ +../../proto/task.proto \ No newline at end of file diff --git a/ge/host_kernels/add_kernel.cc b/src/ge/host_kernels/add_kernel.cc similarity index 100% rename from ge/host_kernels/add_kernel.cc rename to src/ge/host_kernels/add_kernel.cc diff --git a/ge/host_kernels/add_kernel.h b/src/ge/host_kernels/add_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/add_kernel.h rename to src/ge/host_kernels/add_kernel.h diff --git a/ge/host_kernels/broadcast_args_kernel.cc b/src/ge/host_kernels/broadcast_args_kernel.cc similarity index 100% rename from ge/host_kernels/broadcast_args_kernel.cc rename to src/ge/host_kernels/broadcast_args_kernel.cc diff --git a/ge/host_kernels/broadcast_args_kernel.h b/src/ge/host_kernels/broadcast_args_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/broadcast_args_kernel.h rename to src/ge/host_kernels/broadcast_args_kernel.h diff --git a/ge/host_kernels/broadcast_gradient_args_kernel.cc b/src/ge/host_kernels/broadcast_gradient_args_kernel.cc similarity index 100% rename from ge/host_kernels/broadcast_gradient_args_kernel.cc rename to src/ge/host_kernels/broadcast_gradient_args_kernel.cc diff --git a/ge/host_kernels/broadcast_gradient_args_kernel.h b/src/ge/host_kernels/broadcast_gradient_args_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/broadcast_gradient_args_kernel.h rename to src/ge/host_kernels/broadcast_gradient_args_kernel.h diff --git a/ge/host_kernels/cast_kernel.cc b/src/ge/host_kernels/cast_kernel.cc similarity index 100% rename from ge/host_kernels/cast_kernel.cc rename to src/ge/host_kernels/cast_kernel.cc diff --git a/ge/host_kernels/cast_kernel.h b/src/ge/host_kernels/cast_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/cast_kernel.h rename to src/ge/host_kernels/cast_kernel.h diff --git a/ge/host_kernels/concat_offset_kernel.cc b/src/ge/host_kernels/concat_offset_kernel.cc similarity index 100% rename from ge/host_kernels/concat_offset_kernel.cc rename to src/ge/host_kernels/concat_offset_kernel.cc diff --git a/ge/host_kernels/concat_offset_kernel.h b/src/ge/host_kernels/concat_offset_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/concat_offset_kernel.h rename to src/ge/host_kernels/concat_offset_kernel.h diff --git a/ge/host_kernels/concat_v2_kernel.cc b/src/ge/host_kernels/concat_v2_kernel.cc similarity index 100% rename from ge/host_kernels/concat_v2_kernel.cc rename to src/ge/host_kernels/concat_v2_kernel.cc diff --git a/ge/host_kernels/concat_v2_kernel.h b/src/ge/host_kernels/concat_v2_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/concat_v2_kernel.h rename to src/ge/host_kernels/concat_v2_kernel.h diff --git a/ge/host_kernels/dynamic_stitch_kernel.cc b/src/ge/host_kernels/dynamic_stitch_kernel.cc similarity index 100% rename from ge/host_kernels/dynamic_stitch_kernel.cc rename to src/ge/host_kernels/dynamic_stitch_kernel.cc diff --git a/ge/host_kernels/dynamic_stitch_kernel.h b/src/ge/host_kernels/dynamic_stitch_kernel.h similarity index 100% rename from ge/host_kernels/dynamic_stitch_kernel.h rename to src/ge/host_kernels/dynamic_stitch_kernel.h diff --git a/ge/host_kernels/empty_kernel.cc b/src/ge/host_kernels/empty_kernel.cc similarity index 100% rename from ge/host_kernels/empty_kernel.cc rename to src/ge/host_kernels/empty_kernel.cc diff --git a/ge/host_kernels/empty_kernel.h b/src/ge/host_kernels/empty_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/empty_kernel.h rename to src/ge/host_kernels/empty_kernel.h diff --git a/ge/host_kernels/expanddims_kernel.cc b/src/ge/host_kernels/expanddims_kernel.cc similarity index 100% rename from ge/host_kernels/expanddims_kernel.cc rename to src/ge/host_kernels/expanddims_kernel.cc diff --git a/ge/host_kernels/expanddims_kernel.h b/src/ge/host_kernels/expanddims_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/expanddims_kernel.h rename to src/ge/host_kernels/expanddims_kernel.h diff --git a/ge/host_kernels/fill_kernel.cc b/src/ge/host_kernels/fill_kernel.cc similarity index 100% rename from ge/host_kernels/fill_kernel.cc rename to src/ge/host_kernels/fill_kernel.cc diff --git a/ge/host_kernels/fill_kernel.h b/src/ge/host_kernels/fill_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/fill_kernel.h rename to src/ge/host_kernels/fill_kernel.h diff --git a/ge/host_kernels/floordiv_kernel.cc b/src/ge/host_kernels/floordiv_kernel.cc similarity index 100% rename from ge/host_kernels/floordiv_kernel.cc rename to src/ge/host_kernels/floordiv_kernel.cc diff --git a/ge/host_kernels/floordiv_kernel.h b/src/ge/host_kernels/floordiv_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/floordiv_kernel.h rename to src/ge/host_kernels/floordiv_kernel.h diff --git a/ge/host_kernels/floormod_kernel.cc b/src/ge/host_kernels/floormod_kernel.cc similarity index 100% rename from ge/host_kernels/floormod_kernel.cc rename to src/ge/host_kernels/floormod_kernel.cc diff --git a/ge/host_kernels/floormod_kernel.h b/src/ge/host_kernels/floormod_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/floormod_kernel.h rename to src/ge/host_kernels/floormod_kernel.h diff --git a/ge/host_kernels/gather_v2_kernel.cc b/src/ge/host_kernels/gather_v2_kernel.cc similarity index 100% rename from ge/host_kernels/gather_v2_kernel.cc rename to src/ge/host_kernels/gather_v2_kernel.cc diff --git a/ge/host_kernels/gather_v2_kernel.h b/src/ge/host_kernels/gather_v2_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/gather_v2_kernel.h rename to src/ge/host_kernels/gather_v2_kernel.h diff --git a/ge/host_kernels/greater_kernel.cc b/src/ge/host_kernels/greater_kernel.cc similarity index 100% rename from ge/host_kernels/greater_kernel.cc rename to src/ge/host_kernels/greater_kernel.cc diff --git a/ge/host_kernels/greater_kernel.h b/src/ge/host_kernels/greater_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/greater_kernel.h rename to src/ge/host_kernels/greater_kernel.h diff --git a/ge/host_kernels/identity_kernel.cc b/src/ge/host_kernels/identity_kernel.cc similarity index 100% rename from ge/host_kernels/identity_kernel.cc rename to src/ge/host_kernels/identity_kernel.cc diff --git a/ge/host_kernels/identity_kernel.h b/src/ge/host_kernels/identity_kernel.h similarity index 100% rename from ge/host_kernels/identity_kernel.h rename to src/ge/host_kernels/identity_kernel.h diff --git a/ge/host_kernels/kernel_utils.cc b/src/ge/host_kernels/kernel_utils.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/kernel_utils.cc rename to src/ge/host_kernels/kernel_utils.cc diff --git a/ge/host_kernels/kernel_utils.h b/src/ge/host_kernels/kernel_utils.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/kernel_utils.h rename to src/ge/host_kernels/kernel_utils.h diff --git a/ge/host_kernels/maximum_kernel.cc b/src/ge/host_kernels/maximum_kernel.cc similarity index 100% rename from ge/host_kernels/maximum_kernel.cc rename to src/ge/host_kernels/maximum_kernel.cc diff --git a/ge/host_kernels/maximum_kernel.h b/src/ge/host_kernels/maximum_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/maximum_kernel.h rename to src/ge/host_kernels/maximum_kernel.h diff --git a/ge/host_kernels/mul_kernel.cc b/src/ge/host_kernels/mul_kernel.cc similarity index 100% rename from ge/host_kernels/mul_kernel.cc rename to src/ge/host_kernels/mul_kernel.cc diff --git a/ge/host_kernels/mul_kernel.h b/src/ge/host_kernels/mul_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/mul_kernel.h rename to src/ge/host_kernels/mul_kernel.h diff --git a/ge/host_kernels/pack_kernel.cc b/src/ge/host_kernels/pack_kernel.cc similarity index 100% rename from ge/host_kernels/pack_kernel.cc rename to src/ge/host_kernels/pack_kernel.cc diff --git a/ge/host_kernels/pack_kernel.h b/src/ge/host_kernels/pack_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/pack_kernel.h rename to src/ge/host_kernels/pack_kernel.h diff --git a/ge/host_kernels/permute_kernel.cc b/src/ge/host_kernels/permute_kernel.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/permute_kernel.cc rename to src/ge/host_kernels/permute_kernel.cc diff --git a/ge/host_kernels/permute_kernel.h b/src/ge/host_kernels/permute_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/permute_kernel.h rename to src/ge/host_kernels/permute_kernel.h diff --git a/ge/host_kernels/range_kernel.cc b/src/ge/host_kernels/range_kernel.cc similarity index 100% rename from ge/host_kernels/range_kernel.cc rename to src/ge/host_kernels/range_kernel.cc diff --git a/ge/host_kernels/range_kernel.h b/src/ge/host_kernels/range_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/range_kernel.h rename to src/ge/host_kernels/range_kernel.h diff --git a/ge/host_kernels/rank_kernel.cc b/src/ge/host_kernels/rank_kernel.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/rank_kernel.cc rename to src/ge/host_kernels/rank_kernel.cc diff --git a/ge/host_kernels/rank_kernel.h b/src/ge/host_kernels/rank_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/rank_kernel.h rename to src/ge/host_kernels/rank_kernel.h diff --git a/ge/host_kernels/reduce_prod_kernel.cc b/src/ge/host_kernels/reduce_prod_kernel.cc similarity index 100% rename from ge/host_kernels/reduce_prod_kernel.cc rename to src/ge/host_kernels/reduce_prod_kernel.cc diff --git a/ge/host_kernels/reduce_prod_kernel.h b/src/ge/host_kernels/reduce_prod_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/reduce_prod_kernel.h rename to src/ge/host_kernels/reduce_prod_kernel.h diff --git a/ge/host_kernels/reformat_kernel.cc b/src/ge/host_kernels/reformat_kernel.cc similarity index 100% rename from ge/host_kernels/reformat_kernel.cc rename to src/ge/host_kernels/reformat_kernel.cc diff --git a/ge/host_kernels/reformat_kernel.h b/src/ge/host_kernels/reformat_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/reformat_kernel.h rename to src/ge/host_kernels/reformat_kernel.h diff --git a/ge/host_kernels/reshape_kernel.cc b/src/ge/host_kernels/reshape_kernel.cc similarity index 100% rename from ge/host_kernels/reshape_kernel.cc rename to src/ge/host_kernels/reshape_kernel.cc diff --git a/ge/host_kernels/reshape_kernel.h b/src/ge/host_kernels/reshape_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/reshape_kernel.h rename to src/ge/host_kernels/reshape_kernel.h diff --git a/ge/host_kernels/rsqrt_kernel.cc b/src/ge/host_kernels/rsqrt_kernel.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/rsqrt_kernel.cc rename to src/ge/host_kernels/rsqrt_kernel.cc diff --git a/ge/host_kernels/rsqrt_kernel.h b/src/ge/host_kernels/rsqrt_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/rsqrt_kernel.h rename to src/ge/host_kernels/rsqrt_kernel.h diff --git a/ge/host_kernels/shape_kernel.cc b/src/ge/host_kernels/shape_kernel.cc similarity index 100% rename from ge/host_kernels/shape_kernel.cc rename to src/ge/host_kernels/shape_kernel.cc diff --git a/ge/host_kernels/shape_kernel.h b/src/ge/host_kernels/shape_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/shape_kernel.h rename to src/ge/host_kernels/shape_kernel.h diff --git a/ge/host_kernels/shape_n_kernel.cc b/src/ge/host_kernels/shape_n_kernel.cc similarity index 100% rename from ge/host_kernels/shape_n_kernel.cc rename to src/ge/host_kernels/shape_n_kernel.cc diff --git a/ge/host_kernels/shape_n_kernel.h b/src/ge/host_kernels/shape_n_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/shape_n_kernel.h rename to src/ge/host_kernels/shape_n_kernel.h diff --git a/ge/host_kernels/size_kernel.cc b/src/ge/host_kernels/size_kernel.cc similarity index 100% rename from ge/host_kernels/size_kernel.cc rename to src/ge/host_kernels/size_kernel.cc diff --git a/ge/host_kernels/size_kernel.h b/src/ge/host_kernels/size_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/size_kernel.h rename to src/ge/host_kernels/size_kernel.h diff --git a/ge/host_kernels/slice_d_kernel.cc b/src/ge/host_kernels/slice_d_kernel.cc similarity index 100% rename from ge/host_kernels/slice_d_kernel.cc rename to src/ge/host_kernels/slice_d_kernel.cc diff --git a/ge/host_kernels/slice_d_kernel.h b/src/ge/host_kernels/slice_d_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/slice_d_kernel.h rename to src/ge/host_kernels/slice_d_kernel.h diff --git a/ge/host_kernels/slice_kernel.cc b/src/ge/host_kernels/slice_kernel.cc similarity index 100% rename from ge/host_kernels/slice_kernel.cc rename to src/ge/host_kernels/slice_kernel.cc diff --git a/ge/host_kernels/slice_kernel.h b/src/ge/host_kernels/slice_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/slice_kernel.h rename to src/ge/host_kernels/slice_kernel.h diff --git a/ge/host_kernels/squeeze_kernel.cc b/src/ge/host_kernels/squeeze_kernel.cc similarity index 100% rename from ge/host_kernels/squeeze_kernel.cc rename to src/ge/host_kernels/squeeze_kernel.cc diff --git a/ge/host_kernels/squeeze_kernel.h b/src/ge/host_kernels/squeeze_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/squeeze_kernel.h rename to src/ge/host_kernels/squeeze_kernel.h diff --git a/ge/host_kernels/ssd_prior_box_kernel.cc b/src/ge/host_kernels/ssd_prior_box_kernel.cc similarity index 100% rename from ge/host_kernels/ssd_prior_box_kernel.cc rename to src/ge/host_kernels/ssd_prior_box_kernel.cc diff --git a/ge/host_kernels/ssd_prior_box_kernel.h b/src/ge/host_kernels/ssd_prior_box_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/ssd_prior_box_kernel.h rename to src/ge/host_kernels/ssd_prior_box_kernel.h diff --git a/ge/host_kernels/strided_slice_kernel.cc b/src/ge/host_kernels/strided_slice_kernel.cc similarity index 100% rename from ge/host_kernels/strided_slice_kernel.cc rename to src/ge/host_kernels/strided_slice_kernel.cc diff --git a/ge/host_kernels/strided_slice_kernel.h b/src/ge/host_kernels/strided_slice_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/strided_slice_kernel.h rename to src/ge/host_kernels/strided_slice_kernel.h diff --git a/ge/host_kernels/sub_kernel.cc b/src/ge/host_kernels/sub_kernel.cc similarity index 100% rename from ge/host_kernels/sub_kernel.cc rename to src/ge/host_kernels/sub_kernel.cc diff --git a/ge/host_kernels/sub_kernel.h b/src/ge/host_kernels/sub_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/sub_kernel.h rename to src/ge/host_kernels/sub_kernel.h diff --git a/ge/host_kernels/transdata_kernel.cc b/src/ge/host_kernels/transdata_kernel.cc similarity index 100% rename from ge/host_kernels/transdata_kernel.cc rename to src/ge/host_kernels/transdata_kernel.cc diff --git a/ge/host_kernels/transdata_kernel.h b/src/ge/host_kernels/transdata_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/transdata_kernel.h rename to src/ge/host_kernels/transdata_kernel.h diff --git a/ge/host_kernels/transpose_kernel.cc b/src/ge/host_kernels/transpose_kernel.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/transpose_kernel.cc rename to src/ge/host_kernels/transpose_kernel.cc diff --git a/ge/host_kernels/transpose_kernel.h b/src/ge/host_kernels/transpose_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/transpose_kernel.h rename to src/ge/host_kernels/transpose_kernel.h diff --git a/ge/host_kernels/unpack_kernel.cc b/src/ge/host_kernels/unpack_kernel.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/unpack_kernel.cc rename to src/ge/host_kernels/unpack_kernel.cc diff --git a/ge/host_kernels/unpack_kernel.h b/src/ge/host_kernels/unpack_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/host_kernels/unpack_kernel.h rename to src/ge/host_kernels/unpack_kernel.h diff --git a/ge/host_kernels/unsqueeze_kernel.cc b/src/ge/host_kernels/unsqueeze_kernel.cc similarity index 100% rename from ge/host_kernels/unsqueeze_kernel.cc rename to src/ge/host_kernels/unsqueeze_kernel.cc diff --git a/ge/host_kernels/unsqueeze_kernel.h b/src/ge/host_kernels/unsqueeze_kernel.h similarity index 100% rename from ge/host_kernels/unsqueeze_kernel.h rename to src/ge/host_kernels/unsqueeze_kernel.h diff --git a/ge/hybrid/common/npu_memory_allocator.cc b/src/ge/hybrid/common/npu_memory_allocator.cc similarity index 100% rename from ge/hybrid/common/npu_memory_allocator.cc rename to src/ge/hybrid/common/npu_memory_allocator.cc diff --git a/ge/hybrid/common/npu_memory_allocator.h b/src/ge/hybrid/common/npu_memory_allocator.h similarity index 100% rename from ge/hybrid/common/npu_memory_allocator.h rename to src/ge/hybrid/common/npu_memory_allocator.h diff --git a/ge/hybrid/common/tensor_value.cc b/src/ge/hybrid/common/tensor_value.cc similarity index 100% rename from ge/hybrid/common/tensor_value.cc rename to src/ge/hybrid/common/tensor_value.cc diff --git a/ge/hybrid/common/tensor_value.h b/src/ge/hybrid/common/tensor_value.h similarity index 100% rename from ge/hybrid/common/tensor_value.h rename to src/ge/hybrid/common/tensor_value.h diff --git a/ge/hybrid/executor/hybrid_execution_context.cc b/src/ge/hybrid/executor/hybrid_execution_context.cc similarity index 100% rename from ge/hybrid/executor/hybrid_execution_context.cc rename to src/ge/hybrid/executor/hybrid_execution_context.cc diff --git a/ge/hybrid/executor/hybrid_execution_context.h b/src/ge/hybrid/executor/hybrid_execution_context.h similarity index 100% rename from ge/hybrid/executor/hybrid_execution_context.h rename to src/ge/hybrid/executor/hybrid_execution_context.h diff --git a/ge/hybrid/executor/hybrid_model_async_executor.cc b/src/ge/hybrid/executor/hybrid_model_async_executor.cc similarity index 100% rename from ge/hybrid/executor/hybrid_model_async_executor.cc rename to src/ge/hybrid/executor/hybrid_model_async_executor.cc diff --git a/ge/hybrid/executor/hybrid_model_async_executor.h b/src/ge/hybrid/executor/hybrid_model_async_executor.h similarity index 100% rename from ge/hybrid/executor/hybrid_model_async_executor.h rename to src/ge/hybrid/executor/hybrid_model_async_executor.h diff --git a/ge/hybrid/executor/hybrid_model_executor.cc b/src/ge/hybrid/executor/hybrid_model_executor.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/executor/hybrid_model_executor.cc rename to src/ge/hybrid/executor/hybrid_model_executor.cc diff --git a/ge/hybrid/executor/hybrid_model_executor.h b/src/ge/hybrid/executor/hybrid_model_executor.h similarity index 100% rename from ge/hybrid/executor/hybrid_model_executor.h rename to src/ge/hybrid/executor/hybrid_model_executor.h diff --git a/ge/hybrid/executor/hybrid_profiler.cc b/src/ge/hybrid/executor/hybrid_profiler.cc similarity index 100% rename from ge/hybrid/executor/hybrid_profiler.cc rename to src/ge/hybrid/executor/hybrid_profiler.cc diff --git a/ge/hybrid/executor/hybrid_profiler.h b/src/ge/hybrid/executor/hybrid_profiler.h similarity index 100% rename from ge/hybrid/executor/hybrid_profiler.h rename to src/ge/hybrid/executor/hybrid_profiler.h diff --git a/ge/hybrid/executor/node_done_manager.cc b/src/ge/hybrid/executor/node_done_manager.cc similarity index 100% rename from ge/hybrid/executor/node_done_manager.cc rename to src/ge/hybrid/executor/node_done_manager.cc diff --git a/ge/hybrid/executor/node_done_manager.h b/src/ge/hybrid/executor/node_done_manager.h similarity index 100% rename from ge/hybrid/executor/node_done_manager.h rename to src/ge/hybrid/executor/node_done_manager.h diff --git a/ge/hybrid/executor/node_state.cc b/src/ge/hybrid/executor/node_state.cc similarity index 100% rename from ge/hybrid/executor/node_state.cc rename to src/ge/hybrid/executor/node_state.cc diff --git a/ge/hybrid/executor/node_state.h b/src/ge/hybrid/executor/node_state.h similarity index 100% rename from ge/hybrid/executor/node_state.h rename to src/ge/hybrid/executor/node_state.h diff --git a/ge/hybrid/executor/rt_callback_manager.cc b/src/ge/hybrid/executor/rt_callback_manager.cc similarity index 100% rename from ge/hybrid/executor/rt_callback_manager.cc rename to src/ge/hybrid/executor/rt_callback_manager.cc diff --git a/ge/hybrid/executor/rt_callback_manager.h b/src/ge/hybrid/executor/rt_callback_manager.h similarity index 100% rename from ge/hybrid/executor/rt_callback_manager.h rename to src/ge/hybrid/executor/rt_callback_manager.h diff --git a/ge/hybrid/executor/subgraph_context.cc b/src/ge/hybrid/executor/subgraph_context.cc similarity index 100% rename from ge/hybrid/executor/subgraph_context.cc rename to src/ge/hybrid/executor/subgraph_context.cc diff --git a/ge/hybrid/executor/subgraph_context.h b/src/ge/hybrid/executor/subgraph_context.h similarity index 100% rename from ge/hybrid/executor/subgraph_context.h rename to src/ge/hybrid/executor/subgraph_context.h diff --git a/ge/hybrid/executor/subgraph_executor.cc b/src/ge/hybrid/executor/subgraph_executor.cc similarity index 100% rename from ge/hybrid/executor/subgraph_executor.cc rename to src/ge/hybrid/executor/subgraph_executor.cc diff --git a/ge/hybrid/executor/subgraph_executor.h b/src/ge/hybrid/executor/subgraph_executor.h similarity index 100% rename from ge/hybrid/executor/subgraph_executor.h rename to src/ge/hybrid/executor/subgraph_executor.h diff --git a/ge/hybrid/executor/worker/execution_engine.cc b/src/ge/hybrid/executor/worker/execution_engine.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/executor/worker/execution_engine.cc rename to src/ge/hybrid/executor/worker/execution_engine.cc diff --git a/ge/hybrid/executor/worker/execution_engine.h b/src/ge/hybrid/executor/worker/execution_engine.h similarity index 100% rename from ge/hybrid/executor/worker/execution_engine.h rename to src/ge/hybrid/executor/worker/execution_engine.h diff --git a/ge/hybrid/executor/worker/shape_inference_engine.cc b/src/ge/hybrid/executor/worker/shape_inference_engine.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/executor/worker/shape_inference_engine.cc rename to src/ge/hybrid/executor/worker/shape_inference_engine.cc diff --git a/ge/hybrid/executor/worker/shape_inference_engine.h b/src/ge/hybrid/executor/worker/shape_inference_engine.h similarity index 100% rename from ge/hybrid/executor/worker/shape_inference_engine.h rename to src/ge/hybrid/executor/worker/shape_inference_engine.h diff --git a/ge/hybrid/executor/worker/task_compile_engine.cc b/src/ge/hybrid/executor/worker/task_compile_engine.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/executor/worker/task_compile_engine.cc rename to src/ge/hybrid/executor/worker/task_compile_engine.cc diff --git a/ge/hybrid/executor/worker/task_compile_engine.h b/src/ge/hybrid/executor/worker/task_compile_engine.h similarity index 100% rename from ge/hybrid/executor/worker/task_compile_engine.h rename to src/ge/hybrid/executor/worker/task_compile_engine.h diff --git a/ge/hybrid/hybrid_davinci_model.cc b/src/ge/hybrid/hybrid_davinci_model.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/hybrid_davinci_model.cc rename to src/ge/hybrid/hybrid_davinci_model.cc diff --git a/ge/hybrid/hybrid_davinci_model.h b/src/ge/hybrid/hybrid_davinci_model.h similarity index 100% rename from ge/hybrid/hybrid_davinci_model.h rename to src/ge/hybrid/hybrid_davinci_model.h diff --git a/ge/hybrid/hybrid_davinci_model_stub.cc b/src/ge/hybrid/hybrid_davinci_model_stub.cc similarity index 100% rename from ge/hybrid/hybrid_davinci_model_stub.cc rename to src/ge/hybrid/hybrid_davinci_model_stub.cc diff --git a/ge/hybrid/model/graph_item.cc b/src/ge/hybrid/model/graph_item.cc similarity index 100% rename from ge/hybrid/model/graph_item.cc rename to src/ge/hybrid/model/graph_item.cc diff --git a/ge/hybrid/model/graph_item.h b/src/ge/hybrid/model/graph_item.h similarity index 100% rename from ge/hybrid/model/graph_item.h rename to src/ge/hybrid/model/graph_item.h diff --git a/ge/hybrid/model/hybrid_model.cc b/src/ge/hybrid/model/hybrid_model.cc similarity index 100% rename from ge/hybrid/model/hybrid_model.cc rename to src/ge/hybrid/model/hybrid_model.cc diff --git a/ge/hybrid/model/hybrid_model.h b/src/ge/hybrid/model/hybrid_model.h similarity index 100% rename from ge/hybrid/model/hybrid_model.h rename to src/ge/hybrid/model/hybrid_model.h diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/src/ge/hybrid/model/hybrid_model_builder.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/model/hybrid_model_builder.cc rename to src/ge/hybrid/model/hybrid_model_builder.cc diff --git a/ge/hybrid/model/hybrid_model_builder.h b/src/ge/hybrid/model/hybrid_model_builder.h similarity index 100% rename from ge/hybrid/model/hybrid_model_builder.h rename to src/ge/hybrid/model/hybrid_model_builder.h diff --git a/ge/hybrid/model/node_item.cc b/src/ge/hybrid/model/node_item.cc similarity index 100% rename from ge/hybrid/model/node_item.cc rename to src/ge/hybrid/model/node_item.cc diff --git a/ge/hybrid/model/node_item.h b/src/ge/hybrid/model/node_item.h similarity index 100% rename from ge/hybrid/model/node_item.h rename to src/ge/hybrid/model/node_item.h diff --git a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/aicore/aicore_node_executor.cc rename to src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc diff --git a/ge/hybrid/node_executor/aicore/aicore_node_executor.h b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.h old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/aicore/aicore_node_executor.h rename to src/ge/hybrid/node_executor/aicore/aicore_node_executor.h diff --git a/ge/hybrid/node_executor/aicore/aicore_op_task.cc b/src/ge/hybrid/node_executor/aicore/aicore_op_task.cc similarity index 100% rename from ge/hybrid/node_executor/aicore/aicore_op_task.cc rename to src/ge/hybrid/node_executor/aicore/aicore_op_task.cc diff --git a/ge/hybrid/node_executor/aicore/aicore_op_task.h b/src/ge/hybrid/node_executor/aicore/aicore_op_task.h old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/aicore/aicore_op_task.h rename to src/ge/hybrid/node_executor/aicore/aicore_op_task.h diff --git a/ge/hybrid/node_executor/aicore/aicore_task_builder.cc b/src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/aicore/aicore_task_builder.cc rename to src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc diff --git a/ge/hybrid/node_executor/aicore/aicore_task_builder.h b/src/ge/hybrid/node_executor/aicore/aicore_task_builder.h old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/aicore/aicore_task_builder.h rename to src/ge/hybrid/node_executor/aicore/aicore_task_builder.h diff --git a/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc b/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/aicore/aicore_task_compiler.cc rename to src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc diff --git a/ge/hybrid/node_executor/aicore/aicore_task_compiler.h b/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.h old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/aicore/aicore_task_compiler.h rename to src/ge/hybrid/node_executor/aicore/aicore_task_compiler.h diff --git a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc b/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc similarity index 100% rename from ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc rename to src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc diff --git a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h b/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h similarity index 100% rename from ge/hybrid/node_executor/aicpu/aicpu_ext_info.h rename to src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h diff --git a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc rename to src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc diff --git a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h similarity index 100% rename from ge/hybrid/node_executor/aicpu/aicpu_node_executor.h rename to src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h diff --git a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc b/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc rename to src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc diff --git a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h b/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h similarity index 100% rename from ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h rename to src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h diff --git a/ge/hybrid/node_executor/controlop/control_op_executor.cc b/src/ge/hybrid/node_executor/controlop/control_op_executor.cc similarity index 100% rename from ge/hybrid/node_executor/controlop/control_op_executor.cc rename to src/ge/hybrid/node_executor/controlop/control_op_executor.cc diff --git a/ge/hybrid/node_executor/controlop/control_op_executor.h b/src/ge/hybrid/node_executor/controlop/control_op_executor.h similarity index 100% rename from ge/hybrid/node_executor/controlop/control_op_executor.h rename to src/ge/hybrid/node_executor/controlop/control_op_executor.h diff --git a/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc b/src/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc rename to src/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc diff --git a/ge/hybrid/node_executor/ge_local/ge_local_node_executor.h b/src/ge/hybrid/node_executor/ge_local/ge_local_node_executor.h similarity index 100% rename from ge/hybrid/node_executor/ge_local/ge_local_node_executor.h rename to src/ge/hybrid/node_executor/ge_local/ge_local_node_executor.h diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc b/src/ge/hybrid/node_executor/hccl/hccl_node_executor.cc similarity index 100% rename from ge/hybrid/node_executor/hccl/hccl_node_executor.cc rename to src/ge/hybrid/node_executor/hccl/hccl_node_executor.cc diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.h b/src/ge/hybrid/node_executor/hccl/hccl_node_executor.h similarity index 100% rename from ge/hybrid/node_executor/hccl/hccl_node_executor.h rename to src/ge/hybrid/node_executor/hccl/hccl_node_executor.h diff --git a/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc b/src/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc rename to src/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc diff --git a/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.h b/src/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.h similarity index 100% rename from ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.h rename to src/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.h diff --git a/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc b/src/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc similarity index 100% rename from ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc rename to src/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc diff --git a/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.h b/src/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.h similarity index 100% rename from ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.h rename to src/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.h diff --git a/ge/hybrid/node_executor/host_cpu/kernel/kernel.h b/src/ge/hybrid/node_executor/host_cpu/kernel/kernel.h similarity index 100% rename from ge/hybrid/node_executor/host_cpu/kernel/kernel.h rename to src/ge/hybrid/node_executor/host_cpu/kernel/kernel.h diff --git a/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc b/src/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc similarity index 100% rename from ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc rename to src/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc diff --git a/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.h b/src/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.h similarity index 100% rename from ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.h rename to src/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.h diff --git a/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc b/src/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc rename to src/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc diff --git a/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h b/src/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h rename to src/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h diff --git a/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc b/src/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc similarity index 100% rename from ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc rename to src/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc diff --git a/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.h b/src/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.h similarity index 100% rename from ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.h rename to src/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.h diff --git a/ge/hybrid/node_executor/host_cpu/kernel_factory.cc b/src/ge/hybrid/node_executor/host_cpu/kernel_factory.cc similarity index 100% rename from ge/hybrid/node_executor/host_cpu/kernel_factory.cc rename to src/ge/hybrid/node_executor/host_cpu/kernel_factory.cc diff --git a/ge/hybrid/node_executor/host_cpu/kernel_factory.h b/src/ge/hybrid/node_executor/host_cpu/kernel_factory.h similarity index 100% rename from ge/hybrid/node_executor/host_cpu/kernel_factory.h rename to src/ge/hybrid/node_executor/host_cpu/kernel_factory.h diff --git a/ge/hybrid/node_executor/node_executor.cc b/src/ge/hybrid/node_executor/node_executor.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/node_executor.cc rename to src/ge/hybrid/node_executor/node_executor.cc diff --git a/ge/hybrid/node_executor/node_executor.h b/src/ge/hybrid/node_executor/node_executor.h similarity index 100% rename from ge/hybrid/node_executor/node_executor.h rename to src/ge/hybrid/node_executor/node_executor.h diff --git a/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc b/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc rename to src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc diff --git a/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h b/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h similarity index 100% rename from ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h rename to src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h diff --git a/ge/hybrid/node_executor/rts/rts_node_executor.cc b/src/ge/hybrid/node_executor/rts/rts_node_executor.cc similarity index 100% rename from ge/hybrid/node_executor/rts/rts_node_executor.cc rename to src/ge/hybrid/node_executor/rts/rts_node_executor.cc diff --git a/ge/hybrid/node_executor/rts/rts_node_executor.h b/src/ge/hybrid/node_executor/rts/rts_node_executor.h similarity index 100% rename from ge/hybrid/node_executor/rts/rts_node_executor.h rename to src/ge/hybrid/node_executor/rts/rts_node_executor.h diff --git a/ge/hybrid/node_executor/task_context.cc b/src/ge/hybrid/node_executor/task_context.cc similarity index 100% rename from ge/hybrid/node_executor/task_context.cc rename to src/ge/hybrid/node_executor/task_context.cc diff --git a/ge/hybrid/node_executor/task_context.h b/src/ge/hybrid/node_executor/task_context.h similarity index 100% rename from ge/hybrid/node_executor/task_context.h rename to src/ge/hybrid/node_executor/task_context.h diff --git a/ge/inc/graph_pass.h b/src/ge/inc/graph_pass.h similarity index 100% rename from ge/inc/graph_pass.h rename to src/ge/inc/graph_pass.h diff --git a/ge/inc/kernel.h b/src/ge/inc/kernel.h similarity index 100% rename from ge/inc/kernel.h rename to src/ge/inc/kernel.h diff --git a/ge/inc/kernel_factory.h b/src/ge/inc/kernel_factory.h similarity index 100% rename from ge/inc/kernel_factory.h rename to src/ge/inc/kernel_factory.h diff --git a/ge/inc/pass.h b/src/ge/inc/pass.h similarity index 100% rename from ge/inc/pass.h rename to src/ge/inc/pass.h diff --git a/ge/inc/pass_manager.h b/src/ge/inc/pass_manager.h similarity index 100% rename from ge/inc/pass_manager.h rename to src/ge/inc/pass_manager.h diff --git a/ge/init/gelib.cc b/src/ge/init/gelib.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/init/gelib.cc rename to src/ge/init/gelib.cc diff --git a/ge/init/gelib.h b/src/ge/init/gelib.h similarity index 100% rename from ge/init/gelib.h rename to src/ge/init/gelib.h diff --git a/ge/ir_build/atc_ir_common.cc b/src/ge/ir_build/atc_ir_common.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/ir_build/atc_ir_common.cc rename to src/ge/ir_build/atc_ir_common.cc diff --git a/ge/ir_build/atc_ir_common.h b/src/ge/ir_build/atc_ir_common.h similarity index 100% rename from ge/ir_build/atc_ir_common.h rename to src/ge/ir_build/atc_ir_common.h diff --git a/ge/ir_build/ge_ir_build.cc b/src/ge/ir_build/ge_ir_build.cc similarity index 100% rename from ge/ir_build/ge_ir_build.cc rename to src/ge/ir_build/ge_ir_build.cc diff --git a/ge/model/ge_model.cc b/src/ge/model/ge_model.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/model/ge_model.cc rename to src/ge/model/ge_model.cc diff --git a/ge/model/ge_model.h b/src/ge/model/ge_model.h old mode 100755 new mode 100644 similarity index 100% rename from ge/model/ge_model.h rename to src/ge/model/ge_model.h diff --git a/ge/model/ge_root_model.cc b/src/ge/model/ge_root_model.cc similarity index 100% rename from ge/model/ge_root_model.cc rename to src/ge/model/ge_root_model.cc diff --git a/ge/model/ge_root_model.h b/src/ge/model/ge_root_model.h old mode 100755 new mode 100644 similarity index 100% rename from ge/model/ge_root_model.h rename to src/ge/model/ge_root_model.h diff --git a/ge/module.mk b/src/ge/module.mk old mode 100755 new mode 100644 similarity index 100% rename from ge/module.mk rename to src/ge/module.mk diff --git a/ge/omm/csa_interact.cc b/src/ge/omm/csa_interact.cc similarity index 100% rename from ge/omm/csa_interact.cc rename to src/ge/omm/csa_interact.cc diff --git a/ge/omm/csa_interact.h b/src/ge/omm/csa_interact.h similarity index 100% rename from ge/omm/csa_interact.h rename to src/ge/omm/csa_interact.h diff --git a/ge/opskernel_manager/ops_kernel_manager.cc b/src/ge/opskernel_manager/ops_kernel_manager.cc similarity index 100% rename from ge/opskernel_manager/ops_kernel_manager.cc rename to src/ge/opskernel_manager/ops_kernel_manager.cc diff --git a/ge/opskernel_manager/ops_kernel_manager.h b/src/ge/opskernel_manager/ops_kernel_manager.h similarity index 100% rename from ge/opskernel_manager/ops_kernel_manager.h rename to src/ge/opskernel_manager/ops_kernel_manager.h diff --git a/ge/opskernel_manager/optimizer_priority.pbtxt b/src/ge/opskernel_manager/optimizer_priority.pbtxt similarity index 100% rename from ge/opskernel_manager/optimizer_priority.pbtxt rename to src/ge/opskernel_manager/optimizer_priority.pbtxt diff --git a/src/ge/plugin/engine/CMakeLists.txt b/src/ge/plugin/engine/CMakeLists.txt new file mode 100644 index 00000000..a3f14ee2 --- /dev/null +++ b/src/ge/plugin/engine/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# libengine.so +file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "*.cc" + ) + +# include directories +include_directories(${CMAKE_CURRENT_LIST_DIR}) +include_directories(${GE_SOURCE_DIR}) +include_directories(${GE_SOURCE_DIR}/src) +include_directories(${GE_SOURCE_DIR}/src/ge) +include_directories(${GE_SOURCE_DIR}/inc) +include_directories(${GE_SOURCE_DIR}/inc/framework) +include_directories(${GE_SOURCE_DIR}/inc/framework/common) +include_directories(${GE_SOURCE_DIR}/inc/external) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) +include_directories(${CMAKE_BINARY_DIR}) +include_directories(${CMAKE_BINARY_DIR}/proto/ge) +include_directories(${GE_SOURCE_DIR}/build) + +######### libengine.so ############# +add_library(engine SHARED ${SRC_LIST}) +target_compile_definitions(engine PRIVATE + REUSE_MEMORY=1 + PLATFORM_CLOUD + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + Werror) +target_link_libraries(engine + ${slog} + rt + dl) diff --git a/ge/plugin/engine/dnnengines.cc b/src/ge/plugin/engine/dnnengines.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/plugin/engine/dnnengines.cc rename to src/ge/plugin/engine/dnnengines.cc diff --git a/ge/plugin/engine/dnnengines.h b/src/ge/plugin/engine/dnnengines.h similarity index 100% rename from ge/plugin/engine/dnnengines.h rename to src/ge/plugin/engine/dnnengines.h diff --git a/ge/plugin/engine/engine_manage.cc b/src/ge/plugin/engine/engine_manage.cc similarity index 100% rename from ge/plugin/engine/engine_manage.cc rename to src/ge/plugin/engine/engine_manage.cc diff --git a/ge/plugin/engine/engine_manage.h b/src/ge/plugin/engine/engine_manage.h similarity index 100% rename from ge/plugin/engine/engine_manage.h rename to src/ge/plugin/engine/engine_manage.h diff --git a/ge/plugin/engine/module.mk b/src/ge/plugin/engine/module.mk old mode 100755 new mode 100644 similarity index 100% rename from ge/plugin/engine/module.mk rename to src/ge/plugin/engine/module.mk diff --git a/ge/session/inner_session.cc b/src/ge/session/inner_session.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/session/inner_session.cc rename to src/ge/session/inner_session.cc diff --git a/ge/session/inner_session.h b/src/ge/session/inner_session.h similarity index 100% rename from ge/session/inner_session.h rename to src/ge/session/inner_session.h diff --git a/ge/session/omg.cc b/src/ge/session/omg.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/session/omg.cc rename to src/ge/session/omg.cc diff --git a/ge/session/session_manager.cc b/src/ge/session/session_manager.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/session/session_manager.cc rename to src/ge/session/session_manager.cc diff --git a/ge/session/session_manager.h b/src/ge/session/session_manager.h similarity index 100% rename from ge/session/session_manager.h rename to src/ge/session/session_manager.h diff --git a/ge/single_op/single_op.cc b/src/ge/single_op/single_op.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/single_op/single_op.cc rename to src/ge/single_op/single_op.cc diff --git a/ge/single_op/single_op.h b/src/ge/single_op/single_op.h old mode 100755 new mode 100644 similarity index 100% rename from ge/single_op/single_op.h rename to src/ge/single_op/single_op.h diff --git a/ge/single_op/single_op_manager.cc b/src/ge/single_op/single_op_manager.cc similarity index 100% rename from ge/single_op/single_op_manager.cc rename to src/ge/single_op/single_op_manager.cc diff --git a/ge/single_op/single_op_manager.h b/src/ge/single_op/single_op_manager.h similarity index 100% rename from ge/single_op/single_op_manager.h rename to src/ge/single_op/single_op_manager.h diff --git a/ge/single_op/single_op_model.cc b/src/ge/single_op/single_op_model.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/single_op/single_op_model.cc rename to src/ge/single_op/single_op_model.cc diff --git a/ge/single_op/single_op_model.h b/src/ge/single_op/single_op_model.h old mode 100755 new mode 100644 similarity index 100% rename from ge/single_op/single_op_model.h rename to src/ge/single_op/single_op_model.h diff --git a/ge/single_op/stream_resource.cc b/src/ge/single_op/stream_resource.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/single_op/stream_resource.cc rename to src/ge/single_op/stream_resource.cc diff --git a/ge/single_op/stream_resource.h b/src/ge/single_op/stream_resource.h old mode 100755 new mode 100644 similarity index 100% rename from ge/single_op/stream_resource.h rename to src/ge/single_op/stream_resource.h diff --git a/ge/single_op/task/aicpu_kernel_task_builder.cc b/src/ge/single_op/task/aicpu_kernel_task_builder.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/single_op/task/aicpu_kernel_task_builder.cc rename to src/ge/single_op/task/aicpu_kernel_task_builder.cc diff --git a/ge/single_op/task/aicpu_kernel_task_builder.h b/src/ge/single_op/task/aicpu_kernel_task_builder.h old mode 100755 new mode 100644 similarity index 100% rename from ge/single_op/task/aicpu_kernel_task_builder.h rename to src/ge/single_op/task/aicpu_kernel_task_builder.h diff --git a/ge/single_op/task/aicpu_task_builder.cc b/src/ge/single_op/task/aicpu_task_builder.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/single_op/task/aicpu_task_builder.cc rename to src/ge/single_op/task/aicpu_task_builder.cc diff --git a/ge/single_op/task/aicpu_task_builder.h b/src/ge/single_op/task/aicpu_task_builder.h old mode 100755 new mode 100644 similarity index 100% rename from ge/single_op/task/aicpu_task_builder.h rename to src/ge/single_op/task/aicpu_task_builder.h diff --git a/ge/single_op/task/build_task_utils.cc b/src/ge/single_op/task/build_task_utils.cc similarity index 100% rename from ge/single_op/task/build_task_utils.cc rename to src/ge/single_op/task/build_task_utils.cc diff --git a/ge/single_op/task/build_task_utils.h b/src/ge/single_op/task/build_task_utils.h similarity index 100% rename from ge/single_op/task/build_task_utils.h rename to src/ge/single_op/task/build_task_utils.h diff --git a/ge/single_op/task/op_task.cc b/src/ge/single_op/task/op_task.cc old mode 100755 new mode 100644 similarity index 100% rename from ge/single_op/task/op_task.cc rename to src/ge/single_op/task/op_task.cc diff --git a/ge/single_op/task/op_task.h b/src/ge/single_op/task/op_task.h similarity index 100% rename from ge/single_op/task/op_task.h rename to src/ge/single_op/task/op_task.h diff --git a/ge/single_op/task/tbe_task_builder.cc b/src/ge/single_op/task/tbe_task_builder.cc similarity index 100% rename from ge/single_op/task/tbe_task_builder.cc rename to src/ge/single_op/task/tbe_task_builder.cc diff --git a/ge/single_op/task/tbe_task_builder.h b/src/ge/single_op/task/tbe_task_builder.h old mode 100755 new mode 100644 similarity index 100% rename from ge/single_op/task/tbe_task_builder.h rename to src/ge/single_op/task/tbe_task_builder.h diff --git a/ge/stub/Makefile b/src/ge/stub/Makefile similarity index 100% rename from ge/stub/Makefile rename to src/ge/stub/Makefile diff --git a/ge/stub/README b/src/ge/stub/README similarity index 100% rename from ge/stub/README rename to src/ge/stub/README diff --git a/ge/stub/README.md b/src/ge/stub/README.md similarity index 100% rename from ge/stub/README.md rename to src/ge/stub/README.md diff --git a/ge/stub/gen_stubapi.py b/src/ge/stub/gen_stubapi.py similarity index 100% rename from ge/stub/gen_stubapi.py rename to src/ge/stub/gen_stubapi.py diff --git a/ge/executor/proto/dump_task.proto b/src/proto/dump_task.proto similarity index 100% rename from ge/executor/proto/dump_task.proto rename to src/proto/dump_task.proto diff --git a/ge/proto/fusion_model.proto b/src/proto/fusion_model.proto old mode 100755 new mode 100644 similarity index 100% rename from ge/proto/fusion_model.proto rename to src/proto/fusion_model.proto diff --git a/ge/proto/fwk_adapter.proto b/src/proto/fwk_adapter.proto similarity index 100% rename from ge/proto/fwk_adapter.proto rename to src/proto/fwk_adapter.proto diff --git a/ge/client/proto/ge_api.proto b/src/proto/ge_api.proto similarity index 100% rename from ge/client/proto/ge_api.proto rename to src/proto/ge_api.proto diff --git a/ge/client/proto/ge_ir.proto b/src/proto/ge_ir.proto similarity index 100% rename from ge/client/proto/ge_ir.proto rename to src/proto/ge_ir.proto diff --git a/ge/client/proto/insert_op.proto b/src/proto/insert_op.proto similarity index 100% rename from ge/client/proto/insert_op.proto rename to src/proto/insert_op.proto diff --git a/ge/client/proto/om.proto b/src/proto/om.proto old mode 100755 new mode 100644 similarity index 100% rename from ge/client/proto/om.proto rename to src/proto/om.proto diff --git a/ge/common/proto/op_mapping_info.proto b/src/proto/op_mapping_info.proto similarity index 100% rename from ge/common/proto/op_mapping_info.proto rename to src/proto/op_mapping_info.proto diff --git a/ge/proto/optimizer_priority.proto b/src/proto/optimizer_priority.proto similarity index 100% rename from ge/proto/optimizer_priority.proto rename to src/proto/optimizer_priority.proto diff --git a/ge/client/proto/task.proto b/src/proto/task.proto similarity index 100% rename from ge/client/proto/task.proto rename to src/proto/task.proto diff --git a/tests/depends/cce/src/cce_stub.cc b/tests/depends/cce/src/cce_stub.cc index 03df3d0c..6ce332ad 100644 --- a/tests/depends/cce/src/cce_stub.cc +++ b/tests/depends/cce/src/cce_stub.cc @@ -528,6 +528,7 @@ uint32_t Fusion(ComputeGraphPtr model_graph, ComputeGraphPtr fusion_graph, kScop int stream_num = 1; int flag = 0; + // make_graph_nd(graph); NodePtr node_a = fusion_graph->AddNode(op_def_a); NodePtr node_b = fusion_graph->AddNode(op_def_b); diff --git a/tests/st/resnet50/resnet50_train.cc b/tests/st/resnet50/resnet50_train.cc index f1d1e58d..5e082df5 100644 --- a/tests/st/resnet50/resnet50_train.cc +++ b/tests/st/resnet50/resnet50_train.cc @@ -746,6 +746,7 @@ int TestBuildGraphTest(Func fun, Graph &graph, vector &inputs, vecto shapeTensor.SetTensorDesc(shape_desc); vector dataValuec; for (int i = 0; i < sizeshape; i++) { + // dataValuec.push_back((float)(i%255)); dataValuec.push_back(1); } @@ -763,6 +764,7 @@ int TestBuildGraphTest(Func fun, Graph &graph, vector &inputs, vecto } shapeTensor1.SetData((uint8_t *)dataValuec1.data(), 4 * sizeshape1); + // inputs.push_back(shapeTensor1); return 0; } diff --git a/tests/ut/common/graph/testcase/ge_graph/ge_model_unittest.cc b/tests/ut/common/graph/testcase/ge_graph/ge_model_unittest.cc index 496b47b9..07bd90f5 100644 --- a/tests/ut/common/graph/testcase/ge_graph/ge_model_unittest.cc +++ b/tests/ut/common/graph/testcase/ge_graph/ge_model_unittest.cc @@ -69,10 +69,12 @@ TEST_F(UtestGeModelUnittest, save_model_to_file_success) { ge::Graph ge_graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); string file_name = "model_data.pb"; setenv("DUMP_MODEL", "1", true); + // EXPECT_EQ(ge_graph.SaveToFile(file_name), GRAPH_FAILED); setenv("DUMP_MODEL", "0", true); } TEST_F(UtestGeModelUnittest, load_model_from_file_success) { ge::Graph ge_graph; string file_name = "model_data.pb"; + // EXPECT_EQ(ge_graph.LoadFromFile(file_name), GRAPH_SUCCESS); } diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 541df9a7..2e3edfd5 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -190,7 +190,6 @@ file(GLOB_RECURSE DISTINCT_GRAPH_LOAD_SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR} "${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc" "${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc" "${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc" - "${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/model_exit_task_info.cc" "${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" "${GE_SOURCE_DIR}/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" "${GE_SOURCE_DIR}/src/ge/graph/load/output/output.cc" diff --git a/tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc b/tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc index d6b45647..f8deff7f 100644 --- a/tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc +++ b/tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc @@ -182,6 +182,8 @@ TEST_F(UtestModelManagerDavinciModel, contruct_modeldef_createfail) { ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_WINDOW, vector({1, 1})); ge::AttrUtils::SetListInt(op_desc, ge::POOLING_ATTR_STRIDE, vector({1, 1})); + // EXPECT_EQ(ge::SUCCESS, model.Init()); + model.GetEventList(); } @@ -198,6 +200,7 @@ TEST_F(UtestModelManagerDavinciModel, copy_input_data_to_model_fail) { input_data.blobs.push_back(data_buffer); model.op_list_.clear(); + // EXPECT_EQ(ge::PARAM_INVALID, model.CopyInputDataToModel(input_data.blobs, 0)); delete[](char *) data_buffer.data; } @@ -207,6 +210,7 @@ TEST_F(UtestModelManagerDavinciModel, streamnum_success) { DavinciModel *model = new DavinciModel(0, g_label_call_back); OmeTestOpUtils::InitModel(*model); + // EXPECT_EQ(ge::SUCCESS, model->Init()); EXPECT_EQ(0, model->StreamNum()); EXPECT_EQ(ge::INTERNAL_ERROR, model->ModelRunStart()); @@ -222,6 +226,8 @@ TEST_F(UtestModelManagerDavinciModel, eventnum_success) { OmeTestOpUtils::InitModel(*model); + // EXPECT_EQ(ge::SUCCESS, model->Init()); + EXPECT_EQ(0, model->EventNum()); EXPECT_EQ(ge::INTERNAL_ERROR, model->ModelRunStart()); @@ -235,6 +241,8 @@ TEST_F(UtestModelManagerDavinciModel, handlelist_success) { OmeTestOpUtils::InitModel(*model); + // EXPECT_EQ(ge::SUCCESS, model->Init()); + EXPECT_EQ(ge::INTERNAL_ERROR, model->ModelRunStart()); EXPECT_EQ(ge::SUCCESS, model->ModelRunStop()); @@ -248,6 +256,8 @@ TEST_F(UtestModelManagerDavinciModel, eventlist_success) { OmeTestOpUtils::InitModel(*model); + // EXPECT_EQ(ge::SUCCESS, model->Init()); + EXPECT_EQ(true, model->GetEventList().empty()); EXPECT_EQ(ge::INTERNAL_ERROR, model->ModelRunStart()); @@ -272,6 +282,7 @@ TEST_F(UtestModelManagerDavinciModel, failed_reset_device) { TEST_F(UtestModelManagerDavinciModel, init_not_support_priority) { int32_t priority = 8; DavinciModel model(priority, g_label_call_back); + // EXPECT_EQ(ge::PARAM_INVALID, model.Init()); } // test GetInputOutputDescInfo @@ -335,6 +346,7 @@ TEST_F(UtestModelManagerDavinciModel, CopyTensorFromSrcVarNode_success) { NodePtr dst_node = graph->AddNode(op_desc_ptr); DavinciModel model(0, g_label_call_back); Status ret = model.CopyTensorFromSrcVarNode(src_node, dst_node); + // EXPECT_EQ(SUCCESS, ret); } TEST_F(UtestModelManagerDavinciModel, CopyVarData_graph_is_nullptr) { @@ -358,6 +370,7 @@ TEST_F(UtestModelManagerDavinciModel, copy_var_data_success) { DavinciModel model(0, g_label_call_back); Status ret = model.CopyVarData(graph); + // EXPECT_EQ(SUCCESS, ret); } TEST_F(UtestModelManagerDavinciModel, get_input_output_desc_info_without_data_op_list) { @@ -527,6 +540,7 @@ TEST_F(UtestModelManagerDavinciModel, get_flow_ctrl_op_list_success) { std::map flowctrl_op_index_internal_map; flowctrl_op_index_internal_map.insert(pair(1, 1)); model.flowctrl_op_index_internal_map_ = flowctrl_op_index_internal_map; + // EXPECT_EQ(flowctrl_op_index_internal_map_, model.GetFlowctrlOpList()); } // test SetFlowctrlOpList @@ -1190,8 +1204,10 @@ TEST_F(UtestModelManagerDavinciModel, profiling_model_success) { input_data.index = 0; input_data.model_id = 1; input_data.blobs.push_back(data_buffer); + // model.SinkModelProfile(&model); rtFreeHost(data.model_data); + // delete stream; delete[](char *) data_buffer.data; delete model_def; } diff --git a/tests/ut/ge/graph/load/new_model_manager_model_manager_unittest.cc b/tests/ut/ge/graph/load/new_model_manager_model_manager_unittest.cc index 33a59589..b6174793 100644 --- a/tests/ut/ge/graph/load/new_model_manager_model_manager_unittest.cc +++ b/tests/ut/ge/graph/load/new_model_manager_model_manager_unittest.cc @@ -153,6 +153,20 @@ TEST_F(UtestModelManagerModelManager, case_load_model_encypt_not_match) { delete[](uint8_t *) data.model_data; } +#if 0 +TEST_F(UtestModelManagerModelManager, case_load_model_signature_failed) +{ + ModelManager mm; + ge::ModelData data; + GenUnencryptModelData(data); + + uint32_t model_id = 1; + MOCKER(&WBDecryptor::CheckSignature).stubs().will(returnValue(false)); + EXPECT_EQ(ge::PARAM_INVALID, mm.LoadModelOffline(model_id, data, UTEST_CALL_BACK_FUN)); + delete[](uint8_t*)data.model_data; +} +#endif + TEST_F(UtestModelManagerModelManager, case_load_model_encypt_type_unsupported) { ModelManager mm; ge::ModelData data; @@ -164,6 +178,87 @@ TEST_F(UtestModelManagerModelManager, case_load_model_encypt_type_unsupported) { delete[](uint8_t *) data.model_data; } +#if 0 +TEST_F(UtestModelManagerModelManager, case_load_model_header_len_failed) +{ + ModelManager mm; + ge::ModelData data; + GenEncryptModelData(data); + ModelFileHeader *header = (ModelFileHeader*)data.model_data; + data.model_len -= header->length; + header->length = 0; + uint32_t model_id = 1; + EXPECT_EQ(ge::PARAM_INVALID, mm.LoadModelOffline(model_id, data, UTEST_CALL_BACK_FUN)); + delete[](uint8_t*)data.model_data; +} +#endif + +#if 0 +TEST_F(UtestModelManagerModelManager, case_load_success) +{ + const char* model_file = "bin/llt/framework/domi/ut/omg/data/leakyrelu.dav"; + const char* json_file = "test.json"; + const char* key = "bin/llt/framework/domi/ut/omg/data/leakyrelu.dav.PASSCODE"; + + ge::ModelData model; + Status ret = ModelParserBase::LoadFromFile(model_file, key, 0, &model); + EXPECT_EQ(ge::SUCCESS, ret); + + ModelManager mm; + uint32_t model_id = 1; + ret = mm.LoadModelOffline(model_id, model, UTEST_CALL_BACK_FUN); + EXPECT_EQ(ge::SUCCESS, ret); + + if (model.model_data) + delete[](uint8_t*)model.model_data; +} +#endif + +#if 0 +TEST_F(UtestModelManagerModelManager, case_load_encrypt_model_signature_failed) +{ + ModelManager mm; + ge::ModelData data; + GenEncryptModelData(data); + uint32_t model_id = 1; + data.key; + EXPECT_EQ(ge::PARAM_INVALID, mm.LoadModelOffline(model_id, data, UTEST_CALL_BACK_FUN)); + delete[](uint8_t*)data.model_data; +} + +TEST_F(UtestModelManagerModelManager, case_load_encrypt_model_invalid_key_len) +{ + ModelManager mm; + ge::ModelData data; + GenEncryptModelData(data); + data.key = "0123456789abcdef0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0"; + uint32_t model_id = 1; + EXPECT_EQ(ge::PARAM_INVALID, mm.LoadModelOffline(model_id, data, UTEST_CALL_BACK_FUN)); + delete[](uint8_t*)data.model_data; +} + +TEST_F(UtestModelManagerModelManager, case_load_encrypt_model_invalid_key_char) +{ + ModelManager mm; + ge::ModelData data; + GenEncryptModelData(data); + data.key = "0123456789abcdef0123456789ABCDEF0123456789ABCDEF0123456789ABCDEG"; + uint32_t model_id = 1; + EXPECT_EQ(ge::PARAM_INVALID, mm.LoadModelOffline(model_id, data, UTEST_CALL_BACK_FUN)); + delete[](uint8_t*)data.model_data; +} + +TEST_F(UtestModelManagerModelManager, case_load_encrypt_model_load_failed) +{ + ModelManager mm; + ge::ModelData data; + GenEncryptModelData(data); + uint32_t model_id = 1; + EXPECT_EQ(ge::INTERNAL_ERROR, mm.LoadModelOffline(model_id, data, UTEST_CALL_BACK_FUN)); + delete[](uint8_t*)data.model_data; +} +#endif + shared_ptr LabelCallBack(new DModelListener()); // test HandleCommand diff --git a/tests/ut/ge/graph/load/new_op_test_utils.h b/tests/ut/ge/graph/load/new_op_test_utils.h index d492ee98..5e1e2ec1 100644 --- a/tests/ut/ge/graph/load/new_op_test_utils.h +++ b/tests/ut/ge/graph/load/new_op_test_utils.h @@ -76,6 +76,7 @@ class OmeTestOpUtils { return nullptr; } + // return std::make_shared(op_desc, nullptr); auto g = std::make_shared("g"); return g->AddNode(std::move(op_desc)); } @@ -402,6 +403,8 @@ class OmeTestOpDescBuilder { if (SUCCESS != res) { GELOGE(ge::FAILED, "Finish: GraphUtils::AddEdge failed"); } + // ge::NodePtr src_node = node->GetOwnerComputeGraph()->AddNodeFront(src_op_desc); + // node->AddLinkFrom(src_node); } { @@ -431,6 +434,8 @@ class OmeTestOpDescBuilder { vector weights_; int64_t eventId_ = -1; int64_t scopeid_ = -1; + + // std::shared_ptr graph_; }; #endif // OME_REBUILD_OME_OP_TEST_UTILS_H diff --git a/tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc b/tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc index 79e34a60..4e02af70 100644 --- a/tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/dimension_adjust_pass_unittest.cc @@ -122,6 +122,7 @@ TEST_F(UtestGraphPassesDimensionAdjustPass, node_get_original_type_failed) { std::shared_ptr pass = make_shared(); ge::Status ret = pass->Run(op_node); + // EXPECT_EQ(ge::SUCCESS, ret); } TEST_F(UtestGraphPassesDimensionAdjustPass, node_not_register_op) { diff --git a/tests/ut/ge/graph/passes/folding_kernel/strided_slice_kernel_unittest.cc b/tests/ut/ge/graph/passes/folding_kernel/strided_slice_kernel_unittest.cc index e3cb7649..0b16bf97 100644 --- a/tests/ut/ge/graph/passes/folding_kernel/strided_slice_kernel_unittest.cc +++ b/tests/ut/ge/graph/passes/folding_kernel/strided_slice_kernel_unittest.cc @@ -93,6 +93,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test2) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test3) { @@ -122,6 +123,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test3) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test4) { @@ -152,6 +154,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test4) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test5) { @@ -183,6 +186,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test5) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test6) { @@ -215,6 +219,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test6) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test7) { @@ -248,6 +253,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test7) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test8) { @@ -282,6 +288,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test8) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test9) { @@ -315,6 +322,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test9) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test10) { @@ -349,6 +357,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test10) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test11) { @@ -383,6 +392,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test11) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test12) { @@ -417,6 +427,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test12) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test13) { @@ -451,6 +462,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test13) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test14) { @@ -485,6 +497,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test14) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test15) { @@ -519,6 +532,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test15) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test16) { @@ -553,6 +567,7 @@ TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test16) { shared_ptr kernel = KernelFactory::Instance().Create(STRIDEDSLICE); ge::Status status = kernel->Compute(op_desc_ptr, input, outputs); + // EXPECT_EQ(PARAM_INVALID, status); } TEST_F(UtestGraphPassesFoldingKernelStridedSliceKernel, Test17) { diff --git a/tests/ut/ge/graph/passes/guarantee_const_pass_unittest.cc b/tests/ut/ge/graph/passes/guarantee_const_pass_unittest.cc index d5bafbeb..eaad3df7 100644 --- a/tests/ut/ge/graph/passes/guarantee_const_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/guarantee_const_pass_unittest.cc @@ -167,6 +167,7 @@ TEST_F(UtestGraphPassesGuaranteeConstPass, get_origenal_type_fail) { string type2 = "FrameworkOp"; node->GetOpDesc()->SetType(type2); ge::Status ret = guarantee_const_op_remove_pass_->Run(node); + // EXPECT_EQ(ge::SUCCESS, ret); } TEST_F(UtestGraphPassesGuaranteeConstPass, int32_success_6) { diff --git a/tests/ut/ge/graph/passes/identity_pass_unittest.cc b/tests/ut/ge/graph/passes/identity_pass_unittest.cc index b767afb3..eabc3b49 100644 --- a/tests/ut/ge/graph/passes/identity_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/identity_pass_unittest.cc @@ -135,6 +135,7 @@ TEST_F(UtestIdentityPass, succ) { string type2 = "FrameworkOp"; node->GetOpDesc()->SetType(type2); status = pass.Run(node); + // EXPECT_EQ(ge::SUCCESS, status); NodePtr node_err = AddNode(graph, "Identity", IDENTITY, 1, 2); status = pass.Run(node_err); diff --git a/tests/ut/ge/graph/passes/net_output_pass_unittest.cc b/tests/ut/ge/graph/passes/net_output_pass_unittest.cc index 41a5cca8..2655a403 100644 --- a/tests/ut/ge/graph/passes/net_output_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/net_output_pass_unittest.cc @@ -845,6 +845,7 @@ TEST_F(UtestGraphPassesNetOutputPass, out_node_remove_check_fail) { ge::NodePtr mul2 = compute_graph->FindNode("Mul2"); std::vector> output_nodes = {{mul1, 0}, {mul2, 0}}; compute_graph->SetGraphOutNodesInfo(output_nodes); + // compute_graph->RemoveNode(mul1); mul1->GetInDataAnchor(0)->UnlinkAll(); mul1->GetInDataAnchor(1)->UnlinkAll(); GraphUtils::RemoveNodeWithoutRelink(compute_graph, mul1); diff --git a/tests/ut/ge/graph/passes/placeholder_with_default_pass_unittest.cc b/tests/ut/ge/graph/passes/placeholder_with_default_pass_unittest.cc index aa49f6ad..b837bf25 100644 --- a/tests/ut/ge/graph/passes/placeholder_with_default_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/placeholder_with_default_pass_unittest.cc @@ -75,4 +75,5 @@ TEST_F(UtestPlaceholderWithDefaultPass, succ) { string type2 = "FrameworkOp"; node->GetOpDesc()->SetType(type2); pass.Run(node); + // EXPECT_EQ(ge::SUCCESS, status); } diff --git a/tests/ut/ge/graph/passes/prevent_gradient_pass_unittest.cc b/tests/ut/ge/graph/passes/prevent_gradient_pass_unittest.cc index d2d067c2..39a6cb6a 100644 --- a/tests/ut/ge/graph/passes/prevent_gradient_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/prevent_gradient_pass_unittest.cc @@ -75,4 +75,5 @@ TEST_F(UtestPreventGradientPass, succ) { string type2 = "FrameworkOp"; node->GetOpDesc()->SetType(type2); status = pass.Run(node); + // EXPECT_EQ(ge::SUCCESS, status); } diff --git a/tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc b/tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc index 04b2672a..12d35e1f 100644 --- a/tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/reshape_remove_pass_unittest.cc @@ -178,5 +178,6 @@ TEST_F(UtestReshapeRemovePass, reshape_remove_without_const) { EXPECT_EQ(var1->GetOutDataNodes().at(0)->GetName(), "transdata1"); EXPECT_NE(const1, nullptr); EXPECT_EQ(const1->GetOutNodes().size(), 1); + // EXPECT_EQ(const1->GetOutDataNodes().at(0)->GetName(), "transdata2"); } } // namespace ge diff --git a/tests/ut/ge/graph/passes/snapshot_pass_unittest.cc b/tests/ut/ge/graph/passes/snapshot_pass_unittest.cc index f6b811bb..42b2c6ad 100644 --- a/tests/ut/ge/graph/passes/snapshot_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/snapshot_pass_unittest.cc @@ -78,4 +78,5 @@ TEST_F(UtestSnapshotPass, succ) { string type2 = "FrameworkOp"; snapshot->GetOpDesc()->SetType(type2); status = pass.Run(snapshot); + // EXPECT_EQ(ge::SUCCESS, status); } diff --git a/tests/ut/ge/graph/passes/stop_gradient_pass_unittest.cc b/tests/ut/ge/graph/passes/stop_gradient_pass_unittest.cc index edcdd18f..120a8753 100644 --- a/tests/ut/ge/graph/passes/stop_gradient_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/stop_gradient_pass_unittest.cc @@ -176,6 +176,7 @@ TEST_F(UtestGraphPassesStopGradientPass, get_origenal_type_fail) { string type2 = "FrameworkOp"; node->GetOpDesc()->SetType(type2); ge::Status ret = pass_->Run(node); + // EXPECT_EQ(ge::SUCCESS, ret); } TEST_F(UtestGraphPassesStopGradientPass, size_check_fail) { vector dims_vec_0 = {8, 2}; diff --git a/tests/ut/ge/graph/passes/switch_pass_unittest.cc b/tests/ut/ge/graph/passes/switch_pass_unittest.cc index 45f97aa6..0d78fd6d 100644 --- a/tests/ut/ge/graph/passes/switch_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/switch_pass_unittest.cc @@ -256,6 +256,7 @@ TEST_F(UtestGraphPassesSwitchPass, inactive_output_not_exists) { output_true_node_->GetOutDataAnchor(0)->UnlinkAll(); GraphUtils::RemoveNodeWithoutRelink(graph_, output_true_node_); switch_node_->GetOutDataAnchor(1)->UnlinkAll(); + // switch_node_->outDataAnchors_.pop_back(); /// input /// | @@ -393,6 +394,7 @@ TEST_F(UtestGraphPassesSwitchPass, dead_output_connected_to_merge) { /// Merge bool pred_value = true; BuildDefaultGraph(false, &pred_value); + // graph_->RemoveNode(output_false_node_); output_false_node_->GetOutDataAnchor(0)->UnlinkAll(); GraphUtils::RemoveNodeWithoutRelink(graph_, output_false_node_); switch_node_->GetOutDataAnchor(0)->UnlinkAll(); diff --git a/tests/ut/ge/graph/passes/unused_and_isolated_op_remove_pass_unittest.cc b/tests/ut/ge/graph/passes/unused_and_isolated_op_remove_pass_unittest.cc index 13b4e76c..cb174ebd 100644 --- a/tests/ut/ge/graph/passes/unused_and_isolated_op_remove_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/unused_and_isolated_op_remove_pass_unittest.cc @@ -106,6 +106,7 @@ TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv) { Status status = PassManager::Run(graph, passes); EXPECT_EQ(SUCCESS, status); NodePtr found_node0 = graph->FindNode("transpose1"); + // EXPECT_EQ(nullptr, found_node0); NodePtr found_node = graph->FindNode("conv1"); EXPECT_EQ(conv_node, found_node); } diff --git a/tests/ut/ge/graph/passes/variable_op_pass_unittest.cc b/tests/ut/ge/graph/passes/variable_op_pass_unittest.cc index 7bc32a6f..77428549 100644 --- a/tests/ut/ge/graph/passes/variable_op_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/variable_op_pass_unittest.cc @@ -343,6 +343,8 @@ bool BuildComputeGraph0(ge::ComputeGraphPtr &graph) { if (ge::GraphUtils::AddEdge(node_apply_monetum->GetOutDataAnchor(0), node_5d_to_4d_1->GetInDataAnchor(0)) != ge::SUCCESS) { + /// GELOGE(FAILED, "ge::GraphUtils::AddEdge(node_apply_monetum->GetOutDataAnchor(0), + /// node_5d_to_4d_1->GetInDataAnchor(0) ) Failed."); }; ge::GraphUtils::AddEdge(node_5d_to_4d_1->GetOutDataAnchor(0), node_ref->GetInDataAnchor(0)); @@ -393,6 +395,8 @@ bool BuildComputeGraph1(ge::ComputeGraphPtr &graph) { if (ge::GraphUtils::AddEdge(node_apply_monetum->GetOutDataAnchor(0), node_5d_to_4d_1->GetInDataAnchor(0)) != ge::SUCCESS) { + /// GELOGE(FAILED, "ge::GraphUtils::AddEdge(node_apply_monetum->GetOutDataAnchor(0), + /// node_5d_to_4d_1->GetInDataAnchor(0) ) Failed."); }; ge::GraphUtils::AddEdge(node_5d_to_4d_1->GetOutDataAnchor(0), node_ref->GetInDataAnchor(0)); diff --git a/third_party/patch/securec/0001-add-securec-cmake-script.patch b/third_party/patch/securec/0001-add-securec-cmake-script.patch deleted file mode 100644 index 0fcf50c4..00000000 --- a/third_party/patch/securec/0001-add-securec-cmake-script.patch +++ /dev/null @@ -1,105 +0,0 @@ -From 455c9812d70646fe725896d597d6c953bf5a09ac Mon Sep 17 00:00:00 2001 -From: taoxiangdong -Date: Wed, 14 Oct 2020 22:14:01 +0800 -Subject: [PATCH] add securec cmake script - ---- - CMakeLists.txt | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++ - 1 file changed, 86 insertions(+) - create mode 100755 CMakeLists.txt - -diff --git a/CMakeLists.txt b/CMakeLists.txt -new file mode 100755 -index 0000000..9b91fb2 ---- /dev/null -+++ b/CMakeLists.txt -@@ -0,0 +1,86 @@ -+cmake_minimum_required(VERSION 3.14) -+project(Securec) -+file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} -+ "src/vsprintf_s.c" -+ "src/wmemmove_s.c" -+ "src/strncat_s.c" -+ "src/vsnprintf_s.c" -+ "src/fwscanf_s.c" -+ "src/scanf_s.c" -+ "src/strcat_s.c" -+ "src/sscanf_s.c" -+ "src/secureprintoutput_w.c" -+ "src/wmemcpy_s.c" -+ "src/wcsncat_s.c" -+ "src/secureprintoutput_a.c" -+ "src/secureinput_w.c" -+ "src/memcpy_s.c" -+ "src/fscanf_s.c" -+ "src/vswscanf_s.c" -+ "src/secureinput_a.c" -+ "src/sprintf_s.c" -+ "src/memmove_s.c" -+ "src/swscanf_s.c" -+ "src/snprintf_s.c" -+ "src/vscanf_s.c" -+ "src/vswprintf_s.c" -+ "src/wcscpy_s.c" -+ "src/vfwscanf_s.c" -+ "src/memset_s.c" -+ "src/wscanf_s.c" -+ "src/vwscanf_s.c" -+ "src/strtok_s.c" -+ "src/wcsncpy_s.c" -+ "src/vfscanf_s.c" -+ "src/vsscanf_s.c" -+ "src/wcstok_s.c" -+ "src/securecutil.c" -+ "src/gets_s.c" -+ "src/swprintf_s.c" -+ "src/strcpy_s.c" -+ "src/wcscat_s.c" -+ "src/strncpy_s.c" -+ ) -+ -+include_directories(./include) -+include_directories(./src) -+add_library(shared_c_sec SHARED ${SRC_LIST}) -+ -+target_compile_options(shared_c_sec PRIVATE -+ -I/usr/local/include -+ -Werror -+ -Wall -+ -O1 -+) -+target_compile_definitions(shared_c_sec PRIVATE -+ NDEBUG -+ SECUREC_SUPPORT_STRTOLD=1 -+ ) -+ -+add_library(static_c_sec STATIC ${SRC_LIST}) -+ -+target_compile_options(static_c_sec PRIVATE -+ -I/usr/local/include -+ -Werror -+ -Wall -+ -O1 -+) -+ -+target_compile_definitions(static_c_sec PRIVATE -+ NDEBUG -+ SECUREC_SUPPORT_STRTOLD=1 -+ ) -+ -+set_target_properties(static_c_sec -+ PROPERTIES -+ OUTPUT_NAME c_sec -+) -+set_target_properties(shared_c_sec -+ PROPERTIES -+ OUTPUT_NAME c_sec -+) -+install(TARGETS shared_c_sec static_c_sec OPTIONAL -+ DESTINATION lib) -+install(FILES "./include/securec.h" -+ "./include/securectype.h" -+ DESTINATION include) --- -2.17.1 - From 4679646fb5ff97684ad2c1311d303ae0baa54b7f Mon Sep 17 00:00:00 2001 From: lujiale Date: Sat, 31 Oct 2020 18:10:17 +0800 Subject: [PATCH 2/3] update src/common/graph/utils/mem_utils.h. --- src/common/graph/utils/mem_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/graph/utils/mem_utils.h b/src/common/graph/utils/mem_utils.h index 7e8dd9fd..7930ca0b 100644 --- a/src/common/graph/utils/mem_utils.h +++ b/src/common/graph/utils/mem_utils.h @@ -27,6 +27,6 @@ static inline std::shared_ptr<_Tp> MakeShared(_Args &&... __args) { std::shared_ptr<_Tp> ret(new (std::nothrow) _Tp_nc(std::forward<_Args>(__args)...)); return ret; } -} +} // namespace ge #endif // COMMON_GRAPH_UTILS_MEM_UTILS_H_ From 995cfcac259d5a701a983a58d5f2a1900f93e21b Mon Sep 17 00:00:00 2001 From: lujiale Date: Sat, 31 Oct 2020 18:47:06 +0800 Subject: [PATCH 3/3] update src/common/graph/utils/mem_utils.h. --- src/common/graph/utils/mem_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/graph/utils/mem_utils.h b/src/common/graph/utils/mem_utils.h index 7930ca0b..24bbc86c 100644 --- a/src/common/graph/utils/mem_utils.h +++ b/src/common/graph/utils/mem_utils.h @@ -27,6 +27,6 @@ static inline std::shared_ptr<_Tp> MakeShared(_Args &&... __args) { std::shared_ptr<_Tp> ret(new (std::nothrow) _Tp_nc(std::forward<_Args>(__args)...)); return ret; } -} // namespace ge +} // namespace ge #endif // COMMON_GRAPH_UTILS_MEM_UTILS_H_