From 7c19be97fd16cb75da83acb15a16b6406677c177 Mon Sep 17 00:00:00 2001 From: w00562650 Date: Tue, 24 Nov 2020 14:34:00 +0800 Subject: [PATCH 01/13] update cmake and src --- .gitmodules | 8 + CMakeLists.txt | 238 +- build.sh | 145 +- cmake/FindModule.cmake | 23 + cmake/external_libs/eigen.cmake | 22 - cmake/external_libs/gflags.cmake | 47 + cmake/external_libs/gtest.cmake | 24 - cmake/external_libs/json.cmake | 35 +- cmake/external_libs/onnx.cmake | 34 +- cmake/external_libs/protobuf.cmake | 63 - cmake/external_libs/protobuf_shared.cmake | 16 +- cmake/external_libs/protobuf_static.cmake | 51 + cmake/external_libs/protoc.cmake | 2 +- cmake/external_libs/securec.cmake | 73 +- cmake/ge_utils.cmake | 371 --- cmake/intf_pub_android.cmake | 52 + cmake/intf_pub_linux.cmake | 33 + cmake/intf_pub_windows.cmake | 24 + ge/CMakeLists.txt | 901 +++++++ ge/README.md | 0 {src/ge => ge}/analyzer/analyzer.cc | 0 {src/ge => ge}/analyzer/analyzer.h | 0 {src/ge => ge}/client/ge_api.cc | 0 {src/ge => ge}/client/ge_prof.cc | 0 {src/ge => ge}/client/module.mk | 0 {src => ge/client}/proto/ge_api.proto | 0 {src => ge/client}/proto/ge_ir.proto | 0 {src => ge/client}/proto/insert_op.proto | 0 {src => ge/client}/proto/om.proto | 0 {src => ge/client}/proto/task.proto | 0 ge/common/CMakeLists.txt | 171 ++ {src/ge => ge}/common/auth/file_saver.cc | 0 {src/ge => ge}/common/auth/file_saver.h | 0 {src/ge => ge}/common/base64.h | 0 {src/ge => ge}/common/context/ctx.cc | 0 {src/ge => ge}/common/convert/pb2json.cc | 0 {src/ge => ge}/common/convert/pb2json.h | 0 {src/ge => ge}/common/cust_aicpu_kernel_store.cc | 0 {src/ge => ge}/common/cust_aicpu_kernel_store.h | 0 {src/ge => ge}/common/debug/memory_dumper.cc | 0 {src/ge => ge}/common/debug/memory_dumper.h | 0 {src/ge => ge}/common/dump/dump_manager.cc | 0 {src/ge => ge}/common/dump/dump_manager.h | 0 {src/ge => ge}/common/dump/dump_op.cc | 0 {src/ge => ge}/common/dump/dump_op.h | 0 {src/ge => ge}/common/dump/dump_properties.cc | 0 {src/ge => ge}/common/dump/dump_properties.h | 0 {src/ge => ge}/common/dump/dump_server.cc | 0 {src/ge => ge}/common/fmk_error_codes.cc | 0 .../formats/format_transfers/datatype_transfer.cc | 0 .../formats/format_transfers/datatype_transfer.h | 0 .../format_transfer_c1hwncoc0_hwcn.cc | 0 .../format_transfer_c1hwncoc0_hwcn.h | 0 .../format_transfer_dhwcn_fracz3D.cc | 0 .../format_transfer_dhwcn_fracz3D.h | 0 .../format_transfer_dhwnc_fracz3D_transpose.cc | 0 .../format_transfer_dhwnc_fracz3D_transpose.h | 0 .../format_transfers/format_transfer_fractal_nz.cc | 0 .../format_transfers/format_transfer_fractal_nz.h | 0 .../format_transfers/format_transfer_fractal_z.cc | 0 .../format_transfers/format_transfer_fractal_z.h | 0 .../format_transfers/format_transfer_fractal_zz.cc | 0 .../format_transfers/format_transfer_fractal_zz.h | 0 .../format_transfers/format_transfer_fracz_hwcn.cc | 0 .../format_transfers/format_transfer_fracz_hwcn.h | 0 .../format_transfers/format_transfer_fracz_nchw.cc | 0 .../format_transfers/format_transfer_fracz_nchw.h | 0 .../format_transfers/format_transfer_fracz_nhwc.cc | 0 .../format_transfers/format_transfer_fracz_nhwc.h | 0 .../format_transfer_hwcn_c1hwncoc0.cc | 0 .../format_transfer_hwcn_c1hwncoc0.h | 0 .../format_transfer_nc1hwc0_nchw.cc | 0 .../format_transfer_nc1hwc0_nchw.h | 0 .../format_transfer_nc1hwc0_nhwc.cc | 0 .../format_transfer_nc1hwc0_nhwc.h | 0 .../format_transfer_nchw_fz_c04.cc | 0 .../format_transfers/format_transfer_nchw_fz_c04.h | 0 .../format_transfer_nchw_nc1hwc0.cc | 0 .../format_transfer_nchw_nc1hwc0.h | 0 .../format_transfer_nhwc_nc1hwc0.cc | 0 .../format_transfer_nhwc_nc1hwc0.h | 0 .../format_transfers/format_transfer_transpose.cc | 0 .../format_transfers/format_transfer_transpose.h | 0 {src/ge => ge}/common/formats/formats.cc | 0 {src/ge => ge}/common/formats/formats.h | 0 .../common/formats/utils/formats_definitions.h | 0 .../common/formats/utils/formats_trans_utils.cc | 0 .../common/formats/utils/formats_trans_utils.h | 0 {src/ge => ge}/common/fp16_t.cc | 0 {src/ge => ge}/common/fp16_t.h | 0 {src/ge => ge}/common/ge/datatype_util.cc | 0 {src/ge => ge}/common/ge/datatype_util.h | 0 {src/ge => ge}/common/ge/ge_util.h | 0 {src/ge => ge}/common/ge/op_tiling_manager.cc | 0 {src/ge => ge}/common/ge/op_tiling_manager.h | 0 {src/ge => ge}/common/ge/plugin_manager.cc | 0 {src/ge => ge}/common/ge/plugin_manager.h | 0 {src/ge => ge}/common/ge/tbe_plugin_manager.cc | 0 {src/ge => ge}/common/ge/tbe_plugin_manager.h | 0 {src/ge => ge}/common/ge_common.mk | 0 {src/ge => ge}/common/ge_format_util.cc | 0 {src/ge => ge}/common/helper/model_cache_helper.cc | 0 {src/ge => ge}/common/helper/model_cache_helper.h | 0 {src/ge => ge}/common/helper/model_helper.cc | 0 {src/ge => ge}/common/helper/om_file_helper.cc | 0 {src/ge => ge}/common/kernel_store.cc | 0 {src/ge => ge}/common/kernel_store.h | 0 {src/ge => ge}/common/math/fp16_math.cc | 0 {src/ge => ge}/common/math/fp16_math.h | 0 {src/ge => ge}/common/math/math_util.h | 0 {src/ge => ge}/common/math_util.h | 0 {src/ge => ge}/common/model_parser/base.cc | 0 {src/ge => ge}/common/model_parser/base.h | 0 {src/ge => ge}/common/model_saver.cc | 0 {src/ge => ge}/common/model_saver.h | 4 +- {src/ge => ge}/common/module.mk | 0 {src/ge => ge}/common/op/attr_value_util.cc | 0 {src/ge => ge}/common/op/ge_op_utils.cc | 0 .../common/profiling/profiling_manager.cc | 0 .../ge => ge}/common/profiling/profiling_manager.h | 0 {src/ge => ge}/common/properties_manager.cc | 0 {src/ge => ge}/common/properties_manager.h | 0 ge/common/proto/ge_ir.proto | 206 ++ ge/common/proto/insert_op.proto | 152 ++ ge/common/proto/om.proto | 401 +++ {src => ge/common}/proto/op_mapping_info.proto | 0 ge/common/proto/task.proto | 170 ++ ge/common/proto/tensorflow/attr_value.proto | 62 + ge/common/proto/tensorflow/function.proto | 100 + ge/common/proto/tensorflow/graph.proto | 56 + ge/common/proto/tensorflow/graph_library.proto | 14 + ge/common/proto/tensorflow/node_def.proto | 63 + ge/common/proto/tensorflow/op_def.proto | 164 ++ ge/common/proto/tensorflow/resource_handle.proto | 29 + ge/common/proto/tensorflow/tensor.proto | 94 + ge/common/proto/tensorflow/tensor_shape.proto | 45 + ge/common/proto/tensorflow/types.proto | 74 + ge/common/proto/tensorflow/versions.proto | 31 + {src/ge => ge}/common/singleton.h | 0 {src/ge => ge}/common/tbe_kernel_store.cc | 0 {src/ge => ge}/common/tbe_kernel_store.h | 0 {src/ge => ge}/common/thread_pool.cc | 0 {src/ge => ge}/common/thread_pool.h | 0 {src/ge => ge}/common/types.cc | 0 {src/ge => ge}/common/util.cc | 0 {src/ge => ge}/engine_manager/dnnengine_manager.cc | 0 {src/ge => ge}/engine_manager/dnnengine_manager.h | 0 {src/ge => ge}/engine_manager/engine_conf.json | 122 +- ge/executor/CMakeLists.txt | 113 + {src/ge => ge}/executor/ge_executor.cc | 0 {src/ge => ge}/executor/module.mk | 0 {src => ge/executor}/proto/dump_task.proto | 0 ge/executor/proto/ge_ir.proto | 206 ++ ge/executor/proto/insert_op.proto | 152 ++ ge/executor/proto/om.proto | 401 +++ ge/executor/proto/op_mapping_info.proto | 89 + ge/executor/proto/task.proto | 170 ++ {src/ge => ge}/ge_inference.mk | 0 ge/ge_local_engine/CMakeLists.txt | 116 + .../ge_local_engine/common/constant/constant.h | 0 .../ge_local_engine/engine/ge_local_engine.cc | 0 .../ge_local_engine/engine/ge_local_engine.h | 0 .../ge_local_engine/engine/host_cpu_engine.cc | 0 .../ge_local_engine/engine/host_cpu_engine.h | 0 {src/ge => ge}/ge_local_engine/module.mk | 0 .../ops_kernel_store/ge_local_ops_kernel_info.cc | 0 .../ops_kernel_store/ge_local_ops_kernel_info.h | 0 .../ops_kernel_store/op/ge_deleted_op.cc | 0 .../ops_kernel_store/op/ge_deleted_op.h | 0 .../ge_local_engine/ops_kernel_store/op/no_op.cc | 0 .../ge_local_engine/ops_kernel_store/op/no_op.h | 0 .../ge_local_engine/ops_kernel_store/op/op.cc | 0 .../ge_local_engine/ops_kernel_store/op/op.h | 0 .../ops_kernel_store/op/op_factory.cc | 0 .../ops_kernel_store/op/op_factory.h | 0 ge/ge_local_engine/proto/task.proto | 170 ++ {src/ge => ge}/ge_runner.mk | 0 ge/ge_runtime/CMakeLists.txt | 65 + {src/ge => ge}/ge_runtime/model_context.h | 0 {src/ge => ge}/ge_runtime/model_runner.cc | 0 ge/ge_runtime/module.mk | 66 + {src/ge => ge}/ge_runtime/output.cc | 0 {src/ge => ge}/ge_runtime/output.h | 0 {src/ge => ge}/ge_runtime/runtime_model.cc | 0 {src/ge => ge}/ge_runtime/runtime_model.h | 0 {src/ge => ge}/ge_runtime/task/aicpu_task.cc | 0 {src/ge => ge}/ge_runtime/task/aicpu_task.h | 0 {src/ge => ge}/ge_runtime/task/cce_task.cc | 0 {src/ge => ge}/ge_runtime/task/cce_task.h | 0 .../ge => ge}/ge_runtime/task/event_record_task.cc | 0 {src/ge => ge}/ge_runtime/task/event_record_task.h | 0 {src/ge => ge}/ge_runtime/task/event_wait_task.cc | 0 {src/ge => ge}/ge_runtime/task/event_wait_task.h | 0 {src/ge => ge}/ge_runtime/task/hccl_task.cc | 0 {src/ge => ge}/ge_runtime/task/hccl_task.h | 0 {src/ge => ge}/ge_runtime/task/label_goto_task.cc | 0 {src/ge => ge}/ge_runtime/task/label_goto_task.h | 0 {src/ge => ge}/ge_runtime/task/label_set_task.cc | 0 {src/ge => ge}/ge_runtime/task/label_set_task.h | 0 .../ge => ge}/ge_runtime/task/label_switch_task.cc | 0 {src/ge => ge}/ge_runtime/task/label_switch_task.h | 0 .../ge => ge}/ge_runtime/task/memcpy_async_task.cc | 0 {src/ge => ge}/ge_runtime/task/memcpy_async_task.h | 0 {src/ge => ge}/ge_runtime/task/profiler_task.cc | 0 {src/ge => ge}/ge_runtime/task/profiler_task.h | 0 .../ge_runtime/task/stream_active_task.cc | 0 .../ge => ge}/ge_runtime/task/stream_active_task.h | 0 .../ge_runtime/task/stream_switch_task.cc | 0 .../ge => ge}/ge_runtime/task/stream_switch_task.h | 0 {src/ge => ge}/ge_runtime/task/task.h | 0 {src/ge => ge}/ge_runtime/task/task_factory.h | 0 {src/ge => ge}/ge_runtime/task/tbe_task.cc | 0 {src/ge => ge}/ge_runtime/task/tbe_task.h | 0 {src/ge => ge}/generator/ge_generator.cc | 0 {src/ge => ge}/generator/generator_api.cc | 0 {src/ge => ge}/graph/build/graph_builder.cc | 0 {src/ge => ge}/graph/build/graph_builder.h | 0 {src/ge => ge}/graph/build/label_allocator.cc | 0 {src/ge => ge}/graph/build/label_allocator.h | 0 .../graph/build/logical_stream_allocator.cc | 0 .../graph/build/logical_stream_allocator.h | 0 ge/graph/build/memory/CMakeLists.txt | 38 + .../build/memory/binary_block_mem_assigner.cc | 0 .../graph/build/memory/binary_block_mem_assigner.h | 0 .../graph/build/memory/block_mem_assigner.cc | 0 .../graph/build/memory/block_mem_assigner.h | 0 .../graph/build/memory/graph_mem_assigner.cc | 0 .../graph/build/memory/graph_mem_assigner.h | 0 .../graph/build/memory/hybrid_mem_assigner.cc | 0 .../graph/build/memory/hybrid_mem_assigner.h | 0 .../graph/build/memory/max_block_mem_assigner.cc | 0 .../graph/build/memory/max_block_mem_assigner.h | 0 {src/ge => ge}/graph/build/memory/mem_assigner.h | 0 .../graph/build/memory/memory_assigner.cc | 0 {src/ge => ge}/graph/build/memory/module.mk | 0 .../graph/build/memory/var_mem_assign_util.cc | 0 .../graph/build/memory/var_mem_assign_util.h | 0 {src/ge => ge}/graph/build/model_builder.cc | 0 {src/ge => ge}/graph/build/model_builder.h | 0 {src/ge => ge}/graph/build/run_context.cc | 0 {src/ge => ge}/graph/build/run_context.h | 0 {src/ge => ge}/graph/build/stream_allocator.cc | 0 {src/ge => ge}/graph/build/stream_allocator.h | 0 .../graph/build/stream_graph_optimizer.cc | 0 .../ge => ge}/graph/build/stream_graph_optimizer.h | 0 {src/ge => ge}/graph/build/task_generator.cc | 0 {src/ge => ge}/graph/build/task_generator.h | 0 {src/ge => ge}/graph/common/bcast.cc | 0 {src/ge => ge}/graph/common/bcast.h | 0 {src/ge => ge}/graph/common/ge_call_wrapper.h | 0 {src/ge => ge}/graph/common/local_context.cc | 0 {src/ge => ge}/graph/common/local_context.h | 0 {src/ge => ge}/graph/common/omg_util.cc | 0 {src/ge => ge}/graph/common/omg_util.h | 0 {src/ge => ge}/graph/common/transop_util.cc | 0 {src/ge => ge}/graph/common/transop_util.h | 0 {src/ge => ge}/graph/execute/graph_execute.cc | 0 {src/ge => ge}/graph/execute/graph_execute.h | 0 {src/ge => ge}/graph/label/case_label_maker.cc | 0 {src/ge => ge}/graph/label/case_label_maker.h | 0 {src/ge => ge}/graph/label/if_label_maker.cc | 0 {src/ge => ge}/graph/label/if_label_maker.h | 0 {src/ge => ge}/graph/label/label_maker.cc | 0 {src/ge => ge}/graph/label/label_maker.h | 0 {src/ge => ge}/graph/label/label_maker_factory.h | 0 .../graph/label/partitioned_call_label_maker.cc | 0 .../graph/label/partitioned_call_label_maker.h | 0 {src/ge => ge}/graph/label/while_label_maker.cc | 0 {src/ge => ge}/graph/label/while_label_maker.h | 0 {src/ge => ge}/graph/load/graph_loader.cc | 0 {src/ge => ge}/graph/load/graph_loader.h | 0 .../graph/load/new_model_manager/aipp_utils.cc | 0 .../graph/load/new_model_manager/aipp_utils.h | 0 .../load/new_model_manager/cpu_queue_schedule.cc | 0 .../load/new_model_manager/cpu_queue_schedule.h | 0 .../graph/load/new_model_manager/data_dumper.cc | 0 .../graph/load/new_model_manager/data_dumper.h | 0 .../graph/load/new_model_manager/data_inputer.cc | 0 .../graph/load/new_model_manager/data_inputer.h | 0 .../graph/load/new_model_manager/davinci_model.cc | 0 .../graph/load/new_model_manager/davinci_model.h | 0 .../load/new_model_manager/davinci_model_parser.cc | 0 .../load/new_model_manager/davinci_model_parser.h | 0 .../graph/load/new_model_manager/model_manager.cc | 0 .../graph/load/new_model_manager/model_manager.h | 0 .../graph/load/new_model_manager/model_utils.cc | 0 .../graph/load/new_model_manager/model_utils.h | 0 .../task_info/end_graph_task_info.cc | 0 .../task_info/end_graph_task_info.h | 0 .../task_info/event_record_task_info.cc | 0 .../task_info/event_record_task_info.h | 0 .../task_info/event_wait_task_info.cc | 0 .../task_info/event_wait_task_info.h | 0 .../task_info/fusion_start_task_info.cc | 0 .../task_info/fusion_start_task_info.h | 0 .../task_info/fusion_stop_task_info.cc | 0 .../task_info/fusion_stop_task_info.h | 0 .../new_model_manager/task_info/hccl_task_info.cc | 0 .../new_model_manager/task_info/hccl_task_info.h | 0 .../task_info/kernel_ex_task_info.cc | 0 .../task_info/kernel_ex_task_info.h | 0 .../task_info/kernel_task_info.cc | 0 .../new_model_manager/task_info/kernel_task_info.h | 0 .../task_info/label_goto_ex_task_info.cc | 0 .../task_info/label_goto_ex_task_info.h | 0 .../task_info/label_set_task_info.cc | 0 .../task_info/label_set_task_info.h | 0 .../task_info/label_switch_by_index_task_info.cc | 0 .../task_info/label_switch_by_index_task_info.h | 0 .../task_info/memcpy_addr_async_task_info.cc | 0 .../task_info/memcpy_addr_async_task_info.h | 0 .../task_info/memcpy_async_task_info.cc | 0 .../task_info/memcpy_async_task_info.h | 0 .../task_info/profiler_trace_task_info.cc | 0 .../task_info/profiler_trace_task_info.h | 0 .../task_info/stream_active_task_info.cc | 0 .../task_info/stream_active_task_info.h | 0 .../task_info/stream_switch_task_info.cc | 0 .../task_info/stream_switch_task_info.h | 0 .../task_info/stream_switchn_task_info.cc | 0 .../task_info/stream_switchn_task_info.h | 0 .../task_info/super_kernel/super_kernel.cc | 0 .../task_info/super_kernel/super_kernel.h | 0 .../task_info/super_kernel/super_kernel_factory.cc | 0 .../task_info/super_kernel/super_kernel_factory.h | 0 .../load/new_model_manager/task_info/task_info.cc | 0 .../load/new_model_manager/task_info/task_info.h | 0 .../task_info/task_info_factory.h | 0 .../load/new_model_manager/tbe_handle_store.cc | 0 .../load/new_model_manager/tbe_handle_store.h | 0 .../load/new_model_manager/zero_copy_offset.cc | 0 .../load/new_model_manager/zero_copy_offset.h | 0 .../graph/load/new_model_manager/zero_copy_task.cc | 0 .../graph/load/new_model_manager/zero_copy_task.h | 0 {src/ge => ge}/graph/manager/block_memory.h | 0 .../graph/manager/graph_caching_allocator.cc | 0 .../graph/manager/graph_caching_allocator.h | 0 {src/ge => ge}/graph/manager/graph_context.cc | 0 {src/ge => ge}/graph/manager/graph_context.h | 0 {src/ge => ge}/graph/manager/graph_manager.cc | 0 {src/ge => ge}/graph/manager/graph_manager.h | 0 .../ge => ge}/graph/manager/graph_manager_utils.cc | 0 {src/ge => ge}/graph/manager/graph_manager_utils.h | 0 .../ge => ge}/graph/manager/graph_mem_allocator.cc | 0 {src/ge => ge}/graph/manager/graph_mem_allocator.h | 0 {src/ge => ge}/graph/manager/graph_var_manager.cc | 0 {src/ge => ge}/graph/manager/graph_var_manager.h | 0 {src/ge => ge}/graph/manager/host_mem_manager.cc | 0 {src/ge => ge}/graph/manager/host_mem_manager.h | 0 {src/ge => ge}/graph/manager/memory_api.cc | 0 .../graph/manager/model_manager/event_manager.cc | 0 .../graph/manager/model_manager/event_manager.h | 0 .../ge => ge}/graph/manager/rdma_pool_allocator.cc | 0 {src/ge => ge}/graph/manager/rdma_pool_allocator.h | 0 .../graph/manager/trans_var_data_utils.cc | 0 .../ge => ge}/graph/manager/trans_var_data_utils.h | 0 {src/ge => ge}/graph/manager/util/debug.cc | 0 {src/ge => ge}/graph/manager/util/debug.h | 0 {src/ge => ge}/graph/manager/util/hcom_util.cc | 0 {src/ge => ge}/graph/manager/util/hcom_util.h | 0 .../graph/manager/util/rt_context_util.cc | 0 .../ge => ge}/graph/manager/util/rt_context_util.h | 0 .../graph/manager/util/variable_accelerate_ctrl.cc | 0 .../graph/manager/util/variable_accelerate_ctrl.h | 0 {src/ge => ge}/graph/optimize/common/params.h | 0 {src/ge => ge}/graph/optimize/graph_optimize.cc | 0 {src/ge => ge}/graph/optimize/graph_optimize.h | 0 .../graph/optimize/mem_rw_conflict_optimize.cc | 0 .../optimize/optimizer/allreduce_fusion_pass.cc | 0 .../optimize/optimizer/allreduce_fusion_pass.h | 0 {src/ge => ge}/graph/optimize/summary_optimize.cc | 0 .../graph/partition/dynamic_shape_partition.cc | 0 .../graph/partition/dynamic_shape_partition.h | 0 {src/ge => ge}/graph/partition/engine_place.cc | 0 {src/ge => ge}/graph/partition/engine_place.h | 0 {src/ge => ge}/graph/partition/graph_partition.cc | 0 {src/ge => ge}/graph/partition/graph_partition.h | 0 {src/ge => ge}/graph/passes/addn_pass.cc | 0 {src/ge => ge}/graph/passes/addn_pass.h | 0 .../graph/passes/aicpu_constant_folding_pass.cc | 0 .../graph/passes/aicpu_constant_folding_pass.h | 0 {src/ge => ge}/graph/passes/assert_pass.cc | 0 {src/ge => ge}/graph/passes/assert_pass.h | 0 {src/ge => ge}/graph/passes/assign_pass.cc | 0 {src/ge => ge}/graph/passes/assign_pass.h | 0 .../graph/passes/atomic_addr_clean_pass.cc | 0 .../graph/passes/atomic_addr_clean_pass.h | 0 .../graph/passes/attach_stream_label_pass.cc | 0 .../graph/passes/attach_stream_label_pass.h | 0 {src/ge => ge}/graph/passes/base_pass.cc | 0 {src/ge => ge}/graph/passes/base_pass.h | 0 {src/ge => ge}/graph/passes/bitcast_pass.cc | 0 {src/ge => ge}/graph/passes/bitcast_pass.h | 0 {src/ge => ge}/graph/passes/cast_remove_pass.cc | 0 {src/ge => ge}/graph/passes/cast_remove_pass.h | 0 {src/ge => ge}/graph/passes/cast_translate_pass.cc | 0 {src/ge => ge}/graph/passes/cast_translate_pass.h | 0 .../common_subexpression_elimination_pass.cc | 0 .../passes/common_subexpression_elimination_pass.h | 0 {src/ge => ge}/graph/passes/compile_nodes_pass.cc | 0 {src/ge => ge}/graph/passes/compile_nodes_pass.h | 0 {src/ge => ge}/graph/passes/cond_pass.cc | 0 {src/ge => ge}/graph/passes/cond_pass.h | 0 {src/ge => ge}/graph/passes/cond_remove_pass.cc | 0 {src/ge => ge}/graph/passes/cond_remove_pass.h | 0 .../graph/passes/constant_folding_pass.cc | 0 .../ge => ge}/graph/passes/constant_folding_pass.h | 0 .../graph/passes/constant_fuse_same_pass.cc | 0 .../graph/passes/constant_fuse_same_pass.h | 0 .../ge => ge}/graph/passes/control_trigger_pass.cc | 0 {src/ge => ge}/graph/passes/control_trigger_pass.h | 0 .../graph/passes/ctrl_edge_transfer_pass.cc | 0 .../graph/passes/ctrl_edge_transfer_pass.h | 0 {src/ge => ge}/graph/passes/data_pass.cc | 0 {src/ge => ge}/graph/passes/data_pass.h | 0 .../graph/passes/dimension_adjust_pass.cc | 0 .../ge => ge}/graph/passes/dimension_adjust_pass.h | 0 .../graph/passes/dimension_compute_pass.cc | 0 .../graph/passes/dimension_compute_pass.h | 0 {src/ge => ge}/graph/passes/dropout_pass.cc | 0 {src/ge => ge}/graph/passes/dropout_pass.h | 0 .../passes/end_of_sequence_add_control_pass.cc | 0 .../passes/end_of_sequence_add_control_pass.h | 0 {src/ge => ge}/graph/passes/enter_pass.cc | 0 {src/ge => ge}/graph/passes/enter_pass.h | 0 {src/ge => ge}/graph/passes/flow_ctrl_pass.cc | 0 {src/ge => ge}/graph/passes/flow_ctrl_pass.h | 0 {src/ge => ge}/graph/passes/folding_pass.cc | 0 {src/ge => ge}/graph/passes/folding_pass.h | 0 {src/ge => ge}/graph/passes/for_pass.cc | 0 {src/ge => ge}/graph/passes/for_pass.h | 0 .../graph/passes/get_original_format_pass.cc | 0 .../graph/passes/get_original_format_pass.h | 0 .../graph/passes/global_step_insert_pass.cc | 0 .../graph/passes/global_step_insert_pass.h | 0 .../ge => ge}/graph/passes/guarantee_const_pass.cc | 0 {src/ge => ge}/graph/passes/guarantee_const_pass.h | 0 {src/ge => ge}/graph/passes/hccl_group_pass.cc | 0 {src/ge => ge}/graph/passes/hccl_group_pass.h | 0 {src/ge => ge}/graph/passes/hccl_memcpy_pass.cc | 0 {src/ge => ge}/graph/passes/hccl_memcpy_pass.h | 0 {src/ge => ge}/graph/passes/identity_pass.cc | 0 {src/ge => ge}/graph/passes/identity_pass.h | 1 - {src/ge => ge}/graph/passes/infershape_pass.cc | 0 {src/ge => ge}/graph/passes/infershape_pass.h | 0 .../input_output_connection_identify_pass.cc | 0 .../passes/input_output_connection_identify_pass.h | 0 .../graph/passes/isolated_op_remove_pass.cc | 0 .../graph/passes/isolated_op_remove_pass.h | 0 {src/ge => ge}/graph/passes/iterator_op_pass.cc | 0 {src/ge => ge}/graph/passes/iterator_op_pass.h | 0 .../graph/passes/link_gen_mask_nodes_pass.cc | 0 .../graph/passes/link_gen_mask_nodes_pass.h | 0 {src/ge => ge}/graph/passes/mark_agnostic_pass.cc | 0 {src/ge => ge}/graph/passes/mark_agnostic_pass.h | 0 .../graph/passes/mark_graph_unknown_status_pass.cc | 0 .../graph/passes/mark_graph_unknown_status_pass.h | 0 {src/ge => ge}/graph/passes/mark_same_addr_pass.cc | 0 {src/ge => ge}/graph/passes/mark_same_addr_pass.h | 0 .../graph/passes/memcpy_addr_async_pass.cc | 0 .../graph/passes/memcpy_addr_async_pass.h | 0 {src/ge => ge}/graph/passes/merge_pass.cc | 0 {src/ge => ge}/graph/passes/merge_pass.h | 0 .../graph/passes/merge_to_stream_merge_pass.cc | 0 .../graph/passes/merge_to_stream_merge_pass.h | 0 .../graph/passes/multi_batch_clone_pass.cc | 0 .../graph/passes/multi_batch_clone_pass.h | 0 {src/ge => ge}/graph/passes/multi_batch_pass.cc | 0 {src/ge => ge}/graph/passes/multi_batch_pass.h | 0 {src/ge => ge}/graph/passes/net_output_pass.cc | 0 {src/ge => ge}/graph/passes/net_output_pass.h | 0 {src/ge => ge}/graph/passes/next_iteration_pass.cc | 0 {src/ge => ge}/graph/passes/next_iteration_pass.h | 0 .../graph/passes/no_use_reshape_remove_pass.cc | 0 .../graph/passes/no_use_reshape_remove_pass.h | 0 .../graph/passes/parallel_concat_start_op_pass.cc | 0 .../graph/passes/parallel_concat_start_op_pass.h | 0 {src/ge => ge}/graph/passes/pass_manager.cc | 0 {src/ge => ge}/graph/passes/pass_utils.cc | 0 {src/ge => ge}/graph/passes/pass_utils.h | 0 {src/ge => ge}/graph/passes/permute_pass.cc | 0 {src/ge => ge}/graph/passes/permute_pass.h | 0 .../graph/passes/placeholder_with_default_pass.cc | 0 .../graph/passes/placeholder_with_default_pass.h | 0 .../graph/passes/prevent_gradient_pass.cc | 0 .../ge => ge}/graph/passes/prevent_gradient_pass.h | 0 {src/ge => ge}/graph/passes/print_op_pass.cc | 0 {src/ge => ge}/graph/passes/print_op_pass.h | 0 {src/ge => ge}/graph/passes/prune_pass.cc | 0 {src/ge => ge}/graph/passes/prune_pass.h | 0 .../graph/passes/ref_identity_delete_op_pass.cc | 0 .../graph/passes/ref_identity_delete_op_pass.h | 0 {src/ge => ge}/graph/passes/remove_nodes_pass.cc | 0 {src/ge => ge}/graph/passes/remove_nodes_pass.h | 0 .../graph/passes/replace_transshape_pass.cc | 0 .../graph/passes/replace_transshape_pass.h | 0 .../graph/passes/replace_with_empty_const_pass.cc | 0 .../graph/passes/replace_with_empty_const_pass.h | 0 .../graph/passes/reshape_recovery_pass.cc | 0 .../ge => ge}/graph/passes/reshape_recovery_pass.h | 0 {src/ge => ge}/graph/passes/reshape_remove_pass.cc | 0 {src/ge => ge}/graph/passes/reshape_remove_pass.h | 0 .../graph/passes/resource_pair_add_control_pass.cc | 0 .../graph/passes/resource_pair_add_control_pass.h | 0 .../passes/resource_pair_remove_control_pass.cc | 0 .../passes/resource_pair_remove_control_pass.h | 0 .../passes/same_transdata_breadth_fusion_pass.cc | 0 .../passes/same_transdata_breadth_fusion_pass.h | 0 {src/ge => ge}/graph/passes/save_pass.cc | 0 {src/ge => ge}/graph/passes/save_pass.h | 0 .../graph/passes/set_input_output_offset_pass.cc | 0 .../graph/passes/set_input_output_offset_pass.h | 0 .../graph/passes/shape_operate_op_remove_pass.cc | 0 .../graph/passes/shape_operate_op_remove_pass.h | 0 {src/ge => ge}/graph/passes/snapshot_pass.cc | 0 {src/ge => ge}/graph/passes/snapshot_pass.h | 0 {src/ge => ge}/graph/passes/stop_gradient_pass.cc | 0 {src/ge => ge}/graph/passes/stop_gradient_pass.h | 0 .../graph/passes/subexpression_migration_pass.cc | 0 .../graph/passes/subexpression_migration_pass.h | 0 {src/ge => ge}/graph/passes/subgraph_pass.cc | 0 {src/ge => ge}/graph/passes/subgraph_pass.h | 0 .../graph/passes/switch_data_edges_bypass.cc | 0 .../graph/passes/switch_data_edges_bypass.h | 0 .../graph/passes/switch_dead_branch_elimination.cc | 0 .../graph/passes/switch_dead_branch_elimination.h | 0 .../graph/passes/switch_logic_remove_pass.cc | 0 .../graph/passes/switch_logic_remove_pass.h | 0 .../graph/passes/switch_to_stream_switch_pass.cc | 0 .../graph/passes/switch_to_stream_switch_pass.h | 0 .../graph/passes/transop_breadth_fusion_pass.cc | 0 .../graph/passes/transop_breadth_fusion_pass.h | 0 .../graph/passes/transop_depth_fusion_pass.cc | 0 .../graph/passes/transop_depth_fusion_pass.h | 0 .../passes/transop_nearby_allreduce_fusion_pass.cc | 0 .../passes/transop_nearby_allreduce_fusion_pass.h | 0 .../passes/transop_symmetry_elimination_pass.cc | 0 .../passes/transop_symmetry_elimination_pass.h | 0 .../passes/transop_without_reshape_fusion_pass.cc | 0 .../passes/transop_without_reshape_fusion_pass.h | 0 .../graph/passes/transpose_transdata_pass.cc | 0 .../graph/passes/transpose_transdata_pass.h | 2 +- .../graph/passes/unused_args_clean_pass.cc | 0 .../graph/passes/unused_args_clean_pass.h | 0 {src/ge => ge}/graph/passes/unused_const_pass.cc | 0 {src/ge => ge}/graph/passes/unused_const_pass.h | 0 .../graph/passes/unused_op_remove_pass.cc | 0 .../ge => ge}/graph/passes/unused_op_remove_pass.h | 0 .../graph/passes/var_is_initialized_op_pass.cc | 0 .../graph/passes/var_is_initialized_op_pass.h | 0 .../ge => ge}/graph/passes/variable_format_pass.cc | 0 {src/ge => ge}/graph/passes/variable_format_pass.h | 0 {src/ge => ge}/graph/passes/variable_op_pass.cc | 0 {src/ge => ge}/graph/passes/variable_op_pass.h | 0 .../graph/passes/variable_prepare_op_pass.cc | 0 .../graph/passes/variable_prepare_op_pass.h | 0 .../graph/passes/variable_ref_delete_op_pass.cc | 0 .../graph/passes/variable_ref_delete_op_pass.h | 0 ...variable_ref_useless_control_out_delete_pass.cc | 0 .../variable_ref_useless_control_out_delete_pass.h | 0 .../ge => ge}/graph/preprocess/graph_preprocess.cc | 0 {src/ge => ge}/graph/preprocess/graph_preprocess.h | 0 .../graph/preprocess/insert_op/base_insert_op.h | 0 .../graph/preprocess/insert_op/ge_aipp_op.cc | 0 .../graph/preprocess/insert_op/ge_aipp_op.h | 0 .../preprocess/insert_op/util_insert_aipp_op.cc | 0 .../preprocess/insert_op/util_insert_aipp_op.h | 0 .../graph/preprocess/multi_batch_copy_graph.cc | 0 .../graph/preprocess/multi_batch_copy_graph.h | 0 .../graph/preprocess/multi_batch_options.cc | 0 .../graph/preprocess/multi_batch_options.h | 0 ge/host_cpu_engine/CMakeLists.txt | 109 + .../host_cpu_engine/common/constant/constant.h | 0 .../host_cpu_engine/engine/host_cpu_engine.cc | 0 .../host_cpu_engine/engine/host_cpu_engine.h | 0 {src/ge => ge}/host_cpu_engine/module.mk | 0 .../ops_kernel_store/host_cpu_ops_kernel_info.cc | 0 .../ops_kernel_store/host_cpu_ops_kernel_info.h | 0 .../host_cpu_engine/ops_kernel_store/op/host_op.cc | 0 .../host_cpu_engine/ops_kernel_store/op/host_op.h | 0 .../host_cpu_engine/ops_kernel_store/op/op.h | 0 .../ops_kernel_store/op/op_factory.cc | 0 .../ops_kernel_store/op/op_factory.h | 0 ge/host_cpu_engine/proto/task.proto | 1 + {src/ge => ge}/host_kernels/add_kernel.cc | 0 {src/ge => ge}/host_kernels/add_kernel.h | 0 .../host_kernels/broadcast_args_kernel.cc | 0 .../ge => ge}/host_kernels/broadcast_args_kernel.h | 0 .../host_kernels/broadcast_gradient_args_kernel.cc | 0 .../host_kernels/broadcast_gradient_args_kernel.h | 0 {src/ge => ge}/host_kernels/cast_kernel.cc | 0 {src/ge => ge}/host_kernels/cast_kernel.h | 0 .../ge => ge}/host_kernels/concat_offset_kernel.cc | 0 {src/ge => ge}/host_kernels/concat_offset_kernel.h | 0 {src/ge => ge}/host_kernels/concat_v2_kernel.cc | 0 {src/ge => ge}/host_kernels/concat_v2_kernel.h | 0 .../host_kernels/dynamic_stitch_kernel.cc | 0 .../ge => ge}/host_kernels/dynamic_stitch_kernel.h | 0 {src/ge => ge}/host_kernels/empty_kernel.cc | 0 {src/ge => ge}/host_kernels/empty_kernel.h | 0 {src/ge => ge}/host_kernels/expanddims_kernel.cc | 0 {src/ge => ge}/host_kernels/expanddims_kernel.h | 0 {src/ge => ge}/host_kernels/fill_kernel.cc | 0 {src/ge => ge}/host_kernels/fill_kernel.h | 0 {src/ge => ge}/host_kernels/floordiv_kernel.cc | 0 {src/ge => ge}/host_kernels/floordiv_kernel.h | 0 {src/ge => ge}/host_kernels/floormod_kernel.cc | 0 {src/ge => ge}/host_kernels/floormod_kernel.h | 0 {src/ge => ge}/host_kernels/gather_v2_kernel.cc | 0 {src/ge => ge}/host_kernels/gather_v2_kernel.h | 0 {src/ge => ge}/host_kernels/greater_kernel.cc | 0 {src/ge => ge}/host_kernels/greater_kernel.h | 0 {src/ge => ge}/host_kernels/identity_kernel.cc | 0 {src/ge => ge}/host_kernels/identity_kernel.h | 0 {src/ge => ge}/host_kernels/kernel_utils.cc | 0 {src/ge => ge}/host_kernels/kernel_utils.h | 0 {src/ge => ge}/host_kernels/maximum_kernel.cc | 0 {src/ge => ge}/host_kernels/maximum_kernel.h | 0 {src/ge => ge}/host_kernels/mul_kernel.cc | 0 {src/ge => ge}/host_kernels/mul_kernel.h | 0 {src/ge => ge}/host_kernels/pack_kernel.cc | 0 {src/ge => ge}/host_kernels/pack_kernel.h | 0 {src/ge => ge}/host_kernels/permute_kernel.cc | 0 {src/ge => ge}/host_kernels/permute_kernel.h | 0 {src/ge => ge}/host_kernels/range_kernel.cc | 0 {src/ge => ge}/host_kernels/range_kernel.h | 0 {src/ge => ge}/host_kernels/rank_kernel.cc | 0 {src/ge => ge}/host_kernels/rank_kernel.h | 0 {src/ge => ge}/host_kernels/reduce_prod_kernel.cc | 0 {src/ge => ge}/host_kernels/reduce_prod_kernel.h | 0 {src/ge => ge}/host_kernels/reformat_kernel.cc | 0 {src/ge => ge}/host_kernels/reformat_kernel.h | 0 {src/ge => ge}/host_kernels/reshape_kernel.cc | 0 {src/ge => ge}/host_kernels/reshape_kernel.h | 0 {src/ge => ge}/host_kernels/rsqrt_kernel.cc | 0 {src/ge => ge}/host_kernels/rsqrt_kernel.h | 0 {src/ge => ge}/host_kernels/shape_kernel.cc | 0 {src/ge => ge}/host_kernels/shape_kernel.h | 0 {src/ge => ge}/host_kernels/shape_n_kernel.cc | 0 {src/ge => ge}/host_kernels/shape_n_kernel.h | 0 {src/ge => ge}/host_kernels/size_kernel.cc | 0 {src/ge => ge}/host_kernels/size_kernel.h | 0 {src/ge => ge}/host_kernels/slice_d_kernel.cc | 0 {src/ge => ge}/host_kernels/slice_d_kernel.h | 0 {src/ge => ge}/host_kernels/slice_kernel.cc | 0 {src/ge => ge}/host_kernels/slice_kernel.h | 0 {src/ge => ge}/host_kernels/squeeze_kernel.cc | 0 {src/ge => ge}/host_kernels/squeeze_kernel.h | 0 .../ge => ge}/host_kernels/ssd_prior_box_kernel.cc | 0 {src/ge => ge}/host_kernels/ssd_prior_box_kernel.h | 0 .../ge => ge}/host_kernels/strided_slice_kernel.cc | 0 {src/ge => ge}/host_kernels/strided_slice_kernel.h | 0 {src/ge => ge}/host_kernels/sub_kernel.cc | 0 {src/ge => ge}/host_kernels/sub_kernel.h | 0 {src/ge => ge}/host_kernels/transdata_kernel.cc | 0 {src/ge => ge}/host_kernels/transdata_kernel.h | 0 {src/ge => ge}/host_kernels/transpose_kernel.cc | 0 {src/ge => ge}/host_kernels/transpose_kernel.h | 0 {src/ge => ge}/host_kernels/unpack_kernel.cc | 0 {src/ge => ge}/host_kernels/unpack_kernel.h | 0 {src/ge => ge}/host_kernels/unsqueeze_kernel.cc | 0 {src/ge => ge}/host_kernels/unsqueeze_kernel.h | 0 .../hybrid/common/npu_memory_allocator.cc | 0 .../ge => ge}/hybrid/common/npu_memory_allocator.h | 0 {src/ge => ge}/hybrid/common/tensor_value.cc | 0 {src/ge => ge}/hybrid/common/tensor_value.h | 0 .../hybrid/executor/hybrid_execution_context.cc | 0 .../hybrid/executor/hybrid_execution_context.h | 0 .../hybrid/executor/hybrid_model_async_executor.cc | 0 .../hybrid/executor/hybrid_model_async_executor.h | 0 .../hybrid/executor/hybrid_model_executor.cc | 0 .../hybrid/executor/hybrid_model_executor.h | 0 {src/ge => ge}/hybrid/executor/hybrid_profiler.cc | 0 {src/ge => ge}/hybrid/executor/hybrid_profiler.h | 0 .../ge => ge}/hybrid/executor/node_done_manager.cc | 0 {src/ge => ge}/hybrid/executor/node_done_manager.h | 0 {src/ge => ge}/hybrid/executor/node_state.cc | 0 {src/ge => ge}/hybrid/executor/node_state.h | 0 .../hybrid/executor/rt_callback_manager.cc | 0 .../hybrid/executor/rt_callback_manager.h | 0 {src/ge => ge}/hybrid/executor/subgraph_context.cc | 0 {src/ge => ge}/hybrid/executor/subgraph_context.h | 0 .../ge => ge}/hybrid/executor/subgraph_executor.cc | 0 {src/ge => ge}/hybrid/executor/subgraph_executor.h | 0 .../hybrid/executor/worker/execution_engine.cc | 0 .../hybrid/executor/worker/execution_engine.h | 0 .../executor/worker/shape_inference_engine.cc | 0 .../executor/worker/shape_inference_engine.h | 0 .../hybrid/executor/worker/task_compile_engine.cc | 0 .../hybrid/executor/worker/task_compile_engine.h | 0 {src/ge => ge}/hybrid/hybrid_davinci_model.cc | 0 {src/ge => ge}/hybrid/hybrid_davinci_model.h | 0 {src/ge => ge}/hybrid/hybrid_davinci_model_stub.cc | 0 {src/ge => ge}/hybrid/model/graph_item.cc | 0 {src/ge => ge}/hybrid/model/graph_item.h | 0 {src/ge => ge}/hybrid/model/hybrid_model.cc | 0 {src/ge => ge}/hybrid/model/hybrid_model.h | 0 .../ge => ge}/hybrid/model/hybrid_model_builder.cc | 0 {src/ge => ge}/hybrid/model/hybrid_model_builder.h | 0 {src/ge => ge}/hybrid/model/node_item.cc | 0 {src/ge => ge}/hybrid/model/node_item.h | 0 .../node_executor/aicore/aicore_node_executor.cc | 0 .../node_executor/aicore/aicore_node_executor.h | 0 .../hybrid/node_executor/aicore/aicore_op_task.cc | 0 .../hybrid/node_executor/aicore/aicore_op_task.h | 0 .../node_executor/aicore/aicore_task_builder.cc | 0 .../node_executor/aicore/aicore_task_builder.h | 0 .../node_executor/aicore/aicore_task_compiler.cc | 0 .../node_executor/aicore/aicore_task_compiler.h | 0 .../hybrid/node_executor/aicpu/aicpu_ext_info.cc | 0 .../hybrid/node_executor/aicpu/aicpu_ext_info.h | 0 .../node_executor/aicpu/aicpu_node_executor.cc | 0 .../node_executor/aicpu/aicpu_node_executor.h | 0 .../compiledsubgraph/known_node_executor.cc | 0 .../compiledsubgraph/known_node_executor.h | 0 .../node_executor/controlop/control_op_executor.cc | 0 .../node_executor/controlop/control_op_executor.h | 0 .../ge_local/ge_local_node_executor.cc | 0 .../ge_local/ge_local_node_executor.h | 0 .../node_executor/hccl/hccl_node_executor.cc | 0 .../hybrid/node_executor/hccl/hccl_node_executor.h | 0 .../host_cpu/host_cpu_node_executor.cc | 0 .../host_cpu/host_cpu_node_executor.h | 0 .../node_executor/host_cpu/kernel/assign_kernel.cc | 0 .../node_executor/host_cpu/kernel/assign_kernel.h | 0 .../hybrid/node_executor/host_cpu/kernel/kernel.h | 0 .../node_executor/host_cpu/kernel/no_op_kernel.cc | 0 .../node_executor/host_cpu/kernel/no_op_kernel.h | 0 .../host_cpu/kernel/random_uniform_kernel.cc | 0 .../host_cpu/kernel/random_uniform_kernel.h | 0 .../host_cpu/kernel/variable_kernel.cc | 0 .../host_cpu/kernel/variable_kernel.h | 0 .../node_executor/host_cpu/kernel_factory.cc | 0 .../hybrid/node_executor/host_cpu/kernel_factory.h | 0 .../hybrid/node_executor/node_executor.cc | 0 .../ge => ge}/hybrid/node_executor/node_executor.h | 0 .../partitioned_call_node_executor.cc | 0 .../partitioned_call_node_executor.h | 0 .../hybrid/node_executor/rts/rts_node_executor.cc | 0 .../hybrid/node_executor/rts/rts_node_executor.h | 0 .../ge => ge}/hybrid/node_executor/task_context.cc | 0 {src/ge => ge}/hybrid/node_executor/task_context.h | 0 {src/ge => ge}/inc/graph_pass.h | 0 {src/ge => ge}/inc/kernel.h | 0 {src/ge => ge}/inc/kernel_factory.h | 0 {src/ge => ge}/inc/pass.h | 0 {src/ge => ge}/inc/pass_manager.h | 0 {src/ge => ge}/init/gelib.cc | 0 {src/ge => ge}/init/gelib.h | 0 {src/ge => ge}/ir_build/atc_ir_common.cc | 0 {src/ge => ge}/ir_build/atc_ir_common.h | 0 {src/ge => ge}/ir_build/ge_ir_build.cc | 0 {src/ge => ge}/model/ge_model.cc | 0 {src/ge => ge}/model/ge_model.h | 0 {src/ge => ge}/model/ge_root_model.cc | 0 {src/ge => ge}/model/ge_root_model.h | 0 {src/ge => ge}/module.mk | 0 ge/offline/CMakeLists.txt | 81 + ge/offline/main.cc | 1334 ++++++++++ ge/offline/module.mk | 52 + ge/offline/proto/ge_ir.proto | 1 + ge/offline/proto/insert_op.proto | 1 + ge/offline/proto/om.proto | 1 + ge/offline/proto/task.proto | 1 + ge/offline/single_op_parser.cc | 503 ++++ ge/offline/single_op_parser.h | 78 + {src/ge => ge}/omm/csa_interact.cc | 0 {src/ge => ge}/omm/csa_interact.h | 25 +- .../opskernel_manager/ops_kernel_manager.cc | 0 .../opskernel_manager/ops_kernel_manager.h | 0 .../opskernel_manager/optimizer_priority.pbtxt | 0 ge/plugin/engine/CMakeLists.txt | 49 + {src/ge => ge}/plugin/engine/dnnengines.cc | 0 {src/ge => ge}/plugin/engine/dnnengines.h | 0 {src/ge => ge}/plugin/engine/engine_manage.cc | 0 {src/ge => ge}/plugin/engine/engine_manage.h | 0 {src/ge => ge}/plugin/engine/module.mk | 0 ge/proto/caffe/caffe.proto | 1821 +++++++++++++ ge/proto/dump_task.proto | 127 + {src => ge}/proto/fusion_model.proto | 0 {src => ge}/proto/fwk_adapter.proto | 0 ge/proto/ge_api.proto | 104 + ge/proto/ge_ir.proto | 206 ++ ge/proto/insert_op.proto | 152 ++ ge/proto/om.proto | 401 +++ ge/proto/op_mapping_info.proto | 89 + {src => ge}/proto/optimizer_priority.proto | 0 ge/proto/task.proto | 170 ++ ge/proto/tensorflow/attr_value.proto | 62 + ge/proto/tensorflow/function.proto | 100 + ge/proto/tensorflow/graph.proto | 56 + ge/proto/tensorflow/graph_library.proto | 14 + ge/proto/tensorflow/node_def.proto | 63 + ge/proto/tensorflow/op_def.proto | 164 ++ ge/proto/tensorflow/resource_handle.proto | 29 + ge/proto/tensorflow/tensor.proto | 94 + ge/proto/tensorflow/tensor_shape.proto | 45 + ge/proto/tensorflow/types.proto | 74 + ge/proto/tensorflow/versions.proto | 31 + {src/ge => ge}/session/inner_session.cc | 0 {src/ge => ge}/session/inner_session.h | 0 {src/ge => ge}/session/omg.cc | 0 ge/session/readme.txt | 3 + {src/ge => ge}/session/session_manager.cc | 0 {src/ge => ge}/session/session_manager.h | 0 {src/ge => ge}/single_op/single_op.cc | 0 {src/ge => ge}/single_op/single_op.h | 0 {src/ge => ge}/single_op/single_op_manager.cc | 0 {src/ge => ge}/single_op/single_op_manager.h | 0 {src/ge => ge}/single_op/single_op_model.cc | 0 {src/ge => ge}/single_op/single_op_model.h | 0 {src/ge => ge}/single_op/stream_resource.cc | 0 {src/ge => ge}/single_op/stream_resource.h | 0 .../single_op/task/aicpu_kernel_task_builder.cc | 0 .../single_op/task/aicpu_kernel_task_builder.h | 0 .../ge => ge}/single_op/task/aicpu_task_builder.cc | 0 {src/ge => ge}/single_op/task/aicpu_task_builder.h | 0 {src/ge => ge}/single_op/task/build_task_utils.cc | 0 {src/ge => ge}/single_op/task/build_task_utils.h | 0 {src/ge => ge}/single_op/task/op_task.cc | 0 {src/ge => ge}/single_op/task/op_task.h | 0 {src/ge => ge}/single_op/task/tbe_task_builder.cc | 0 {src/ge => ge}/single_op/task/tbe_task_builder.h | 0 {src/ge => ge}/stub/Makefile | 0 {src/ge => ge}/stub/README | 0 {src/ge => ge}/stub/README.md | 0 {src/ge => ge}/stub/gen_stubapi.py | 2 +- inc/common/blocking_queue.h | 141 - inc/common/dynamic_aipp.h | 104 - inc/common/npu_error_define.h | 94 - inc/common/opskernel/ge_task_info.h | 74 - inc/common/opskernel/ops_kernel_info_store.h | 88 - inc/common/opskernel/ops_kernel_info_types.h | 66 - inc/common/optimizer/graph_optimizer.h | 71 - .../util/ai_core/common/aicore_util_attr_define.h | 41 - inc/common/util/ai_core/common/aicore_util_types.h | 118 - inc/common/util/ai_core/common/graph_comm.h | 107 - inc/common/util/ai_core/common/scope_allocator.h | 43 - .../param_calculate/aicore_param_calculator.h | 33 - .../param_calculate/tensorsize_calculator.h | 45 - inc/common/util/compress/compress.h | 37 - inc/common/util/compress/compress_weight.h | 33 - inc/common/util/error_manager/error_manager.h | 94 - inc/common/util/platform_info.h | 101 - inc/common/util/platform_info_def.h | 140 - inc/external/graph/attr_value.h | 75 - inc/external/graph/ge_error_codes.h | 38 - inc/external/graph/graph.h | 81 - inc/external/graph/inference_context.h | 76 - inc/external/graph/operator.h | 289 -- inc/external/graph/operator_factory.h | 68 - inc/external/graph/operator_reg.h | 376 --- inc/external/graph/tensor.h | 131 - inc/external/graph/types.h | 240 -- inc/external/register/register.h | 163 -- inc/external/register/register_error_codes.h | 39 - inc/external/register/register_fmk_types.h | 37 - inc/external/register/register_types.h | 59 - .../register/scope/scope_fusion_pass_register.h | 334 --- inc/framework/omg/parser/model_parser.h | 111 + inc/framework/omg/parser/op_parser.h | 92 + .../omg/parser/parser_api.h} | 27 +- inc/framework/omg/parser/parser_factory.h | 138 + inc/framework/omg/parser/parser_inner_ctx.h | 43 + inc/framework/omg/parser/weights_parser.h | 74 + inc/graph/anchor.h | 284 -- inc/graph/attr_value_serializable.h | 191 -- inc/graph/buffer.h | 82 - inc/graph/compute_graph.h | 308 --- inc/graph/debug/ge_attr_define.h | 1130 -------- inc/graph/def_types.h | 195 -- inc/graph/detail/any_map.h | 120 - inc/graph/detail/attributes_holder.h | 165 -- inc/graph/detail/model_serialize_imp.h | 93 - inc/graph/ge_attr_value.h | 343 --- inc/graph/ge_context.h | 46 - inc/graph/ge_global_options.h | 26 - inc/graph/ge_local_context.h | 44 - inc/graph/ge_tensor.h | 193 -- inc/graph/graph_util.h | 134 - inc/graph/model.h | 94 - inc/graph/model_serialize.h | 52 - inc/graph/node.h | 213 -- inc/graph/op_desc.h | 329 --- inc/graph/op_kernel_bin.h | 48 - inc/graph/operator_factory_impl.h | 56 - inc/graph/opsproto_manager.h | 46 - inc/graph/range_vistor.h | 57 - inc/graph/ref_relation.h | 79 - inc/graph/runtime_inference_context.h | 49 - inc/graph/shape_refiner.h | 40 - inc/graph/tuning_utils.h | 130 - inc/graph/usr_types.h | 133 - inc/graph/utils/anchor_utils.h | 45 - inc/graph/utils/attr_utils.h | 150 -- inc/graph/utils/graph_utils.h | 771 ------ inc/graph/utils/node_utils.h | 170 -- inc/graph/utils/op_desc_utils.h | 182 -- inc/graph/utils/tensor_adapter.h | 43 - inc/graph/utils/tensor_utils.h | 77 - inc/graph/utils/type_utils.h | 53 - src/common/graph/CMakeLists.txt | 81 - src/common/graph/anchor.cc | 371 --- src/common/graph/attr_value.cc | 38 - src/common/graph/buffer.cc | 112 - src/common/graph/compute_graph.cc | 1314 ---------- src/common/graph/debug/ge_log.h | 147 -- src/common/graph/debug/ge_op_types.h | 69 - src/common/graph/debug/ge_util.h | 274 -- src/common/graph/debug/graph_debug.cc | 246 -- src/common/graph/debug/graph_debug.h | 48 - src/common/graph/detail/attributes_holder.cc | 241 -- src/common/graph/format_refiner.cc | 508 ---- src/common/graph/format_refiner.h | 50 - src/common/graph/ge_attr_define.cc | 1086 -------- src/common/graph/ge_attr_value.cc | 1289 --------- src/common/graph/ge_tensor.cc | 1021 -------- src/common/graph/graph.cc | 384 --- src/common/graph/graph.mk | 294 --- src/common/graph/inference_context.cc | 112 - src/common/graph/model.cc | 190 -- src/common/graph/model_serialize.cc | 763 ------ src/common/graph/module.mk | 3 - src/common/graph/node.cc | 878 ------- src/common/graph/op_desc.cc | 1410 ---------- src/common/graph/op_imp.cc | 79 - src/common/graph/operator.cc | 1587 ----------- src/common/graph/operator_factory.cc | 48 - src/common/graph/operator_factory_impl.cc | 149 -- src/common/graph/opsproto/opsproto_manager.cc | 187 -- src/common/graph/option/ge_context.cc | 104 - src/common/graph/option/ge_local_context.cc | 60 - src/common/graph/ref_relation.cc | 455 ---- src/common/graph/runtime_inference_context.cc | 129 - src/common/graph/shape_refiner.cc | 688 ----- src/common/graph/stub/Makefile | 6 - src/common/graph/stub/gen_stubapi.py | 578 ---- src/common/graph/tensor.cc | 704 ----- src/common/graph/utils/anchor_utils.cc | 102 - src/common/graph/utils/ge_ir_utils.cc | 1178 --------- src/common/graph/utils/ge_ir_utils.h | 206 -- src/common/graph/utils/graph_utils.cc | 2767 -------------------- src/common/graph/utils/mem_utils.h | 32 - src/common/graph/utils/node_utils.cc | 1005 ------- src/common/graph/utils/op_desc_utils.cc | 825 ------ src/common/graph/utils/string_utils.h | 68 - src/common/graph/utils/tensor_utils.cc | 401 --- src/common/graph/utils/tuning_utils.cc | 684 ----- src/common/graph/utils/type_utils.cc | 448 ---- src/ge/CMakeLists.txt | 382 --- src/ge/client/CMakeLists.txt | 75 - src/ge/common/CMakeLists.txt | 104 - src/ge/executor/CMakeLists.txt | 127 - src/ge/ge_local_engine/CMakeLists.txt | 53 - src/ge/ge_runtime/CMakeLists.txt | 52 - src/ge/ge_runtime/proto/task.pb.h | 27 - src/ge/graph/build/memory/CMakeLists.txt | 51 - src/ge/host_cpu_engine/proto/task.proto | 1 - src/ge/plugin/engine/CMakeLists.txt | 45 - src/proto/onnx.proto | 569 ---- .../securec/0001-add-securec-cmake-script.patch | 105 + 961 files changed, 11598 insertions(+), 35256 deletions(-) create mode 100644 .gitmodules create mode 100644 cmake/FindModule.cmake delete mode 100644 cmake/external_libs/eigen.cmake create mode 100755 cmake/external_libs/gflags.cmake delete mode 100644 cmake/external_libs/gtest.cmake delete mode 100644 cmake/external_libs/protobuf.cmake create mode 100755 cmake/external_libs/protobuf_static.cmake delete mode 100644 cmake/ge_utils.cmake create mode 100644 cmake/intf_pub_android.cmake create mode 100644 cmake/intf_pub_linux.cmake create mode 100644 cmake/intf_pub_windows.cmake create mode 100755 ge/CMakeLists.txt create mode 100644 ge/README.md rename {src/ge => ge}/analyzer/analyzer.cc (100%) rename {src/ge => ge}/analyzer/analyzer.h (100%) rename {src/ge => ge}/client/ge_api.cc (100%) rename {src/ge => ge}/client/ge_prof.cc (100%) rename {src/ge => ge}/client/module.mk (100%) rename {src => ge/client}/proto/ge_api.proto (100%) rename {src => ge/client}/proto/ge_ir.proto (100%) rename {src => ge/client}/proto/insert_op.proto (100%) rename {src => ge/client}/proto/om.proto (100%) rename {src => ge/client}/proto/task.proto (100%) create mode 100644 ge/common/CMakeLists.txt rename {src/ge => ge}/common/auth/file_saver.cc (100%) rename {src/ge => ge}/common/auth/file_saver.h (100%) rename {src/ge => ge}/common/base64.h (100%) rename {src/ge => ge}/common/context/ctx.cc (100%) rename {src/ge => ge}/common/convert/pb2json.cc (100%) rename {src/ge => ge}/common/convert/pb2json.h (100%) rename {src/ge => ge}/common/cust_aicpu_kernel_store.cc (100%) rename {src/ge => ge}/common/cust_aicpu_kernel_store.h (100%) rename {src/ge => ge}/common/debug/memory_dumper.cc (100%) rename {src/ge => ge}/common/debug/memory_dumper.h (100%) rename {src/ge => ge}/common/dump/dump_manager.cc (100%) rename {src/ge => ge}/common/dump/dump_manager.h (100%) rename {src/ge => ge}/common/dump/dump_op.cc (100%) rename {src/ge => ge}/common/dump/dump_op.h (100%) rename {src/ge => ge}/common/dump/dump_properties.cc (100%) rename {src/ge => ge}/common/dump/dump_properties.h (100%) rename {src/ge => ge}/common/dump/dump_server.cc (100%) rename {src/ge => ge}/common/fmk_error_codes.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/datatype_transfer.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/datatype_transfer.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_fractal_nz.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_fractal_nz.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_fractal_z.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_fractal_z.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_fractal_zz.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_fractal_zz.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_fracz_hwcn.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_fracz_hwcn.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_fracz_nchw.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_fracz_nchw.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_fracz_nhwc.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_fracz_nhwc.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_nchw_fz_c04.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_transpose.cc (100%) rename {src/ge => ge}/common/formats/format_transfers/format_transfer_transpose.h (100%) rename {src/ge => ge}/common/formats/formats.cc (100%) rename {src/ge => ge}/common/formats/formats.h (100%) rename {src/ge => ge}/common/formats/utils/formats_definitions.h (100%) rename {src/ge => ge}/common/formats/utils/formats_trans_utils.cc (100%) rename {src/ge => ge}/common/formats/utils/formats_trans_utils.h (100%) rename {src/ge => ge}/common/fp16_t.cc (100%) rename {src/ge => ge}/common/fp16_t.h (100%) rename {src/ge => ge}/common/ge/datatype_util.cc (100%) rename {src/ge => ge}/common/ge/datatype_util.h (100%) rename {src/ge => ge}/common/ge/ge_util.h (100%) rename {src/ge => ge}/common/ge/op_tiling_manager.cc (100%) rename {src/ge => ge}/common/ge/op_tiling_manager.h (100%) rename {src/ge => ge}/common/ge/plugin_manager.cc (100%) rename {src/ge => ge}/common/ge/plugin_manager.h (100%) rename {src/ge => ge}/common/ge/tbe_plugin_manager.cc (100%) rename {src/ge => ge}/common/ge/tbe_plugin_manager.h (100%) rename {src/ge => ge}/common/ge_common.mk (100%) rename {src/ge => ge}/common/ge_format_util.cc (100%) rename {src/ge => ge}/common/helper/model_cache_helper.cc (100%) rename {src/ge => ge}/common/helper/model_cache_helper.h (100%) rename {src/ge => ge}/common/helper/model_helper.cc (100%) rename {src/ge => ge}/common/helper/om_file_helper.cc (100%) rename {src/ge => ge}/common/kernel_store.cc (100%) rename {src/ge => ge}/common/kernel_store.h (100%) rename {src/ge => ge}/common/math/fp16_math.cc (100%) rename {src/ge => ge}/common/math/fp16_math.h (100%) rename {src/ge => ge}/common/math/math_util.h (100%) rename {src/ge => ge}/common/math_util.h (100%) rename {src/ge => ge}/common/model_parser/base.cc (100%) rename {src/ge => ge}/common/model_parser/base.h (100%) rename {src/ge => ge}/common/model_saver.cc (100%) rename {src/ge => ge}/common/model_saver.h (94%) rename {src/ge => ge}/common/module.mk (100%) rename {src/ge => ge}/common/op/attr_value_util.cc (100%) rename {src/ge => ge}/common/op/ge_op_utils.cc (100%) rename {src/ge => ge}/common/profiling/profiling_manager.cc (100%) rename {src/ge => ge}/common/profiling/profiling_manager.h (100%) rename {src/ge => ge}/common/properties_manager.cc (100%) rename {src/ge => ge}/common/properties_manager.h (100%) create mode 100644 ge/common/proto/ge_ir.proto create mode 100644 ge/common/proto/insert_op.proto create mode 100644 ge/common/proto/om.proto rename {src => ge/common}/proto/op_mapping_info.proto (100%) create mode 100644 ge/common/proto/task.proto create mode 100644 ge/common/proto/tensorflow/attr_value.proto create mode 100644 ge/common/proto/tensorflow/function.proto create mode 100644 ge/common/proto/tensorflow/graph.proto create mode 100644 ge/common/proto/tensorflow/graph_library.proto create mode 100644 ge/common/proto/tensorflow/node_def.proto create mode 100644 ge/common/proto/tensorflow/op_def.proto create mode 100644 ge/common/proto/tensorflow/resource_handle.proto create mode 100644 ge/common/proto/tensorflow/tensor.proto create mode 100644 ge/common/proto/tensorflow/tensor_shape.proto create mode 100644 ge/common/proto/tensorflow/types.proto create mode 100644 ge/common/proto/tensorflow/versions.proto rename {src/ge => ge}/common/singleton.h (100%) rename {src/ge => ge}/common/tbe_kernel_store.cc (100%) rename {src/ge => ge}/common/tbe_kernel_store.h (100%) rename {src/ge => ge}/common/thread_pool.cc (100%) rename {src/ge => ge}/common/thread_pool.h (100%) rename {src/ge => ge}/common/types.cc (100%) rename {src/ge => ge}/common/util.cc (100%) rename {src/ge => ge}/engine_manager/dnnengine_manager.cc (100%) rename {src/ge => ge}/engine_manager/dnnengine_manager.h (100%) rename {src/ge => ge}/engine_manager/engine_conf.json (95%) mode change 100755 => 100644 create mode 100755 ge/executor/CMakeLists.txt rename {src/ge => ge}/executor/ge_executor.cc (100%) rename {src/ge => ge}/executor/module.mk (100%) rename {src => ge/executor}/proto/dump_task.proto (100%) create mode 100644 ge/executor/proto/ge_ir.proto create mode 100644 ge/executor/proto/insert_op.proto create mode 100644 ge/executor/proto/om.proto create mode 100644 ge/executor/proto/op_mapping_info.proto create mode 100644 ge/executor/proto/task.proto rename {src/ge => ge}/ge_inference.mk (100%) create mode 100755 ge/ge_local_engine/CMakeLists.txt rename {src/ge => ge}/ge_local_engine/common/constant/constant.h (100%) rename {src/ge => ge}/ge_local_engine/engine/ge_local_engine.cc (100%) rename {src/ge => ge}/ge_local_engine/engine/ge_local_engine.h (100%) rename {src/ge => ge}/ge_local_engine/engine/host_cpu_engine.cc (100%) rename {src/ge => ge}/ge_local_engine/engine/host_cpu_engine.h (100%) rename {src/ge => ge}/ge_local_engine/module.mk (100%) rename {src/ge => ge}/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc (100%) rename {src/ge => ge}/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h (100%) rename {src/ge => ge}/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc (100%) rename {src/ge => ge}/ge_local_engine/ops_kernel_store/op/ge_deleted_op.h (100%) rename {src/ge => ge}/ge_local_engine/ops_kernel_store/op/no_op.cc (100%) rename {src/ge => ge}/ge_local_engine/ops_kernel_store/op/no_op.h (100%) rename {src/ge => ge}/ge_local_engine/ops_kernel_store/op/op.cc (100%) rename {src/ge => ge}/ge_local_engine/ops_kernel_store/op/op.h (100%) rename {src/ge => ge}/ge_local_engine/ops_kernel_store/op/op_factory.cc (100%) rename {src/ge => ge}/ge_local_engine/ops_kernel_store/op/op_factory.h (100%) create mode 100644 ge/ge_local_engine/proto/task.proto rename {src/ge => ge}/ge_runner.mk (100%) create mode 100644 ge/ge_runtime/CMakeLists.txt rename {src/ge => ge}/ge_runtime/model_context.h (100%) rename {src/ge => ge}/ge_runtime/model_runner.cc (100%) create mode 100644 ge/ge_runtime/module.mk rename {src/ge => ge}/ge_runtime/output.cc (100%) rename {src/ge => ge}/ge_runtime/output.h (100%) rename {src/ge => ge}/ge_runtime/runtime_model.cc (100%) rename {src/ge => ge}/ge_runtime/runtime_model.h (100%) rename {src/ge => ge}/ge_runtime/task/aicpu_task.cc (100%) rename {src/ge => ge}/ge_runtime/task/aicpu_task.h (100%) rename {src/ge => ge}/ge_runtime/task/cce_task.cc (100%) rename {src/ge => ge}/ge_runtime/task/cce_task.h (100%) rename {src/ge => ge}/ge_runtime/task/event_record_task.cc (100%) rename {src/ge => ge}/ge_runtime/task/event_record_task.h (100%) rename {src/ge => ge}/ge_runtime/task/event_wait_task.cc (100%) rename {src/ge => ge}/ge_runtime/task/event_wait_task.h (100%) rename {src/ge => ge}/ge_runtime/task/hccl_task.cc (100%) rename {src/ge => ge}/ge_runtime/task/hccl_task.h (100%) rename {src/ge => ge}/ge_runtime/task/label_goto_task.cc (100%) rename {src/ge => ge}/ge_runtime/task/label_goto_task.h (100%) rename {src/ge => ge}/ge_runtime/task/label_set_task.cc (100%) rename {src/ge => ge}/ge_runtime/task/label_set_task.h (100%) rename {src/ge => ge}/ge_runtime/task/label_switch_task.cc (100%) rename {src/ge => ge}/ge_runtime/task/label_switch_task.h (100%) rename {src/ge => ge}/ge_runtime/task/memcpy_async_task.cc (100%) rename {src/ge => ge}/ge_runtime/task/memcpy_async_task.h (100%) rename {src/ge => ge}/ge_runtime/task/profiler_task.cc (100%) rename {src/ge => ge}/ge_runtime/task/profiler_task.h (100%) rename {src/ge => ge}/ge_runtime/task/stream_active_task.cc (100%) rename {src/ge => ge}/ge_runtime/task/stream_active_task.h (100%) rename {src/ge => ge}/ge_runtime/task/stream_switch_task.cc (100%) rename {src/ge => ge}/ge_runtime/task/stream_switch_task.h (100%) rename {src/ge => ge}/ge_runtime/task/task.h (100%) rename {src/ge => ge}/ge_runtime/task/task_factory.h (100%) rename {src/ge => ge}/ge_runtime/task/tbe_task.cc (100%) rename {src/ge => ge}/ge_runtime/task/tbe_task.h (100%) rename {src/ge => ge}/generator/ge_generator.cc (100%) rename {src/ge => ge}/generator/generator_api.cc (100%) rename {src/ge => ge}/graph/build/graph_builder.cc (100%) rename {src/ge => ge}/graph/build/graph_builder.h (100%) rename {src/ge => ge}/graph/build/label_allocator.cc (100%) rename {src/ge => ge}/graph/build/label_allocator.h (100%) rename {src/ge => ge}/graph/build/logical_stream_allocator.cc (100%) rename {src/ge => ge}/graph/build/logical_stream_allocator.h (100%) create mode 100644 ge/graph/build/memory/CMakeLists.txt rename {src/ge => ge}/graph/build/memory/binary_block_mem_assigner.cc (100%) rename {src/ge => ge}/graph/build/memory/binary_block_mem_assigner.h (100%) rename {src/ge => ge}/graph/build/memory/block_mem_assigner.cc (100%) rename {src/ge => ge}/graph/build/memory/block_mem_assigner.h (100%) rename {src/ge => ge}/graph/build/memory/graph_mem_assigner.cc (100%) rename {src/ge => ge}/graph/build/memory/graph_mem_assigner.h (100%) rename {src/ge => ge}/graph/build/memory/hybrid_mem_assigner.cc (100%) rename {src/ge => ge}/graph/build/memory/hybrid_mem_assigner.h (100%) rename {src/ge => ge}/graph/build/memory/max_block_mem_assigner.cc (100%) rename {src/ge => ge}/graph/build/memory/max_block_mem_assigner.h (100%) rename {src/ge => ge}/graph/build/memory/mem_assigner.h (100%) rename {src/ge => ge}/graph/build/memory/memory_assigner.cc (100%) rename {src/ge => ge}/graph/build/memory/module.mk (100%) rename {src/ge => ge}/graph/build/memory/var_mem_assign_util.cc (100%) rename {src/ge => ge}/graph/build/memory/var_mem_assign_util.h (100%) rename {src/ge => ge}/graph/build/model_builder.cc (100%) rename {src/ge => ge}/graph/build/model_builder.h (100%) rename {src/ge => ge}/graph/build/run_context.cc (100%) rename {src/ge => ge}/graph/build/run_context.h (100%) rename {src/ge => ge}/graph/build/stream_allocator.cc (100%) rename {src/ge => ge}/graph/build/stream_allocator.h (100%) rename {src/ge => ge}/graph/build/stream_graph_optimizer.cc (100%) rename {src/ge => ge}/graph/build/stream_graph_optimizer.h (100%) rename {src/ge => ge}/graph/build/task_generator.cc (100%) rename {src/ge => ge}/graph/build/task_generator.h (100%) rename {src/ge => ge}/graph/common/bcast.cc (100%) rename {src/ge => ge}/graph/common/bcast.h (100%) rename {src/ge => ge}/graph/common/ge_call_wrapper.h (100%) rename {src/ge => ge}/graph/common/local_context.cc (100%) rename {src/ge => ge}/graph/common/local_context.h (100%) rename {src/ge => ge}/graph/common/omg_util.cc (100%) rename {src/ge => ge}/graph/common/omg_util.h (100%) rename {src/ge => ge}/graph/common/transop_util.cc (100%) rename {src/ge => ge}/graph/common/transop_util.h (100%) rename {src/ge => ge}/graph/execute/graph_execute.cc (100%) rename {src/ge => ge}/graph/execute/graph_execute.h (100%) rename {src/ge => ge}/graph/label/case_label_maker.cc (100%) rename {src/ge => ge}/graph/label/case_label_maker.h (100%) rename {src/ge => ge}/graph/label/if_label_maker.cc (100%) rename {src/ge => ge}/graph/label/if_label_maker.h (100%) rename {src/ge => ge}/graph/label/label_maker.cc (100%) rename {src/ge => ge}/graph/label/label_maker.h (100%) rename {src/ge => ge}/graph/label/label_maker_factory.h (100%) rename {src/ge => ge}/graph/label/partitioned_call_label_maker.cc (100%) rename {src/ge => ge}/graph/label/partitioned_call_label_maker.h (100%) rename {src/ge => ge}/graph/label/while_label_maker.cc (100%) rename {src/ge => ge}/graph/label/while_label_maker.h (100%) rename {src/ge => ge}/graph/load/graph_loader.cc (100%) rename {src/ge => ge}/graph/load/graph_loader.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/aipp_utils.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/aipp_utils.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/cpu_queue_schedule.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/cpu_queue_schedule.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/data_dumper.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/data_dumper.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/data_inputer.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/data_inputer.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/davinci_model.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/davinci_model.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/davinci_model_parser.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/davinci_model_parser.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/model_manager.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/model_manager.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/model_utils.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/model_utils.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/end_graph_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/end_graph_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/event_record_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/event_record_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/event_wait_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/event_wait_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/fusion_start_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/fusion_start_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/fusion_stop_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/hccl_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/hccl_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/kernel_ex_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/kernel_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/kernel_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/label_set_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/label_set_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/memcpy_async_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/profiler_trace_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/stream_active_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/stream_active_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/stream_switch_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/stream_switch_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/stream_switchn_task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/task_info.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/task_info.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/task_info/task_info_factory.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/tbe_handle_store.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/tbe_handle_store.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/zero_copy_offset.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/zero_copy_offset.h (100%) rename {src/ge => ge}/graph/load/new_model_manager/zero_copy_task.cc (100%) rename {src/ge => ge}/graph/load/new_model_manager/zero_copy_task.h (100%) rename {src/ge => ge}/graph/manager/block_memory.h (100%) rename {src/ge => ge}/graph/manager/graph_caching_allocator.cc (100%) rename {src/ge => ge}/graph/manager/graph_caching_allocator.h (100%) rename {src/ge => ge}/graph/manager/graph_context.cc (100%) rename {src/ge => ge}/graph/manager/graph_context.h (100%) rename {src/ge => ge}/graph/manager/graph_manager.cc (100%) rename {src/ge => ge}/graph/manager/graph_manager.h (100%) rename {src/ge => ge}/graph/manager/graph_manager_utils.cc (100%) rename {src/ge => ge}/graph/manager/graph_manager_utils.h (100%) rename {src/ge => ge}/graph/manager/graph_mem_allocator.cc (100%) rename {src/ge => ge}/graph/manager/graph_mem_allocator.h (100%) rename {src/ge => ge}/graph/manager/graph_var_manager.cc (100%) rename {src/ge => ge}/graph/manager/graph_var_manager.h (100%) rename {src/ge => ge}/graph/manager/host_mem_manager.cc (100%) rename {src/ge => ge}/graph/manager/host_mem_manager.h (100%) rename {src/ge => ge}/graph/manager/memory_api.cc (100%) rename {src/ge => ge}/graph/manager/model_manager/event_manager.cc (100%) rename {src/ge => ge}/graph/manager/model_manager/event_manager.h (100%) rename {src/ge => ge}/graph/manager/rdma_pool_allocator.cc (100%) rename {src/ge => ge}/graph/manager/rdma_pool_allocator.h (100%) rename {src/ge => ge}/graph/manager/trans_var_data_utils.cc (100%) rename {src/ge => ge}/graph/manager/trans_var_data_utils.h (100%) rename {src/ge => ge}/graph/manager/util/debug.cc (100%) rename {src/ge => ge}/graph/manager/util/debug.h (100%) rename {src/ge => ge}/graph/manager/util/hcom_util.cc (100%) rename {src/ge => ge}/graph/manager/util/hcom_util.h (100%) rename {src/ge => ge}/graph/manager/util/rt_context_util.cc (100%) rename {src/ge => ge}/graph/manager/util/rt_context_util.h (100%) rename {src/ge => ge}/graph/manager/util/variable_accelerate_ctrl.cc (100%) rename {src/ge => ge}/graph/manager/util/variable_accelerate_ctrl.h (100%) rename {src/ge => ge}/graph/optimize/common/params.h (100%) rename {src/ge => ge}/graph/optimize/graph_optimize.cc (100%) rename {src/ge => ge}/graph/optimize/graph_optimize.h (100%) rename {src/ge => ge}/graph/optimize/mem_rw_conflict_optimize.cc (100%) rename {src/ge => ge}/graph/optimize/optimizer/allreduce_fusion_pass.cc (100%) rename {src/ge => ge}/graph/optimize/optimizer/allreduce_fusion_pass.h (100%) rename {src/ge => ge}/graph/optimize/summary_optimize.cc (100%) rename {src/ge => ge}/graph/partition/dynamic_shape_partition.cc (100%) rename {src/ge => ge}/graph/partition/dynamic_shape_partition.h (100%) rename {src/ge => ge}/graph/partition/engine_place.cc (100%) rename {src/ge => ge}/graph/partition/engine_place.h (100%) rename {src/ge => ge}/graph/partition/graph_partition.cc (100%) rename {src/ge => ge}/graph/partition/graph_partition.h (100%) rename {src/ge => ge}/graph/passes/addn_pass.cc (100%) rename {src/ge => ge}/graph/passes/addn_pass.h (100%) rename {src/ge => ge}/graph/passes/aicpu_constant_folding_pass.cc (100%) rename {src/ge => ge}/graph/passes/aicpu_constant_folding_pass.h (100%) rename {src/ge => ge}/graph/passes/assert_pass.cc (100%) rename {src/ge => ge}/graph/passes/assert_pass.h (100%) rename {src/ge => ge}/graph/passes/assign_pass.cc (100%) rename {src/ge => ge}/graph/passes/assign_pass.h (100%) rename {src/ge => ge}/graph/passes/atomic_addr_clean_pass.cc (100%) rename {src/ge => ge}/graph/passes/atomic_addr_clean_pass.h (100%) rename {src/ge => ge}/graph/passes/attach_stream_label_pass.cc (100%) rename {src/ge => ge}/graph/passes/attach_stream_label_pass.h (100%) rename {src/ge => ge}/graph/passes/base_pass.cc (100%) rename {src/ge => ge}/graph/passes/base_pass.h (100%) rename {src/ge => ge}/graph/passes/bitcast_pass.cc (100%) rename {src/ge => ge}/graph/passes/bitcast_pass.h (100%) rename {src/ge => ge}/graph/passes/cast_remove_pass.cc (100%) rename {src/ge => ge}/graph/passes/cast_remove_pass.h (100%) rename {src/ge => ge}/graph/passes/cast_translate_pass.cc (100%) rename {src/ge => ge}/graph/passes/cast_translate_pass.h (100%) rename {src/ge => ge}/graph/passes/common_subexpression_elimination_pass.cc (100%) rename {src/ge => ge}/graph/passes/common_subexpression_elimination_pass.h (100%) rename {src/ge => ge}/graph/passes/compile_nodes_pass.cc (100%) rename {src/ge => ge}/graph/passes/compile_nodes_pass.h (100%) rename {src/ge => ge}/graph/passes/cond_pass.cc (100%) rename {src/ge => ge}/graph/passes/cond_pass.h (100%) rename {src/ge => ge}/graph/passes/cond_remove_pass.cc (100%) rename {src/ge => ge}/graph/passes/cond_remove_pass.h (100%) rename {src/ge => ge}/graph/passes/constant_folding_pass.cc (100%) rename {src/ge => ge}/graph/passes/constant_folding_pass.h (100%) rename {src/ge => ge}/graph/passes/constant_fuse_same_pass.cc (100%) rename {src/ge => ge}/graph/passes/constant_fuse_same_pass.h (100%) rename {src/ge => ge}/graph/passes/control_trigger_pass.cc (100%) rename {src/ge => ge}/graph/passes/control_trigger_pass.h (100%) rename {src/ge => ge}/graph/passes/ctrl_edge_transfer_pass.cc (100%) rename {src/ge => ge}/graph/passes/ctrl_edge_transfer_pass.h (100%) rename {src/ge => ge}/graph/passes/data_pass.cc (100%) rename {src/ge => ge}/graph/passes/data_pass.h (100%) rename {src/ge => ge}/graph/passes/dimension_adjust_pass.cc (100%) rename {src/ge => ge}/graph/passes/dimension_adjust_pass.h (100%) rename {src/ge => ge}/graph/passes/dimension_compute_pass.cc (100%) rename {src/ge => ge}/graph/passes/dimension_compute_pass.h (100%) rename {src/ge => ge}/graph/passes/dropout_pass.cc (100%) rename {src/ge => ge}/graph/passes/dropout_pass.h (100%) rename {src/ge => ge}/graph/passes/end_of_sequence_add_control_pass.cc (100%) rename {src/ge => ge}/graph/passes/end_of_sequence_add_control_pass.h (100%) rename {src/ge => ge}/graph/passes/enter_pass.cc (100%) rename {src/ge => ge}/graph/passes/enter_pass.h (100%) rename {src/ge => ge}/graph/passes/flow_ctrl_pass.cc (100%) rename {src/ge => ge}/graph/passes/flow_ctrl_pass.h (100%) rename {src/ge => ge}/graph/passes/folding_pass.cc (100%) rename {src/ge => ge}/graph/passes/folding_pass.h (100%) rename {src/ge => ge}/graph/passes/for_pass.cc (100%) rename {src/ge => ge}/graph/passes/for_pass.h (100%) rename {src/ge => ge}/graph/passes/get_original_format_pass.cc (100%) rename {src/ge => ge}/graph/passes/get_original_format_pass.h (100%) rename {src/ge => ge}/graph/passes/global_step_insert_pass.cc (100%) rename {src/ge => ge}/graph/passes/global_step_insert_pass.h (100%) rename {src/ge => ge}/graph/passes/guarantee_const_pass.cc (100%) rename {src/ge => ge}/graph/passes/guarantee_const_pass.h (100%) rename {src/ge => ge}/graph/passes/hccl_group_pass.cc (100%) rename {src/ge => ge}/graph/passes/hccl_group_pass.h (100%) rename {src/ge => ge}/graph/passes/hccl_memcpy_pass.cc (100%) rename {src/ge => ge}/graph/passes/hccl_memcpy_pass.h (100%) rename {src/ge => ge}/graph/passes/identity_pass.cc (100%) rename {src/ge => ge}/graph/passes/identity_pass.h (99%) rename {src/ge => ge}/graph/passes/infershape_pass.cc (100%) rename {src/ge => ge}/graph/passes/infershape_pass.h (100%) rename {src/ge => ge}/graph/passes/input_output_connection_identify_pass.cc (100%) rename {src/ge => ge}/graph/passes/input_output_connection_identify_pass.h (100%) rename {src/ge => ge}/graph/passes/isolated_op_remove_pass.cc (100%) rename {src/ge => ge}/graph/passes/isolated_op_remove_pass.h (100%) rename {src/ge => ge}/graph/passes/iterator_op_pass.cc (100%) rename {src/ge => ge}/graph/passes/iterator_op_pass.h (100%) rename {src/ge => ge}/graph/passes/link_gen_mask_nodes_pass.cc (100%) rename {src/ge => ge}/graph/passes/link_gen_mask_nodes_pass.h (100%) rename {src/ge => ge}/graph/passes/mark_agnostic_pass.cc (100%) rename {src/ge => ge}/graph/passes/mark_agnostic_pass.h (100%) rename {src/ge => ge}/graph/passes/mark_graph_unknown_status_pass.cc (100%) rename {src/ge => ge}/graph/passes/mark_graph_unknown_status_pass.h (100%) rename {src/ge => ge}/graph/passes/mark_same_addr_pass.cc (100%) rename {src/ge => ge}/graph/passes/mark_same_addr_pass.h (100%) rename {src/ge => ge}/graph/passes/memcpy_addr_async_pass.cc (100%) rename {src/ge => ge}/graph/passes/memcpy_addr_async_pass.h (100%) rename {src/ge => ge}/graph/passes/merge_pass.cc (100%) rename {src/ge => ge}/graph/passes/merge_pass.h (100%) rename {src/ge => ge}/graph/passes/merge_to_stream_merge_pass.cc (100%) rename {src/ge => ge}/graph/passes/merge_to_stream_merge_pass.h (100%) rename {src/ge => ge}/graph/passes/multi_batch_clone_pass.cc (100%) rename {src/ge => ge}/graph/passes/multi_batch_clone_pass.h (100%) rename {src/ge => ge}/graph/passes/multi_batch_pass.cc (100%) rename {src/ge => ge}/graph/passes/multi_batch_pass.h (100%) rename {src/ge => ge}/graph/passes/net_output_pass.cc (100%) rename {src/ge => ge}/graph/passes/net_output_pass.h (100%) rename {src/ge => ge}/graph/passes/next_iteration_pass.cc (100%) rename {src/ge => ge}/graph/passes/next_iteration_pass.h (100%) rename {src/ge => ge}/graph/passes/no_use_reshape_remove_pass.cc (100%) rename {src/ge => ge}/graph/passes/no_use_reshape_remove_pass.h (100%) rename {src/ge => ge}/graph/passes/parallel_concat_start_op_pass.cc (100%) rename {src/ge => ge}/graph/passes/parallel_concat_start_op_pass.h (100%) rename {src/ge => ge}/graph/passes/pass_manager.cc (100%) rename {src/ge => ge}/graph/passes/pass_utils.cc (100%) rename {src/ge => ge}/graph/passes/pass_utils.h (100%) rename {src/ge => ge}/graph/passes/permute_pass.cc (100%) rename {src/ge => ge}/graph/passes/permute_pass.h (100%) rename {src/ge => ge}/graph/passes/placeholder_with_default_pass.cc (100%) rename {src/ge => ge}/graph/passes/placeholder_with_default_pass.h (100%) rename {src/ge => ge}/graph/passes/prevent_gradient_pass.cc (100%) rename {src/ge => ge}/graph/passes/prevent_gradient_pass.h (100%) rename {src/ge => ge}/graph/passes/print_op_pass.cc (100%) rename {src/ge => ge}/graph/passes/print_op_pass.h (100%) rename {src/ge => ge}/graph/passes/prune_pass.cc (100%) rename {src/ge => ge}/graph/passes/prune_pass.h (100%) rename {src/ge => ge}/graph/passes/ref_identity_delete_op_pass.cc (100%) rename {src/ge => ge}/graph/passes/ref_identity_delete_op_pass.h (100%) rename {src/ge => ge}/graph/passes/remove_nodes_pass.cc (100%) rename {src/ge => ge}/graph/passes/remove_nodes_pass.h (100%) rename {src/ge => ge}/graph/passes/replace_transshape_pass.cc (100%) rename {src/ge => ge}/graph/passes/replace_transshape_pass.h (100%) rename {src/ge => ge}/graph/passes/replace_with_empty_const_pass.cc (100%) rename {src/ge => ge}/graph/passes/replace_with_empty_const_pass.h (100%) rename {src/ge => ge}/graph/passes/reshape_recovery_pass.cc (100%) rename {src/ge => ge}/graph/passes/reshape_recovery_pass.h (100%) rename {src/ge => ge}/graph/passes/reshape_remove_pass.cc (100%) rename {src/ge => ge}/graph/passes/reshape_remove_pass.h (100%) rename {src/ge => ge}/graph/passes/resource_pair_add_control_pass.cc (100%) rename {src/ge => ge}/graph/passes/resource_pair_add_control_pass.h (100%) rename {src/ge => ge}/graph/passes/resource_pair_remove_control_pass.cc (100%) rename {src/ge => ge}/graph/passes/resource_pair_remove_control_pass.h (100%) rename {src/ge => ge}/graph/passes/same_transdata_breadth_fusion_pass.cc (100%) rename {src/ge => ge}/graph/passes/same_transdata_breadth_fusion_pass.h (100%) rename {src/ge => ge}/graph/passes/save_pass.cc (100%) rename {src/ge => ge}/graph/passes/save_pass.h (100%) rename {src/ge => ge}/graph/passes/set_input_output_offset_pass.cc (100%) rename {src/ge => ge}/graph/passes/set_input_output_offset_pass.h (100%) rename {src/ge => ge}/graph/passes/shape_operate_op_remove_pass.cc (100%) rename {src/ge => ge}/graph/passes/shape_operate_op_remove_pass.h (100%) rename {src/ge => ge}/graph/passes/snapshot_pass.cc (100%) rename {src/ge => ge}/graph/passes/snapshot_pass.h (100%) rename {src/ge => ge}/graph/passes/stop_gradient_pass.cc (100%) rename {src/ge => ge}/graph/passes/stop_gradient_pass.h (100%) rename {src/ge => ge}/graph/passes/subexpression_migration_pass.cc (100%) rename {src/ge => ge}/graph/passes/subexpression_migration_pass.h (100%) rename {src/ge => ge}/graph/passes/subgraph_pass.cc (100%) rename {src/ge => ge}/graph/passes/subgraph_pass.h (100%) rename {src/ge => ge}/graph/passes/switch_data_edges_bypass.cc (100%) rename {src/ge => ge}/graph/passes/switch_data_edges_bypass.h (100%) rename {src/ge => ge}/graph/passes/switch_dead_branch_elimination.cc (100%) rename {src/ge => ge}/graph/passes/switch_dead_branch_elimination.h (100%) rename {src/ge => ge}/graph/passes/switch_logic_remove_pass.cc (100%) rename {src/ge => ge}/graph/passes/switch_logic_remove_pass.h (100%) rename {src/ge => ge}/graph/passes/switch_to_stream_switch_pass.cc (100%) rename {src/ge => ge}/graph/passes/switch_to_stream_switch_pass.h (100%) rename {src/ge => ge}/graph/passes/transop_breadth_fusion_pass.cc (100%) rename {src/ge => ge}/graph/passes/transop_breadth_fusion_pass.h (100%) rename {src/ge => ge}/graph/passes/transop_depth_fusion_pass.cc (100%) rename {src/ge => ge}/graph/passes/transop_depth_fusion_pass.h (100%) rename {src/ge => ge}/graph/passes/transop_nearby_allreduce_fusion_pass.cc (100%) rename {src/ge => ge}/graph/passes/transop_nearby_allreduce_fusion_pass.h (100%) rename {src/ge => ge}/graph/passes/transop_symmetry_elimination_pass.cc (100%) rename {src/ge => ge}/graph/passes/transop_symmetry_elimination_pass.h (100%) rename {src/ge => ge}/graph/passes/transop_without_reshape_fusion_pass.cc (100%) rename {src/ge => ge}/graph/passes/transop_without_reshape_fusion_pass.h (100%) rename {src/ge => ge}/graph/passes/transpose_transdata_pass.cc (100%) rename {src/ge => ge}/graph/passes/transpose_transdata_pass.h (100%) rename {src/ge => ge}/graph/passes/unused_args_clean_pass.cc (100%) rename {src/ge => ge}/graph/passes/unused_args_clean_pass.h (100%) rename {src/ge => ge}/graph/passes/unused_const_pass.cc (100%) rename {src/ge => ge}/graph/passes/unused_const_pass.h (100%) rename {src/ge => ge}/graph/passes/unused_op_remove_pass.cc (100%) rename {src/ge => ge}/graph/passes/unused_op_remove_pass.h (100%) rename {src/ge => ge}/graph/passes/var_is_initialized_op_pass.cc (100%) rename {src/ge => ge}/graph/passes/var_is_initialized_op_pass.h (100%) rename {src/ge => ge}/graph/passes/variable_format_pass.cc (100%) rename {src/ge => ge}/graph/passes/variable_format_pass.h (100%) rename {src/ge => ge}/graph/passes/variable_op_pass.cc (100%) rename {src/ge => ge}/graph/passes/variable_op_pass.h (100%) rename {src/ge => ge}/graph/passes/variable_prepare_op_pass.cc (100%) rename {src/ge => ge}/graph/passes/variable_prepare_op_pass.h (100%) rename {src/ge => ge}/graph/passes/variable_ref_delete_op_pass.cc (100%) rename {src/ge => ge}/graph/passes/variable_ref_delete_op_pass.h (100%) rename {src/ge => ge}/graph/passes/variable_ref_useless_control_out_delete_pass.cc (100%) rename {src/ge => ge}/graph/passes/variable_ref_useless_control_out_delete_pass.h (100%) rename {src/ge => ge}/graph/preprocess/graph_preprocess.cc (100%) rename {src/ge => ge}/graph/preprocess/graph_preprocess.h (100%) rename {src/ge => ge}/graph/preprocess/insert_op/base_insert_op.h (100%) rename {src/ge => ge}/graph/preprocess/insert_op/ge_aipp_op.cc (100%) rename {src/ge => ge}/graph/preprocess/insert_op/ge_aipp_op.h (100%) rename {src/ge => ge}/graph/preprocess/insert_op/util_insert_aipp_op.cc (100%) rename {src/ge => ge}/graph/preprocess/insert_op/util_insert_aipp_op.h (100%) rename {src/ge => ge}/graph/preprocess/multi_batch_copy_graph.cc (100%) rename {src/ge => ge}/graph/preprocess/multi_batch_copy_graph.h (100%) rename {src/ge => ge}/graph/preprocess/multi_batch_options.cc (100%) rename {src/ge => ge}/graph/preprocess/multi_batch_options.h (100%) create mode 100644 ge/host_cpu_engine/CMakeLists.txt rename {src/ge => ge}/host_cpu_engine/common/constant/constant.h (100%) rename {src/ge => ge}/host_cpu_engine/engine/host_cpu_engine.cc (100%) rename {src/ge => ge}/host_cpu_engine/engine/host_cpu_engine.h (100%) rename {src/ge => ge}/host_cpu_engine/module.mk (100%) rename {src/ge => ge}/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc (100%) rename {src/ge => ge}/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h (100%) rename {src/ge => ge}/host_cpu_engine/ops_kernel_store/op/host_op.cc (100%) rename {src/ge => ge}/host_cpu_engine/ops_kernel_store/op/host_op.h (100%) rename {src/ge => ge}/host_cpu_engine/ops_kernel_store/op/op.h (100%) rename {src/ge => ge}/host_cpu_engine/ops_kernel_store/op/op_factory.cc (100%) rename {src/ge => ge}/host_cpu_engine/ops_kernel_store/op/op_factory.h (100%) create mode 100644 ge/host_cpu_engine/proto/task.proto rename {src/ge => ge}/host_kernels/add_kernel.cc (100%) rename {src/ge => ge}/host_kernels/add_kernel.h (100%) rename {src/ge => ge}/host_kernels/broadcast_args_kernel.cc (100%) rename {src/ge => ge}/host_kernels/broadcast_args_kernel.h (100%) rename {src/ge => ge}/host_kernels/broadcast_gradient_args_kernel.cc (100%) rename {src/ge => ge}/host_kernels/broadcast_gradient_args_kernel.h (100%) rename {src/ge => ge}/host_kernels/cast_kernel.cc (100%) rename {src/ge => ge}/host_kernels/cast_kernel.h (100%) rename {src/ge => ge}/host_kernels/concat_offset_kernel.cc (100%) rename {src/ge => ge}/host_kernels/concat_offset_kernel.h (100%) rename {src/ge => ge}/host_kernels/concat_v2_kernel.cc (100%) rename {src/ge => ge}/host_kernels/concat_v2_kernel.h (100%) rename {src/ge => ge}/host_kernels/dynamic_stitch_kernel.cc (100%) rename {src/ge => ge}/host_kernels/dynamic_stitch_kernel.h (100%) rename {src/ge => ge}/host_kernels/empty_kernel.cc (100%) rename {src/ge => ge}/host_kernels/empty_kernel.h (100%) rename {src/ge => ge}/host_kernels/expanddims_kernel.cc (100%) rename {src/ge => ge}/host_kernels/expanddims_kernel.h (100%) rename {src/ge => ge}/host_kernels/fill_kernel.cc (100%) rename {src/ge => ge}/host_kernels/fill_kernel.h (100%) rename {src/ge => ge}/host_kernels/floordiv_kernel.cc (100%) rename {src/ge => ge}/host_kernels/floordiv_kernel.h (100%) rename {src/ge => ge}/host_kernels/floormod_kernel.cc (100%) rename {src/ge => ge}/host_kernels/floormod_kernel.h (100%) rename {src/ge => ge}/host_kernels/gather_v2_kernel.cc (100%) rename {src/ge => ge}/host_kernels/gather_v2_kernel.h (100%) rename {src/ge => ge}/host_kernels/greater_kernel.cc (100%) rename {src/ge => ge}/host_kernels/greater_kernel.h (100%) rename {src/ge => ge}/host_kernels/identity_kernel.cc (100%) rename {src/ge => ge}/host_kernels/identity_kernel.h (100%) rename {src/ge => ge}/host_kernels/kernel_utils.cc (100%) rename {src/ge => ge}/host_kernels/kernel_utils.h (100%) rename {src/ge => ge}/host_kernels/maximum_kernel.cc (100%) rename {src/ge => ge}/host_kernels/maximum_kernel.h (100%) rename {src/ge => ge}/host_kernels/mul_kernel.cc (100%) rename {src/ge => ge}/host_kernels/mul_kernel.h (100%) rename {src/ge => ge}/host_kernels/pack_kernel.cc (100%) rename {src/ge => ge}/host_kernels/pack_kernel.h (100%) rename {src/ge => ge}/host_kernels/permute_kernel.cc (100%) rename {src/ge => ge}/host_kernels/permute_kernel.h (100%) rename {src/ge => ge}/host_kernels/range_kernel.cc (100%) rename {src/ge => ge}/host_kernels/range_kernel.h (100%) rename {src/ge => ge}/host_kernels/rank_kernel.cc (100%) rename {src/ge => ge}/host_kernels/rank_kernel.h (100%) rename {src/ge => ge}/host_kernels/reduce_prod_kernel.cc (100%) rename {src/ge => ge}/host_kernels/reduce_prod_kernel.h (100%) rename {src/ge => ge}/host_kernels/reformat_kernel.cc (100%) rename {src/ge => ge}/host_kernels/reformat_kernel.h (100%) rename {src/ge => ge}/host_kernels/reshape_kernel.cc (100%) rename {src/ge => ge}/host_kernels/reshape_kernel.h (100%) rename {src/ge => ge}/host_kernels/rsqrt_kernel.cc (100%) rename {src/ge => ge}/host_kernels/rsqrt_kernel.h (100%) rename {src/ge => ge}/host_kernels/shape_kernel.cc (100%) rename {src/ge => ge}/host_kernels/shape_kernel.h (100%) rename {src/ge => ge}/host_kernels/shape_n_kernel.cc (100%) rename {src/ge => ge}/host_kernels/shape_n_kernel.h (100%) rename {src/ge => ge}/host_kernels/size_kernel.cc (100%) rename {src/ge => ge}/host_kernels/size_kernel.h (100%) rename {src/ge => ge}/host_kernels/slice_d_kernel.cc (100%) rename {src/ge => ge}/host_kernels/slice_d_kernel.h (100%) rename {src/ge => ge}/host_kernels/slice_kernel.cc (100%) rename {src/ge => ge}/host_kernels/slice_kernel.h (100%) rename {src/ge => ge}/host_kernels/squeeze_kernel.cc (100%) rename {src/ge => ge}/host_kernels/squeeze_kernel.h (100%) rename {src/ge => ge}/host_kernels/ssd_prior_box_kernel.cc (100%) rename {src/ge => ge}/host_kernels/ssd_prior_box_kernel.h (100%) rename {src/ge => ge}/host_kernels/strided_slice_kernel.cc (100%) rename {src/ge => ge}/host_kernels/strided_slice_kernel.h (100%) rename {src/ge => ge}/host_kernels/sub_kernel.cc (100%) rename {src/ge => ge}/host_kernels/sub_kernel.h (100%) rename {src/ge => ge}/host_kernels/transdata_kernel.cc (100%) rename {src/ge => ge}/host_kernels/transdata_kernel.h (100%) rename {src/ge => ge}/host_kernels/transpose_kernel.cc (100%) rename {src/ge => ge}/host_kernels/transpose_kernel.h (100%) rename {src/ge => ge}/host_kernels/unpack_kernel.cc (100%) rename {src/ge => ge}/host_kernels/unpack_kernel.h (100%) rename {src/ge => ge}/host_kernels/unsqueeze_kernel.cc (100%) rename {src/ge => ge}/host_kernels/unsqueeze_kernel.h (100%) rename {src/ge => ge}/hybrid/common/npu_memory_allocator.cc (100%) rename {src/ge => ge}/hybrid/common/npu_memory_allocator.h (100%) rename {src/ge => ge}/hybrid/common/tensor_value.cc (100%) rename {src/ge => ge}/hybrid/common/tensor_value.h (100%) rename {src/ge => ge}/hybrid/executor/hybrid_execution_context.cc (100%) rename {src/ge => ge}/hybrid/executor/hybrid_execution_context.h (100%) rename {src/ge => ge}/hybrid/executor/hybrid_model_async_executor.cc (100%) rename {src/ge => ge}/hybrid/executor/hybrid_model_async_executor.h (100%) rename {src/ge => ge}/hybrid/executor/hybrid_model_executor.cc (100%) rename {src/ge => ge}/hybrid/executor/hybrid_model_executor.h (100%) rename {src/ge => ge}/hybrid/executor/hybrid_profiler.cc (100%) rename {src/ge => ge}/hybrid/executor/hybrid_profiler.h (100%) rename {src/ge => ge}/hybrid/executor/node_done_manager.cc (100%) rename {src/ge => ge}/hybrid/executor/node_done_manager.h (100%) rename {src/ge => ge}/hybrid/executor/node_state.cc (100%) rename {src/ge => ge}/hybrid/executor/node_state.h (100%) rename {src/ge => ge}/hybrid/executor/rt_callback_manager.cc (100%) rename {src/ge => ge}/hybrid/executor/rt_callback_manager.h (100%) rename {src/ge => ge}/hybrid/executor/subgraph_context.cc (100%) rename {src/ge => ge}/hybrid/executor/subgraph_context.h (100%) rename {src/ge => ge}/hybrid/executor/subgraph_executor.cc (100%) rename {src/ge => ge}/hybrid/executor/subgraph_executor.h (100%) rename {src/ge => ge}/hybrid/executor/worker/execution_engine.cc (100%) rename {src/ge => ge}/hybrid/executor/worker/execution_engine.h (100%) rename {src/ge => ge}/hybrid/executor/worker/shape_inference_engine.cc (100%) rename {src/ge => ge}/hybrid/executor/worker/shape_inference_engine.h (100%) rename {src/ge => ge}/hybrid/executor/worker/task_compile_engine.cc (100%) rename {src/ge => ge}/hybrid/executor/worker/task_compile_engine.h (100%) rename {src/ge => ge}/hybrid/hybrid_davinci_model.cc (100%) rename {src/ge => ge}/hybrid/hybrid_davinci_model.h (100%) rename {src/ge => ge}/hybrid/hybrid_davinci_model_stub.cc (100%) rename {src/ge => ge}/hybrid/model/graph_item.cc (100%) rename {src/ge => ge}/hybrid/model/graph_item.h (100%) rename {src/ge => ge}/hybrid/model/hybrid_model.cc (100%) rename {src/ge => ge}/hybrid/model/hybrid_model.h (100%) rename {src/ge => ge}/hybrid/model/hybrid_model_builder.cc (100%) rename {src/ge => ge}/hybrid/model/hybrid_model_builder.h (100%) rename {src/ge => ge}/hybrid/model/node_item.cc (100%) rename {src/ge => ge}/hybrid/model/node_item.h (100%) rename {src/ge => ge}/hybrid/node_executor/aicore/aicore_node_executor.cc (100%) rename {src/ge => ge}/hybrid/node_executor/aicore/aicore_node_executor.h (100%) rename {src/ge => ge}/hybrid/node_executor/aicore/aicore_op_task.cc (100%) rename {src/ge => ge}/hybrid/node_executor/aicore/aicore_op_task.h (100%) rename {src/ge => ge}/hybrid/node_executor/aicore/aicore_task_builder.cc (100%) rename {src/ge => ge}/hybrid/node_executor/aicore/aicore_task_builder.h (100%) rename {src/ge => ge}/hybrid/node_executor/aicore/aicore_task_compiler.cc (100%) rename {src/ge => ge}/hybrid/node_executor/aicore/aicore_task_compiler.h (100%) rename {src/ge => ge}/hybrid/node_executor/aicpu/aicpu_ext_info.cc (100%) rename {src/ge => ge}/hybrid/node_executor/aicpu/aicpu_ext_info.h (100%) rename {src/ge => ge}/hybrid/node_executor/aicpu/aicpu_node_executor.cc (100%) rename {src/ge => ge}/hybrid/node_executor/aicpu/aicpu_node_executor.h (100%) rename {src/ge => ge}/hybrid/node_executor/compiledsubgraph/known_node_executor.cc (100%) rename {src/ge => ge}/hybrid/node_executor/compiledsubgraph/known_node_executor.h (100%) rename {src/ge => ge}/hybrid/node_executor/controlop/control_op_executor.cc (100%) rename {src/ge => ge}/hybrid/node_executor/controlop/control_op_executor.h (100%) rename {src/ge => ge}/hybrid/node_executor/ge_local/ge_local_node_executor.cc (100%) rename {src/ge => ge}/hybrid/node_executor/ge_local/ge_local_node_executor.h (100%) rename {src/ge => ge}/hybrid/node_executor/hccl/hccl_node_executor.cc (100%) rename {src/ge => ge}/hybrid/node_executor/hccl/hccl_node_executor.h (100%) rename {src/ge => ge}/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc (100%) rename {src/ge => ge}/hybrid/node_executor/host_cpu/host_cpu_node_executor.h (100%) rename {src/ge => ge}/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc (100%) rename {src/ge => ge}/hybrid/node_executor/host_cpu/kernel/assign_kernel.h (100%) rename {src/ge => ge}/hybrid/node_executor/host_cpu/kernel/kernel.h (100%) rename {src/ge => ge}/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc (100%) rename {src/ge => ge}/hybrid/node_executor/host_cpu/kernel/no_op_kernel.h (100%) rename {src/ge => ge}/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc (100%) rename {src/ge => ge}/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h (100%) rename {src/ge => ge}/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc (100%) rename {src/ge => ge}/hybrid/node_executor/host_cpu/kernel/variable_kernel.h (100%) rename {src/ge => ge}/hybrid/node_executor/host_cpu/kernel_factory.cc (100%) rename {src/ge => ge}/hybrid/node_executor/host_cpu/kernel_factory.h (100%) rename {src/ge => ge}/hybrid/node_executor/node_executor.cc (100%) rename {src/ge => ge}/hybrid/node_executor/node_executor.h (100%) rename {src/ge => ge}/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc (100%) rename {src/ge => ge}/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h (100%) rename {src/ge => ge}/hybrid/node_executor/rts/rts_node_executor.cc (100%) rename {src/ge => ge}/hybrid/node_executor/rts/rts_node_executor.h (100%) rename {src/ge => ge}/hybrid/node_executor/task_context.cc (100%) rename {src/ge => ge}/hybrid/node_executor/task_context.h (100%) rename {src/ge => ge}/inc/graph_pass.h (100%) rename {src/ge => ge}/inc/kernel.h (100%) rename {src/ge => ge}/inc/kernel_factory.h (100%) rename {src/ge => ge}/inc/pass.h (100%) rename {src/ge => ge}/inc/pass_manager.h (100%) rename {src/ge => ge}/init/gelib.cc (100%) rename {src/ge => ge}/init/gelib.h (100%) rename {src/ge => ge}/ir_build/atc_ir_common.cc (100%) rename {src/ge => ge}/ir_build/atc_ir_common.h (100%) rename {src/ge => ge}/ir_build/ge_ir_build.cc (100%) rename {src/ge => ge}/model/ge_model.cc (100%) rename {src/ge => ge}/model/ge_model.h (100%) rename {src/ge => ge}/model/ge_root_model.cc (100%) rename {src/ge => ge}/model/ge_root_model.h (100%) rename {src/ge => ge}/module.mk (100%) create mode 100644 ge/offline/CMakeLists.txt create mode 100644 ge/offline/main.cc create mode 100644 ge/offline/module.mk create mode 100644 ge/offline/proto/ge_ir.proto create mode 100644 ge/offline/proto/insert_op.proto create mode 100644 ge/offline/proto/om.proto create mode 100644 ge/offline/proto/task.proto create mode 100644 ge/offline/single_op_parser.cc create mode 100644 ge/offline/single_op_parser.h rename {src/ge => ge}/omm/csa_interact.cc (100%) rename {src/ge => ge}/omm/csa_interact.h (85%) rename {src/ge => ge}/opskernel_manager/ops_kernel_manager.cc (100%) rename {src/ge => ge}/opskernel_manager/ops_kernel_manager.h (100%) rename {src/ge => ge}/opskernel_manager/optimizer_priority.pbtxt (100%) mode change 100755 => 100644 create mode 100644 ge/plugin/engine/CMakeLists.txt rename {src/ge => ge}/plugin/engine/dnnengines.cc (100%) rename {src/ge => ge}/plugin/engine/dnnengines.h (100%) rename {src/ge => ge}/plugin/engine/engine_manage.cc (100%) rename {src/ge => ge}/plugin/engine/engine_manage.h (100%) rename {src/ge => ge}/plugin/engine/module.mk (100%) create mode 100644 ge/proto/caffe/caffe.proto create mode 100644 ge/proto/dump_task.proto rename {src => ge}/proto/fusion_model.proto (100%) rename {src => ge}/proto/fwk_adapter.proto (100%) create mode 100644 ge/proto/ge_api.proto create mode 100644 ge/proto/ge_ir.proto create mode 100644 ge/proto/insert_op.proto create mode 100644 ge/proto/om.proto create mode 100644 ge/proto/op_mapping_info.proto rename {src => ge}/proto/optimizer_priority.proto (100%) create mode 100644 ge/proto/task.proto create mode 100644 ge/proto/tensorflow/attr_value.proto create mode 100644 ge/proto/tensorflow/function.proto create mode 100644 ge/proto/tensorflow/graph.proto create mode 100644 ge/proto/tensorflow/graph_library.proto create mode 100644 ge/proto/tensorflow/node_def.proto create mode 100644 ge/proto/tensorflow/op_def.proto create mode 100644 ge/proto/tensorflow/resource_handle.proto create mode 100644 ge/proto/tensorflow/tensor.proto create mode 100644 ge/proto/tensorflow/tensor_shape.proto create mode 100644 ge/proto/tensorflow/types.proto create mode 100644 ge/proto/tensorflow/versions.proto rename {src/ge => ge}/session/inner_session.cc (100%) rename {src/ge => ge}/session/inner_session.h (100%) rename {src/ge => ge}/session/omg.cc (100%) create mode 100644 ge/session/readme.txt rename {src/ge => ge}/session/session_manager.cc (100%) rename {src/ge => ge}/session/session_manager.h (100%) rename {src/ge => ge}/single_op/single_op.cc (100%) rename {src/ge => ge}/single_op/single_op.h (100%) rename {src/ge => ge}/single_op/single_op_manager.cc (100%) rename {src/ge => ge}/single_op/single_op_manager.h (100%) rename {src/ge => ge}/single_op/single_op_model.cc (100%) rename {src/ge => ge}/single_op/single_op_model.h (100%) rename {src/ge => ge}/single_op/stream_resource.cc (100%) rename {src/ge => ge}/single_op/stream_resource.h (100%) rename {src/ge => ge}/single_op/task/aicpu_kernel_task_builder.cc (100%) rename {src/ge => ge}/single_op/task/aicpu_kernel_task_builder.h (100%) rename {src/ge => ge}/single_op/task/aicpu_task_builder.cc (100%) rename {src/ge => ge}/single_op/task/aicpu_task_builder.h (100%) rename {src/ge => ge}/single_op/task/build_task_utils.cc (100%) rename {src/ge => ge}/single_op/task/build_task_utils.h (100%) rename {src/ge => ge}/single_op/task/op_task.cc (100%) rename {src/ge => ge}/single_op/task/op_task.h (100%) rename {src/ge => ge}/single_op/task/tbe_task_builder.cc (100%) rename {src/ge => ge}/single_op/task/tbe_task_builder.h (100%) rename {src/ge => ge}/stub/Makefile (100%) rename {src/ge => ge}/stub/README (100%) rename {src/ge => ge}/stub/README.md (100%) mode change 100755 => 100644 rename {src/ge => ge}/stub/gen_stubapi.py (99%) delete mode 100644 inc/common/blocking_queue.h delete mode 100644 inc/common/dynamic_aipp.h delete mode 100644 inc/common/npu_error_define.h delete mode 100644 inc/common/opskernel/ge_task_info.h delete mode 100644 inc/common/opskernel/ops_kernel_info_store.h delete mode 100644 inc/common/opskernel/ops_kernel_info_types.h delete mode 100644 inc/common/optimizer/graph_optimizer.h delete mode 100644 inc/common/util/ai_core/common/aicore_util_attr_define.h delete mode 100644 inc/common/util/ai_core/common/aicore_util_types.h delete mode 100644 inc/common/util/ai_core/common/graph_comm.h delete mode 100644 inc/common/util/ai_core/common/scope_allocator.h delete mode 100644 inc/common/util/ai_core/param_calculate/aicore_param_calculator.h delete mode 100644 inc/common/util/ai_core/param_calculate/tensorsize_calculator.h delete mode 100644 inc/common/util/compress/compress.h delete mode 100644 inc/common/util/compress/compress_weight.h delete mode 100644 inc/common/util/error_manager/error_manager.h delete mode 100644 inc/common/util/platform_info.h delete mode 100644 inc/common/util/platform_info_def.h delete mode 100644 inc/external/graph/attr_value.h delete mode 100644 inc/external/graph/ge_error_codes.h delete mode 100644 inc/external/graph/graph.h delete mode 100644 inc/external/graph/inference_context.h delete mode 100644 inc/external/graph/operator.h delete mode 100644 inc/external/graph/operator_factory.h delete mode 100644 inc/external/graph/operator_reg.h delete mode 100644 inc/external/graph/tensor.h delete mode 100644 inc/external/graph/types.h delete mode 100644 inc/external/register/register.h delete mode 100644 inc/external/register/register_error_codes.h delete mode 100644 inc/external/register/register_fmk_types.h delete mode 100644 inc/external/register/register_types.h delete mode 100644 inc/external/register/scope/scope_fusion_pass_register.h create mode 100644 inc/framework/omg/parser/model_parser.h create mode 100644 inc/framework/omg/parser/op_parser.h rename inc/{common/optimizer/graph_optimizer_types.h => framework/omg/parser/parser_api.h} (59%) create mode 100644 inc/framework/omg/parser/parser_factory.h create mode 100644 inc/framework/omg/parser/parser_inner_ctx.h create mode 100644 inc/framework/omg/parser/weights_parser.h delete mode 100644 inc/graph/anchor.h delete mode 100644 inc/graph/attr_value_serializable.h delete mode 100644 inc/graph/buffer.h delete mode 100644 inc/graph/compute_graph.h delete mode 100644 inc/graph/debug/ge_attr_define.h delete mode 100644 inc/graph/def_types.h delete mode 100644 inc/graph/detail/any_map.h delete mode 100644 inc/graph/detail/attributes_holder.h delete mode 100644 inc/graph/detail/model_serialize_imp.h delete mode 100644 inc/graph/ge_attr_value.h delete mode 100644 inc/graph/ge_context.h delete mode 100644 inc/graph/ge_global_options.h delete mode 100644 inc/graph/ge_local_context.h delete mode 100644 inc/graph/ge_tensor.h delete mode 100644 inc/graph/graph_util.h delete mode 100644 inc/graph/model.h delete mode 100644 inc/graph/model_serialize.h delete mode 100644 inc/graph/node.h delete mode 100644 inc/graph/op_desc.h delete mode 100644 inc/graph/op_kernel_bin.h delete mode 100644 inc/graph/operator_factory_impl.h delete mode 100644 inc/graph/opsproto_manager.h delete mode 100644 inc/graph/range_vistor.h delete mode 100644 inc/graph/ref_relation.h delete mode 100644 inc/graph/runtime_inference_context.h delete mode 100644 inc/graph/shape_refiner.h delete mode 100644 inc/graph/tuning_utils.h delete mode 100644 inc/graph/usr_types.h delete mode 100644 inc/graph/utils/anchor_utils.h delete mode 100644 inc/graph/utils/attr_utils.h delete mode 100644 inc/graph/utils/graph_utils.h delete mode 100644 inc/graph/utils/node_utils.h delete mode 100644 inc/graph/utils/op_desc_utils.h delete mode 100644 inc/graph/utils/tensor_adapter.h delete mode 100644 inc/graph/utils/tensor_utils.h delete mode 100644 inc/graph/utils/type_utils.h delete mode 100755 src/common/graph/CMakeLists.txt delete mode 100644 src/common/graph/anchor.cc delete mode 100644 src/common/graph/attr_value.cc delete mode 100644 src/common/graph/buffer.cc delete mode 100644 src/common/graph/compute_graph.cc delete mode 100644 src/common/graph/debug/ge_log.h delete mode 100644 src/common/graph/debug/ge_op_types.h delete mode 100644 src/common/graph/debug/ge_util.h delete mode 100644 src/common/graph/debug/graph_debug.cc delete mode 100644 src/common/graph/debug/graph_debug.h delete mode 100644 src/common/graph/detail/attributes_holder.cc delete mode 100644 src/common/graph/format_refiner.cc delete mode 100644 src/common/graph/format_refiner.h delete mode 100644 src/common/graph/ge_attr_define.cc delete mode 100644 src/common/graph/ge_attr_value.cc delete mode 100644 src/common/graph/ge_tensor.cc delete mode 100644 src/common/graph/graph.cc delete mode 100644 src/common/graph/graph.mk delete mode 100644 src/common/graph/inference_context.cc delete mode 100644 src/common/graph/model.cc delete mode 100644 src/common/graph/model_serialize.cc delete mode 100644 src/common/graph/module.mk delete mode 100644 src/common/graph/node.cc delete mode 100644 src/common/graph/op_desc.cc delete mode 100644 src/common/graph/op_imp.cc delete mode 100644 src/common/graph/operator.cc delete mode 100644 src/common/graph/operator_factory.cc delete mode 100644 src/common/graph/operator_factory_impl.cc delete mode 100644 src/common/graph/opsproto/opsproto_manager.cc delete mode 100644 src/common/graph/option/ge_context.cc delete mode 100644 src/common/graph/option/ge_local_context.cc delete mode 100644 src/common/graph/ref_relation.cc delete mode 100644 src/common/graph/runtime_inference_context.cc delete mode 100644 src/common/graph/shape_refiner.cc delete mode 100644 src/common/graph/stub/Makefile delete mode 100644 src/common/graph/stub/gen_stubapi.py delete mode 100644 src/common/graph/tensor.cc delete mode 100644 src/common/graph/utils/anchor_utils.cc delete mode 100644 src/common/graph/utils/ge_ir_utils.cc delete mode 100644 src/common/graph/utils/ge_ir_utils.h delete mode 100644 src/common/graph/utils/graph_utils.cc delete mode 100644 src/common/graph/utils/mem_utils.h delete mode 100644 src/common/graph/utils/node_utils.cc delete mode 100644 src/common/graph/utils/op_desc_utils.cc delete mode 100644 src/common/graph/utils/string_utils.h delete mode 100644 src/common/graph/utils/tensor_utils.cc delete mode 100644 src/common/graph/utils/tuning_utils.cc delete mode 100644 src/common/graph/utils/type_utils.cc delete mode 100755 src/ge/CMakeLists.txt delete mode 100755 src/ge/client/CMakeLists.txt delete mode 100755 src/ge/common/CMakeLists.txt delete mode 100755 src/ge/executor/CMakeLists.txt delete mode 100755 src/ge/ge_local_engine/CMakeLists.txt delete mode 100755 src/ge/ge_runtime/CMakeLists.txt delete mode 100644 src/ge/ge_runtime/proto/task.pb.h delete mode 100644 src/ge/graph/build/memory/CMakeLists.txt delete mode 120000 src/ge/host_cpu_engine/proto/task.proto delete mode 100644 src/ge/plugin/engine/CMakeLists.txt delete mode 100644 src/proto/onnx.proto create mode 100644 third_party/patch/securec/0001-add-securec-cmake-script.patch diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..a2b1f260 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,8 @@ +[submodule "metadef"] + path = metadef + url = https://gitee.com/ascend/metadef.git + branch = master +[submodule "parser"] + path = parser + url = https://gitee.com/ascend/parser.git + branch = master diff --git a/CMakeLists.txt b/CMakeLists.txt index 971b4156..9a9a7a9d 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,137 +1,133 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - cmake_minimum_required(VERSION 3.14) project (GraphEngine[CXX]) -set(CMAKE_CXX_STANDARD 17) -add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) -set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) -set(GE_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) -set(GE_PROTO_DIR ${GE_SOURCE_DIR}/src) +set(GE_CODE_DIR ${CMAKE_CURRENT_LIST_DIR}) +set(CMAKE_SKIP_INSTALL_ALL_DEPENDENCY TRUE) if (NOT BUILD_PATH) set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build") endif() -# architecture: aarch64 or x86_64 -message(STATUS "System architecture: ${CMAKE_HOST_SYSTEM_PROCESSOR}") -# system: euleros or ubuntu -if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - execute_process( - COMMAND bash "-c" "cat /etc/os-release | grep ^ID= | awk -F '=' '{print $2}'" - OUTPUT_VARIABLE SYSTEM_TYPE - ) - MESSAGE(STATUS "System type: ${SYSTEM_TYPE}.") -endif() -# download json headers, rather than whole repository -include(${GE_SOURCE_DIR}/cmake/ge_utils.cmake) -include(${GE_SOURCE_DIR}/cmake/external_libs/json.cmake) -include(${GE_SOURCE_DIR}/cmake/external_libs/eigen.cmake) -include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake) -#include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) -include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf_shared.cmake) -include(${GE_SOURCE_DIR}/cmake/external_libs/protoc.cmake) -include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake) -include(${GE_SOURCE_DIR}/cmake/external_libs/securec.cmake) -set(CMAKE_SKIP_RPATH TRUE) +option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) -# for CPU/GPU mode, find c_sec and slog from local prebuild -if(NOT ENABLE_D AND NOT GE_ONLY) - set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}) - find_library(slog libslog.so ${GE_PREBUILD_PATH}) -# if D_LINK_PATH is set in environment variables, search libraries in given path -elseif(DEFINED ENV{D_LINK_PATH}) - # D_LINK_PATH is set - set(GE_LIB_PATH $ENV{D_LINK_PATH}) - set(GE_SYS_ARCH "") - if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") - # x86 ubuntu - set(GE_SYS_ARCH "x86_64") - elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") - # arm euleros - set(GE_SYS_ARCH "aarch64") - else() - message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") - endif() - set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) - find_library(slog libslog.so ${GE_LIB_PATH}) - find_library(mmpa libmmpa.so ${GE_LIB_PATH}) - find_library(runtime libruntime.so ${GE_LIB_PATH}) - find_library(msprof libmsprofiler.a ${GE_LIB_PATH}) - find_library(register libregister.so ${GE_LIB_PATH}) - find_library(hccl libhccl.so ${GE_LIB_PATH}) - find_library(resource libresource.so ${GE_LIB_PATH}) - find_library(error_manager liberror_manager.so ${GE_LIB_PATH}) - find_library(adump_server libadump_server.a ${GE_LIB_PATH}) -else() - # Ascend mode - if(DEFINED ENV{ASCEND_CUSTOM_PATH}) - set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) - else() - set(ASCEND_DIR /usr/local/Ascend) - endif() - set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64/common) - set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) - find_library(slog libslog.so ${ASCEND_DRIVER_DIR}) - find_library(mmpa libmmpa.so ${ASCEND_DRIVER_DIR}) - find_library(msprof libmsprofiler.a ${ASCEND_RUNTIME_DIR}) +if (ENABLE_OPEN_SRC) + set(HI_PYTHON python3.7) - find_library(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) - find_library(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) - find_library(register libregister.so ${ASCEND_RUNTIME_DIR}) - find_library(resource libresource.so ${ASCEND_RUNTIME_DIR}) - find_library(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) - find_library(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) -endif() + include(cmake/external_libs/protobuf_shared.cmake) + include(cmake/external_libs/protobuf_static.cmake) + include(cmake/external_libs/protoc.cmake) + include(cmake/external_libs/gflags.cmake) + include(cmake/external_libs/securec.cmake) + include(cmake/external_libs/json.cmake) + include(cmake/FindModule.cmake) + include(cmake/intf_pub_linux.cmake) -# add compile flags -if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") - message("Build in Debug mode") - set(CMAKE_C_FLAGS "-O0 -g -Wall -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -pipe -fPIC ${CMAKE_C_FLAGS}") - set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -pipe -fPIC ${CMAKE_CXX_FLAGS}") - if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -rdynamic") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic") + # for CPU/GPU mode, find c_sec and slog from local prebuild + #if(NOT ENABLE_D AND NOT GE_ONLY) + # set(GE_PREBUILD_PATH ${GE_CODE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}) + # find_module(slog libslog.so ${GE_PREBUILD_PATH}) + # if D_LINK_PATH is set in environment variables, search libraries in given path + if(DEFINED ENV{D_LINK_PATH}) + # D_LINK_PATH is set + set(GE_LIB_PATH $ENV{D_LINK_PATH}) + set(GE_SYS_ARCH "") + if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") + # x86 ubuntu + set(GE_SYS_ARCH "x86_64") + elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") + # arm euleros + set(GE_SYS_ARCH "aarch64") + else() + message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") + endif() + set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) + set(STATIC_ACL_LIB ${GE_LIB_PATH}) + find_module(slog libslog.so ${GE_LIB_PATH}) + find_module(mmpa libmmpa.so ${GE_LIB_PATH}) + find_module(msprof libmsprof.so ${GE_LIB_PATH}) + find_module(hccl libhccl.so ${GE_LIB_PATH}) + find_module(adump_server libadump_server.a ${GE_LIB_PATH}) + find_module(runtime libruntime.so ${GE_LIB_PATH}) + find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) + find_module(resource libresource.so ${GE_LIB_PATH}) + find_module(error_manager liberror_manager.so ${GE_LIB_PATH}) + find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) + find_module(error_manager_static liberror_manager.a ${GE_LIB_PATH}) + find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH}) + else() + if(DEFINED ENV{ASCEND_CUSTOM_PATH}) + set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) + else() + set(ASCEND_DIR /usr/local/Ascend) + endif() + set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) + set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) + set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) + set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) + set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) + set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) + set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) + find_module(slog libslog.so ${ASCEND_ATC_DIR}) + find_module(mmpa libmmpa.so ${ASCEND_ATC_DIR}) + if(PLATFORM STREQUAL "train") + find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) + find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) + find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) + find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) + find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) + find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) + find_module(msprofiler libmsprofiler.a ${ASCEND_RUNTIME_DIR}) + find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) + if(PRODUCT STREQUAL "flr3") + message(FATAL_ERROR "This platform is not supported in train mode, build terminated") + endif() + elseif(PLATFORM STREQUAL "inference") + find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) + find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) + find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) + find_module(resource libresource.so ${ASCEND_ATC_DIR}) + find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) + find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) + find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) + if(PRODUCT STREQUAL "flr3") + find_module(msprof libmsprof.so ${ASCEND_DRIVER_SHARE_DIR}) + elseif(PRODUCT STREQUAL "flr1") + find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) + find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) + elseif(PRODUCT STREQUAL "flr2") + # flr2 ascend_hal_stub limsprof ? + else() + find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) + find_module(msprof libmsprof.so ${ASCEND_DRIVER_DIR}) + endif() + elseif(PLATFORM STREQUAL "all") + find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) + find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) + find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) + find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) + find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) + find_module(resource libresource.so ${ASCEND_ATC_DIR}) + find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) + find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) + find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) + find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) + else() + message(FATAL_ERROR "PLATFORM param is invalid, should be train or inference, build terminated") + endif() endif() -else() - set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -pipe ${CMAKE_C_FLAGS}") - set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -pipe ${CMAKE_CXX_FLAGS}") -endif () -# force __FILE__ to show relative path of file, from source directory, as cmake project makes __FILE__ absolute directory -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst $(realpath ${CMAKE_SOURCE_DIR})/,,$(abspath $<))\"' -Wno-builtin-macro-redefined") + set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) + set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser) + set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) -# compile libraries from following directories -# libgraph is compiled in any situation -add_subdirectory(${GE_SOURCE_DIR}/src/common/graph) -if(ENABLE_D) - # if MindSpore compiles in D mode, compile the following libraries - add_subdirectory(${GE_SOURCE_DIR}/src/ge/common) - add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_runtime) -elseif(GE_ONLY) - # standalone GraphEngine compiles all following libraries - add_subdirectory(${GE_SOURCE_DIR}/src/ge/common) - add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_runtime) - add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_local_engine) - add_subdirectory(${GE_SOURCE_DIR}/src/ge/graph/build/memory) - add_subdirectory(${GE_SOURCE_DIR}/src/ge/) - add_subdirectory(${GE_SOURCE_DIR}/src/ge/plugin/engine) + add_subdirectory(metadef) + add_subdirectory(parser) + #add_subdirectory(metadef/graph) + #add_subdirectory(metadef/register) +else() + set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) + set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) + set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) endif() -# if (ENABLE_GE_COV OR ENABLE_GE_UT OR ENABLE_GE_ST) -# add_subdirectory(tests) -# endif() - +add_subdirectory(ge) diff --git a/build.sh b/build.sh index 5227f21f..b693ba74 100644 --- a/build.sh +++ b/build.sh @@ -23,7 +23,7 @@ export BUILD_PATH="${BASEPATH}/build/" usage() { echo "Usage:" - echo "sh build.sh [-j[n]] [-h] [-v] [-s] [-t] [-u] [-c]" + echo "sh build.sh [-j[n]] [-h] [-v] [-s] [-t] [-u] [-c] [-S on|off]" echo "" echo "Options:" echo " -h Print usage" @@ -32,10 +32,23 @@ usage() echo " -j[n] Set the number of threads used for building GraphEngine, default is 8" echo " -t Build and execute ut" echo " -c Build ut with coverage tag" + echo " -p Build inference or train" echo " -v Display build command" + echo " -S Enable enable download cmake compile dependency from gitee , default off" echo "to be continued ..." } +# check value of input is 'on' or 'off' +# usage: check_on_off arg_value arg_name +check_on_off() +{ + if [[ "X$1" != "Xon" && "X$1" != "Xoff" ]]; then + echo "Invalid value $1 for option -$2" + usage + exit 1 + fi +} + # parse and set options checkopts() { @@ -46,8 +59,11 @@ checkopts() ENABLE_GE_ST="off" ENABLE_GE_COV="off" GE_ONLY="on" + PLATFORM="train" + PRODUCT="normal" + ENABLE_GITEE="off" # Process the options - while getopts 'ustchj:v' opt + while getopts 'ustchj:p:g:vS:' opt do OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') case "${opt}" in @@ -77,6 +93,17 @@ checkopts() v) VERBOSE="VERBOSE=1" ;; + p) + PLATFORM=$OPTARG + ;; + g) + PRODUCT=$OPTARG + ;; + S) + check_on_off $OPTARG S + ENABLE_GITEE="$OPTARG" + echo "enable download from gitee" + ;; *) echo "Undefined option: ${opt}" usage @@ -86,6 +113,9 @@ checkopts() } checkopts "$@" +git submodule update --init metadef +git submodule update --init parser + mk_dir() { local create_dir="$1" # the target to make @@ -100,8 +130,8 @@ echo "---------------- GraphEngine build start ----------------" build_graphengine() { echo "create build directory and build GraphEngine"; - mk_dir "${BUILD_PATH}/graphengine" - cd "${BUILD_PATH}/graphengine" + mk_dir "${BUILD_PATH}" + cd "${BUILD_PATH}" CMAKE_ARGS="-DBUILD_PATH=$BUILD_PATH -DGE_ONLY=$GE_ONLY" if [[ "X$ENABLE_GE_COV" = "Xon" ]]; then @@ -117,17 +147,45 @@ build_graphengine() CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_ST=ON" fi + if [[ "X$ENABLE_GITEE" = "Xon" ]]; then + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GITEE=ON" + fi + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_OPEN_SRC=True -DCMAKE_INSTALL_PREFIX=${OUTPUT_PATH} -DPLATFORM=${PLATFORM} -DPRODUCT=${PRODUCT}" echo "${CMAKE_ARGS}" - cmake ${CMAKE_ARGS} ../.. - make ${VERBOSE} -j${THREAD_NUM} + cmake ${CMAKE_ARGS} .. + if [ $? -ne 0 ] + then + echo "execute command: cmake ${CMAKE_ARGS} .. failed." + return 1 + fi + COMMON_TARGET="ge_common engine fmk_parser parser_common _caffe_parser fmk_onnx_parser graph register engine_conf.json optimizer_priority.pbtxt " + TARGET=${COMMON_TARGET} + if [ "x${PLATFORM}" = "xtrain" ] + then + TARGET="ge_runner ge_local_engine host_cpu_engine ${TARGET}" + elif [ "x${PLATFORM}" = "xinference" ] + then + TARGET="ge_compiler atc_ge_local_engine atc_host_cpu_engine atc opensrc_ascendcl ${TARGET}" + elif [ "x${PLATFORM}" = "xall" ] + then + # build all the target + TARGET="" + fi + + make ${VERBOSE} ${TARGET} -j${THREAD_NUM} && make install + if [ $? -ne 0 ] + then + echo "execute command: make ${VERBOSE} -j${THREAD_NUM} && make install failed." + return 1 + fi echo "GraphEngine build success!" } g++ -v -build_graphengine -echo "---------------- GraphEngine build finished ----------------" mk_dir ${OUTPUT_PATH} -cp -rf "${BUILD_PATH}/graphengine/"*.so "${OUTPUT_PATH}" -rm -rf "${OUTPUT_PATH}/"libproto* +build_graphengine || { echo "GraphEngine build failed."; return; } +echo "---------------- GraphEngine build finished ----------------" +#cp -rf "${BUILD_PATH}/graphengine/"*.so "${OUTPUT_PATH}" +#rm -rf "${OUTPUT_PATH}/"libproto* rm -f ${OUTPUT_PATH}/libgmock*.so rm -f ${OUTPUT_PATH}/libgtest*.so rm -f ${OUTPUT_PATH}/lib*_stub.so @@ -175,43 +233,82 @@ echo "---------------- GraphEngine output generated ----------------" generate_package() { cd "${BASEPATH}" + + GRAPHENGINE_LIB_PATH="lib" + ACL_PATH="acllib/lib64" FWK_PATH="fwkacllib/lib64" ATC_PATH="atc/lib64" + ATC_BIN_PATH="atc/bin" NNENGINE_PATH="plugin/nnengine/ge_config" OPSKERNEL_PATH="plugin/opskernel" - ATC_LIB=("libc_sec.so" "libge_common.so" "libge_compiler.so" "libgraph.so") - FWK_LIB=("libge_common.so" "libge_runner.so" "libgraph.so") + ATC_LIB=("libc_sec.so" "libge_common.so" "libge_compiler.so" "libgraph.so" "libregister.so") + FWK_LIB=("libge_common.so" "libge_runner.so" "libgraph.so" "libregister.so") + PLUGIN_OPSKERNEL=("libge_local_engine.so" "libge_local_opskernel_builder.so" "libhost_cpu_engine.so" "libhost_cpu_opskernel_builder.so" "optimizer_priority.pbtxt") + PARSER_LIB=("lib_caffe_parser.so" "libfmk_onnx_parser.so" "libfmk_parser.so" "libparser_common.so") rm -rf ${OUTPUT_PATH:?}/${FWK_PATH}/ + rm -rf ${OUTPUT_PATH:?}/${ACL_PATH}/ rm -rf ${OUTPUT_PATH:?}/${ATC_PATH}/ + rm -rf ${OUTPUT_PATH:?}/${ATC_BIN_PATH}/ + mk_dir "${OUTPUT_PATH}/${FWK_PATH}/${NNENGINE_PATH}" mk_dir "${OUTPUT_PATH}/${FWK_PATH}/${OPSKERNEL_PATH}" mk_dir "${OUTPUT_PATH}/${ATC_PATH}/${NNENGINE_PATH}" mk_dir "${OUTPUT_PATH}/${ATC_PATH}/${OPSKERNEL_PATH}" + mk_dir "${OUTPUT_PATH}/${ACL_PATH}" + mk_dir "${OUTPUT_PATH}/${ATC_BIN_PATH}" + + cd "${OUTPUT_PATH}" - find output/ -name graphengine_lib.tar -exec rm {} \; - cp src/ge/engine_manager/engine_conf.json ${OUTPUT_PATH}/${FWK_PATH}/${NNENGINE_PATH} - cp src/ge/engine_manager/engine_conf.json ${OUTPUT_PATH}/${ATC_PATH}/${NNENGINE_PATH} + find ./ -name graphengine_lib.tar -exec rm {} \; - find output/ -maxdepth 1 -name libengine.so -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH}/${NNENGINE_PATH}/../ \; - find output/ -maxdepth 1 -name libengine.so -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH}/${NNENGINE_PATH}/../ \; + cp ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH}/engine_conf.json ${OUTPUT_PATH}/${FWK_PATH}/${NNENGINE_PATH} + cp ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH}/engine_conf.json ${OUTPUT_PATH}/${ATC_PATH}/${NNENGINE_PATH} - find output/ -maxdepth 1 -name libge_local_engine.so -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH}/${OPSKERNEL_PATH} \; - find output/ -maxdepth 1 -name libge_local_engine.so -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH}/${OPSKERNEL_PATH} \; + find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name libengine.so -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH}/${NNENGINE_PATH}/../ \; + find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name libengine.so -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH}/${NNENGINE_PATH}/../ \; - cd "${OUTPUT_PATH}" - for lib in "${ATC_LIB[@]}"; + MAX_DEPTH=1 + if [ "x${PLATFORM}" = "xall" ] || [ "x${PLATFORM}" = "xinference" ] + then + MAX_DEPTH=2 + fi + for lib in "${PLUGIN_OPSKERNEL[@]}"; + do + find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth ${MAX_DEPTH} -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH}/${OPSKERNEL_PATH} \; + find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth ${MAX_DEPTH} -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH}/${OPSKERNEL_PATH} \; + done + + for lib in "${PARSER_LIB[@]}"; do - cp "$lib" "${OUTPUT_PATH}/${ATC_PATH}" + find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH} \; + find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \; done for lib in "${FWK_LIB[@]}"; do - cp "$lib" "${OUTPUT_PATH}/${FWK_PATH}" + find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH} \; done - tar -cf graphengine_lib.tar fwkacllib/ atc/ + for lib in "${ATC_LIB[@]}"; + do + find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \; + done + + find ./bin -name atc -exec cp {} "${OUTPUT_PATH}/${ATC_BIN_PATH}" \; + find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "libascendcl.so" -exec cp -f {} ${OUTPUT_PATH}/${ACL_PATH} \; + + if [ "x${PLATFORM}" = "xtrain" ] + then + tar -cf graphengine_lib.tar fwkacllib + elif [ "x${PLATFORM}" = "xinference" ] + then + tar -cf graphengine_lib.tar acllib atc + elif [ "x${PLATFORM}" = "xall" ] + then + tar -cf graphengine_lib.tar fwkacllib acllib atc + fi } if [[ "X$ENABLE_GE_UT" = "Xoff" ]]; then diff --git a/cmake/FindModule.cmake b/cmake/FindModule.cmake new file mode 100644 index 00000000..74a63634 --- /dev/null +++ b/cmake/FindModule.cmake @@ -0,0 +1,23 @@ +#[[ + module - the name of export imported target + name - find the library name + path - find the library path +#]] +function(find_module module name path) + if (TARGET ${module}) + return() + endif() + find_library(${module}_LIBRARY_DIR NAMES ${name} NAMES_PER_DIR PATHS ${path} + PATH_SUFFIXES lib + ) + + message(STATUS "find ${name} location ${${module}_LIBRARY_DIR}") + if ("${${module}_LIBRARY_DIR}" STREQUAL "${module}_LIBRARY_DIR-NOTFOUND") + message(FATAL_ERROR "${name} not found in ${path}") + endif() + + add_library(${module} SHARED IMPORTED) + set_target_properties(${module} PROPERTIES + IMPORTED_LOCATION ${${module}_LIBRARY_DIR} + ) +endfunction() diff --git a/cmake/external_libs/eigen.cmake b/cmake/external_libs/eigen.cmake deleted file mode 100644 index 5cdfc346..00000000 --- a/cmake/external_libs/eigen.cmake +++ /dev/null @@ -1,22 +0,0 @@ -set(Eigen3_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") -set(Eigen3_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") -set(Eigen3_NS "ge_") - -if (ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/eigen-git-mirrorsource/repository/archive/3.3.7.tar.gz") - set(MD5 "cf6552a5d90c1aca4b5e0b011f65ea93") -else() - set(REQ_URL "https://gitlab.com/libeigen/eigen/-/archive/3.3.7/eigen-3.3.7.tar.gz") - set(MD5 "9e30f67e8531477de4117506fe44669b") -endif () - -graphengine_add_pkg(Eigen3 - VER 3.3.7 - URL ${REQ_URL} - MD5 ${MD5} - CMAKE_OPTION -DBUILD_TESTING=OFF) - -find_package(Eigen3 3.3.7 REQUIRED ${GE_FIND_NO_DEFAULT_PATH}) -set_property(TARGET Eigen3::Eigen PROPERTY IMPORTED_GLOBAL TRUE) -add_library(graphengine::eigen ALIAS Eigen3::Eigen) -include_directories(${EIGEN3_INCLUDE_DIRS}) diff --git a/cmake/external_libs/gflags.cmake b/cmake/external_libs/gflags.cmake new file mode 100755 index 00000000..0294192e --- /dev/null +++ b/cmake/external_libs/gflags.cmake @@ -0,0 +1,47 @@ +if (HAVE_GFLAGS) + return() +endif() + +include(ExternalProject) +#set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output) + +if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR + (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) + set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) + message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") +endif() + +if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz") + set(MD5 "") +else() + set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz") + set(MD5 "") +endif () + +ExternalProject_Add(gflags_build + URL ${REQ_URL} + #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz + #SOURCE_DIR ${GE_CODE_DIR}/../third_party/gflags/src/gflags-2.2.2 + CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags + BUILD_COMMAND $(MAKE) + INSTALL_COMMAND $(MAKE) install + EXCLUDE_FROM_ALL TRUE +) + +set(GFLAGS_PKG_DIR ${CMAKE_INSTALL_PREFIX}/gflags) + +add_library(gflags_static STATIC IMPORTED) + +set_target_properties(gflags_static PROPERTIES + IMPORTED_LOCATION ${GFLAGS_PKG_DIR}/lib/libgflags.a +) + +add_library(gflags INTERFACE) +target_include_directories(gflags INTERFACE ${GFLAGS_PKG_DIR}/include) +target_link_libraries(gflags INTERFACE gflags_static) + +add_dependencies(gflags gflags_build) + +#set(HAVE_GFLAGS TRUE CACHE BOOL "gflags build add") +set(HAVE_GFLAGS TRUE) diff --git a/cmake/external_libs/gtest.cmake b/cmake/external_libs/gtest.cmake deleted file mode 100644 index 5e175fd2..00000000 --- a/cmake/external_libs/gtest.cmake +++ /dev/null @@ -1,24 +0,0 @@ -set(ge_gtest_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") -set(ge_gtest_CFLAGS "-D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") - -if (ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.0.tar.gz") - set(MD5 "89e13ca1aa48d370719d58010b83f62c") -else() - set(REQ_URL "https://github.com/google/googletest/archive/release-1.8.0.tar.gz") - set(MD5 "16877098823401d1bf2ed7891d7dce36") -endif () - -graphengine_add_pkg(ge_gtest - VER 1.8.0 - LIBS gtest gtest_main - URL ${REQ_URL} - MD5 ${MD5} - CMAKE_OPTION -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON - -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON) - -add_library(graphengine::gtest ALIAS ge_gtest::gtest) -add_library(graphengine::gtest_main ALIAS ge_gtest::gtest_main) -include_directories(${ge_gtest_INC}) -file(COPY ${ge_gtest_INC}/../lib/libgtest.so DESTINATION ${CMAKE_SOURCE_DIR}/build/graphengine) -file(COPY ${ge_gtest_INC}/../lib/libgtest_main.so DESTINATION ${CMAKE_SOURCE_DIR}/build/graphengine) diff --git a/cmake/external_libs/json.cmake b/cmake/external_libs/json.cmake index f2ae5310..ce473d4b 100644 --- a/cmake/external_libs/json.cmake +++ b/cmake/external_libs/json.cmake @@ -1,20 +1,33 @@ -set(nlohmann_json_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") -set(nlohmann_json_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") +if (HAVE_JSON) + return() +endif() +include(ExternalProject) + +set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include) if (ENABLE_GITEE) set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") - set(INCLUDE "./include") + set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include") else() set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") set(MD5 "0dc903888211db3a0f170304cd9f3a89") - set(INCLUDE "./") + set(JSON_INCLUDE_DIR ${JSON_SRC_DIR}) endif () +ExternalProject_Add(json_build + URL ${REQ_URL} + #URL /home/txd/workspace/cloud_code/pkg/include.zip + SOURCE_DIR ${JSON_SRC_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + EXCLUDE_FROM_ALL TRUE +) + + +add_library(json INTERFACE) +target_include_directories(json INTERFACE ${JSON_INCLUDE_DIR}) +add_dependencies(json json_build) -graphengine_add_pkg(ge_nlohmann_json - VER 3.6.1 - HEAD_ONLY ${INCLUDE} - URL ${REQ_URL} - MD5 ${MD5}) -include_directories(${ge_nlohmann_json_INC}) -add_library(graphengine::json ALIAS ge_nlohmann_json) \ No newline at end of file +#set(HAVE_JSON TRUE CACHE BOOL "json build add") +set(HAVE_JSON TRUE) diff --git a/cmake/external_libs/onnx.cmake b/cmake/external_libs/onnx.cmake index a092f964..9dadb544 100644 --- a/cmake/external_libs/onnx.cmake +++ b/cmake/external_libs/onnx.cmake @@ -1,3 +1,11 @@ +include(ExternalProject) + +#set(ONNX_SRC_DIR /home/txd/workspace/cloud_code/graphengine/build/graphengine/open_source/onnx) +#set(ONNX_PROTO ${ONNX_SRC_DIR}/onnx/onnx.proto) +set(ONNX_PROTO_DIR ${CMAKE_BINARY_DIR}/onnx) +set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto) +file(MAKE_DIRECTORY ${ONNX_PROTO_DIR}) + if (ENABLE_GITEE) set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz") set(MD5 "1bdbcecdd68ea8392630467646776e02") @@ -6,8 +14,24 @@ else() set(MD5 "512f2779d6215d4a36f366b6b9acdf1e") endif () -graphengine_add_pkg(onnx - VER 1.6.0 - HEAD_ONLY ./ - URL ${REQ_URL} - MD5 ${MD5}) +ExternalProject_Add(onnx + URL ${REQ_URL} + #URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz + #URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345 + #SOURCE_DIR ${ONNX_SRC_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + #INSTALL_COMMAND "" + INSTALL_COMMAND ${CMAKE_COMMAND} -E copy /onnx/onnx.proto ${ONNX_PROTO_FILE} + #BUILD_ALWAYS TRUE + EXCLUDE_FROM_ALL TRUE +) + +macro(onnx_protobuf_generate comp c_var h_var) + add_custom_command(OUTPUT ${ONNX_PROTO_FILE} + DEPENDS onnx + ) + ge_protobuf_generate(${comp} ${c_var} ${h_var} ${ONNX_PROTO_FILE}) +endmacro() + + diff --git a/cmake/external_libs/protobuf.cmake b/cmake/external_libs/protobuf.cmake deleted file mode 100644 index 8be594c7..00000000 --- a/cmake/external_libs/protobuf.cmake +++ /dev/null @@ -1,63 +0,0 @@ -if (NOT TARGET protobuf::protobuf) -set(protobuf_USE_STATIC_LIBS ON) -set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -O2") -set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") -set(_ge_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) -string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") -string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") - -if (ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") - set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") -else() - set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") - set(MD5 "3d9e32700639618a4d2d342c99d4507a") -endif () - -graphengine_add_pkg(protobuf - VER 3.8.0 - LIBS protobuf - EXE protoc - URL ${REQ_URL} - MD5 ${MD5} - CMAKE_PATH ../cmake/ - CMAKE_OPTION -Dprotobuf_BUILD_TESTS=OFF -Dprotobuf_BUILD_SHARED_LIBS=OFF) -set(CMAKE_CXX_FLAGS ${_ge_tmp_CMAKE_CXX_FLAGS}) -endif() -add_library(graphengine::protobuf ALIAS protobuf::protobuf) -set(PROTOBUF_LIBRARY protobuf::protobuf) -include_directories(${protobuf_INC}) -include_directories(${protobuf_DIRPATH}/src) - -function(ge_protobuf_generate comp c_var h_var) - if(NOT ARGN) - message(SEND_ERROR "Error: ge_protobuf_generate() called without any proto files") - return() - endif() - - set(${c_var}) - set(${h_var}) - - foreach(file ${ARGN}) - get_filename_component(abs_file ${file} ABSOLUTE) - get_filename_component(file_name ${file} NAME_WE) - get_filename_component(file_dir ${abs_file} PATH) - - list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/proto/${comp}/proto/${file_name}.pb.cc") - list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/proto/${comp}/proto/${file_name}.pb.h") - - add_custom_command( - OUTPUT "${CMAKE_BINARY_DIR}/proto/${comp}/proto/${file_name}.pb.cc" - "${CMAKE_BINARY_DIR}/proto/${comp}/proto/${file_name}.pb.h" - WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} - COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/proto/${comp}/proto" - COMMAND protobuf::protoc -I${file_dir} --cpp_out=${CMAKE_BINARY_DIR}/proto/${comp}/proto ${abs_file} - DEPENDS protobuf::protoc ${abs_file} - COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM ) - endforeach() - - set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE) - set(${c_var} ${${c_var}} PARENT_SCOPE) - set(${h_var} ${${h_var}} PARENT_SCOPE) - -endfunction() diff --git a/cmake/external_libs/protobuf_shared.cmake b/cmake/external_libs/protobuf_shared.cmake index a2b922fa..f30aefe8 100644 --- a/cmake/external_libs/protobuf_shared.cmake +++ b/cmake/external_libs/protobuf_shared.cmake @@ -7,17 +7,25 @@ include(GNUInstallDirs) if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) - set(CMAKE_INSTALL_PREFIX ${GE_SOURCE_DIR}/output CACHE STRING "path for install()" FORCE) + set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") endif() +if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") + set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") +else() + set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") + set(MD5 "3d9e32700639618a4d2d342c99d4507a") +endif () + set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") ExternalProject_Add(protobuf_build - URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz + URL ${REQ_URL} #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz - #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 - #DOWNLOAD_COMMAND ${CMAKE_COMMAND} -E copy_directory ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 + #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 + #DOWNLOAD_COMMAND ${CMAKE_COMMAND} -E copy_directory ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 #CONFIGURE_COMMAND ${CMAKE_COMMAND} #-DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} #-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} diff --git a/cmake/external_libs/protobuf_static.cmake b/cmake/external_libs/protobuf_static.cmake new file mode 100755 index 00000000..57f4fd05 --- /dev/null +++ b/cmake/external_libs/protobuf_static.cmake @@ -0,0 +1,51 @@ +include(ExternalProject) +include(GNUInstallDirs) +#set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output) + +if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR + (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) + set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) + message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") +endif() + +if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") + set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") +else() + set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") + set(MD5 "3d9e32700639618a4d2d342c99d4507a") +endif () + +set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") +set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") +set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) +ExternalProject_Add(protobuf_static_build + URL ${REQ_URL} + #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz + #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 + CONFIGURE_COMMAND ${CMAKE_COMMAND} + -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} + -DCMAKE_LINKER=${CMAKE_LINKER} + -DCMAKE_AR=${CMAKE_AR} + -DCMAKE_RANLIB=${CMAKE_RANLIB} + -Dprotobuf_WITH_ZLIB=OFF + -Dprotobuf_BUILD_TESTS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${PROTOBUF_STATIC_PKG_DIR} /cmake + BUILD_COMMAND $(MAKE) + INSTALL_COMMAND $(MAKE) install + EXCLUDE_FROM_ALL TRUE +) +include(GNUInstallDirs) + +add_library(protobuf_static_lib STATIC IMPORTED) + +set_target_properties(protobuf_static_lib PROPERTIES + IMPORTED_LOCATION ${PROTOBUF_STATIC_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/libprotobuf.a +) + +add_library(protobuf_static INTERFACE) +target_include_directories(protobuf_static INTERFACE ${PROTOBUF_STATIC_PKG_DIR}/include) +target_link_libraries(protobuf_static INTERFACE protobuf_static_lib) + +add_dependencies(protobuf_static protobuf_static_build) diff --git a/cmake/external_libs/protoc.cmake b/cmake/external_libs/protoc.cmake index ed63afed..d1d329f2 100644 --- a/cmake/external_libs/protoc.cmake +++ b/cmake/external_libs/protoc.cmake @@ -8,7 +8,7 @@ include(GNUInstallDirs) if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) - set(CMAKE_INSTALL_PREFIX ${GE_SOURCE_DIR}/output CACHE STRING "path for install()" FORCE) + set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") endif() diff --git a/cmake/external_libs/securec.cmake b/cmake/external_libs/securec.cmake index 2fbf8b80..0bd62ab2 100644 --- a/cmake/external_libs/securec.cmake +++ b/cmake/external_libs/securec.cmake @@ -1,11 +1,62 @@ -graphengine_add_pkg(securec - VER 1.1.10 - URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz - MD5 193f0ca5246c1dd84920db34d2d8249f - LIBS c_sec - PATCHES ${GE_SOURCE_DIR}/third_party/patch/securec/securec.patch001 - CMAKE_OPTION "-DCMAKE_BUILD_TYPE=Release" - ) -include_directories(${securec_INC}) -file(COPY ${securec_INC}/../lib/libc_sec.so DESTINATION ${CMAKE_SOURCE_DIR}/build/graphengine) -add_library(graphengine::securec ALIAS securec::c_sec) \ No newline at end of file +if (HAVE_C_SEC) + return() +endif() + +include(ExternalProject) + +if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR + (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) + set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) + message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") +endif() + +ExternalProject_Add(c_sec_build + URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz + #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz + #SOURCE_DIR ${GE_CODE_DIR}/../libc_sec + PATCH_COMMAND patch -p1 < ${GE_CODE_DIR}/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch + CONFIGURE_COMMAND ${CMAKE_COMMAND} + -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_LINKER=${CMAKE_LINKER} + -DCMAKE_AR=${CMAKE_AR} + -DCMAKE_RANLIB=${CMAKE_RANLIB} + -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/c_sec + BUILD_COMMAND $(MAKE) + INSTALL_COMMAND $(MAKE) install + EXCLUDE_FROM_ALL TRUE +) + +set(C_SEC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/c_sec) + +add_library(c_sec SHARED IMPORTED) + +file(MAKE_DIRECTORY ${C_SEC_PKG_DIR}/include) + +set_target_properties(c_sec PROPERTIES + IMPORTED_LOCATION ${C_SEC_PKG_DIR}/lib/libc_sec.so +) + +target_include_directories(c_sec INTERFACE ${C_SEC_PKG_DIR}/include) + +add_dependencies(c_sec c_sec_build) + +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(FILES ${C_SEC_PKG_DIR}/lib/libc_sec.so OPTIONAL + DESTINATION ${INSTALL_LIBRARY_DIR}) + +add_library(c_sec_static_lib STATIC IMPORTED) +set_target_properties(c_sec_static_lib PROPERTIES + IMPORTED_LOCATION ${C_SEC_PKG_DIR}/lib/libc_sec.a +) + +add_library(c_sec_static INTERFACE) +target_include_directories(c_sec_static INTERFACE ${C_SEC_PKG_DIR}/include) +target_link_libraries(c_sec_static INTERFACE c_sec_static_lib) + +add_dependencies(c_sec_static c_sec_build) + +#set(HAVE_C_SEC TRUE CACHE BOOL "c_sec build add") +set(HAVE_C_SEC TRUE) diff --git a/cmake/ge_utils.cmake b/cmake/ge_utils.cmake deleted file mode 100644 index 75480ded..00000000 --- a/cmake/ge_utils.cmake +++ /dev/null @@ -1,371 +0,0 @@ -include(FetchContent) -set(FETCHCONTENT_QUIET OFF) - -function(graphengine_add_submodule_obj des_submodule_objs sub_dir submodule_name_obj) - - add_subdirectory(${sub_dir}) - - if(NOT TARGET ${submodule_name_obj}) - message(FATAL_ERROR "Can not find submodule '${submodule_name_obj}'. in ${CMAKE_CURRENT_LIST_FILE}") - endif() - if("$" IN_LIST ${des_submodule_objs}) - message(FATAL_ERROR "submodule '${submodule_name_obj}' added more than once. in ${CMAKE_CURRENT_LIST_FILE}") - endif() - - set(${des_submodule_objs} ${${des_submodule_objs}} $ PARENT_SCOPE) - -endfunction() - -if (DEFINED ENV{MSLIBS_CACHE_PATH}) - set(_MS_LIB_CACHE $ENV{MSLIBS_CACHE_PATH}) -else() - set(_MS_LIB_CACHE ${CMAKE_BINARY_DIR}/.mslib) -endif () -message("MS LIBS CACHE PATH: ${_MS_LIB_CACHE}") - -if (NOT EXISTS ${_MS_LIB_CACHE}) - file(MAKE_DIRECTORY ${_MS_LIB_CACHE}) -endif () - -if (DEFINED ENV{MSLIBS_SERVER}) - set(LOCAL_LIBS_SERVER $ENV{MSLIBS_SERVER}) - message("LOCAL_LIBS_SERVER: ${LOCAL_LIBS_SERVER}") -endif () - -include(ProcessorCount) -ProcessorCount(N) -if (JOBS) - set(THNUM ${JOBS}) -else() - set(JOBS 8) - if (${JOBS} GREATER ${N}) - set(THNUM ${N}) - endif() -endif () -message("set make thread num: ${THNUM}") - -if(LOCAL_LIBS_SERVER) - if (NOT ENV{no_proxy}) - set(ENV{no_proxy} "${LOCAL_LIBS_SERVER}") - else() - string(FIND $ENV{no_proxy} ${LOCAL_LIBS_SERVER} IP_POS) - if (${IP_POS} EQUAL -1) - set(ENV{no_proxy} "$ENV{no_proxy},${LOCAL_LIBS_SERVER}") - endif () - endif () -endif() - -function(__download_pkg pkg_name pkg_url pkg_md5) - - if(LOCAL_LIBS_SERVER) - get_filename_component(_URL_FILE_NAME ${pkg_url} NAME) - set(pkg_url "http://${LOCAL_LIBS_SERVER}:8081/libs/${pkg_name}/${_URL_FILE_NAME}" ${pkg_url}) - endif() - - FetchContent_Declare( - ${pkg_name} - URL ${pkg_url} - URL_HASH MD5=${pkg_md5} - ) - FetchContent_GetProperties(${pkg_name}) - message("download: ${${pkg_name}_SOURCE_DIR} , ${pkg_name} , ${pkg_url}") - if(NOT ${pkg_name}_POPULATED) - FetchContent_Populate(${pkg_name}) - set(${pkg_name}_SOURCE_DIR ${${pkg_name}_SOURCE_DIR} PARENT_SCOPE) - endif() - -endfunction() - -function(__download_pkg_with_git pkg_name pkg_url pkg_git_commit pkg_md5) - - if(LOCAL_LIBS_SERVER) - set(pkg_url "http://${LOCAL_LIBS_SERVER}:8081/libs/${pkg_name}/${pkg_git_commit}") - FetchContent_Declare( - ${pkg_name} - URL ${pkg_url} - URL_HASH MD5=${pkg_md5} - ) - else() - FetchContent_Declare( - ${pkg_name} - GIT_REPOSITORY ${pkg_url} - GIT_TAG ${pkg_git_commit}) - endif() - FetchContent_GetProperties(${pkg_name}) - message("download: ${${pkg_name}_SOURCE_DIR} , ${pkg_name} , ${pkg_url}") - if(NOT ${pkg_name}_POPULATED) - FetchContent_Populate(${pkg_name}) - set(${pkg_name}_SOURCE_DIR ${${pkg_name}_SOURCE_DIR} PARENT_SCOPE) - endif() - -endfunction() - - -function(__find_pkg_then_add_target pkg_name pkg_exe) - - unset(${pkg_name}_LIBS) - - message("_FIND:${${pkg_name}_BASE_DIR}") - - if(pkg_exe) - find_program(${pkg_exe}_EXE ${pkg_exe} PATHS ${${pkg_name}_BASE_DIR}/bin NO_DEFAULT_PATH) - if(NOT ${pkg_exe}_EXE) - return() - endif() - add_executable(${pkg_name}::${pkg_exe} IMPORTED GLOBAL) - set_target_properties(${pkg_name}::${pkg_exe} PROPERTIES - IMPORTED_LOCATION ${${pkg_exe}_EXE} - ) - message("found ${${pkg_exe}_EXE}") - endif() - - foreach(_LIB_NAME ${ARGN}) - set(_LIB_SEARCH_NAME ${_LIB_NAME}) - set(_LIB_TYPE SHARED) - if (${pkg_name}_USE_STATIC_LIBS) - set(_LIB_SEARCH_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}") - set(_LIB_TYPE STATIC) - endif () - set(${_LIB_NAME}_LIB ${_LIB_NAME}_LIB-NOTFOUND) - find_library(${_LIB_NAME}_LIB ${_LIB_SEARCH_NAME} PATHS ${${pkg_name}_BASE_DIR}/lib NO_DEFAULT_PATH) - if(NOT ${_LIB_NAME}_LIB) - return() - endif() - add_library(${pkg_name}::${_LIB_NAME} ${_LIB_TYPE} IMPORTED GLOBAL) - set_target_properties(${pkg_name}::${_LIB_NAME} PROPERTIES - INTERFACE_INCLUDE_DIRECTORIES "${${pkg_name}_BASE_DIR}/include" - IMPORTED_LOCATION ${${_LIB_NAME}_LIB} - ) - list(APPEND ${pkg_name}_LIBS ${pkg_name}::${_LIB_NAME}) - message("found ${${_LIB_NAME}_LIB}") - STRING( REGEX REPLACE "(.+)/(.+)" "\\1" LIBPATH ${${_LIB_NAME}_LIB}) - set(${pkg_name}_LIBPATH ${LIBPATH} CACHE STRING INTERNAL) - endforeach(_LIB_NAME) - - set(${pkg_name}_LIBS ${${pkg_name}_LIBS} PARENT_SCOPE) -endfunction() - -function(__exec_cmd) - set(options ) - set(oneValueArgs WORKING_DIRECTORY) - set(multiValueArgs COMMAND) - - cmake_parse_arguments(EXEC "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) - - execute_process(COMMAND ${EXEC_COMMAND} - WORKING_DIRECTORY ${EXEC_WORKING_DIRECTORY} - RESULT_VARIABLE RESULT) - if(NOT RESULT EQUAL "0") - message(FATAL_ERROR "error! when ${EXEC_COMMAND} in ${EXEC_WORKING_DIRECTORY}") - endif() -endfunction() - -function(__check_patches pkg_patches) - # check patches - if (PKG_PATCHES) - file(TOUCH ${_MS_LIB_CACHE}/${pkg_name}_patch.md5) - file(READ ${_MS_LIB_CACHE}/${pkg_name}_patch.md5 ${pkg_name}_PATCHES_MD5) - - message("patches md5:${${pkg_name}_PATCHES_MD5}") - - set(${pkg_name}_PATCHES_NEW_MD5 ) - foreach(_PATCH ${PKG_PATCHES}) - file(MD5 ${_PATCH} _PF_MD5) - set(${pkg_name}_PATCHES_NEW_MD5 "${${pkg_name}_PATCHES_NEW_MD5},${_PF_MD5}") - endforeach(_PATCH) - - if (NOT ${pkg_name}_PATCHES_MD5 STREQUAL ${pkg_name}_PATCHES_NEW_MD5) - set(${pkg_name}_PATCHES ${PKG_PATCHES}) - file(REMOVE_RECURSE "${_MS_LIB_CACHE}/${pkg_name}-subbuild") - file(WRITE ${_MS_LIB_CACHE}/${pkg_name}_patch.md5 ${${pkg_name}_PATCHES_NEW_MD5}) - message("patches changed : ${${pkg_name}_PATCHES_NEW_MD5}") - endif () - endif () -endfunction() - -set(GE_FIND_NO_DEFAULT_PATH NO_CMAKE_PATH NO_CMAKE_ENVIRONMENT_PATH NO_SYSTEM_ENVIRONMENT_PATH - NO_CMAKE_BUILDS_PATH NO_CMAKE_PACKAGE_REGISTRY NO_CMAKE_SYSTEM_PATH - NO_CMAKE_SYSTEM_PACKAGE_REGISTRY) -set(GE_FIND_NO_DEFAULT_PATH ${GE_FIND_NO_DEFAULT_PATH} PARENT_SCOPE) - -function(graphengine_add_pkg pkg_name ) - set(options ) - set(oneValueArgs URL MD5 GIT_REPOSITORY GIT_TAG VER EXE DIR HEAD_ONLY CMAKE_PATH) - set(multiValueArgs CMAKE_OPTION LIBS PRE_CONFIGURE_COMMAND CONFIGURE_COMMAND BUILD_OPTION INSTALL_INCS INSTALL_LIBS PATCHES) - cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) - - if (NOT PKG_CMAKE_PATH) - set(PKG_CMAKE_PATH ..) - endif () - - set(__FIND_PKG_NAME ${pkg_name}) - string(TOLOWER ${pkg_name} pkg_name) - message("pkg name:${__FIND_PKG_NAME},${pkg_name}") - - set(${pkg_name}_PATCHES_HASH ) - foreach(_PATCH ${PKG_PATCHES}) - file(MD5 ${_PATCH} _PF_MD5) - set(${pkg_name}_PATCHES_HASH "${${pkg_name}_PATCHES_HASH},${_PF_MD5}") - endforeach(_PATCH) - - # check options - set(${pkg_name}_CONFIG_TXT - "${CMAKE_CXX_COMPILER_VERSION}-${CMAKE_C_COMPILER_VERSION} - ${ARGN} - ${${pkg_name}_USE_STATIC_LIBS}- ${${pkg_name}_PATCHES_HASH} - ${${pkg_name}_CXXFLAGS}--${${pkg_name}_CFLAGS}--${${pkg_name}_LDFLAGS}") - string(REPLACE ";" "-" ${pkg_name}_CONFIG_TXT ${${pkg_name}_CONFIG_TXT}) - string(MD5 ${pkg_name}_CONFIG_HASH ${${pkg_name}_CONFIG_TXT}) - - message("${pkg_name} config hash: ${${pkg_name}_CONFIG_HASH}") - - set(${pkg_name}_BASE_DIR ${_MS_LIB_CACHE}/${pkg_name}_${${pkg_name}_CONFIG_HASH}) - set(${pkg_name}_DIRPATH ${${pkg_name}_BASE_DIR} CACHE STRING INTERNAL) - - if(EXISTS ${${pkg_name}_BASE_DIR}/options.txt AND PKG_HEAD_ONLY) - set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/${PKG_HEAD_ONLY} PARENT_SCOPE) - add_library(${pkg_name} INTERFACE) - target_include_directories(${pkg_name} INTERFACE ${${pkg_name}_INC}) - return() - endif () - - if(NOT PKG_EXE) - set(PKG_EXE 0) - endif() - - set(${__FIND_PKG_NAME}_ROOT ${${pkg_name}_BASE_DIR}) - set(${__FIND_PKG_NAME}_ROOT ${${pkg_name}_BASE_DIR} PARENT_SCOPE) - - if (PKG_LIBS) - __find_pkg_then_add_target(${pkg_name} ${PKG_EXE} ${PKG_LIBS}) - if(${pkg_name}_LIBS) - set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) - message("Found libs: ${${pkg_name}_LIBS}") - return() - endif() - elseif(NOT PKG_HEAD_ONLY) - find_package(${__FIND_PKG_NAME} ${PKG_VER} ${GE_FIND_NO_DEFAULT_PATH}) - if (${__FIND_PKG_NAME}_FOUND) - set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) - message("Found pkg: ${__FIND_PKG_NAME}") - return() - endif () - endif () - - if (NOT PKG_DIR) - if (PKG_GIT_REPOSITORY) - __download_pkg_with_git(${pkg_name} ${PKG_GIT_REPOSITORY} ${PKG_GIT_TAG} ${PKG_MD5}) - else() - __download_pkg(${pkg_name} ${PKG_URL} ${PKG_MD5}) - endif() - else() - set(${pkg_name}_SOURCE_DIR ${PKG_DIR}) - endif () - file(WRITE ${${pkg_name}_BASE_DIR}/options.txt ${${pkg_name}_CONFIG_TXT}) - message("${pkg_name}_SOURCE_DIR : ${${pkg_name}_SOURCE_DIR}") - - foreach(_PATCH_FILE ${PKG_PATCHES}) - message("patching ${${pkg_name}_SOURCE_DIR} -p1 < ${_PATCH_FILE}") - execute_process(COMMAND patch -p1 INPUT_FILE ${_PATCH_FILE} - WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR} - RESULT_VARIABLE Result) - if(NOT Result EQUAL "0") - message(FATAL_ERROR "Failed patch: ${_PATCH_FILE}") - endif() - endforeach(_PATCH_FILE) - - file(LOCK ${${pkg_name}_BASE_DIR} DIRECTORY GUARD FUNCTION RESULT_VARIABLE ${pkg_name}_LOCK_RET TIMEOUT 600) - if(NOT ${pkg_name}_LOCK_RET EQUAL "0") - message(FATAL_ERROR "error! when try lock ${${pkg_name}_BASE_DIR} : ${${pkg_name}_LOCK_RET}") - endif() - - if(${pkg_name}_SOURCE_DIR) - if (PKG_HEAD_ONLY) - file(GLOB ${pkg_name}_SOURCE_SUBDIRS ${${pkg_name}_SOURCE_DIR}/*) - file(COPY ${${pkg_name}_SOURCE_SUBDIRS} DESTINATION ${${pkg_name}_BASE_DIR}) - set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/${PKG_HEAD_ONLY} PARENT_SCOPE) - add_library(${pkg_name} INTERFACE) - target_include_directories(${pkg_name} INTERFACE ${${pkg_name}_INC}) - - elseif (PKG_CMAKE_OPTION) - # in cmake - file(MAKE_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) - if (${pkg_name}_CFLAGS) - set(${pkg_name}_CMAKE_CFLAGS "-DCMAKE_C_FLAGS=${${pkg_name}_CFLAGS}") - endif () - if (${pkg_name}_CXXFLAGS) - set(${pkg_name}_CMAKE_CXXFLAGS "-DCMAKE_CXX_FLAGS=${${pkg_name}_CXXFLAGS}") - endif () - - if (${pkg_name}_LDFLAGS) - if (${pkg_name}_USE_STATIC_LIBS) - #set(${pkg_name}_CMAKE_LDFLAGS "-DCMAKE_STATIC_LINKER_FLAGS=${${pkg_name}_LDFLAGS}") - else() - set(${pkg_name}_CMAKE_LDFLAGS "-DCMAKE_SHARED_LINKER_FLAGS=${${pkg_name}_LDFLAGS}") - endif () - endif () - - __exec_cmd(COMMAND ${CMAKE_COMMAND} ${PKG_CMAKE_OPTION} -G ${CMAKE_GENERATOR} - ${${pkg_name}_CMAKE_CFLAGS} ${${pkg_name}_CMAKE_CXXFLAGS} ${${pkg_name}_CMAKE_LDFLAGS} - -DCMAKE_INSTALL_PREFIX=${${pkg_name}_BASE_DIR} ${PKG_CMAKE_PATH} - WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) - - __exec_cmd(COMMAND ${CMAKE_COMMAND} --build . --target install -- -j${THNUM} - WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) - - else() - if (${pkg_name}_CFLAGS) - set(${pkg_name}_MAKE_CFLAGS "CFLAGS=${${pkg_name}_CFLAGS}") - endif () - if (${pkg_name}_CXXFLAGS) - set(${pkg_name}_MAKE_CXXFLAGS "CXXFLAGS=${${pkg_name}_CXXFLAGS}") - endif () - if (${pkg_name}_LDFLAGS) - set(${pkg_name}_MAKE_LDFLAGS "LDFLAGS=${${pkg_name}_LDFLAGS}") - endif () - # in configure && make - if (PKG_PRE_CONFIGURE_COMMAND) - __exec_cmd(COMMAND ${PKG_PRE_CONFIGURE_COMMAND} - WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) - endif () - - if (PKG_CONFIGURE_COMMAND) - __exec_cmd(COMMAND ${PKG_CONFIGURE_COMMAND} - ${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS} - --prefix=${${pkg_name}_BASE_DIR} - WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) - endif () - set(${pkg_name}_BUILD_OPTION ${PKG_BUILD_OPTION}) - if (NOT PKG_CONFIGURE_COMMAND) - set(${pkg_name}_BUILD_OPTION ${${pkg_name}_BUILD_OPTION} - ${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS}) - endif () - # build - __exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} ${${pkg_name}_BUILD_OPTION} -j${THNUM} - WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) - - if (PKG_INSTALL_INCS OR PKG_INSTALL_LIBS) - file(GLOB ${pkg_name}_INSTALL_INCS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_INCS}) - file(GLOB ${pkg_name}_INSTALL_LIBS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_LIBS}) - file(COPY ${${pkg_name}_INSTALL_INCS} DESTINATION ${${pkg_name}_BASE_DIR}/include) - file(COPY ${${pkg_name}_INSTALL_LIBS} DESTINATION ${${pkg_name}_BASE_DIR}/lib) - else() - __exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} install WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) - endif () - endif () - endif() - - if (PKG_LIBS) - __find_pkg_then_add_target(${pkg_name} ${PKG_EXE} ${PKG_LIBS}) - set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) - if(NOT ${pkg_name}_LIBS) - message(FATAL_ERROR "Can not find pkg: ${pkg_name}") - endif() - else() - find_package(${__FIND_PKG_NAME} ${PKG_VER} QUIET) - if (${__FIND_PKG_NAME}_FOUND) - set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) - message("Found pkg: ${${__FIND_PKG_NAME}_LIBRARIES}") - return() - endif () - endif () -endfunction() diff --git a/cmake/intf_pub_android.cmake b/cmake/intf_pub_android.cmake new file mode 100644 index 00000000..153d5764 --- /dev/null +++ b/cmake/intf_pub_android.cmake @@ -0,0 +1,52 @@ + +add_library(intf_pub INTERFACE) + +target_compile_options(intf_pub INTERFACE + -Wall + -fPIC + -fstack-protector-strong +) +target_compile_definitions(intf_pub INTERFACE + $<$:_GLIBCXX_USE_CXX11_ABI=0> + $<$:CFG_BUILD_NDEBUG> + $<$:CFG_BUILD_DEBUG> + WIN64=1 + LINUX=0 +) +target_link_options(intf_pub INTERFACE + -Wl,-z,relro + -Wl,-z,now + -Wl,-z,noexecstack + $<$:-Wl,--build-id=none> +) +target_link_directories(intf_pub INTERFACE +) + +add_library(intf_ccec INTERFACE) +target_compile_options(intf_ccec INTERFACE + -mcpu=cortex-a73 + --target=aarch64-linux-android29 + --sysroot=${HCC_PATH}/../sysroot + -L${HCC_PATH}/../lib/gcc/aarch64-linux-android/4.9.x + -Wall + -fPIC + -fstack-protector-strong +) +target_compile_definitions(intf_ccec INTERFACE + $<$:_GLIBCXX_USE_CXX11_ABI=0> + $<$:CFG_BUILD_NDEBUG> + $<$:CFG_BUILD_DEBUG> +) + +target_link_options(intf_ccec INTERFACE + -mcpu=cortex-a73 + --target=aarch64-linux-android29 + --sysroot=${HCC_PATH}/../sysroot + -L${HCC_PATH}/../lib/gcc/aarch64-linux-android/4.9.x + -Wl,-cce-host-android + -Wl,-z,relro + -Wl,-z,now + -Wl,-z,noexecstack + $<$:-Wl,--build-id=none> +) + diff --git a/cmake/intf_pub_linux.cmake b/cmake/intf_pub_linux.cmake new file mode 100644 index 00000000..40c6bca9 --- /dev/null +++ b/cmake/intf_pub_linux.cmake @@ -0,0 +1,33 @@ +if (HAVE_PUB) + return() +endif() + +add_library(intf_pub INTERFACE) + +target_compile_options(intf_pub INTERFACE + -Wall + -fPIC + $,-fstack-protector-all,-fstack-protector-strong> + $<$:-std=c++11> +) +target_compile_definitions(intf_pub INTERFACE + _GLIBCXX_USE_CXX11_ABI=0 + $<$:CFG_BUILD_NDEBUG> + $<$:CFG_BUILD_DEBUG> + WIN64=1 + LINUX=0 +) +target_link_options(intf_pub INTERFACE + -Wl,-z,relro + -Wl,-z,now + -Wl,-z,noexecstack + $<$:-Wl,--build-id=none> +) +target_link_directories(intf_pub INTERFACE +) +target_link_libraries(intf_pub INTERFACE + -lpthread +) + +#set(HAVE_PUB TRUE CACHE BOOL "pub add") +set(HAVE_PUB TRUE) diff --git a/cmake/intf_pub_windows.cmake b/cmake/intf_pub_windows.cmake new file mode 100644 index 00000000..19e37283 --- /dev/null +++ b/cmake/intf_pub_windows.cmake @@ -0,0 +1,24 @@ + +add_library(intf_pub INTERFACE) + +target_compile_options(intf_pub INTERFACE + -Wall + -fPIC + $,-fstack-protector-all,-fstack-protector-strong> + $<$:-std=c++11> +) +target_compile_definitions(intf_pub INTERFACE + $<$:_GLIBCXX_USE_CXX11_ABI=0> + OS_TYPE=WIN64 + WIN64=1 + LINUX=0 + $<$:CFG_BUILD_NDEBUG> + $<$:CFG_BUILD_DEBUG> +) +target_link_options(intf_pub INTERFACE + $<$:-Wl,--build-id=none> +) +target_link_directories(intf_pub INTERFACE +) +target_link_libraries(intf_pub INTERFACE +) diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt new file mode 100755 index 00000000..cd4d0c92 --- /dev/null +++ b/ge/CMakeLists.txt @@ -0,0 +1,901 @@ +add_subdirectory(common) +add_subdirectory(plugin/engine) +add_subdirectory(graph/build/memory) +add_subdirectory(ge_local_engine) +add_subdirectory(host_cpu_engine) +add_subdirectory(executor) +add_subdirectory(offline) + +set(PROTO_LIST + "${METADEF_DIR}/proto/fusion_model.proto" + "${GE_CODE_DIR}/ge/proto/optimizer_priority.proto" +) + +set(PROTO_CLIENT_LIST + "${METADEF_DIR}/proto/ge_api.proto" +) + +set(PROTO_HEADER_LIST + "${METADEF_DIR}/proto/om.proto" + "${METADEF_DIR}/proto/task.proto" + "${METADEF_DIR}/proto/insert_op.proto" + "${METADEF_DIR}/proto/ge_ir.proto" + "${METADEF_DIR}/proto/fwk_adapter.proto" + "${METADEF_DIR}/proto/op_mapping_info.proto" +) + +protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) +protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) + +############ libge_runner.so ############ +set(TRAIN_SRC_LIST + "common/formats/format_transfers/datatype_transfer.cc" + "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" + "common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" + "common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" + "common/formats/format_transfers/format_transfer_fractal_nz.cc" + "common/formats/format_transfers/format_transfer_fractal_z.cc" + "common/formats/format_transfers/format_transfer_fractal_zz.cc" + "common/formats/format_transfers/format_transfer_fracz_hwcn.cc" + "common/formats/format_transfers/format_transfer_fracz_nchw.cc" + "common/formats/format_transfers/format_transfer_fracz_nhwc.cc" + "common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" + "common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" + "common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" + "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" + "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" + "common/formats/format_transfers/format_transfer_transpose.cc" + "common/formats/formats.cc" + "common/formats/utils/formats_trans_utils.cc" + "common/fp16_t.cc" + "common/ge/plugin_manager.cc" + "common/ge/op_tiling_manager.cc" + "common/helper/model_cache_helper.cc" + "common/profiling/profiling_manager.cc" + "common/dump/dump_manager.cc" + "common/dump/dump_properties.cc" + "common/dump/dump_op.cc" + "engine_manager/dnnengine_manager.cc" + "ge_local_engine/engine/host_cpu_engine.cc" + "generator/ge_generator.cc" + "generator/generator_api.cc" + "graph/build/graph_builder.cc" + "graph/build/label_allocator.cc" + "graph/build/logical_stream_allocator.cc" + "graph/build/model_builder.cc" + "graph/build/run_context.cc" + "graph/build/stream_allocator.cc" + "graph/build/stream_graph_optimizer.cc" + "graph/build/task_generator.cc" + "graph/common/bcast.cc" + "graph/common/local_context.cc" + "graph/common/omg_util.cc" + "graph/common/transop_util.cc" + "graph/execute/graph_execute.cc" + "graph/label/case_label_maker.cc" + "graph/label/if_label_maker.cc" + "graph/label/label_maker.cc" + "graph/label/partitioned_call_label_maker.cc" + "graph/label/while_label_maker.cc" + "graph/load/graph_loader.cc" + "graph/load/new_model_manager/cpu_queue_schedule.cc" + "graph/load/new_model_manager/data_dumper.cc" + "graph/load/new_model_manager/data_inputer.cc" + "graph/load/new_model_manager/davinci_model.cc" + "graph/load/new_model_manager/davinci_model_parser.cc" + "graph/load/new_model_manager/model_manager.cc" + "graph/load/new_model_manager/model_utils.cc" + "graph/load/new_model_manager/aipp_utils.cc" + "graph/load/new_model_manager/task_info/end_graph_task_info.cc" + "graph/load/new_model_manager/task_info/event_record_task_info.cc" + "graph/load/new_model_manager/task_info/event_wait_task_info.cc" + "graph/load/new_model_manager/task_info/fusion_start_task_info.cc" + "graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" + "graph/load/new_model_manager/task_info/hccl_task_info.cc" + "graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" + "graph/load/new_model_manager/task_info/kernel_task_info.cc" + "graph/load/new_model_manager/task_info/label_set_task_info.cc" + "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" + "graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" + "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" + "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" + "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" + "graph/load/new_model_manager/task_info/stream_active_task_info.cc" + "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" + "graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" + "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" + "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" + "graph/load/new_model_manager/task_info/task_info.cc" + "graph/load/new_model_manager/tbe_handle_store.cc" + "graph/load/new_model_manager/zero_copy_task.cc" + "graph/load/new_model_manager/zero_copy_offset.cc" + "graph/manager/graph_context.cc" + "graph/manager/graph_manager.cc" + "graph/manager/graph_manager_utils.cc" + "graph/manager/graph_mem_allocator.cc" + "graph/manager/graph_caching_allocator.cc" + "graph/manager/graph_var_manager.cc" + "graph/manager/host_mem_manager.cc" + "graph/manager/rdma_pool_allocator.cc" + "graph/manager/memory_api.cc" + "graph/manager/model_manager/event_manager.cc" + "graph/manager/trans_var_data_utils.cc" + "graph/manager/util/debug.cc" + "graph/manager/util/hcom_util.cc" + "graph/manager/util/rt_context_util.cc" + "graph/manager/util/variable_accelerate_ctrl.cc" + "graph/optimize/graph_optimize.cc" + "graph/optimize/mem_rw_conflict_optimize.cc" + "graph/optimize/summary_optimize.cc" + "graph/partition/engine_place.cc" + "graph/partition/graph_partition.cc" + "graph/passes/addn_pass.cc" + "graph/passes/aicpu_constant_folding_pass.cc" + "graph/passes/assert_pass.cc" + "graph/passes/input_output_connection_identify_pass.cc" + "graph/passes/atomic_addr_clean_pass.cc" + "graph/passes/mark_same_addr_pass.cc" + "graph/passes/mark_graph_unknown_status_pass.cc" + "graph/passes/mark_agnostic_pass.cc" + "graph/partition/dynamic_shape_partition.cc" + "graph/passes/base_pass.cc" + "graph/passes/bitcast_pass.cc" + "graph/passes/cast_remove_pass.cc" + "graph/passes/cast_translate_pass.cc" + "graph/passes/common_subexpression_elimination_pass.cc" + "graph/passes/transop_symmetry_elimination_pass.cc" + "graph/passes/compile_nodes_pass.cc" + "graph/passes/constant_folding_pass.cc" + "graph/passes/constant_fuse_same_pass.cc" + "graph/passes/control_trigger_pass.cc" + "graph/passes/dimension_adjust_pass.cc" + "graph/passes/dimension_compute_pass.cc" + "graph/passes/dropout_pass.cc" + "graph/passes/hccl_group_pass.cc" + "graph/passes/enter_pass.cc" + "graph/passes/assign_pass.cc" + "graph/passes/flow_ctrl_pass.cc" + "graph/passes/global_step_insert_pass.cc" + "host_kernels/transpose_kernel.cc" + "host_kernels/add_kernel.cc" + "host_kernels/broadcast_args_kernel.cc" + "host_kernels/broadcast_gradient_args_kernel.cc" + "host_kernels/cast_kernel.cc" + "host_kernels/concat_offset_kernel.cc" + "host_kernels/concat_v2_kernel.cc" + "host_kernels/dynamic_stitch_kernel.cc" + "host_kernels/identity_kernel.cc" + "host_kernels/empty_kernel.cc" + "host_kernels/expanddims_kernel.cc" + "host_kernels/fill_kernel.cc" + "host_kernels/floordiv_kernel.cc" + "host_kernels/floormod_kernel.cc" + "host_kernels/gather_v2_kernel.cc" + "host_kernels/greater_kernel.cc" + "host_kernels/kernel_utils.cc" + "host_kernels/maximum_kernel.cc" + "host_kernels/mul_kernel.cc" + "host_kernels/pack_kernel.cc" + "host_kernels/permute_kernel.cc" + "host_kernels/range_kernel.cc" + "host_kernels/rank_kernel.cc" + "host_kernels/reduce_prod_kernel.cc" + "host_kernels/reshape_kernel.cc" + "host_kernels/rsqrt_kernel.cc" + "host_kernels/shape_kernel.cc" + "host_kernels/shape_n_kernel.cc" + "host_kernels/size_kernel.cc" + "host_kernels/slice_d_kernel.cc" + "host_kernels/slice_kernel.cc" + "host_kernels/squeeze_kernel.cc" + "host_kernels/unsqueeze_kernel.cc" + "host_kernels/ssd_prior_box_kernel.cc" + "host_kernels/strided_slice_kernel.cc" + "host_kernels/sub_kernel.cc" + "host_kernels/transdata_kernel.cc" + "host_kernels/unpack_kernel.cc" + "graph/passes/folding_pass.cc" + "graph/passes/get_original_format_pass.cc" + "graph/passes/guarantee_const_pass.cc" + "graph/passes/hccl_memcpy_pass.cc" + "graph/passes/identity_pass.cc" + "graph/passes/ref_identity_delete_op_pass.cc" + "graph/passes/infershape_pass.cc" + "graph/passes/isolated_op_remove_pass.cc" + "graph/passes/iterator_op_pass.cc" + "graph/passes/link_gen_mask_nodes_pass.cc" + "graph/passes/merge_pass.cc" + "graph/passes/multi_batch_pass.cc" + "graph/passes/multi_batch_clone_pass.cc" + "graph/passes/subexpression_migration_pass.cc" + "graph/passes/unused_args_clean_pass.cc" + "graph/passes/net_output_pass.cc" + "graph/passes/next_iteration_pass.cc" + "graph/passes/no_use_reshape_remove_pass.cc" + "graph/passes/pass_manager.cc" + "graph/passes/pass_utils.cc" + "graph/passes/permute_pass.cc" + "graph/passes/placeholder_with_default_pass.cc" + "graph/passes/prevent_gradient_pass.cc" + "graph/passes/print_op_pass.cc" + "graph/passes/prune_pass.cc" + "graph/passes/ctrl_edge_transfer_pass.cc" + "graph/passes/replace_with_empty_const_pass.cc" + "graph/passes/reshape_remove_pass.cc" + "graph/passes/reshape_recovery_pass.cc" + "graph/passes/resource_pair_add_control_pass.cc" + "graph/passes/resource_pair_remove_control_pass.cc" + "graph/passes/same_transdata_breadth_fusion_pass.cc" + "graph/passes/save_pass.cc" + "graph/passes/shape_operate_op_remove_pass.cc" + "graph/passes/snapshot_pass.cc" + "graph/passes/stop_gradient_pass.cc" + "graph/passes/subgraph_pass.cc" + "graph/passes/data_pass.cc" + "graph/passes/switch_data_edges_bypass.cc" + "graph/passes/switch_logic_remove_pass.cc" + "graph/passes/merge_to_stream_merge_pass.cc" + "graph/passes/switch_to_stream_switch_pass.cc" + "graph/passes/attach_stream_label_pass.cc" + "graph/passes/switch_dead_branch_elimination.cc" + "graph/passes/replace_transshape_pass.cc" + "graph/passes/transop_breadth_fusion_pass.cc" + "graph/passes/transop_depth_fusion_pass.cc" + "graph/passes/transop_nearby_allreduce_fusion_pass.cc" + "graph/passes/transop_without_reshape_fusion_pass.cc" + "graph/passes/transpose_transdata_pass.cc" + "graph/passes/unused_const_pass.cc" + "graph/passes/unused_op_remove_pass.cc" + "graph/passes/var_is_initialized_op_pass.cc" + "graph/passes/parallel_concat_start_op_pass.cc" + "graph/passes/cond_pass.cc" + "graph/passes/cond_remove_pass.cc" + "graph/passes/for_pass.cc" + "graph/passes/variable_format_pass.cc" + "graph/passes/variable_op_pass.cc" + "graph/passes/variable_prepare_op_pass.cc" + "graph/passes/variable_ref_delete_op_pass.cc" + "graph/passes/variable_ref_useless_control_out_delete_pass.cc" + "graph/passes/end_of_sequence_add_control_pass.cc" + "graph/passes/memcpy_addr_async_pass.cc" + "graph/passes/set_input_output_offset_pass.cc" + "graph/preprocess/graph_preprocess.cc" + "graph/preprocess/insert_op/ge_aipp_op.cc" + "graph/preprocess/insert_op/util_insert_aipp_op.cc" + "graph/preprocess/multi_batch_options.cc" + "graph/preprocess/multi_batch_copy_graph.cc" + "init/gelib.cc" + "model/ge_model.cc" + "model/ge_root_model.cc" + "omm/csa_interact.cc" + "opskernel_manager/ops_kernel_manager.cc" + "session/inner_session.cc" + "session/session_manager.cc" + "single_op/single_op.cc" + "single_op/single_op_manager.cc" + "single_op/single_op_model.cc" + "single_op/stream_resource.cc" + "single_op/task/build_task_utils.cc" + "single_op/task/op_task.cc" + "single_op/task/tbe_task_builder.cc" + "single_op/task/aicpu_task_builder.cc" + "single_op/task/aicpu_kernel_task_builder.cc" + "hybrid/common/tensor_value.cc" + "hybrid/common/npu_memory_allocator.cc" + "hybrid/executor/rt_callback_manager.cc" + "hybrid/executor/node_state.cc" + "hybrid/executor/node_done_manager.cc" + "hybrid/executor/hybrid_profiler.cc" + "hybrid/executor/hybrid_model_executor.cc" + "hybrid/executor/hybrid_model_async_executor.cc" + "hybrid/executor/hybrid_execution_context.cc" + "hybrid/executor/subgraph_context.cc" + "hybrid/executor/subgraph_executor.cc" + "hybrid/executor/worker/task_compile_engine.cc" + "hybrid/executor/worker/shape_inference_engine.cc" + "hybrid/executor/worker/execution_engine.cc" + "hybrid/model/hybrid_model.cc" + "hybrid/model/hybrid_model_builder.cc" + "hybrid/model/node_item.cc" + "hybrid/model/graph_item.cc" + "hybrid/node_executor/aicore/aicore_node_executor.cc" + "hybrid/node_executor/aicore/aicore_op_task.cc" + "hybrid/node_executor/aicore/aicore_task_builder.cc" + "hybrid/node_executor/aicore/aicore_task_compiler.cc" + "hybrid/node_executor/aicpu/aicpu_ext_info.cc" + "hybrid/node_executor/aicpu/aicpu_node_executor.cc" + "hybrid/node_executor/compiledsubgraph/known_node_executor.cc" + "hybrid/node_executor/ge_local/ge_local_node_executor.cc" + "hybrid/node_executor/host_cpu/host_cpu_node_executor.cc" + "hybrid/node_executor/host_cpu/kernel_factory.cc" + "hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc" + "hybrid/node_executor/host_cpu/kernel/variable_kernel.cc" + "hybrid/node_executor/host_cpu/kernel/assign_kernel.cc" + "hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc" + "hybrid/node_executor/controlop/control_op_executor.cc" + "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" + "hybrid/node_executor/hccl/hccl_node_executor.cc" + "hybrid/node_executor/rts/rts_node_executor.cc" + "hybrid/node_executor/node_executor.cc" + "hybrid/node_executor/task_context.cc" + "hybrid/hybrid_davinci_model.cc" + "executor/ge_executor.cc" + "client/ge_api.cc" + "client/ge_prof.cc" + "analyzer/analyzer.cc" +) + +add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) + +target_compile_definitions(ge_runner PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + DAVINCI_SUPPORT_PROFILING + REUSE_MEMORY=1 + FMK_SUPPORT_DUMP + DAVINCI_CLOUD +) + +target_compile_options(ge_runner PRIVATE + -O2 +) + +target_include_directories(ge_runner PRIVATE + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/ge/analyzer + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${GE_CODE_DIR}/inc/framework/common + ${METADEF_DIR} + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + ${GE_CODE_DIR}/../inc/external + ${GE_CODE_DIR}/../inc/cce + ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external + #### blue zone + ${ASCEND_DIR}/driver/include + ${ASCEND_DIR}/fwkacllib/include + ${GE_CODE_DIR}/third_party/fwkacllib/inc + ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain +) + +target_link_libraries(ge_runner + $ + ge_memory + adump_server + msprofiler + -Wl,--no-as-needed + graph + ge_common + protobuf + register + c_sec + slog + mmpa + msprof + runtime + resource + error_manager + ascend_hal_stub + -Wl,--as-needed + json + -lrt + -ldl +) + +############ libge_compiler.so ############ +set(INFER_SRC_LIST + "graph/manager/trans_var_data_utils.cc" + "omm/csa_interact.cc" + "common/fp16_t.cc" + "common/formats/utils/formats_trans_utils.cc" + "common/formats/format_transfers/datatype_transfer.cc" + "common/formats/format_transfers/format_transfer_transpose.cc" + "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" + "common/formats/format_transfers/format_transfer_fractal_z.cc" + "common/formats/format_transfers/format_transfer_fractal_nz.cc" + "common/formats/format_transfers/format_transfer_fractal_zz.cc" + "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" + "common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" + "common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" + "common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" + "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" + "common/formats/format_transfers/format_transfer_fracz_nchw.cc" + "common/formats/format_transfers/format_transfer_fracz_nhwc.cc" + "common/formats/format_transfers/format_transfer_fracz_hwcn.cc" + "common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" + "common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" + "common/formats/format_transfers/format_transfer_nchw_fz_c04.cc" + "common/formats/formats.cc" + "common/profiling/profiling_manager.cc" + "common/dump/dump_properties.cc" + "common/dump/dump_manager.cc" + "common/dump/dump_op.cc" + "common/dump/dump_server.cc" + "common/helper/model_cache_helper.cc" + "ge_local_engine/engine/host_cpu_engine.cc" + "common/ge/plugin_manager.cc" + "common/ge/op_tiling_manager.cc" + "init/gelib.cc" + "session/inner_session.cc" + "session/session_manager.cc" + "engine_manager/dnnengine_manager.cc" + "opskernel_manager/ops_kernel_manager.cc" + "graph/manager/graph_manager.cc" + "graph/manager/graph_manager_utils.cc" + "graph/manager/graph_context.cc" + "graph/preprocess/graph_preprocess.cc" + "graph/preprocess/multi_batch_options.cc" + "graph/preprocess/multi_batch_copy_graph.cc" + "graph/execute/graph_execute.cc" + "graph/load/graph_loader.cc" + "graph/optimize/graph_optimize.cc" + "graph/optimize/mem_rw_conflict_optimize.cc" + "graph/optimize/summary_optimize.cc" + "graph/build/graph_builder.cc" + "graph/partition/engine_place.cc" + "graph/partition/graph_partition.cc" + "graph/partition/dynamic_shape_partition.cc" + "generator/ge_generator.cc" + "generator/generator_api.cc" + "graph/manager/graph_var_manager.cc" + "graph/manager/host_mem_manager.cc" + "graph/manager/rdma_pool_allocator.cc" + "graph/manager/graph_mem_allocator.cc" + "graph/manager/graph_caching_allocator.cc" + "model/ge_model.cc" + "model/ge_root_model.cc" + "graph/common/transop_util.cc" + "graph/passes/pass_manager.cc" + "graph/passes/resource_pair_add_control_pass.cc" + "graph/passes/resource_pair_remove_control_pass.cc" + "graph/passes/pass_utils.cc" + "graph/passes/base_pass.cc" + "graph/passes/bitcast_pass.cc" + "graph/passes/constant_folding_pass.cc" + "graph/passes/aicpu_constant_folding_pass.cc" + "graph/passes/reshape_remove_pass.cc" + "graph/passes/reshape_recovery_pass.cc" + "graph/passes/transop_breadth_fusion_pass.cc" + "graph/passes/transop_depth_fusion_pass.cc" + "graph/passes/transop_nearby_allreduce_fusion_pass.cc" + "graph/passes/same_transdata_breadth_fusion_pass.cc" + "graph/passes/transop_without_reshape_fusion_pass.cc" + "graph/passes/compile_nodes_pass.cc" + "graph/passes/variable_prepare_op_pass.cc" + "graph/passes/variable_ref_delete_op_pass.cc" + "graph/passes/variable_ref_useless_control_out_delete_pass.cc" + "graph/passes/subgraph_pass.cc" + "graph/passes/data_pass.cc" + "graph/passes/net_output_pass.cc" + "graph/passes/replace_transshape_pass.cc" + "graph/passes/constant_fuse_same_pass.cc" + "graph/passes/print_op_pass.cc" + "graph/passes/no_use_reshape_remove_pass.cc" + "graph/passes/iterator_op_pass.cc" + "graph/passes/input_output_connection_identify_pass.cc" + "graph/passes/atomic_addr_clean_pass.cc" + "graph/passes/mark_same_addr_pass.cc" + "graph/passes/mark_graph_unknown_status_pass.cc" + "graph/passes/mark_agnostic_pass.cc" + "graph/common/omg_util.cc" + "graph/common/bcast.cc" + "graph/common/local_context.cc" + "graph/passes/dimension_compute_pass.cc" + "graph/passes/dimension_adjust_pass.cc" + "graph/passes/get_original_format_pass.cc" + "graph/passes/shape_operate_op_remove_pass.cc" + "graph/passes/unused_op_remove_pass.cc" + "graph/passes/assert_pass.cc" + "graph/passes/dropout_pass.cc" + "graph/passes/infershape_pass.cc" + "graph/passes/unused_const_pass.cc" + "graph/passes/isolated_op_remove_pass.cc" + "graph/passes/permute_pass.cc" + "graph/passes/ctrl_edge_transfer_pass.cc" + "graph/passes/end_of_sequence_add_control_pass.cc" + "host_kernels/broadcast_gradient_args_kernel.cc" + "host_kernels/greater_kernel.cc" + "host_kernels/gather_v2_kernel.cc" + "host_kernels/maximum_kernel.cc" + "host_kernels/floormod_kernel.cc" + "host_kernels/floordiv_kernel.cc" + "host_kernels/range_kernel.cc" + "host_kernels/shape_kernel.cc" + "host_kernels/size_kernel.cc" + "host_kernels/shape_n_kernel.cc" + "host_kernels/rank_kernel.cc" + "host_kernels/broadcast_args_kernel.cc" + "host_kernels/fill_kernel.cc" + "host_kernels/empty_kernel.cc" + "host_kernels/expanddims_kernel.cc" + "host_kernels/reshape_kernel.cc" + "host_kernels/squeeze_kernel.cc" + "host_kernels/unsqueeze_kernel.cc" + "host_kernels/kernel_utils.cc" + "host_kernels/cast_kernel.cc" + "host_kernels/transdata_kernel.cc" + "host_kernels/unpack_kernel.cc" + "host_kernels/transpose_kernel.cc" + "host_kernels/permute_kernel.cc" + "host_kernels/pack_kernel.cc" + "host_kernels/concat_v2_kernel.cc" + "host_kernels/concat_offset_kernel.cc" + "host_kernels/strided_slice_kernel.cc" + "host_kernels/ssd_prior_box_kernel.cc" + "host_kernels/add_kernel.cc" + "host_kernels/sub_kernel.cc" + "host_kernels/mul_kernel.cc" + "host_kernels/reduce_prod_kernel.cc" + "host_kernels/rsqrt_kernel.cc" + "host_kernels/slice_kernel.cc" + "host_kernels/slice_d_kernel.cc" + "host_kernels/dynamic_stitch_kernel.cc" + "host_kernels/identity_kernel.cc" + "graph/passes/stop_gradient_pass.cc" + "graph/passes/prevent_gradient_pass.cc" + "graph/passes/identity_pass.cc" + "graph/passes/ref_identity_delete_op_pass.cc" + "graph/passes/placeholder_with_default_pass.cc" + "graph/passes/snapshot_pass.cc" + "graph/passes/guarantee_const_pass.cc" + "graph/passes/var_is_initialized_op_pass.cc" + "graph/passes/parallel_concat_start_op_pass.cc" + "graph/passes/folding_pass.cc" + "graph/passes/cast_translate_pass.cc" + "graph/passes/prune_pass.cc" + "graph/passes/merge_to_stream_merge_pass.cc" + "graph/passes/switch_to_stream_switch_pass.cc" + "graph/passes/attach_stream_label_pass.cc" + "graph/passes/multi_batch_pass.cc" + "graph/passes/multi_batch_clone_pass.cc" + "graph/passes/subexpression_migration_pass.cc" + "graph/passes/unused_args_clean_pass.cc" + "graph/passes/next_iteration_pass.cc" + "graph/passes/control_trigger_pass.cc" + "graph/passes/cond_pass.cc" + "graph/passes/cond_remove_pass.cc" + "graph/passes/for_pass.cc" + "graph/passes/enter_pass.cc" + "graph/passes/assign_pass.cc" + "graph/passes/addn_pass.cc" + "graph/passes/common_subexpression_elimination_pass.cc" + "graph/passes/transop_symmetry_elimination_pass.cc" + "graph/passes/save_pass.cc" + "graph/passes/switch_dead_branch_elimination.cc" + "graph/passes/switch_logic_remove_pass.cc" + "graph/passes/switch_data_edges_bypass.cc" + "graph/passes/merge_pass.cc" + "graph/passes/variable_format_pass.cc" + "graph/passes/variable_op_pass.cc" + "graph/passes/cast_remove_pass.cc" + "graph/passes/transpose_transdata_pass.cc" + "graph/passes/hccl_memcpy_pass.cc" + "graph/passes/flow_ctrl_pass.cc" + "graph/passes/global_step_insert_pass.cc" + "graph/passes/link_gen_mask_nodes_pass.cc" + "graph/passes/replace_with_empty_const_pass.cc" + "graph/passes/hccl_group_pass.cc" + "graph/passes/memcpy_addr_async_pass.cc" + "graph/passes/set_input_output_offset_pass.cc" + "graph/manager/model_manager/event_manager.cc" + "graph/manager/util/rt_context_util.cc" + "graph/manager/util/variable_accelerate_ctrl.cc" + "graph/manager/util/debug.cc" + "graph/load/new_model_manager/model_manager.cc" + "graph/load/new_model_manager/data_inputer.cc" + "graph/load/new_model_manager/davinci_model.cc" + "graph/load/new_model_manager/davinci_model_parser.cc" + "graph/load/new_model_manager/model_utils.cc" + "graph/load/new_model_manager/aipp_utils.cc" + "graph/load/new_model_manager/tbe_handle_store.cc" + "graph/load/new_model_manager/cpu_queue_schedule.cc" + "graph/load/new_model_manager/zero_copy_task.cc" + "graph/load/new_model_manager/zero_copy_offset.cc" + "graph/load/new_model_manager/data_dumper.cc" + "graph/load/new_model_manager/task_info/task_info.cc" + "graph/load/new_model_manager/task_info/event_record_task_info.cc" + "graph/load/new_model_manager/task_info/event_wait_task_info.cc" + "graph/load/new_model_manager/task_info/fusion_start_task_info.cc" + "graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" + "graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" + "graph/load/new_model_manager/task_info/kernel_task_info.cc" + "graph/load/new_model_manager/task_info/label_set_task_info.cc" + "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" + "graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" + "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" + "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" + "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" + "graph/load/new_model_manager/task_info/stream_active_task_info.cc" + "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" + "graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" + "graph/load/new_model_manager/task_info/end_graph_task_info.cc" + "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" + "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" + "single_op/task/op_task.cc" + "single_op/task/build_task_utils.cc" + "single_op/task/tbe_task_builder.cc" + "single_op/task/aicpu_task_builder.cc" + "single_op/task/aicpu_kernel_task_builder.cc" + "single_op/single_op.cc" + "single_op/single_op_model.cc" + "single_op/stream_resource.cc" + "single_op/single_op_manager.cc" + "hybrid/hybrid_davinci_model_stub.cc" + "ir_build/ge_ir_build.cc" + "ir_build/atc_ir_common.cc" + "graph/preprocess/insert_op/ge_aipp_op.cc" + "graph/preprocess/insert_op/util_insert_aipp_op.cc" + "hybrid/node_executor/aicpu/aicpu_ext_info.cc" + "graph/build/model_builder.cc" + "graph/build/task_generator.cc" + "graph/build/stream_allocator.cc" + "graph/build/logical_stream_allocator.cc" + "graph/build/stream_graph_optimizer.cc" + "graph/build/run_context.cc" + "graph/build/label_allocator.cc" + "graph/label/label_maker.cc" + "graph/label/if_label_maker.cc" + "graph/label/case_label_maker.cc" + "graph/label/while_label_maker.cc" + "graph/label/partitioned_call_label_maker.cc" + "analyzer/analyzer.cc" +) + +add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) + +target_compile_definitions(ge_compiler PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + REUSE_MEMORY=1 + FMK_SUPPORT_DUMP + FMK_HOST_INFER + COMPILE_OMG_PACKAGE +) + +target_compile_options(ge_compiler PRIVATE + -O2 +) + +target_include_directories(ge_compiler PRIVATE + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/ge/analyzer + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${GE_CODE_DIR}/inc/framework/common + ${METADEF_DIR} + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + ${GE_CODE_DIR}/../inc/external + ${GE_CODE_DIR}/../inc/cce + ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external + #### blue zone #### + ${ASCEND_DIR}/driver/include + ${ASCEND_DIR}/fwkacllib/include + ${GE_CODE_DIR}/third_party/fwkacllib/inc + ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain +) + +target_link_libraries(ge_compiler + $ + ge_memory + -Wl,--no-as-needed + graph + ge_common + protobuf + register + c_sec + error_manager + slog + mmpa + runtime_compile + resource + -Wl,--as-needed + json + -lrt + -ldl +) + +############ libascendcl.so ############ +file(GENERATE OUTPUT ${CMAKE_BINARY_DIR}/dummy.c CONTENT "") +#add_library(dummy_obj OBJECT ${CMAKE_BINARY_DIR}/dummy.c) +#set(DUMMY_OBJ $) + +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ascendcl_object) + +if(EXISTS ${STATIC_ACL_LIB}/libascendcl.a) + execute_process( + COMMAND ar x ${STATIC_ACL_LIB}/libascendcl.a + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ascendcl_object + ) + file(GLOB OBJECT_LIST ${CMAKE_CURRENT_BINARY_DIR}/ascendcl_object/*.o) +else() + set(OBJECT_LIST ${CMAKE_BINARY_DIR}/dummy.c) +endif() + +add_library(opensrc_ascendcl SHARED + ${OBJECT_LIST} +) +target_compile_options(opensrc_ascendcl PRIVATE + -O2 + -fvisibility=hidden +) +target_link_options(opensrc_ascendcl PRIVATE + -rdynamic + -Wl,--allow-multiple-definition + -Wl,-z,muldefs + -Wl,-Bsymbolic + -Wl,--exclude-libs,ALL +) +target_link_libraries(opensrc_ascendcl PRIVATE + -Wl,--whole-archive + ge_executor + ge_common_static + graph_static + protobuf_static + register_static + error_manager_static + adump_server + msprofiler + -Wl,--no-whole-archive + -Wl,--no-as-needed + c_sec + runtime + mmpa + slog + msprof + ascend_hal_stub + -Wl,--as-needed + -ldl + json +) + +set_target_properties(opensrc_ascendcl PROPERTIES + OUTPUT_NAME ascendcl +) + +################################################################## +add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_ir_build.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_api.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_prof.cc + COMMAND echo "Generating stub files." + && ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/stub/gen_stubapi.py ${GE_CODE_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} + && mv ge_ir_build.cc stub_ge_ir_build.cc + && mv ge_api.cc stub_ge_api.cc + && mv ge_prof.cc stub_ge_prof.cc + && echo "Generating stub files end." + #WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + #DEPENDS stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} +) + +add_custom_target(ge_stub + DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_ir_build.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_api.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_prof.cc +) + +################################################################## +############ stub/libge_compiler.so ############ +add_library(atc_stub_ge_compiler SHARED + stub_ge_ir_build.cc +) + +add_dependencies(atc_stub_ge_compiler ge_stub) + +target_link_libraries(atc_stub_ge_compiler PRIVATE + $ +) + +set_target_properties(atc_stub_ge_compiler PROPERTIES + OUTPUT_NAME ge_compiler + LIBRARY_OUTPUT_DIRECTORY atc_stub +) + +target_include_directories(atc_stub_ge_compiler PRIVATE + ${GE_CODE_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/ge/analyzer + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/framework + ${GE_CODE_DIR}/inc/framework/common + ${GE_CODE_DIR}/inc/external + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/graph + #### yellow zone #### + ${GE_CODE_DIR}/../inc/cce + ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external + #### blue zone #### + ${ASCEND_DIR}/driver/include + ${ASCEND_DIR}/fwkacllib/include +) + +############ stub/libge_runner.so ############ +add_library(fwk_stub_ge_runner SHARED + stub_ge_api.cc + stub_ge_prof.cc +) + +add_dependencies(fwk_stub_ge_runner ge_stub) + +target_link_libraries(fwk_stub_ge_runner PRIVATE + $ +) + +set_target_properties(fwk_stub_ge_runner PROPERTIES + OUTPUT_NAME ge_runner + LIBRARY_OUTPUT_DIRECTORY fwk_stub +) + +target_include_directories(fwk_stub_ge_runner PRIVATE + ${GE_CODE_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/ge/analyzer + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${GE_CODE_DIR}/inc/framework/common + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/graph + #### yellow zone #### + ${GE_CODE_DIR}/../inc/cce + ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external + #### blue zone #### + ${ASCEND_DIR}/driver/include + ${ASCEND_DIR}/fwkacllib/include +) + +############################################################### +add_custom_target( + engine_conf.json ALL + DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/engine_conf.json +) +add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/engine_conf.json + COMMAND cp ${CMAKE_CURRENT_LIST_DIR}/engine_manager/engine_conf.json ${CMAKE_CURRENT_BINARY_DIR}/ +) + +############################################################### +add_custom_target( + optimizer_priority.pbtxt ALL + DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt +) +add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt + COMMAND cp ${CMAKE_CURRENT_LIST_DIR}/opskernel_manager/optimizer_priority.pbtxt ${CMAKE_CURRENT_BINARY_DIR}/ +) + +############################################################### + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS ge_runner ge_compiler opensrc_ascendcl OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} +) + +install(TARGETS atc_stub_ge_compiler fwk_stub_ge_runner OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/stub +) + +install(FILES + ${CMAKE_CURRENT_BINARY_DIR}/engine_conf.json + ${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt OPTIONAL + DESTINATION ${INSTALL_LIBRARY_DIR} +) diff --git a/ge/README.md b/ge/README.md new file mode 100644 index 00000000..e69de29b diff --git a/src/ge/analyzer/analyzer.cc b/ge/analyzer/analyzer.cc similarity index 100% rename from src/ge/analyzer/analyzer.cc rename to ge/analyzer/analyzer.cc diff --git a/src/ge/analyzer/analyzer.h b/ge/analyzer/analyzer.h similarity index 100% rename from src/ge/analyzer/analyzer.h rename to ge/analyzer/analyzer.h diff --git a/src/ge/client/ge_api.cc b/ge/client/ge_api.cc similarity index 100% rename from src/ge/client/ge_api.cc rename to ge/client/ge_api.cc diff --git a/src/ge/client/ge_prof.cc b/ge/client/ge_prof.cc similarity index 100% rename from src/ge/client/ge_prof.cc rename to ge/client/ge_prof.cc diff --git a/src/ge/client/module.mk b/ge/client/module.mk similarity index 100% rename from src/ge/client/module.mk rename to ge/client/module.mk diff --git a/src/proto/ge_api.proto b/ge/client/proto/ge_api.proto similarity index 100% rename from src/proto/ge_api.proto rename to ge/client/proto/ge_api.proto diff --git a/src/proto/ge_ir.proto b/ge/client/proto/ge_ir.proto similarity index 100% rename from src/proto/ge_ir.proto rename to ge/client/proto/ge_ir.proto diff --git a/src/proto/insert_op.proto b/ge/client/proto/insert_op.proto similarity index 100% rename from src/proto/insert_op.proto rename to ge/client/proto/insert_op.proto diff --git a/src/proto/om.proto b/ge/client/proto/om.proto similarity index 100% rename from src/proto/om.proto rename to ge/client/proto/om.proto diff --git a/src/proto/task.proto b/ge/client/proto/task.proto similarity index 100% rename from src/proto/task.proto rename to ge/client/proto/task.proto diff --git a/ge/common/CMakeLists.txt b/ge/common/CMakeLists.txt new file mode 100644 index 00000000..685a6fe2 --- /dev/null +++ b/ge/common/CMakeLists.txt @@ -0,0 +1,171 @@ +set(PROTO_LIST + "${METADEF_DIR}/proto/om.proto" + "${METADEF_DIR}/proto/ge_ir.proto" + "${METADEF_DIR}/proto/insert_op.proto" + "${METADEF_DIR}/proto/task.proto" + "${METADEF_DIR}/proto/tensorflow/attr_value.proto" + "${METADEF_DIR}/proto/tensorflow/function.proto" + "${METADEF_DIR}/proto/tensorflow/graph.proto" + "${METADEF_DIR}/proto/tensorflow/node_def.proto" + "${METADEF_DIR}/proto/tensorflow/op_def.proto" + "${METADEF_DIR}/proto/tensorflow/resource_handle.proto" + "${METADEF_DIR}/proto/tensorflow/tensor.proto" + "${METADEF_DIR}/proto/tensorflow/tensor_shape.proto" + "${METADEF_DIR}/proto/tensorflow/types.proto" + "${METADEF_DIR}/proto/tensorflow/versions.proto" +) + +protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) + +set(SRC_LIST + "context/ctx.cc" + "model_saver.cc" + "ge/datatype_util.cc" + "helper/om_file_helper.cc" + "helper/model_helper.cc" + "../model/ge_model.cc" + "auth/file_saver.cc" + "fp16_t.cc" + "math/fp16_math.cc" + "debug/memory_dumper.cc" + "formats/utils/formats_trans_utils.cc" + "dump/dump_properties.cc" + "formats/format_transfers/datatype_transfer.cc" + "formats/format_transfers/format_transfer_transpose.cc" + "formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" + "formats/format_transfers/format_transfer_fractal_z.cc" + "formats/format_transfers/format_transfer_fractal_nz.cc" + "formats/format_transfers/format_transfer_fractal_zz.cc" + "formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" + "formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" + "formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" + "formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" + "formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" + "formats/format_transfers/format_transfer_fracz_nchw.cc" + "formats/format_transfers/format_transfer_fracz_nhwc.cc" + "formats/format_transfers/format_transfer_fracz_hwcn.cc" + "formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" + "formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" + "formats/format_transfers/format_transfer_nchw_fz_c04.cc" + "formats/formats.cc" + "ge_format_util.cc" + "fmk_error_codes.cc" + "util.cc" + "properties_manager.cc" + "types.cc" + "model_parser/base.cc" + "kernel_store.cc" + "tbe_kernel_store.cc" + "cust_aicpu_kernel_store.cc" + "op/attr_value_util.cc" + "op/ge_op_utils.cc" + "thread_pool.cc" + "ge/tbe_plugin_manager.cc" +) + +############ libge_common.so ############ +add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) +target_compile_definitions(ge_common PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + HOST_VISIBILITY + FMK_SUPPORT_DUMP + OS_CENTOS +) + +target_compile_options(ge_common PRIVATE + -fvisibility=hidden + -O2 + -Werror +) + +target_include_directories(ge_common PRIVATE + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/ge/common + ${GE_CODE_DIR}/ge/common/op + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_DEPEND_DIR}/inc + ${GE_DEPEND_DIR}/inc/cce + #### blue zone #### + #${GE_DEPEND_DIR}/include + ${GE_CODE_DIR}/third_party/fwkacllib/inc + ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain +) + +target_link_libraries(ge_common PRIVATE + $ + -Wl,--no-as-needed + graph + protobuf + register + c_sec + error_manager + slog + mmpa + -Wl,--as-needed + json + -lrt + -ldl +) + +############ libge_common.a ############ +add_library(ge_common_static STATIC ${SRC_LIST} ${PROTO_HDRS}) +target_compile_definitions(ge_common_static PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + HOST_VISIBILITY + FMK_SUPPORT_DUMP + OS_CENTOS +) + +target_compile_options(ge_common_static PRIVATE + -fvisibility=hidden + -O2 + -Werror +) + +target_include_directories(ge_common_static PRIVATE + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/ge/common + ${GE_CODE_DIR}/ge/common/op + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_DEPEND_DIR}/inc + ${GE_DEPEND_DIR}/inc/cce + #### blue zone #### + #${GE_DEPEND_DIR}/include + ${GE_CODE_DIR}/third_party/fwkacllib/inc + ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain +) + +target_link_libraries(ge_common_static PRIVATE + $ + protobuf + json + c_sec + -lrt + -ldl +) + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS ge_common OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} +) diff --git a/src/ge/common/auth/file_saver.cc b/ge/common/auth/file_saver.cc similarity index 100% rename from src/ge/common/auth/file_saver.cc rename to ge/common/auth/file_saver.cc diff --git a/src/ge/common/auth/file_saver.h b/ge/common/auth/file_saver.h similarity index 100% rename from src/ge/common/auth/file_saver.h rename to ge/common/auth/file_saver.h diff --git a/src/ge/common/base64.h b/ge/common/base64.h similarity index 100% rename from src/ge/common/base64.h rename to ge/common/base64.h diff --git a/src/ge/common/context/ctx.cc b/ge/common/context/ctx.cc similarity index 100% rename from src/ge/common/context/ctx.cc rename to ge/common/context/ctx.cc diff --git a/src/ge/common/convert/pb2json.cc b/ge/common/convert/pb2json.cc similarity index 100% rename from src/ge/common/convert/pb2json.cc rename to ge/common/convert/pb2json.cc diff --git a/src/ge/common/convert/pb2json.h b/ge/common/convert/pb2json.h similarity index 100% rename from src/ge/common/convert/pb2json.h rename to ge/common/convert/pb2json.h diff --git a/src/ge/common/cust_aicpu_kernel_store.cc b/ge/common/cust_aicpu_kernel_store.cc similarity index 100% rename from src/ge/common/cust_aicpu_kernel_store.cc rename to ge/common/cust_aicpu_kernel_store.cc diff --git a/src/ge/common/cust_aicpu_kernel_store.h b/ge/common/cust_aicpu_kernel_store.h similarity index 100% rename from src/ge/common/cust_aicpu_kernel_store.h rename to ge/common/cust_aicpu_kernel_store.h diff --git a/src/ge/common/debug/memory_dumper.cc b/ge/common/debug/memory_dumper.cc similarity index 100% rename from src/ge/common/debug/memory_dumper.cc rename to ge/common/debug/memory_dumper.cc diff --git a/src/ge/common/debug/memory_dumper.h b/ge/common/debug/memory_dumper.h similarity index 100% rename from src/ge/common/debug/memory_dumper.h rename to ge/common/debug/memory_dumper.h diff --git a/src/ge/common/dump/dump_manager.cc b/ge/common/dump/dump_manager.cc similarity index 100% rename from src/ge/common/dump/dump_manager.cc rename to ge/common/dump/dump_manager.cc diff --git a/src/ge/common/dump/dump_manager.h b/ge/common/dump/dump_manager.h similarity index 100% rename from src/ge/common/dump/dump_manager.h rename to ge/common/dump/dump_manager.h diff --git a/src/ge/common/dump/dump_op.cc b/ge/common/dump/dump_op.cc similarity index 100% rename from src/ge/common/dump/dump_op.cc rename to ge/common/dump/dump_op.cc diff --git a/src/ge/common/dump/dump_op.h b/ge/common/dump/dump_op.h similarity index 100% rename from src/ge/common/dump/dump_op.h rename to ge/common/dump/dump_op.h diff --git a/src/ge/common/dump/dump_properties.cc b/ge/common/dump/dump_properties.cc similarity index 100% rename from src/ge/common/dump/dump_properties.cc rename to ge/common/dump/dump_properties.cc diff --git a/src/ge/common/dump/dump_properties.h b/ge/common/dump/dump_properties.h similarity index 100% rename from src/ge/common/dump/dump_properties.h rename to ge/common/dump/dump_properties.h diff --git a/src/ge/common/dump/dump_server.cc b/ge/common/dump/dump_server.cc similarity index 100% rename from src/ge/common/dump/dump_server.cc rename to ge/common/dump/dump_server.cc diff --git a/src/ge/common/fmk_error_codes.cc b/ge/common/fmk_error_codes.cc similarity index 100% rename from src/ge/common/fmk_error_codes.cc rename to ge/common/fmk_error_codes.cc diff --git a/src/ge/common/formats/format_transfers/datatype_transfer.cc b/ge/common/formats/format_transfers/datatype_transfer.cc similarity index 100% rename from src/ge/common/formats/format_transfers/datatype_transfer.cc rename to ge/common/formats/format_transfers/datatype_transfer.cc diff --git a/src/ge/common/formats/format_transfers/datatype_transfer.h b/ge/common/formats/format_transfers/datatype_transfer.h similarity index 100% rename from src/ge/common/formats/format_transfers/datatype_transfer.h rename to ge/common/formats/format_transfers/datatype_transfer.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc b/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc rename to ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h b/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h rename to ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc b/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc rename to ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h b/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h rename to ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc b/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc rename to ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h b/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h rename to ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc b/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc rename to ge/common/formats/format_transfers/format_transfer_fractal_nz.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.h b/ge/common/formats/format_transfers/format_transfer_fractal_nz.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_fractal_nz.h rename to ge/common/formats/format_transfers/format_transfer_fractal_nz.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/ge/common/formats/format_transfers/format_transfer_fractal_z.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc rename to ge/common/formats/format_transfers/format_transfer_fractal_z.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_z.h b/ge/common/formats/format_transfers/format_transfer_fractal_z.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_fractal_z.h rename to ge/common/formats/format_transfers/format_transfer_fractal_z.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc b/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc rename to ge/common/formats/format_transfers/format_transfer_fractal_zz.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.h b/ge/common/formats/format_transfers/format_transfer_fractal_zz.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_fractal_zz.h rename to ge/common/formats/format_transfers/format_transfer_fractal_zz.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc b/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc rename to ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.h b/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.h rename to ge/common/formats/format_transfers/format_transfer_fracz_hwcn.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc b/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc rename to ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.h b/ge/common/formats/format_transfers/format_transfer_fracz_nchw.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.h rename to ge/common/formats/format_transfers/format_transfer_fracz_nchw.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc b/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc rename to ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.h b/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.h rename to ge/common/formats/format_transfers/format_transfer_fracz_nhwc.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc b/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc rename to ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h b/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h rename to ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc rename to ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h rename to ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc rename to ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h rename to ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc b/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc rename to ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h b/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h rename to ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc b/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc rename to ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h b/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h rename to ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc b/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc rename to ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h b/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h rename to ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h diff --git a/src/ge/common/formats/format_transfers/format_transfer_transpose.cc b/ge/common/formats/format_transfers/format_transfer_transpose.cc similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_transpose.cc rename to ge/common/formats/format_transfers/format_transfer_transpose.cc diff --git a/src/ge/common/formats/format_transfers/format_transfer_transpose.h b/ge/common/formats/format_transfers/format_transfer_transpose.h similarity index 100% rename from src/ge/common/formats/format_transfers/format_transfer_transpose.h rename to ge/common/formats/format_transfers/format_transfer_transpose.h diff --git a/src/ge/common/formats/formats.cc b/ge/common/formats/formats.cc similarity index 100% rename from src/ge/common/formats/formats.cc rename to ge/common/formats/formats.cc diff --git a/src/ge/common/formats/formats.h b/ge/common/formats/formats.h similarity index 100% rename from src/ge/common/formats/formats.h rename to ge/common/formats/formats.h diff --git a/src/ge/common/formats/utils/formats_definitions.h b/ge/common/formats/utils/formats_definitions.h similarity index 100% rename from src/ge/common/formats/utils/formats_definitions.h rename to ge/common/formats/utils/formats_definitions.h diff --git a/src/ge/common/formats/utils/formats_trans_utils.cc b/ge/common/formats/utils/formats_trans_utils.cc similarity index 100% rename from src/ge/common/formats/utils/formats_trans_utils.cc rename to ge/common/formats/utils/formats_trans_utils.cc diff --git a/src/ge/common/formats/utils/formats_trans_utils.h b/ge/common/formats/utils/formats_trans_utils.h similarity index 100% rename from src/ge/common/formats/utils/formats_trans_utils.h rename to ge/common/formats/utils/formats_trans_utils.h diff --git a/src/ge/common/fp16_t.cc b/ge/common/fp16_t.cc similarity index 100% rename from src/ge/common/fp16_t.cc rename to ge/common/fp16_t.cc diff --git a/src/ge/common/fp16_t.h b/ge/common/fp16_t.h similarity index 100% rename from src/ge/common/fp16_t.h rename to ge/common/fp16_t.h diff --git a/src/ge/common/ge/datatype_util.cc b/ge/common/ge/datatype_util.cc similarity index 100% rename from src/ge/common/ge/datatype_util.cc rename to ge/common/ge/datatype_util.cc diff --git a/src/ge/common/ge/datatype_util.h b/ge/common/ge/datatype_util.h similarity index 100% rename from src/ge/common/ge/datatype_util.h rename to ge/common/ge/datatype_util.h diff --git a/src/ge/common/ge/ge_util.h b/ge/common/ge/ge_util.h similarity index 100% rename from src/ge/common/ge/ge_util.h rename to ge/common/ge/ge_util.h diff --git a/src/ge/common/ge/op_tiling_manager.cc b/ge/common/ge/op_tiling_manager.cc similarity index 100% rename from src/ge/common/ge/op_tiling_manager.cc rename to ge/common/ge/op_tiling_manager.cc diff --git a/src/ge/common/ge/op_tiling_manager.h b/ge/common/ge/op_tiling_manager.h similarity index 100% rename from src/ge/common/ge/op_tiling_manager.h rename to ge/common/ge/op_tiling_manager.h diff --git a/src/ge/common/ge/plugin_manager.cc b/ge/common/ge/plugin_manager.cc similarity index 100% rename from src/ge/common/ge/plugin_manager.cc rename to ge/common/ge/plugin_manager.cc diff --git a/src/ge/common/ge/plugin_manager.h b/ge/common/ge/plugin_manager.h similarity index 100% rename from src/ge/common/ge/plugin_manager.h rename to ge/common/ge/plugin_manager.h diff --git a/src/ge/common/ge/tbe_plugin_manager.cc b/ge/common/ge/tbe_plugin_manager.cc similarity index 100% rename from src/ge/common/ge/tbe_plugin_manager.cc rename to ge/common/ge/tbe_plugin_manager.cc diff --git a/src/ge/common/ge/tbe_plugin_manager.h b/ge/common/ge/tbe_plugin_manager.h similarity index 100% rename from src/ge/common/ge/tbe_plugin_manager.h rename to ge/common/ge/tbe_plugin_manager.h diff --git a/src/ge/common/ge_common.mk b/ge/common/ge_common.mk similarity index 100% rename from src/ge/common/ge_common.mk rename to ge/common/ge_common.mk diff --git a/src/ge/common/ge_format_util.cc b/ge/common/ge_format_util.cc similarity index 100% rename from src/ge/common/ge_format_util.cc rename to ge/common/ge_format_util.cc diff --git a/src/ge/common/helper/model_cache_helper.cc b/ge/common/helper/model_cache_helper.cc similarity index 100% rename from src/ge/common/helper/model_cache_helper.cc rename to ge/common/helper/model_cache_helper.cc diff --git a/src/ge/common/helper/model_cache_helper.h b/ge/common/helper/model_cache_helper.h similarity index 100% rename from src/ge/common/helper/model_cache_helper.h rename to ge/common/helper/model_cache_helper.h diff --git a/src/ge/common/helper/model_helper.cc b/ge/common/helper/model_helper.cc similarity index 100% rename from src/ge/common/helper/model_helper.cc rename to ge/common/helper/model_helper.cc diff --git a/src/ge/common/helper/om_file_helper.cc b/ge/common/helper/om_file_helper.cc similarity index 100% rename from src/ge/common/helper/om_file_helper.cc rename to ge/common/helper/om_file_helper.cc diff --git a/src/ge/common/kernel_store.cc b/ge/common/kernel_store.cc similarity index 100% rename from src/ge/common/kernel_store.cc rename to ge/common/kernel_store.cc diff --git a/src/ge/common/kernel_store.h b/ge/common/kernel_store.h similarity index 100% rename from src/ge/common/kernel_store.h rename to ge/common/kernel_store.h diff --git a/src/ge/common/math/fp16_math.cc b/ge/common/math/fp16_math.cc similarity index 100% rename from src/ge/common/math/fp16_math.cc rename to ge/common/math/fp16_math.cc diff --git a/src/ge/common/math/fp16_math.h b/ge/common/math/fp16_math.h similarity index 100% rename from src/ge/common/math/fp16_math.h rename to ge/common/math/fp16_math.h diff --git a/src/ge/common/math/math_util.h b/ge/common/math/math_util.h similarity index 100% rename from src/ge/common/math/math_util.h rename to ge/common/math/math_util.h diff --git a/src/ge/common/math_util.h b/ge/common/math_util.h similarity index 100% rename from src/ge/common/math_util.h rename to ge/common/math_util.h diff --git a/src/ge/common/model_parser/base.cc b/ge/common/model_parser/base.cc similarity index 100% rename from src/ge/common/model_parser/base.cc rename to ge/common/model_parser/base.cc diff --git a/src/ge/common/model_parser/base.h b/ge/common/model_parser/base.h similarity index 100% rename from src/ge/common/model_parser/base.h rename to ge/common/model_parser/base.h diff --git a/src/ge/common/model_saver.cc b/ge/common/model_saver.cc similarity index 100% rename from src/ge/common/model_saver.cc rename to ge/common/model_saver.cc diff --git a/src/ge/common/model_saver.h b/ge/common/model_saver.h similarity index 94% rename from src/ge/common/model_saver.h rename to ge/common/model_saver.h index 4c4fcdf1..411d5e35 100644 --- a/src/ge/common/model_saver.h +++ b/ge/common/model_saver.h @@ -22,8 +22,8 @@ #include "framework/common/types.h" /** - * Provide read and write operations for offline model files - */ +* Provide read and write operations for offline model files +*/ namespace ge { using Json = nlohmann::json; diff --git a/src/ge/common/module.mk b/ge/common/module.mk similarity index 100% rename from src/ge/common/module.mk rename to ge/common/module.mk diff --git a/src/ge/common/op/attr_value_util.cc b/ge/common/op/attr_value_util.cc similarity index 100% rename from src/ge/common/op/attr_value_util.cc rename to ge/common/op/attr_value_util.cc diff --git a/src/ge/common/op/ge_op_utils.cc b/ge/common/op/ge_op_utils.cc similarity index 100% rename from src/ge/common/op/ge_op_utils.cc rename to ge/common/op/ge_op_utils.cc diff --git a/src/ge/common/profiling/profiling_manager.cc b/ge/common/profiling/profiling_manager.cc similarity index 100% rename from src/ge/common/profiling/profiling_manager.cc rename to ge/common/profiling/profiling_manager.cc diff --git a/src/ge/common/profiling/profiling_manager.h b/ge/common/profiling/profiling_manager.h similarity index 100% rename from src/ge/common/profiling/profiling_manager.h rename to ge/common/profiling/profiling_manager.h diff --git a/src/ge/common/properties_manager.cc b/ge/common/properties_manager.cc similarity index 100% rename from src/ge/common/properties_manager.cc rename to ge/common/properties_manager.cc diff --git a/src/ge/common/properties_manager.h b/ge/common/properties_manager.h similarity index 100% rename from src/ge/common/properties_manager.h rename to ge/common/properties_manager.h diff --git a/ge/common/proto/ge_ir.proto b/ge/common/proto/ge_ir.proto new file mode 100644 index 00000000..87886c84 --- /dev/null +++ b/ge/common/proto/ge_ir.proto @@ -0,0 +1,206 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package ge.proto; + +enum DataType +{ + DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. + DT_FLOAT = 1; // float type + DT_FLOAT16 = 2; // fp16 type + DT_INT8 = 3; // int8 type + DT_UINT8 = 4; // uint8 type + DT_INT16 = 5; // int16 type + DT_UINT16 = 6; // uint16 type + DT_INT32 = 7; // + DT_INT64 = 8; // int64 type + DT_UINT32 = 9; // unsigned int32 + DT_UINT64 = 10; // unsigned int64 + DT_BOOL = 11; // bool type + DT_DOUBLE = 12; // double type + DT_STRING = 13; // string type + DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ + DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ + DT_COMPLEX64 = 16; // complex64 type + DT_COMPLEX128 = 17; // complex128 type + DT_QINT8 = 18; // qint8 type + DT_QINT16 = 19; // qint16 type + DT_QINT32 = 20; // qint32 type + DT_QUINT8 = 21; // quint8 type + DT_QUINT16 = 22; // quint16 type + DT_RESOURCE = 23; // resource type + DT_STRING_REF = 24; // string_ref type + DT_DUAL = 25; /**< dual output type */ +} + +message AttrDef +{ + message ListValue + { + enum ListValueType{ + VT_LIST_NONE = 0; + VT_LIST_STRING = 1; + VT_LIST_INT = 2; + VT_LIST_FLOAT = 3; + VT_LIST_BOOL = 4; + VT_LIST_BYTES = 5; + VT_LIST_TENSOR_DESC = 6; + VT_LIST_TENSOR = 7; + VT_LIST_GRAPH = 8; + VT_LIST_NAMED_ATTRS = 9; + VT_LIST_DATA_TYPE = 10; + } + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3; // "list(int)" + repeated float f = 4; // "list(float)" + repeated bool b = 5; // "list(bool)" + repeated bytes bt = 7; + repeated TensorDescriptor td = 8; + repeated TensorDef t = 9; + repeated GraphDef g = 10; + repeated NamedAttrs na = 11; + repeated int64 dt = 12; // list ge::DataType + + ListValueType val_type = 20; + } + + message ListListInt{ + message ListInt{ + repeated int64 list_i = 1; // list int + } + repeated ListInt list_list_i = 1; // list list int + } + + oneof value + { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; // Used to support attr nesting + TensorDescriptor td = 11; // GeTensorDesc type + TensorDef t = 12; // GeTensor type + GraphDef g = 13; // Graph type + ListListInt list_list_int = 14; // List List Int type + int64 dt = 15; // ge::DataType + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs +{ + string name = 1; + map attr = 2; +} + +// Shape / dimension description, using row-major order +message ShapeDef +{ + repeated int64 dim = 1; // Size of each dimension +} + +// Multidimensional data description +message TensorDescriptor +{ + string name = 1; // Optional parameter, tensor name + + DataType dtype = 2; // tensor datatype + ShapeDef shape = 3; // Shape / dimension + string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" + + bool has_out_attr = 9; + int64 size = 10; + int64 weight_size = 11; + bool reuse_input = 12; + bool output_tensor = 13; + string device_type = 14; + bool input_tensor =15; + int64 real_dim_cnt = 16; + int64 reuse_input_index = 17; + int64 data_offset = 18; + int64 cmps_size = 19; + string cmps_tab = 20; + int64 cmps_tab_offset = 21; + + map attr = 5; // Set of extra parameter fields +} + +// GeTensor definition +message TensorDef +{ + TensorDescriptor desc = 1; // Tensor description + bytes data = 2; // Tensor data +} + + +// Operator description +message OpDef +{ + string name = 1; // name + string type = 2; // type + + repeated string input = 5; // input original op name + outgoing index. op_name:index + + map attr = 10; // Set of operator parameter fields + + bool has_out_attr = 20; + int64 id = 21; + int64 stream_id =22; + repeated string input_name = 23; + repeated string src_name = 24; + repeated int64 src_index = 25; + repeated string dst_name = 26; + repeated int64 dst_index = 27; + repeated int64 input_i = 28; + repeated int64 output_i = 29; + repeated int64 workspace = 30; + repeated int64 workspace_bytes = 31; + repeated bool is_input_const = 32; + repeated TensorDescriptor input_desc = 33; + repeated TensorDescriptor output_desc = 34; + repeated string subgraph_name = 35; +} + +// Graph definition +message GraphDef +{ + string name = 1; // name + + repeated string input = 4; // Graph input + repeated string output = 5; // Graph output + + repeated OpDef op = 6; // List of operators + + map attr = 11; // Extended field +} + +// model definition +message ModelDef +{ + string name = 1; // name + uint32 version = 2; // IR Proto verion + string custom_version = 3; // User model version number, passed in by user + + repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef + + map attr = 11; // Extended field +} + diff --git a/ge/common/proto/insert_op.proto b/ge/common/proto/insert_op.proto new file mode 100644 index 00000000..a059e122 --- /dev/null +++ b/ge/common/proto/insert_op.proto @@ -0,0 +1,152 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package domi; + +message InsertNewOps { + repeated AippOpParams aipp_op = 1; + repeated MultiShapeOpParams multi_shape_op = 2; +} + +message AippOpParams { + enum InputFormat { + UNDEFINED = 0; + YUV420SP_U8 = 1; + XRGB8888_U8 = 2; + RGB888_U8 = 3; + YUV400_U8 = 4; + NC1HWC0DI_FP16 = 5; + NC1HWC0DI_S8 = 6; + ARGB8888_U8 = 7; + YUYV_U8 = 8; + YUV422SP_U8 = 9; + AYUV444_U8 = 10; + RAW10 = 11; + RAW12 = 12; + RAW16 = 13; + RAW24 = 14; + RGB16 = 15; + RGB20 = 16; + RGB24 = 17; + RGB8_IR = 18; + RGB16_IR = 19; + RGB24_IR = 20; + } + + enum AippMode { + undefined = 0; + static = 1; + dynamic = 2; + } + + // AIPPģʽ£¬Çø·Ö¾²Ì¬AIPPºÍ¶¯Ì¬AIPP + AippMode aipp_mode = 1; + + // related_input_rank²ÎÊýΪ±ØÌÀàÐÍΪÕûÐÍ£¬ÅäÖ÷¶Î§>=0, <=ÊäÈëDataËã×ӵĸöÊý£¬Ä¬ÈÏֵΪ0¡£ + // ±êʶ¶ÔÄ£Ð͵ĵڼ¸¸öÊäÈë×öAIPP´¦Àí£¬ÀýÈçÄ£ÐÍÓÐÁ½¸öÊäÈ룬ÐèÒª¶ÔµÚ2¸öÊäÈë×öAIPP£¬ÔòÅäÖÃrelated_input_rankΪ1¡£ + uint32 related_input_rank = 2; + + // input_edge_idx²ÎÊýΪ¿ÉÑ¡£¬ÀàÐÍΪÕûÐÍ£¬ÅäÖ÷¶Î§Îª>=0¡£ + // ÅäÖøòÎÊýµÄ×÷Óã¬ÔÚÓÚ¶ÔDataËã×Ó²»Í¬µÄÊä³ö×ö²»Í¬µÄAIPP´¦Àí£¬Èç¹û¸Ã²ÎÊýûÓÐÅäÖã¬Ä¬È϶Ôrelated_input_rankÖ¸¶¨µÄÄ£ÐÍÊäÈëµÄËùÓÐÊä³ö±ß×öAIPP¡£ + // ÅäÖÃÖµ <= DataËã×ÓÊä³ö±ßµÄ¸öÊý¡£ + repeated uint32 input_edge_idx = 3; + + // [Begin] ¶¯Ì¬AIPP²ÎÊý£¬ÅäÖþ²Ì¬AIPPʱÎÞЧ + uint32 max_src_image_size = 4; + + // ÊÇ·ñÖ§³ÖÐýת¡£Ä¬Èϲ»Ö§³Ö£¬¿ªÆôÖ§³ÖÐýתʱ£¬»áÓжîÍâµÄ¿Õ¼äºÍÐÔÄÜËðʧ + bool support_rotation = 5; + + // [End] ¶¯Ì¬AIPP²ÎÊý + + + // [Begin] ¾²Ì¬AIPP²ÎÊý£¬ÅäÖö¯Ì¬AIPPʱÎÞЧ + InputFormat input_format = 51; + bool csc_switch = 52; + float cpadding_value = 53; + bool rbuv_swap_switch = 54; + bool ax_swap_switch = 55; + bool single_line_mode = 56; + + int32 src_image_size_w = 57; + int32 src_image_size_h = 58; + + bool crop = 59; + int32 load_start_pos_w = 60; + int32 load_start_pos_h = 61; + int32 crop_size_w = 62; + int32 crop_size_h = 63; + + bool resize = 64; + int32 resize_output_w = 65; + int32 resize_output_h = 66; + + bool padding = 67; + int32 left_padding_size = 68; + int32 right_padding_size = 69; + int32 top_padding_size = 70; + int32 bottom_padding_size = 71; + + int32 mean_chn_0 = 10; + int32 mean_chn_1 = 11; + int32 mean_chn_2 = 12; + int32 mean_chn_3 = 19; + float min_chn_0 = 13; + float min_chn_1 = 14; + float min_chn_2 = 15; + float min_chn_3 = 20; + repeated float var_reci_chn_0 = 16; + repeated float var_reci_chn_1 = 17; + repeated float var_reci_chn_2 = 18; + repeated float var_reci_chn_3 = 21; + + repeated int32 matrix_r0c0 = 30; + repeated int32 matrix_r0c1 = 31; + repeated int32 matrix_r0c2 = 32; + repeated int32 matrix_r1c0 = 33; + repeated int32 matrix_r1c1 = 34; + repeated int32 matrix_r1c2 = 35; + repeated int32 matrix_r2c0 = 36; + repeated int32 matrix_r2c1 = 37; + repeated int32 matrix_r2c2 = 38; + repeated int32 output_bias_0 = 39; + repeated int32 output_bias_1 = 40; + repeated int32 output_bias_2 = 41; + repeated int32 input_bias_0 = 42; + repeated int32 input_bias_1 = 43; + repeated int32 input_bias_2 = 44; + + // [End] ¾²Ì¬AIPP²ÎÊý + + // The n number that is used for raw/rgbir data into f16 transformation. + // The transformation equation is x/(2^n). If set to 0, no transform is performed. + uint32 raw_rgbir_to_f16_n = 45; +} + +message MultiShapeOpParams { + enum MultiShapeMode { + batch = 0; //¶¯Ì¬batch + resolution = 1; //¶¯Ì¬·Ö±æÂÊ£¬À©Õ¹Óà + } + + MultiShapeMode mode = 1; //Ëã×Óģʽ + uint32 related_input_rank = 2; //ÐÂÔöËã×Ó²åÈëµ½ÄĸöÊäÈë + + + repeated uint32 batch_list = 11; //batch_listÖµ£¬batch_listµÄ¸öÊýÊÇ2µ½8Ö®¼ä +} diff --git a/ge/common/proto/om.proto b/ge/common/proto/om.proto new file mode 100644 index 00000000..dd992191 --- /dev/null +++ b/ge/common/proto/om.proto @@ -0,0 +1,401 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package domi; + +enum TargetType +{ + MINI = 0; + TINY = 1; + LITE = 2; +} + +// offline model +message ModelDef { + string name = 1; + uint32 version = 2; + + uint64 memory_size = 10; + uint32 stream_num = 11; + uint32 event_num = 12; + uint64 weight_size = 13; + uint32 label_num = 15; + repeated OpDef op = 20; + TargetType target_type = 23; + + map attr = 30; +}; + +// operator define +message OpDef { + string name = 1; + string type = 2; + + uint32 id = 3; + uint32 stream_id = 4; + + repeated string input_name = 5; + + repeated string src_name = 8; + repeated int32 src_index = 9; + repeated int64 input = 10; + repeated int64 output = 11; + repeated TensorDescriptor input_desc = 12; + repeated TensorDescriptor output_desc = 13; + repeated WeightDef weights = 14; + repeated string dst_name = 15; + repeated int32 dst_index = 16; + + repeated int64 workspace = 20; + repeated uint32 workspace_bytes = 21; + + repeated string weight_name = 22; + repeated bool is_input_const = 23; + + map attr = 30; + + QuantizeFactorParams quantize_factor = 31; + + oneof op_params { + // start at 100 here + SendOpParams sender_param = 100; + RecvOpParams receiver_param = 200; + ConvolutionOpParams convolution_param = 300; + PoolingOpParams pooling_param = 400; + EltwiseOpParams eltwise_param = 500; + BatchNormOpParams batchnorm_param = 600; + ScaleOpParams scale_param = 700; + FullConnectionOpParams full_connection_param = 800; + SoftmaxOpParams softmax_param = 900; + ActivationOpParams activation_param = 1000; + ReshapeOpParams reshape_param = 1100; + } +}; + +message SendOpParams { + uint32 event_id = 1; +}; + +message RecvOpParams { + uint32 event_id = 1; +}; + +enum QuantizeScaleType +{ + VECTOR_SCALE = 0; + SCALAR_SCALE = 1; +} + +enum QuantizeScaleMode +{ + NORMAL_MODE = 0; + SQRT_MODE = 1; +} + +enum QuantizeAlgorithm +{ + NON_OFFSET_ALGO = 0; + HALF_OFFSET_ALGO = 1; + ALL_OFFSET_ALGO = 2; +} +message QuantizeFactor +{ + QuantizeScaleMode scale_mode = 1; + bytes scale_value = 2; + int64 scale_offset = 3; + bytes offset_data_value = 4; + int64 offset_data_offset = 5; + bytes offset_weight_value = 6; + int64 offset_weight_offset = 7; + bytes offset_pad_value = 8; + int64 offset_pad_offset = 9; +}; + +message QuantizeCalcFactor +{ + bytes offsetw = 1; + int64 offsetw_offset = 2; + bytes offsetd = 3; + int64 offsetd_offset = 4; + bytes scalereq = 5; + int64 scaledreq_offset = 6; + bytes offsetdnext = 7; + int64 offsetdnext_offset = 8; +} + +message QuantizeFactorParams +{ + QuantizeAlgorithm quantize_algo = 1; + QuantizeScaleType scale_type = 2; + QuantizeFactor quantize_param = 3; + QuantizeFactor dequantize_param = 4; + QuantizeFactor requantize_param = 5; + QuantizeCalcFactor quantizecalc_param = 6; +}; + +message ConvolutionOpParams { + int32 mode = 1; + int32 algo = 2; + int32 pad_mode = 3; + uint32 group = 4; + uint32 num_output = 5; + + repeated uint32 pad = 10; + repeated uint32 stride = 11; + repeated uint32 dilation = 12; + repeated uint32 kernel = 13; + + float alpha = 20; + float beta = 21; + + WeightDef filter = 40; + WeightDef bias = 41; + + bool relu_flag = 62; + repeated uint32 adj = 70; + repeated uint32 target_shape = 71; + repeated uint32 before_pad = 72; +}; + +message PoolingOpParams { + int32 mode = 1; + int32 nan_opt = 2; + int32 pad_mode = 3; + bool global_pooling = 4; + + repeated uint32 window = 10; + repeated uint32 pad = 11; + repeated uint32 stride = 12; + bool ceil_mode = 13; + int32 data_mode = 14; + + float alpha = 20; + float beta = 21; + repeated uint32 before_pad = 22; +}; + +message EltwiseOpParams { + int32 mode = 1; + repeated float coeff = 2; + float alpha = 3; + float beta = 4; + repeated WeightDef weight = 5; + bool relu_flag = 6; +}; + +message ActivationOpParams { + int32 mode = 1; + float coef = 2; + float alpha = 3; + float beta = 4; +}; + +message BatchNormOpParams { + int32 mode = 1; + + float alpha = 2; + float beta = 3; + double epsilon = 4;//optinal,[default = 1e-5] + bool use_global_stats = 5; //optinal,by default true,testing mode + float moving_average_fraction = 6; //optinal,[default = .999]; + + WeightDef estimated_mean = 7; + WeightDef estimated_variance = 8; + + WeightDef scale = 9; + WeightDef bias = 10; +}; + +message ScaleOpParams { + WeightDef scale = 1; + WeightDef bias = 2; +}; + +message ReshapeOpParams { + float alpha = 1; + float beta = 2; + ShapeDef shape = 3; + int32 axis = 4; + int32 num_axes = 5; + int32 format = 6; +}; + +message SoftmaxOpParams { + int32 algo = 1; + int32 mode = 2; + float alpha = 3; + float beta = 4; +}; + +message FullConnectionOpParams { + WeightDef filter = 1; + WeightDef bias = 2; + uint32 num_output = 3; + bool relu_flag = 12; +}; + +message FlattenOpParams { + float alpha = 1; + float beta = 2; + int32 start_axis = 3; + int32 end_axis = 4; +} + +message AddLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message MulLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message AddOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message MulOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message SubOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message BiasAddOpParams { + float alpha = 1; + float beta = 2; + + WeightDef bias = 10; +}; + +message MatMulOpParams { + float alpha = 1; + float beta = 2; + bool transposeX = 3; + bool transposeW = 4; + + WeightDef filter = 10; + WeightDef bias = 12; +}; + +message RsqrtOpParams { + float alpha = 1; + float beta = 2; +}; + + +message WeightDef { + int32 format = 1; + int32 data_type = 2; + ShapeDef shape = 3; + bytes data = 4; + int64 data_offset = 5; + uint32 cmps_size = 6; + bytes cmps_tab = 7; + int64 cmps_tab_offset = 10; + CompressInfo cmps_info = 8; + AllOffsetQuantizeInfo alloffset_quantize_info = 11; +} + +message ShapeDef { + repeated int64 dim = 1; +} + +enum DeviceType { + NPU = 0; // In default, we will use NPU. + CPU = 1; // CPU +} + +message AllOffsetQuantizeInfo { + float scale = 1; + int32 offset = 2; +} + +message TensorDescriptor { + int32 format = 1; + int32 data_type = 2; + repeated int64 dim = 3; + uint32 size = 4; + bool reuse_input = 5; + bool output_tensor = 7; + DeviceType device_type = 8; + bool input_tensor = 9; + uint32 real_dim_cnt = 10; + uint32 reuse_input_index = 11; + AllOffsetQuantizeInfo alloffset_quantize_info = 12; +} + +message CompressInfo { + int32 blockRow = 1; // block row + int32 blockCol = 2; // block col + int32 fractalK = 3; // fractal K + int32 fractalN = 4; // fractal N + int32 lastFractalK = 5; // K of last fractal + int32 lastFractalN = 6; // N of last fractal + int32 cubeSize = 7; // cube's length + int32 loadDir = 8; // data load directtiono 0:col load 1:row load +} + +message AttrDef { + message ListValue { + repeated string s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated uint32 u = 6 [packed = true]; // "list(uint)" + repeated bytes bt = 7; + } + + oneof value { + string s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + uint32 u = 6; // "uint32" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs { + string name = 1; + map attr = 2; +} + diff --git a/src/proto/op_mapping_info.proto b/ge/common/proto/op_mapping_info.proto similarity index 100% rename from src/proto/op_mapping_info.proto rename to ge/common/proto/op_mapping_info.proto diff --git a/ge/common/proto/task.proto b/ge/common/proto/task.proto new file mode 100644 index 00000000..50ea061b --- /dev/null +++ b/ge/common/proto/task.proto @@ -0,0 +1,170 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package domi; + +message ModelTaskDef { + string version = 1; + + map attr = 9; // Extended field + repeated TaskDef task = 10; + + uint64 memory_size = 11; + uint32 stream_num = 12; + uint32 event_num = 13; + uint64 weight_size = 14; + + repeated bytes op = 15; // input/output opdef in bytes + + uint64 base_addr = 16; // base addr + uint64 weight_addr = 17; // weight addr + uint32 batch_num = 18; +} + + +message TaskDef { + uint32 id = 1; + uint32 type = 2; + + uint32 stream_id = 10; + uint32 event_id = 11; + + KernelDef kernel = 20; + KernelExDef kernel_ex = 21; + KernelHcclDef kernel_hccl = 25; + EventExDef event_ex = 26; + LogTimeStampDef log_timestamp = 28; + + uint32 label_id = 30; + + MemcpyAsyncDef memcpy_async = 31; + StreamSwitchDef stream_switch = 32; + StreamActiveDef stream_active = 33; + bytes private_def = 34; + uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future + StreamSwitchNDef stream_switch_n = 36; + + LabelSetDef label_set = 37; + LabelGotoExDef label_goto_ex = 38; + LabelSwitchByIndexDef label_switch_by_index = 39; +} + +message KernelDef { + KernelContext context = 1; + + string stub_func = 10; + uint32 block_dim = 11; + uint32 args_size = 12; + bytes args = 13; + bytes sm_desc = 14; + bytes flowtable = 15; + string so_name = 16; + string kernel_name = 17; + bytes kernel_ext_info = 18; + uint32 kernel_ext_info_size = 19; +} + +message KernelContext { + uint32 kernel_type = 1; + uint32 op_id = 2; // OP type in CCE + uint32 kernel_func_id = 3; + uint32 op_index = 4; // TE/Custom operator + bool is_flowtable = 5; // Identify whether args is a flowtable structure + bytes args_offset = 6; // args offset information + uint32 args_count = 7; // args count + repeated uint32 origin_op_index = 8; +} + + +message KernelExDef { + uint32 flags = 1; + + uint32 op_index = 4; + uint32 args_size = 12; + bytes args = 13; + bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput + uint32 task_info_size = 15; + bytes kernel_ext_info = 16; + uint32 kernel_ext_info_size = 17; +} + + +message KernelHcclDef { + uint32 op_index = 8; + string hccl_type = 9; +} + + +message EventExDef { + uint32 op_index = 1; + uint32 event_type = 2; +} + +message LogTimeStampDef { + uint64 logid = 1; + bool notify = 2; + uint32 flat = 3; +} + +message MemcpyAsyncDef { + uint64 dst = 1; + uint64 dst_max = 2; + uint64 src = 3; + uint64 count = 4; + uint32 kind = 5; + uint32 op_index = 6; +} + +message StreamSwitchDef { + uint32 op_index = 1; + uint32 true_stream_id = 2; + int64 value = 3; + uint64 value_ptr = 4; + uint32 data_type = 5; +} + +message StreamActiveDef { + uint32 op_index = 1; + uint32 active_stream_id = 2; +} + +message StreamSwitchNDef { + uint32 op_index = 1; + uint32 size = 2; + repeated int64 target_value = 3; + repeated uint32 true_stream_id = 4; + uint32 element_size = 5; + uint32 data_type = 6; +} + +message LabelSetDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelGotoExDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelSwitchByIndexDef { + uint32 op_index = 1; + uint32 label_max = 2; +} diff --git a/ge/common/proto/tensorflow/attr_value.proto b/ge/common/proto/tensorflow/attr_value.proto new file mode 100644 index 00000000..1cc67d62 --- /dev/null +++ b/ge/common/proto/tensorflow/attr_value.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AttrValueProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensor.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + // LINT.IfChange + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + repeated NameAttrList func = 9; // "list(attr)" + } + // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/ge/common/proto/tensorflow/function.proto b/ge/common/proto/tensorflow/function.proto new file mode 100644 index 00000000..075897c6 --- /dev/null +++ b/ge/common/proto/tensorflow/function.proto @@ -0,0 +1,100 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "FunctionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "node_def.proto"; +import "op_def.proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; + repeated GradientDef gradient = 2; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // Attributes specific to this function definition. + map attr = 5; + + // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. + reserved 2; + + // In both of the following fields, there is the need to specify an + // output that is used as either the input to another node (in + // `node_def`) or as a return value of the function (in `ret`). + // Unlike the NodeDefs in GraphDef, we need to be able to specify a + // list in some cases (instead of just single outputs). Also, we + // need to be able to deal with lists of unknown length (so the + // output index may not be known at function definition time). So + // we use the following format instead: + // * "fun_in" where "fun_in" is the name of a function input arg in + // the `signature` field above. This represents that input, whether + // it is a single tensor or a list. + // * "fun_in:0" gives the first element of a function input arg (a + // non-list input is considered a list of length 1 for these + // purposes). + // * "node:out" where "node" is the name of a node in `node_def` and + // "out" is the name one of its op's output arguments (the name + // comes from the OpDef of the node's op). This represents that + // node's output, whether it is a single tensor or a list. + // Note: We enforce that an op's output arguments are never + // renamed in the backwards-compatibility test. + // * "node:out:0" gives the first element of a node output arg (a + // non-list output is considered a list of length 1 for these + // purposes). + // + // NOT CURRENTLY SUPPORTED (but may be in the future): + // * "node:out:-1" gives last element in a node output list + // * "node:out:1:" gives a list with all but the first element in a + // node output list + // * "node:out::-1" gives a list with all but the last element in a + // node output list + + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs + // may have values of type `placeholder` and the `input` field uses + // the "output" format above. + + // By convention, "op" in node_def is resolved by consulting with a + // user-defined library first. If not resolved, "func" is assumed to + // be a builtin op. + repeated NodeDef node_def = 3; + + // A mapping from the output arg names from `signature` to the + // outputs from `node_def` that should be returned by the function. + map ret = 4; +} + +// GradientDef defines the gradient function of a function defined in +// a function library. +// +// A gradient function g (specified by gradient_func) for a function f +// (specified by function_name) must follow the following: +// +// The function 'f' must be a numerical function which takes N inputs +// and produces M outputs. Its gradient function 'g', which is a +// function taking N + M inputs and produces N outputs. +// +// I.e. if we have +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// then, g is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the +// loss function). dL/dx_i is the partial derivative of L with respect +// to x_i. +message GradientDef { + string function_name = 1; // The function name. + string gradient_func = 2; // The gradient function's name. +} diff --git a/ge/common/proto/tensorflow/graph.proto b/ge/common/proto/tensorflow/graph.proto new file mode 100644 index 00000000..d639a7d6 --- /dev/null +++ b/ge/common/proto/tensorflow/graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "node_def.proto"; +import "function.proto"; +import "versions.proto"; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + + // Compatibility versions of the graph. See core/public/version.h for version + // history. The GraphDef version is distinct from the TensorFlow version, and + // each release of TensorFlow will support a range of GraphDef versions. + VersionDef versions = 4; + + // Deprecated single version field; use versions above instead. Since all + // GraphDef changes before "versions" was introduced were forward + // compatible, this field is entirely ignored. + int32 version = 3 [deprecated = true]; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", { ... }} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; diff --git a/ge/common/proto/tensorflow/graph_library.proto b/ge/common/proto/tensorflow/graph_library.proto new file mode 100644 index 00000000..e393d38d --- /dev/null +++ b/ge/common/proto/tensorflow/graph_library.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package domi.tensorflow; + +import "graph.proto"; + +message GeGraphDef { + string name = 1; + GraphDef graph = 2; +} + +message GraphDefLibrary { + repeated GeGraphDef graph_def = 1; +}; \ No newline at end of file diff --git a/ge/common/proto/tensorflow/node_def.proto b/ge/common/proto/tensorflow/node_def.proto new file mode 100644 index 00000000..b9bc97ee --- /dev/null +++ b/ge/common/proto/tensorflow/node_def.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "NodeProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= PARTIAL_SPEC + // + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) + // * "/job:worker/device:GPU:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // Add some examples here showing best practices. + map attr = 5; +}; diff --git a/ge/common/proto/tensorflow/op_def.proto b/ge/common/proto/tensorflow/op_def.proto new file mode 100644 index 00000000..3485d045 --- /dev/null +++ b/ge/common/proto/tensorflow/op_def.proto @@ -0,0 +1,164 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "OpDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +// LINT.IfChange +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the "list" field of AttrValue). + // If type == "type" or "list(type)" above, then the "type" field + // of "allowed_values.list" has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the "s" field of + // "allowed_values.list" has the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // Optional deprecation based on GraphDef versions. + OpDeprecation deprecation = 8; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // Ops are marked as stateful if their behavior depends on some state beyond + // their input tensors (e.g. variable reading op) or if they have + // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops + // must always produce the same output for the same input and have + // no side-effects. + // + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) + +// Information about version-dependent deprecation of an op +message OpDeprecation { + // First GraphDef version at which the op is disallowed. + int32 version = 1; + + // Explanation of why it was deprecated and what to use instead. + string explanation = 2; +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/ge/common/proto/tensorflow/resource_handle.proto b/ge/common/proto/tensorflow/resource_handle.proto new file mode 100644 index 00000000..a3452351 --- /dev/null +++ b/ge/common/proto/tensorflow/resource_handle.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ResourceHandle"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message ResourceHandleProto { + // Unique name for the device containing the resource. + string device = 1; + + // Container in which this resource is placed. + string container = 2; + + // Unique name of this resource. + string name = 3; + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code = 4; + + // For debug-only, the name of the type pointed to by this handle, if + // available. + string maybe_type_name = 5; +}; diff --git a/ge/common/proto/tensorflow/tensor.proto b/ge/common/proto/tensorflow/tensor.proto new file mode 100644 index 00000000..d0a4d024 --- /dev/null +++ b/ge/common/proto/tensorflow/tensor.proto @@ -0,0 +1,94 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "resource_handle.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized raw tensor content from either Tensor::AsProtoTensorContent or + // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + // can be used for all tensor types. The purpose of this representation is to + // reduce serialization overhead during RPC call by avoiding serialization of + // many repeated small items. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll + // have some pointless zero padding for each value here. + repeated int32 half_val = 13 [packed = true]; + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; + + // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + // and imaginary parts of i-th double precision complex. + repeated double dcomplex_val = 12 [packed = true]; + + // DT_RESOURCE + repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; + + // DT_UINT32 + repeated uint32 uint32_val = 16 [packed = true]; + + // DT_UINT64 + repeated uint64 uint64_val = 17 [packed = true]; +}; + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/ge/common/proto/tensorflow/tensor_shape.proto b/ge/common/proto/tensorflow/tensor_shape.proto new file mode 100644 index 00000000..4225a2e3 --- /dev/null +++ b/ge/common/proto/tensorflow/tensor_shape.proto @@ -0,0 +1,45 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorShapeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +package domi.tensorflow; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/ge/common/proto/tensorflow/types.proto b/ge/common/proto/tensorflow/types.proto new file mode 100644 index 00000000..ba7a72b3 --- /dev/null +++ b/ge/common/proto/tensorflow/types.proto @@ -0,0 +1,74 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TypesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// LINT.IfChange +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + DT_QINT16 = 15; // Quantized int16 + DT_QUINT16 = 16; // Quantized uint16 + DT_UINT16 = 17; + DT_COMPLEX128 = 18; // Double-precision complex + DT_HALF = 19; + DT_RESOURCE = 20; + DT_VARIANT = 21; // Arbitrary C++ data types + DT_UINT32 = 22; + DT_UINT64 = 23; + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; + DT_QINT16_REF = 115; + DT_QUINT16_REF = 116; + DT_UINT16_REF = 117; + DT_COMPLEX128_REF = 118; + DT_HALF_REF = 119; + DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; + DT_UINT32_REF = 122; + DT_UINT64_REF = 123; +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/c/c_api.h, +// https://www.tensorflow.org/code/tensorflow/go/tensor.go, +// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, +// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, +// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/ge/common/proto/tensorflow/versions.proto b/ge/common/proto/tensorflow/versions.proto new file mode 100644 index 00000000..48061218 --- /dev/null +++ b/ge/common/proto/tensorflow/versions.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VersionsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Version information for a piece of serialized data +// +// There are different types of versions for each type of data +// (GraphDef, etc.), but they all have the same common shape +// described here. +// +// Each consumer has "consumer" and "min_producer" versions (specified +// elsewhere). A consumer is allowed to consume this data if +// +// producer >= min_producer +// consumer >= min_consumer +// consumer not in bad_consumers +// +message VersionDef { + // The version of the code that produced this data. + int32 producer = 1; + + // Any consumer below this version is not allowed to consume this data. + int32 min_consumer = 2; + + // Specific consumer versions which are disallowed (e.g. due to bugs). + repeated int32 bad_consumers = 3; +}; diff --git a/src/ge/common/singleton.h b/ge/common/singleton.h similarity index 100% rename from src/ge/common/singleton.h rename to ge/common/singleton.h diff --git a/src/ge/common/tbe_kernel_store.cc b/ge/common/tbe_kernel_store.cc similarity index 100% rename from src/ge/common/tbe_kernel_store.cc rename to ge/common/tbe_kernel_store.cc diff --git a/src/ge/common/tbe_kernel_store.h b/ge/common/tbe_kernel_store.h similarity index 100% rename from src/ge/common/tbe_kernel_store.h rename to ge/common/tbe_kernel_store.h diff --git a/src/ge/common/thread_pool.cc b/ge/common/thread_pool.cc similarity index 100% rename from src/ge/common/thread_pool.cc rename to ge/common/thread_pool.cc diff --git a/src/ge/common/thread_pool.h b/ge/common/thread_pool.h similarity index 100% rename from src/ge/common/thread_pool.h rename to ge/common/thread_pool.h diff --git a/src/ge/common/types.cc b/ge/common/types.cc similarity index 100% rename from src/ge/common/types.cc rename to ge/common/types.cc diff --git a/src/ge/common/util.cc b/ge/common/util.cc similarity index 100% rename from src/ge/common/util.cc rename to ge/common/util.cc diff --git a/src/ge/engine_manager/dnnengine_manager.cc b/ge/engine_manager/dnnengine_manager.cc similarity index 100% rename from src/ge/engine_manager/dnnengine_manager.cc rename to ge/engine_manager/dnnengine_manager.cc diff --git a/src/ge/engine_manager/dnnengine_manager.h b/ge/engine_manager/dnnengine_manager.h similarity index 100% rename from src/ge/engine_manager/dnnengine_manager.h rename to ge/engine_manager/dnnengine_manager.h diff --git a/src/ge/engine_manager/engine_conf.json b/ge/engine_manager/engine_conf.json old mode 100755 new mode 100644 similarity index 95% rename from src/ge/engine_manager/engine_conf.json rename to ge/engine_manager/engine_conf.json index 82360562..4a767fb8 --- a/src/ge/engine_manager/engine_conf.json +++ b/ge/engine_manager/engine_conf.json @@ -1,61 +1,61 @@ -{ - "schedule_units": [ - { - "id": "TS_1", - "name": "1980_hwts", - "ex_attrs": "", - "cal_engines": [ - { - "id": "DNN_VM_HOST_CPU", - "name": "HOST_CPU", - "independent": false, - "skip_assign_stream": true, - "attach": true - }, - { - "id": "DNN_VM_GE_LOCAL", - "name": "GE_LOCAL", - "independent": false, - "skip_assign_stream": true, - "attach": true - }, - { - "id": "AIcoreEngine", - "name": "AICORE", - "independent": false, - "skip_assign_stream": false, - "attach":false - }, - - { - "id": "VectorEngine", - "name": "VECTORCORE", - "independent": false, - "skip_assign_stream": false, - "attach":false - }, - { - "id": "DNN_VM_AICPU", - "name": "AICPU", - "independent": false, - "skip_assign_stream": false, - "attach": true - }, - { - "id": "DNN_HCCL", - "name": "HCCL", - "independent": true, - "skip_assign_stream": false, - "attach": false - }, - { - "id": "DNN_VM_RTS", - "name": "RTS", - "independent": false, - "skip_assign_stream": false, - "attach": true - } - ] - } - ] -} +{ + "schedule_units": [ + { + "id": "TS_1", + "name": "1980_hwts", + "ex_attrs": "", + "cal_engines": [ + { + "id": "DNN_VM_HOST_CPU", + "name": "HOST_CPU", + "independent": false, + "skip_assign_stream": true, + "attach": true + }, + { + "id": "DNN_VM_GE_LOCAL", + "name": "GE_LOCAL", + "independent": false, + "skip_assign_stream": true, + "attach": true + }, + { + "id": "AIcoreEngine", + "name": "AICORE", + "independent": false, + "skip_assign_stream": false, + "attach":false + }, + + { + "id": "VectorEngine", + "name": "VECTORCORE", + "independent": false, + "skip_assign_stream": false, + "attach":false + }, + { + "id": "DNN_VM_AICPU", + "name": "AICPU", + "independent": false, + "skip_assign_stream": false, + "attach": true + }, + { + "id": "DNN_HCCL", + "name": "HCCL", + "independent": true, + "skip_assign_stream": false, + "attach": false + }, + { + "id": "DNN_VM_RTS", + "name": "RTS", + "independent": false, + "skip_assign_stream": false, + "attach": true + } + ] + } + ] +} diff --git a/ge/executor/CMakeLists.txt b/ge/executor/CMakeLists.txt new file mode 100755 index 00000000..f247fd46 --- /dev/null +++ b/ge/executor/CMakeLists.txt @@ -0,0 +1,113 @@ +set(PROTO_LIST + "${METADEF_DIR}/proto/om.proto" + "${METADEF_DIR}/proto/ge_ir.proto" + "${METADEF_DIR}/proto/insert_op.proto" + "${METADEF_DIR}/proto/task.proto" + "${METADEF_DIR}/proto/op_mapping_info.proto" + "${METADEF_DIR}/proto/dump_task.proto" +) + +protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) + +set(SRC_LIST + "ge_executor.cc" + "../common/profiling/profiling_manager.cc" + "../common/ge/plugin_manager.cc" + "../common/ge/op_tiling_manager.cc" + "../common/dump/dump_properties.cc" + "../common/dump/dump_manager.cc" + "../common/dump/dump_op.cc" + "../graph/load/graph_loader.cc" + "../graph/execute/graph_execute.cc" + "../omm/csa_interact.cc" + "../graph/manager/graph_manager_utils.cc" + "../graph/manager/graph_var_manager.cc" + "../graph/manager/graph_mem_allocator.cc" + "../graph/manager/graph_caching_allocator.cc" + "../graph/manager/trans_var_data_utils.cc" + "../graph/manager/util/debug.cc" + "../graph/manager/rdma_pool_allocator.cc" + "../hybrid/node_executor/aicpu/aicpu_ext_info.cc" + "../model/ge_model.cc" + "../model/ge_root_model.cc" + "../graph/load/new_model_manager/davinci_model.cc" + "../graph/load/new_model_manager/davinci_model_parser.cc" + "../graph/load/new_model_manager/model_manager.cc" + "../graph/load/new_model_manager/tbe_handle_store.cc" + "../graph/load/new_model_manager/cpu_queue_schedule.cc" + "../graph/load/new_model_manager/model_utils.cc" + "../graph/load/new_model_manager/aipp_utils.cc" + "../graph/load/new_model_manager/data_inputer.cc" + "../graph/load/new_model_manager/data_dumper.cc" + "../graph/load/new_model_manager/zero_copy_task.cc" + "../graph/load/new_model_manager/zero_copy_offset.cc" + "../graph/load/new_model_manager/task_info/task_info.cc" + "../graph/load/new_model_manager/task_info/event_record_task_info.cc" + "../graph/load/new_model_manager/task_info/event_wait_task_info.cc" + "../graph/load/new_model_manager/task_info/fusion_start_task_info.cc" + "../graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" + "../graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" + "../graph/load/new_model_manager/task_info/kernel_task_info.cc" + "../graph/load/new_model_manager/task_info/label_set_task_info.cc" + "../graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" + "../graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" + "../graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" + "../graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" + "../graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" + "../graph/load/new_model_manager/task_info/stream_active_task_info.cc" + "../graph/load/new_model_manager/task_info/stream_switch_task_info.cc" + "../graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" + "../graph/load/new_model_manager/task_info/end_graph_task_info.cc" + "../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" + "../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" + "../single_op/single_op_manager.cc" + "../single_op/single_op_model.cc" + "../single_op/single_op.cc" + "../single_op/stream_resource.cc" + "../single_op/task/op_task.cc" + "../single_op/task/build_task_utils.cc" + "../single_op/task/tbe_task_builder.cc" + "../single_op/task/aicpu_task_builder.cc" + "../single_op/task/aicpu_kernel_task_builder.cc" + "../hybrid/hybrid_davinci_model_stub.cc" +) + +######## libge_executor.a ######## +add_library(ge_executor STATIC ${SRC_LIST} ${PROTO_HDRS}) + +target_compile_options(ge_executor PRIVATE + -Werror + -O2 +) + +target_compile_definitions(ge_executor PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + DAVINCI_SUPPORT_PROFILING +) + +target_include_directories(ge_executor PRIVATE + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + ${GE_CODE_DIR}/../inc/cce + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc +) + +target_link_libraries(ge_executor PRIVATE + $ + json + protobuf + c_sec + -lrt + -ldl +) diff --git a/src/ge/executor/ge_executor.cc b/ge/executor/ge_executor.cc similarity index 100% rename from src/ge/executor/ge_executor.cc rename to ge/executor/ge_executor.cc diff --git a/src/ge/executor/module.mk b/ge/executor/module.mk similarity index 100% rename from src/ge/executor/module.mk rename to ge/executor/module.mk diff --git a/src/proto/dump_task.proto b/ge/executor/proto/dump_task.proto similarity index 100% rename from src/proto/dump_task.proto rename to ge/executor/proto/dump_task.proto diff --git a/ge/executor/proto/ge_ir.proto b/ge/executor/proto/ge_ir.proto new file mode 100644 index 00000000..87886c84 --- /dev/null +++ b/ge/executor/proto/ge_ir.proto @@ -0,0 +1,206 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package ge.proto; + +enum DataType +{ + DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. + DT_FLOAT = 1; // float type + DT_FLOAT16 = 2; // fp16 type + DT_INT8 = 3; // int8 type + DT_UINT8 = 4; // uint8 type + DT_INT16 = 5; // int16 type + DT_UINT16 = 6; // uint16 type + DT_INT32 = 7; // + DT_INT64 = 8; // int64 type + DT_UINT32 = 9; // unsigned int32 + DT_UINT64 = 10; // unsigned int64 + DT_BOOL = 11; // bool type + DT_DOUBLE = 12; // double type + DT_STRING = 13; // string type + DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ + DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ + DT_COMPLEX64 = 16; // complex64 type + DT_COMPLEX128 = 17; // complex128 type + DT_QINT8 = 18; // qint8 type + DT_QINT16 = 19; // qint16 type + DT_QINT32 = 20; // qint32 type + DT_QUINT8 = 21; // quint8 type + DT_QUINT16 = 22; // quint16 type + DT_RESOURCE = 23; // resource type + DT_STRING_REF = 24; // string_ref type + DT_DUAL = 25; /**< dual output type */ +} + +message AttrDef +{ + message ListValue + { + enum ListValueType{ + VT_LIST_NONE = 0; + VT_LIST_STRING = 1; + VT_LIST_INT = 2; + VT_LIST_FLOAT = 3; + VT_LIST_BOOL = 4; + VT_LIST_BYTES = 5; + VT_LIST_TENSOR_DESC = 6; + VT_LIST_TENSOR = 7; + VT_LIST_GRAPH = 8; + VT_LIST_NAMED_ATTRS = 9; + VT_LIST_DATA_TYPE = 10; + } + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3; // "list(int)" + repeated float f = 4; // "list(float)" + repeated bool b = 5; // "list(bool)" + repeated bytes bt = 7; + repeated TensorDescriptor td = 8; + repeated TensorDef t = 9; + repeated GraphDef g = 10; + repeated NamedAttrs na = 11; + repeated int64 dt = 12; // list ge::DataType + + ListValueType val_type = 20; + } + + message ListListInt{ + message ListInt{ + repeated int64 list_i = 1; // list int + } + repeated ListInt list_list_i = 1; // list list int + } + + oneof value + { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; // Used to support attr nesting + TensorDescriptor td = 11; // GeTensorDesc type + TensorDef t = 12; // GeTensor type + GraphDef g = 13; // Graph type + ListListInt list_list_int = 14; // List List Int type + int64 dt = 15; // ge::DataType + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs +{ + string name = 1; + map attr = 2; +} + +// Shape / dimension description, using row-major order +message ShapeDef +{ + repeated int64 dim = 1; // Size of each dimension +} + +// Multidimensional data description +message TensorDescriptor +{ + string name = 1; // Optional parameter, tensor name + + DataType dtype = 2; // tensor datatype + ShapeDef shape = 3; // Shape / dimension + string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" + + bool has_out_attr = 9; + int64 size = 10; + int64 weight_size = 11; + bool reuse_input = 12; + bool output_tensor = 13; + string device_type = 14; + bool input_tensor =15; + int64 real_dim_cnt = 16; + int64 reuse_input_index = 17; + int64 data_offset = 18; + int64 cmps_size = 19; + string cmps_tab = 20; + int64 cmps_tab_offset = 21; + + map attr = 5; // Set of extra parameter fields +} + +// GeTensor definition +message TensorDef +{ + TensorDescriptor desc = 1; // Tensor description + bytes data = 2; // Tensor data +} + + +// Operator description +message OpDef +{ + string name = 1; // name + string type = 2; // type + + repeated string input = 5; // input original op name + outgoing index. op_name:index + + map attr = 10; // Set of operator parameter fields + + bool has_out_attr = 20; + int64 id = 21; + int64 stream_id =22; + repeated string input_name = 23; + repeated string src_name = 24; + repeated int64 src_index = 25; + repeated string dst_name = 26; + repeated int64 dst_index = 27; + repeated int64 input_i = 28; + repeated int64 output_i = 29; + repeated int64 workspace = 30; + repeated int64 workspace_bytes = 31; + repeated bool is_input_const = 32; + repeated TensorDescriptor input_desc = 33; + repeated TensorDescriptor output_desc = 34; + repeated string subgraph_name = 35; +} + +// Graph definition +message GraphDef +{ + string name = 1; // name + + repeated string input = 4; // Graph input + repeated string output = 5; // Graph output + + repeated OpDef op = 6; // List of operators + + map attr = 11; // Extended field +} + +// model definition +message ModelDef +{ + string name = 1; // name + uint32 version = 2; // IR Proto verion + string custom_version = 3; // User model version number, passed in by user + + repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef + + map attr = 11; // Extended field +} + diff --git a/ge/executor/proto/insert_op.proto b/ge/executor/proto/insert_op.proto new file mode 100644 index 00000000..a059e122 --- /dev/null +++ b/ge/executor/proto/insert_op.proto @@ -0,0 +1,152 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package domi; + +message InsertNewOps { + repeated AippOpParams aipp_op = 1; + repeated MultiShapeOpParams multi_shape_op = 2; +} + +message AippOpParams { + enum InputFormat { + UNDEFINED = 0; + YUV420SP_U8 = 1; + XRGB8888_U8 = 2; + RGB888_U8 = 3; + YUV400_U8 = 4; + NC1HWC0DI_FP16 = 5; + NC1HWC0DI_S8 = 6; + ARGB8888_U8 = 7; + YUYV_U8 = 8; + YUV422SP_U8 = 9; + AYUV444_U8 = 10; + RAW10 = 11; + RAW12 = 12; + RAW16 = 13; + RAW24 = 14; + RGB16 = 15; + RGB20 = 16; + RGB24 = 17; + RGB8_IR = 18; + RGB16_IR = 19; + RGB24_IR = 20; + } + + enum AippMode { + undefined = 0; + static = 1; + dynamic = 2; + } + + // AIPPģʽ£¬Çø·Ö¾²Ì¬AIPPºÍ¶¯Ì¬AIPP + AippMode aipp_mode = 1; + + // related_input_rank²ÎÊýΪ±ØÌÀàÐÍΪÕûÐÍ£¬ÅäÖ÷¶Î§>=0, <=ÊäÈëDataËã×ӵĸöÊý£¬Ä¬ÈÏֵΪ0¡£ + // ±êʶ¶ÔÄ£Ð͵ĵڼ¸¸öÊäÈë×öAIPP´¦Àí£¬ÀýÈçÄ£ÐÍÓÐÁ½¸öÊäÈ룬ÐèÒª¶ÔµÚ2¸öÊäÈë×öAIPP£¬ÔòÅäÖÃrelated_input_rankΪ1¡£ + uint32 related_input_rank = 2; + + // input_edge_idx²ÎÊýΪ¿ÉÑ¡£¬ÀàÐÍΪÕûÐÍ£¬ÅäÖ÷¶Î§Îª>=0¡£ + // ÅäÖøòÎÊýµÄ×÷Óã¬ÔÚÓÚ¶ÔDataËã×Ó²»Í¬µÄÊä³ö×ö²»Í¬µÄAIPP´¦Àí£¬Èç¹û¸Ã²ÎÊýûÓÐÅäÖã¬Ä¬È϶Ôrelated_input_rankÖ¸¶¨µÄÄ£ÐÍÊäÈëµÄËùÓÐÊä³ö±ß×öAIPP¡£ + // ÅäÖÃÖµ <= DataËã×ÓÊä³ö±ßµÄ¸öÊý¡£ + repeated uint32 input_edge_idx = 3; + + // [Begin] ¶¯Ì¬AIPP²ÎÊý£¬ÅäÖþ²Ì¬AIPPʱÎÞЧ + uint32 max_src_image_size = 4; + + // ÊÇ·ñÖ§³ÖÐýת¡£Ä¬Èϲ»Ö§³Ö£¬¿ªÆôÖ§³ÖÐýתʱ£¬»áÓжîÍâµÄ¿Õ¼äºÍÐÔÄÜËðʧ + bool support_rotation = 5; + + // [End] ¶¯Ì¬AIPP²ÎÊý + + + // [Begin] ¾²Ì¬AIPP²ÎÊý£¬ÅäÖö¯Ì¬AIPPʱÎÞЧ + InputFormat input_format = 51; + bool csc_switch = 52; + float cpadding_value = 53; + bool rbuv_swap_switch = 54; + bool ax_swap_switch = 55; + bool single_line_mode = 56; + + int32 src_image_size_w = 57; + int32 src_image_size_h = 58; + + bool crop = 59; + int32 load_start_pos_w = 60; + int32 load_start_pos_h = 61; + int32 crop_size_w = 62; + int32 crop_size_h = 63; + + bool resize = 64; + int32 resize_output_w = 65; + int32 resize_output_h = 66; + + bool padding = 67; + int32 left_padding_size = 68; + int32 right_padding_size = 69; + int32 top_padding_size = 70; + int32 bottom_padding_size = 71; + + int32 mean_chn_0 = 10; + int32 mean_chn_1 = 11; + int32 mean_chn_2 = 12; + int32 mean_chn_3 = 19; + float min_chn_0 = 13; + float min_chn_1 = 14; + float min_chn_2 = 15; + float min_chn_3 = 20; + repeated float var_reci_chn_0 = 16; + repeated float var_reci_chn_1 = 17; + repeated float var_reci_chn_2 = 18; + repeated float var_reci_chn_3 = 21; + + repeated int32 matrix_r0c0 = 30; + repeated int32 matrix_r0c1 = 31; + repeated int32 matrix_r0c2 = 32; + repeated int32 matrix_r1c0 = 33; + repeated int32 matrix_r1c1 = 34; + repeated int32 matrix_r1c2 = 35; + repeated int32 matrix_r2c0 = 36; + repeated int32 matrix_r2c1 = 37; + repeated int32 matrix_r2c2 = 38; + repeated int32 output_bias_0 = 39; + repeated int32 output_bias_1 = 40; + repeated int32 output_bias_2 = 41; + repeated int32 input_bias_0 = 42; + repeated int32 input_bias_1 = 43; + repeated int32 input_bias_2 = 44; + + // [End] ¾²Ì¬AIPP²ÎÊý + + // The n number that is used for raw/rgbir data into f16 transformation. + // The transformation equation is x/(2^n). If set to 0, no transform is performed. + uint32 raw_rgbir_to_f16_n = 45; +} + +message MultiShapeOpParams { + enum MultiShapeMode { + batch = 0; //¶¯Ì¬batch + resolution = 1; //¶¯Ì¬·Ö±æÂÊ£¬À©Õ¹Óà + } + + MultiShapeMode mode = 1; //Ëã×Óģʽ + uint32 related_input_rank = 2; //ÐÂÔöËã×Ó²åÈëµ½ÄĸöÊäÈë + + + repeated uint32 batch_list = 11; //batch_listÖµ£¬batch_listµÄ¸öÊýÊÇ2µ½8Ö®¼ä +} diff --git a/ge/executor/proto/om.proto b/ge/executor/proto/om.proto new file mode 100644 index 00000000..dd992191 --- /dev/null +++ b/ge/executor/proto/om.proto @@ -0,0 +1,401 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package domi; + +enum TargetType +{ + MINI = 0; + TINY = 1; + LITE = 2; +} + +// offline model +message ModelDef { + string name = 1; + uint32 version = 2; + + uint64 memory_size = 10; + uint32 stream_num = 11; + uint32 event_num = 12; + uint64 weight_size = 13; + uint32 label_num = 15; + repeated OpDef op = 20; + TargetType target_type = 23; + + map attr = 30; +}; + +// operator define +message OpDef { + string name = 1; + string type = 2; + + uint32 id = 3; + uint32 stream_id = 4; + + repeated string input_name = 5; + + repeated string src_name = 8; + repeated int32 src_index = 9; + repeated int64 input = 10; + repeated int64 output = 11; + repeated TensorDescriptor input_desc = 12; + repeated TensorDescriptor output_desc = 13; + repeated WeightDef weights = 14; + repeated string dst_name = 15; + repeated int32 dst_index = 16; + + repeated int64 workspace = 20; + repeated uint32 workspace_bytes = 21; + + repeated string weight_name = 22; + repeated bool is_input_const = 23; + + map attr = 30; + + QuantizeFactorParams quantize_factor = 31; + + oneof op_params { + // start at 100 here + SendOpParams sender_param = 100; + RecvOpParams receiver_param = 200; + ConvolutionOpParams convolution_param = 300; + PoolingOpParams pooling_param = 400; + EltwiseOpParams eltwise_param = 500; + BatchNormOpParams batchnorm_param = 600; + ScaleOpParams scale_param = 700; + FullConnectionOpParams full_connection_param = 800; + SoftmaxOpParams softmax_param = 900; + ActivationOpParams activation_param = 1000; + ReshapeOpParams reshape_param = 1100; + } +}; + +message SendOpParams { + uint32 event_id = 1; +}; + +message RecvOpParams { + uint32 event_id = 1; +}; + +enum QuantizeScaleType +{ + VECTOR_SCALE = 0; + SCALAR_SCALE = 1; +} + +enum QuantizeScaleMode +{ + NORMAL_MODE = 0; + SQRT_MODE = 1; +} + +enum QuantizeAlgorithm +{ + NON_OFFSET_ALGO = 0; + HALF_OFFSET_ALGO = 1; + ALL_OFFSET_ALGO = 2; +} +message QuantizeFactor +{ + QuantizeScaleMode scale_mode = 1; + bytes scale_value = 2; + int64 scale_offset = 3; + bytes offset_data_value = 4; + int64 offset_data_offset = 5; + bytes offset_weight_value = 6; + int64 offset_weight_offset = 7; + bytes offset_pad_value = 8; + int64 offset_pad_offset = 9; +}; + +message QuantizeCalcFactor +{ + bytes offsetw = 1; + int64 offsetw_offset = 2; + bytes offsetd = 3; + int64 offsetd_offset = 4; + bytes scalereq = 5; + int64 scaledreq_offset = 6; + bytes offsetdnext = 7; + int64 offsetdnext_offset = 8; +} + +message QuantizeFactorParams +{ + QuantizeAlgorithm quantize_algo = 1; + QuantizeScaleType scale_type = 2; + QuantizeFactor quantize_param = 3; + QuantizeFactor dequantize_param = 4; + QuantizeFactor requantize_param = 5; + QuantizeCalcFactor quantizecalc_param = 6; +}; + +message ConvolutionOpParams { + int32 mode = 1; + int32 algo = 2; + int32 pad_mode = 3; + uint32 group = 4; + uint32 num_output = 5; + + repeated uint32 pad = 10; + repeated uint32 stride = 11; + repeated uint32 dilation = 12; + repeated uint32 kernel = 13; + + float alpha = 20; + float beta = 21; + + WeightDef filter = 40; + WeightDef bias = 41; + + bool relu_flag = 62; + repeated uint32 adj = 70; + repeated uint32 target_shape = 71; + repeated uint32 before_pad = 72; +}; + +message PoolingOpParams { + int32 mode = 1; + int32 nan_opt = 2; + int32 pad_mode = 3; + bool global_pooling = 4; + + repeated uint32 window = 10; + repeated uint32 pad = 11; + repeated uint32 stride = 12; + bool ceil_mode = 13; + int32 data_mode = 14; + + float alpha = 20; + float beta = 21; + repeated uint32 before_pad = 22; +}; + +message EltwiseOpParams { + int32 mode = 1; + repeated float coeff = 2; + float alpha = 3; + float beta = 4; + repeated WeightDef weight = 5; + bool relu_flag = 6; +}; + +message ActivationOpParams { + int32 mode = 1; + float coef = 2; + float alpha = 3; + float beta = 4; +}; + +message BatchNormOpParams { + int32 mode = 1; + + float alpha = 2; + float beta = 3; + double epsilon = 4;//optinal,[default = 1e-5] + bool use_global_stats = 5; //optinal,by default true,testing mode + float moving_average_fraction = 6; //optinal,[default = .999]; + + WeightDef estimated_mean = 7; + WeightDef estimated_variance = 8; + + WeightDef scale = 9; + WeightDef bias = 10; +}; + +message ScaleOpParams { + WeightDef scale = 1; + WeightDef bias = 2; +}; + +message ReshapeOpParams { + float alpha = 1; + float beta = 2; + ShapeDef shape = 3; + int32 axis = 4; + int32 num_axes = 5; + int32 format = 6; +}; + +message SoftmaxOpParams { + int32 algo = 1; + int32 mode = 2; + float alpha = 3; + float beta = 4; +}; + +message FullConnectionOpParams { + WeightDef filter = 1; + WeightDef bias = 2; + uint32 num_output = 3; + bool relu_flag = 12; +}; + +message FlattenOpParams { + float alpha = 1; + float beta = 2; + int32 start_axis = 3; + int32 end_axis = 4; +} + +message AddLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message MulLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message AddOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message MulOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message SubOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message BiasAddOpParams { + float alpha = 1; + float beta = 2; + + WeightDef bias = 10; +}; + +message MatMulOpParams { + float alpha = 1; + float beta = 2; + bool transposeX = 3; + bool transposeW = 4; + + WeightDef filter = 10; + WeightDef bias = 12; +}; + +message RsqrtOpParams { + float alpha = 1; + float beta = 2; +}; + + +message WeightDef { + int32 format = 1; + int32 data_type = 2; + ShapeDef shape = 3; + bytes data = 4; + int64 data_offset = 5; + uint32 cmps_size = 6; + bytes cmps_tab = 7; + int64 cmps_tab_offset = 10; + CompressInfo cmps_info = 8; + AllOffsetQuantizeInfo alloffset_quantize_info = 11; +} + +message ShapeDef { + repeated int64 dim = 1; +} + +enum DeviceType { + NPU = 0; // In default, we will use NPU. + CPU = 1; // CPU +} + +message AllOffsetQuantizeInfo { + float scale = 1; + int32 offset = 2; +} + +message TensorDescriptor { + int32 format = 1; + int32 data_type = 2; + repeated int64 dim = 3; + uint32 size = 4; + bool reuse_input = 5; + bool output_tensor = 7; + DeviceType device_type = 8; + bool input_tensor = 9; + uint32 real_dim_cnt = 10; + uint32 reuse_input_index = 11; + AllOffsetQuantizeInfo alloffset_quantize_info = 12; +} + +message CompressInfo { + int32 blockRow = 1; // block row + int32 blockCol = 2; // block col + int32 fractalK = 3; // fractal K + int32 fractalN = 4; // fractal N + int32 lastFractalK = 5; // K of last fractal + int32 lastFractalN = 6; // N of last fractal + int32 cubeSize = 7; // cube's length + int32 loadDir = 8; // data load directtiono 0:col load 1:row load +} + +message AttrDef { + message ListValue { + repeated string s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated uint32 u = 6 [packed = true]; // "list(uint)" + repeated bytes bt = 7; + } + + oneof value { + string s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + uint32 u = 6; // "uint32" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs { + string name = 1; + map attr = 2; +} + diff --git a/ge/executor/proto/op_mapping_info.proto b/ge/executor/proto/op_mapping_info.proto new file mode 100644 index 00000000..7b84a115 --- /dev/null +++ b/ge/executor/proto/op_mapping_info.proto @@ -0,0 +1,89 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; +package aicpu.dump; + +message Shape { + repeated uint64 dim = 1; +} + +message Output { + int32 data_type = 1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + string original_name = 5; + int32 original_output_index = 6; + int32 original_output_data_type = 7; + int32 original_output_format = 8; + uint64 size = 9; +} + +message Input { + int32 data_type =1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + uint64 size = 5; +} + +enum BufferType { + L1 = 0; +} + +message OpBuffer { + BufferType buffer_type = 1; + uint64 address = 2; + uint64 size = 3; +} + +message Op { + string op_name = 1; + string op_type = 2; +} + +message Task { + uint32 task_id = 1; + uint32 stream_id = 2; + Op op = 3; + repeated Output output = 4; + bool end_graph = 5; + repeated Input input = 6; + repeated OpBuffer buffer = 7; +} + +message OpMappingInfo { + string dump_path = 1; + oneof model_name_param { + string model_name = 2; + } + oneof model_id_param { + uint32 model_id = 3; + } + oneof step_id { + uint64 step_id_addr = 4; + } + oneof iterations_per_loop { + uint64 iterations_per_loop_addr = 5; + } + oneof loop_cond { + uint64 loop_cond_addr = 6; + } + uint32 flag = 7; // 0x01 load, 0x00 unload + repeated Task task = 8; + string dump_step = 9; +} \ No newline at end of file diff --git a/ge/executor/proto/task.proto b/ge/executor/proto/task.proto new file mode 100644 index 00000000..50ea061b --- /dev/null +++ b/ge/executor/proto/task.proto @@ -0,0 +1,170 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package domi; + +message ModelTaskDef { + string version = 1; + + map attr = 9; // Extended field + repeated TaskDef task = 10; + + uint64 memory_size = 11; + uint32 stream_num = 12; + uint32 event_num = 13; + uint64 weight_size = 14; + + repeated bytes op = 15; // input/output opdef in bytes + + uint64 base_addr = 16; // base addr + uint64 weight_addr = 17; // weight addr + uint32 batch_num = 18; +} + + +message TaskDef { + uint32 id = 1; + uint32 type = 2; + + uint32 stream_id = 10; + uint32 event_id = 11; + + KernelDef kernel = 20; + KernelExDef kernel_ex = 21; + KernelHcclDef kernel_hccl = 25; + EventExDef event_ex = 26; + LogTimeStampDef log_timestamp = 28; + + uint32 label_id = 30; + + MemcpyAsyncDef memcpy_async = 31; + StreamSwitchDef stream_switch = 32; + StreamActiveDef stream_active = 33; + bytes private_def = 34; + uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future + StreamSwitchNDef stream_switch_n = 36; + + LabelSetDef label_set = 37; + LabelGotoExDef label_goto_ex = 38; + LabelSwitchByIndexDef label_switch_by_index = 39; +} + +message KernelDef { + KernelContext context = 1; + + string stub_func = 10; + uint32 block_dim = 11; + uint32 args_size = 12; + bytes args = 13; + bytes sm_desc = 14; + bytes flowtable = 15; + string so_name = 16; + string kernel_name = 17; + bytes kernel_ext_info = 18; + uint32 kernel_ext_info_size = 19; +} + +message KernelContext { + uint32 kernel_type = 1; + uint32 op_id = 2; // OP type in CCE + uint32 kernel_func_id = 3; + uint32 op_index = 4; // TE/Custom operator + bool is_flowtable = 5; // Identify whether args is a flowtable structure + bytes args_offset = 6; // args offset information + uint32 args_count = 7; // args count + repeated uint32 origin_op_index = 8; +} + + +message KernelExDef { + uint32 flags = 1; + + uint32 op_index = 4; + uint32 args_size = 12; + bytes args = 13; + bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput + uint32 task_info_size = 15; + bytes kernel_ext_info = 16; + uint32 kernel_ext_info_size = 17; +} + + +message KernelHcclDef { + uint32 op_index = 8; + string hccl_type = 9; +} + + +message EventExDef { + uint32 op_index = 1; + uint32 event_type = 2; +} + +message LogTimeStampDef { + uint64 logid = 1; + bool notify = 2; + uint32 flat = 3; +} + +message MemcpyAsyncDef { + uint64 dst = 1; + uint64 dst_max = 2; + uint64 src = 3; + uint64 count = 4; + uint32 kind = 5; + uint32 op_index = 6; +} + +message StreamSwitchDef { + uint32 op_index = 1; + uint32 true_stream_id = 2; + int64 value = 3; + uint64 value_ptr = 4; + uint32 data_type = 5; +} + +message StreamActiveDef { + uint32 op_index = 1; + uint32 active_stream_id = 2; +} + +message StreamSwitchNDef { + uint32 op_index = 1; + uint32 size = 2; + repeated int64 target_value = 3; + repeated uint32 true_stream_id = 4; + uint32 element_size = 5; + uint32 data_type = 6; +} + +message LabelSetDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelGotoExDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelSwitchByIndexDef { + uint32 op_index = 1; + uint32 label_max = 2; +} diff --git a/src/ge/ge_inference.mk b/ge/ge_inference.mk similarity index 100% rename from src/ge/ge_inference.mk rename to ge/ge_inference.mk diff --git a/ge/ge_local_engine/CMakeLists.txt b/ge/ge_local_engine/CMakeLists.txt new file mode 100755 index 00000000..1c45e399 --- /dev/null +++ b/ge/ge_local_engine/CMakeLists.txt @@ -0,0 +1,116 @@ +set(PROTO_LIST + "${METADEF_DIR}/proto/task.proto" +) + +set(SRC_LIST + "engine/ge_local_engine.cc" + "ops_kernel_store/ge_local_ops_kernel_info.cc" + "ops_kernel_store/op/op_factory.cc" + "ops_kernel_store/op/op.cc" + "ops_kernel_store/op/ge_deleted_op.cc" + "ops_kernel_store/op/no_op.cc" +) + +set(OPS_KERNEL_SRC_LIST + "ops_kernel_store/op/op_factory.cc" + "ops_kernel_store/op/op.cc" + "ops_kernel_store/op/ge_deleted_op.cc" + "ops_kernel_store/op/no_op.cc" +) + +protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) + +############ libge_local_engine.so ############ +add_library(ge_local_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) + +target_compile_options(ge_local_engine PRIVATE + -Werror +) + +target_include_directories(ge_local_engine PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc +) + +target_link_libraries(ge_local_engine PRIVATE + $ + -Wl,--no-as-needed + graph + protobuf + register + c_sec + slog + runtime + -Wl,--as-needed +) + +######### atclib/libge_local_engine.so ############# +add_library(atc_ge_local_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) + +target_compile_options(atc_ge_local_engine PRIVATE + -Werror +) + +target_compile_definitions(atc_ge_local_engine PRIVATE + COMPILE_OMG_PACKAGE +) + +target_include_directories(atc_ge_local_engine PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc +) + +target_link_libraries(atc_ge_local_engine PRIVATE + $ + -Wl,--no-as-needed + graph + protobuf + register + c_sec + slog + runtime_compile + -Wl,--as-needed +) + +set_target_properties(atc_ge_local_engine PROPERTIES + OUTPUT_NAME ge_local_engine + LIBRARY_OUTPUT_DIRECTORY atclib +) + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS ge_local_engine OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} +) + +install(TARGETS atc_ge_local_engine OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/atclib +) diff --git a/src/ge/ge_local_engine/common/constant/constant.h b/ge/ge_local_engine/common/constant/constant.h similarity index 100% rename from src/ge/ge_local_engine/common/constant/constant.h rename to ge/ge_local_engine/common/constant/constant.h diff --git a/src/ge/ge_local_engine/engine/ge_local_engine.cc b/ge/ge_local_engine/engine/ge_local_engine.cc similarity index 100% rename from src/ge/ge_local_engine/engine/ge_local_engine.cc rename to ge/ge_local_engine/engine/ge_local_engine.cc diff --git a/src/ge/ge_local_engine/engine/ge_local_engine.h b/ge/ge_local_engine/engine/ge_local_engine.h similarity index 100% rename from src/ge/ge_local_engine/engine/ge_local_engine.h rename to ge/ge_local_engine/engine/ge_local_engine.h diff --git a/src/ge/ge_local_engine/engine/host_cpu_engine.cc b/ge/ge_local_engine/engine/host_cpu_engine.cc similarity index 100% rename from src/ge/ge_local_engine/engine/host_cpu_engine.cc rename to ge/ge_local_engine/engine/host_cpu_engine.cc diff --git a/src/ge/ge_local_engine/engine/host_cpu_engine.h b/ge/ge_local_engine/engine/host_cpu_engine.h similarity index 100% rename from src/ge/ge_local_engine/engine/host_cpu_engine.h rename to ge/ge_local_engine/engine/host_cpu_engine.h diff --git a/src/ge/ge_local_engine/module.mk b/ge/ge_local_engine/module.mk similarity index 100% rename from src/ge/ge_local_engine/module.mk rename to ge/ge_local_engine/module.mk diff --git a/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc similarity index 100% rename from src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc rename to ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc diff --git a/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h similarity index 100% rename from src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h rename to ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h diff --git a/src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc b/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc similarity index 100% rename from src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc rename to ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc diff --git a/src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.h b/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.h similarity index 100% rename from src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.h rename to ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.h diff --git a/src/ge/ge_local_engine/ops_kernel_store/op/no_op.cc b/ge/ge_local_engine/ops_kernel_store/op/no_op.cc similarity index 100% rename from src/ge/ge_local_engine/ops_kernel_store/op/no_op.cc rename to ge/ge_local_engine/ops_kernel_store/op/no_op.cc diff --git a/src/ge/ge_local_engine/ops_kernel_store/op/no_op.h b/ge/ge_local_engine/ops_kernel_store/op/no_op.h similarity index 100% rename from src/ge/ge_local_engine/ops_kernel_store/op/no_op.h rename to ge/ge_local_engine/ops_kernel_store/op/no_op.h diff --git a/src/ge/ge_local_engine/ops_kernel_store/op/op.cc b/ge/ge_local_engine/ops_kernel_store/op/op.cc similarity index 100% rename from src/ge/ge_local_engine/ops_kernel_store/op/op.cc rename to ge/ge_local_engine/ops_kernel_store/op/op.cc diff --git a/src/ge/ge_local_engine/ops_kernel_store/op/op.h b/ge/ge_local_engine/ops_kernel_store/op/op.h similarity index 100% rename from src/ge/ge_local_engine/ops_kernel_store/op/op.h rename to ge/ge_local_engine/ops_kernel_store/op/op.h diff --git a/src/ge/ge_local_engine/ops_kernel_store/op/op_factory.cc b/ge/ge_local_engine/ops_kernel_store/op/op_factory.cc similarity index 100% rename from src/ge/ge_local_engine/ops_kernel_store/op/op_factory.cc rename to ge/ge_local_engine/ops_kernel_store/op/op_factory.cc diff --git a/src/ge/ge_local_engine/ops_kernel_store/op/op_factory.h b/ge/ge_local_engine/ops_kernel_store/op/op_factory.h similarity index 100% rename from src/ge/ge_local_engine/ops_kernel_store/op/op_factory.h rename to ge/ge_local_engine/ops_kernel_store/op/op_factory.h diff --git a/ge/ge_local_engine/proto/task.proto b/ge/ge_local_engine/proto/task.proto new file mode 100644 index 00000000..50ea061b --- /dev/null +++ b/ge/ge_local_engine/proto/task.proto @@ -0,0 +1,170 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package domi; + +message ModelTaskDef { + string version = 1; + + map attr = 9; // Extended field + repeated TaskDef task = 10; + + uint64 memory_size = 11; + uint32 stream_num = 12; + uint32 event_num = 13; + uint64 weight_size = 14; + + repeated bytes op = 15; // input/output opdef in bytes + + uint64 base_addr = 16; // base addr + uint64 weight_addr = 17; // weight addr + uint32 batch_num = 18; +} + + +message TaskDef { + uint32 id = 1; + uint32 type = 2; + + uint32 stream_id = 10; + uint32 event_id = 11; + + KernelDef kernel = 20; + KernelExDef kernel_ex = 21; + KernelHcclDef kernel_hccl = 25; + EventExDef event_ex = 26; + LogTimeStampDef log_timestamp = 28; + + uint32 label_id = 30; + + MemcpyAsyncDef memcpy_async = 31; + StreamSwitchDef stream_switch = 32; + StreamActiveDef stream_active = 33; + bytes private_def = 34; + uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future + StreamSwitchNDef stream_switch_n = 36; + + LabelSetDef label_set = 37; + LabelGotoExDef label_goto_ex = 38; + LabelSwitchByIndexDef label_switch_by_index = 39; +} + +message KernelDef { + KernelContext context = 1; + + string stub_func = 10; + uint32 block_dim = 11; + uint32 args_size = 12; + bytes args = 13; + bytes sm_desc = 14; + bytes flowtable = 15; + string so_name = 16; + string kernel_name = 17; + bytes kernel_ext_info = 18; + uint32 kernel_ext_info_size = 19; +} + +message KernelContext { + uint32 kernel_type = 1; + uint32 op_id = 2; // OP type in CCE + uint32 kernel_func_id = 3; + uint32 op_index = 4; // TE/Custom operator + bool is_flowtable = 5; // Identify whether args is a flowtable structure + bytes args_offset = 6; // args offset information + uint32 args_count = 7; // args count + repeated uint32 origin_op_index = 8; +} + + +message KernelExDef { + uint32 flags = 1; + + uint32 op_index = 4; + uint32 args_size = 12; + bytes args = 13; + bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput + uint32 task_info_size = 15; + bytes kernel_ext_info = 16; + uint32 kernel_ext_info_size = 17; +} + + +message KernelHcclDef { + uint32 op_index = 8; + string hccl_type = 9; +} + + +message EventExDef { + uint32 op_index = 1; + uint32 event_type = 2; +} + +message LogTimeStampDef { + uint64 logid = 1; + bool notify = 2; + uint32 flat = 3; +} + +message MemcpyAsyncDef { + uint64 dst = 1; + uint64 dst_max = 2; + uint64 src = 3; + uint64 count = 4; + uint32 kind = 5; + uint32 op_index = 6; +} + +message StreamSwitchDef { + uint32 op_index = 1; + uint32 true_stream_id = 2; + int64 value = 3; + uint64 value_ptr = 4; + uint32 data_type = 5; +} + +message StreamActiveDef { + uint32 op_index = 1; + uint32 active_stream_id = 2; +} + +message StreamSwitchNDef { + uint32 op_index = 1; + uint32 size = 2; + repeated int64 target_value = 3; + repeated uint32 true_stream_id = 4; + uint32 element_size = 5; + uint32 data_type = 6; +} + +message LabelSetDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelGotoExDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelSwitchByIndexDef { + uint32 op_index = 1; + uint32 label_max = 2; +} diff --git a/src/ge/ge_runner.mk b/ge/ge_runner.mk similarity index 100% rename from src/ge/ge_runner.mk rename to ge/ge_runner.mk diff --git a/ge/ge_runtime/CMakeLists.txt b/ge/ge_runtime/CMakeLists.txt new file mode 100644 index 00000000..b4c7fe9e --- /dev/null +++ b/ge/ge_runtime/CMakeLists.txt @@ -0,0 +1,65 @@ +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +############ libge_runtime.so ############ +set(GE_SRC_LIST + "model_runner.cc" + "runtime_model.cc" + "output.cc" + "task/*.cc" +) + +add_library(ge_runtime SHARED ${GE_SRC_LIST}) + +target_compile_options(ge_runtime PRIVATE + -Werror + -O2 +) + +target_compile_definitions(ge_runtime PUBLIC + PROTOBUF_INLINE_NOT_IN_HEADERS=0 +) + +target_include_directories(ge_runtime PRIVATE + ${TOP_DIR} + ${TOP_DIR}/inc + ${TOP_DIR}/inc/graph + ${TOP_DIR}/inc/external + ${TOP_DIR}/inc/framework + ${TOP_DIR}/inc/framework/common + ${TOP_DIR}/inc/framework/ge_runtime + ${TOP_DIR}/inc/cce + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge +) + +target_link_libraries(ge_runtime PRIVATE + $ + -Wl,--no-as-needed + graph + slog + runtime + c_sec + -Wl,--as-needed + -lrt + -ldl +) + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS ge_runtime OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} +) diff --git a/src/ge/ge_runtime/model_context.h b/ge/ge_runtime/model_context.h similarity index 100% rename from src/ge/ge_runtime/model_context.h rename to ge/ge_runtime/model_context.h diff --git a/src/ge/ge_runtime/model_runner.cc b/ge/ge_runtime/model_runner.cc similarity index 100% rename from src/ge/ge_runtime/model_runner.cc rename to ge/ge_runtime/model_runner.cc diff --git a/ge/ge_runtime/module.mk b/ge/ge_runtime/module.mk new file mode 100644 index 00000000..43d81bfa --- /dev/null +++ b/ge/ge_runtime/module.mk @@ -0,0 +1,66 @@ +LOCAL_PATH := $(call my-dir) + +# task.proto is old task, add it for ops_kernel_info_store +local_ge_runtime_src_files := \ + model_runner.cc \ + runtime_model.cc \ + output.cc \ + task/aicpu_task.cc \ + task/cce_task.cc \ + task/tbe_task.cc \ + task/event_record_task.cc \ + task/event_wait_task.cc \ + task/stream_active_task.cc \ + task/stream_switch_task.cc \ + task/hccl_task.cc \ + task/memcpy_async_task.cc \ + task/profiler_task.cc \ + +local_ge_runtime_include := \ + $(LOCAL_PATH)/ \ + $(TOPDIR)libc_sec/include \ + $(TOPDIR)inc/external \ + $(TOPDIR)inc/external/graph \ + $(TOPDIR)inc/framework \ + $(TOPDIR)inc/graph \ + $(TOPDIR)inc \ + $(LOCAL_PATH)/../ \ + third_party/protobuf/include + +local_ge_runtime_shared_library := \ + libruntime \ + libslog \ + libc_sec + +local_ge_runtime_ldflags := -lrt -ldl + +# compile device libge_runtime +include $(CLEAR_VARS) + +LOCAL_MODULE := libge_runtime +LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -O2 +LOCAL_CFLAGS += -Werror +LOCAL_SRC_FILES := $(local_ge_runtime_src_files) +LOCAL_C_INCLUDES := $(local_ge_runtime_include) +LOCAL_SHARED_LIBRARIES := $(local_ge_runtime_shared_library) +LOCAL_LDFLAGS += $(local_ge_runtime_ldflags) + +include $(BUILD_SHARED_LIBRARY) + +# compile host libge_runtime +include $(CLEAR_VARS) + +LOCAL_MODULE := libge_runtime +LOCAL_CFLAGS += -Werror +LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 +ifeq ($(DEBUG), 1) + LOCAL_CFLAGS += -g -O0 +else + LOCAL_CFLAGS += -O2 +endif +LOCAL_SRC_FILES := $(local_ge_runtime_src_files) +LOCAL_C_INCLUDES := $(local_ge_runtime_include) +LOCAL_SHARED_LIBRARIES := $(local_ge_runtime_shared_library) +LOCAL_LDFLAGS += $(local_ge_runtime_ldflags) + +include $(BUILD_HOST_SHARED_LIBRARY) diff --git a/src/ge/ge_runtime/output.cc b/ge/ge_runtime/output.cc similarity index 100% rename from src/ge/ge_runtime/output.cc rename to ge/ge_runtime/output.cc diff --git a/src/ge/ge_runtime/output.h b/ge/ge_runtime/output.h similarity index 100% rename from src/ge/ge_runtime/output.h rename to ge/ge_runtime/output.h diff --git a/src/ge/ge_runtime/runtime_model.cc b/ge/ge_runtime/runtime_model.cc similarity index 100% rename from src/ge/ge_runtime/runtime_model.cc rename to ge/ge_runtime/runtime_model.cc diff --git a/src/ge/ge_runtime/runtime_model.h b/ge/ge_runtime/runtime_model.h similarity index 100% rename from src/ge/ge_runtime/runtime_model.h rename to ge/ge_runtime/runtime_model.h diff --git a/src/ge/ge_runtime/task/aicpu_task.cc b/ge/ge_runtime/task/aicpu_task.cc similarity index 100% rename from src/ge/ge_runtime/task/aicpu_task.cc rename to ge/ge_runtime/task/aicpu_task.cc diff --git a/src/ge/ge_runtime/task/aicpu_task.h b/ge/ge_runtime/task/aicpu_task.h similarity index 100% rename from src/ge/ge_runtime/task/aicpu_task.h rename to ge/ge_runtime/task/aicpu_task.h diff --git a/src/ge/ge_runtime/task/cce_task.cc b/ge/ge_runtime/task/cce_task.cc similarity index 100% rename from src/ge/ge_runtime/task/cce_task.cc rename to ge/ge_runtime/task/cce_task.cc diff --git a/src/ge/ge_runtime/task/cce_task.h b/ge/ge_runtime/task/cce_task.h similarity index 100% rename from src/ge/ge_runtime/task/cce_task.h rename to ge/ge_runtime/task/cce_task.h diff --git a/src/ge/ge_runtime/task/event_record_task.cc b/ge/ge_runtime/task/event_record_task.cc similarity index 100% rename from src/ge/ge_runtime/task/event_record_task.cc rename to ge/ge_runtime/task/event_record_task.cc diff --git a/src/ge/ge_runtime/task/event_record_task.h b/ge/ge_runtime/task/event_record_task.h similarity index 100% rename from src/ge/ge_runtime/task/event_record_task.h rename to ge/ge_runtime/task/event_record_task.h diff --git a/src/ge/ge_runtime/task/event_wait_task.cc b/ge/ge_runtime/task/event_wait_task.cc similarity index 100% rename from src/ge/ge_runtime/task/event_wait_task.cc rename to ge/ge_runtime/task/event_wait_task.cc diff --git a/src/ge/ge_runtime/task/event_wait_task.h b/ge/ge_runtime/task/event_wait_task.h similarity index 100% rename from src/ge/ge_runtime/task/event_wait_task.h rename to ge/ge_runtime/task/event_wait_task.h diff --git a/src/ge/ge_runtime/task/hccl_task.cc b/ge/ge_runtime/task/hccl_task.cc similarity index 100% rename from src/ge/ge_runtime/task/hccl_task.cc rename to ge/ge_runtime/task/hccl_task.cc diff --git a/src/ge/ge_runtime/task/hccl_task.h b/ge/ge_runtime/task/hccl_task.h similarity index 100% rename from src/ge/ge_runtime/task/hccl_task.h rename to ge/ge_runtime/task/hccl_task.h diff --git a/src/ge/ge_runtime/task/label_goto_task.cc b/ge/ge_runtime/task/label_goto_task.cc similarity index 100% rename from src/ge/ge_runtime/task/label_goto_task.cc rename to ge/ge_runtime/task/label_goto_task.cc diff --git a/src/ge/ge_runtime/task/label_goto_task.h b/ge/ge_runtime/task/label_goto_task.h similarity index 100% rename from src/ge/ge_runtime/task/label_goto_task.h rename to ge/ge_runtime/task/label_goto_task.h diff --git a/src/ge/ge_runtime/task/label_set_task.cc b/ge/ge_runtime/task/label_set_task.cc similarity index 100% rename from src/ge/ge_runtime/task/label_set_task.cc rename to ge/ge_runtime/task/label_set_task.cc diff --git a/src/ge/ge_runtime/task/label_set_task.h b/ge/ge_runtime/task/label_set_task.h similarity index 100% rename from src/ge/ge_runtime/task/label_set_task.h rename to ge/ge_runtime/task/label_set_task.h diff --git a/src/ge/ge_runtime/task/label_switch_task.cc b/ge/ge_runtime/task/label_switch_task.cc similarity index 100% rename from src/ge/ge_runtime/task/label_switch_task.cc rename to ge/ge_runtime/task/label_switch_task.cc diff --git a/src/ge/ge_runtime/task/label_switch_task.h b/ge/ge_runtime/task/label_switch_task.h similarity index 100% rename from src/ge/ge_runtime/task/label_switch_task.h rename to ge/ge_runtime/task/label_switch_task.h diff --git a/src/ge/ge_runtime/task/memcpy_async_task.cc b/ge/ge_runtime/task/memcpy_async_task.cc similarity index 100% rename from src/ge/ge_runtime/task/memcpy_async_task.cc rename to ge/ge_runtime/task/memcpy_async_task.cc diff --git a/src/ge/ge_runtime/task/memcpy_async_task.h b/ge/ge_runtime/task/memcpy_async_task.h similarity index 100% rename from src/ge/ge_runtime/task/memcpy_async_task.h rename to ge/ge_runtime/task/memcpy_async_task.h diff --git a/src/ge/ge_runtime/task/profiler_task.cc b/ge/ge_runtime/task/profiler_task.cc similarity index 100% rename from src/ge/ge_runtime/task/profiler_task.cc rename to ge/ge_runtime/task/profiler_task.cc diff --git a/src/ge/ge_runtime/task/profiler_task.h b/ge/ge_runtime/task/profiler_task.h similarity index 100% rename from src/ge/ge_runtime/task/profiler_task.h rename to ge/ge_runtime/task/profiler_task.h diff --git a/src/ge/ge_runtime/task/stream_active_task.cc b/ge/ge_runtime/task/stream_active_task.cc similarity index 100% rename from src/ge/ge_runtime/task/stream_active_task.cc rename to ge/ge_runtime/task/stream_active_task.cc diff --git a/src/ge/ge_runtime/task/stream_active_task.h b/ge/ge_runtime/task/stream_active_task.h similarity index 100% rename from src/ge/ge_runtime/task/stream_active_task.h rename to ge/ge_runtime/task/stream_active_task.h diff --git a/src/ge/ge_runtime/task/stream_switch_task.cc b/ge/ge_runtime/task/stream_switch_task.cc similarity index 100% rename from src/ge/ge_runtime/task/stream_switch_task.cc rename to ge/ge_runtime/task/stream_switch_task.cc diff --git a/src/ge/ge_runtime/task/stream_switch_task.h b/ge/ge_runtime/task/stream_switch_task.h similarity index 100% rename from src/ge/ge_runtime/task/stream_switch_task.h rename to ge/ge_runtime/task/stream_switch_task.h diff --git a/src/ge/ge_runtime/task/task.h b/ge/ge_runtime/task/task.h similarity index 100% rename from src/ge/ge_runtime/task/task.h rename to ge/ge_runtime/task/task.h diff --git a/src/ge/ge_runtime/task/task_factory.h b/ge/ge_runtime/task/task_factory.h similarity index 100% rename from src/ge/ge_runtime/task/task_factory.h rename to ge/ge_runtime/task/task_factory.h diff --git a/src/ge/ge_runtime/task/tbe_task.cc b/ge/ge_runtime/task/tbe_task.cc similarity index 100% rename from src/ge/ge_runtime/task/tbe_task.cc rename to ge/ge_runtime/task/tbe_task.cc diff --git a/src/ge/ge_runtime/task/tbe_task.h b/ge/ge_runtime/task/tbe_task.h similarity index 100% rename from src/ge/ge_runtime/task/tbe_task.h rename to ge/ge_runtime/task/tbe_task.h diff --git a/src/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc similarity index 100% rename from src/ge/generator/ge_generator.cc rename to ge/generator/ge_generator.cc diff --git a/src/ge/generator/generator_api.cc b/ge/generator/generator_api.cc similarity index 100% rename from src/ge/generator/generator_api.cc rename to ge/generator/generator_api.cc diff --git a/src/ge/graph/build/graph_builder.cc b/ge/graph/build/graph_builder.cc similarity index 100% rename from src/ge/graph/build/graph_builder.cc rename to ge/graph/build/graph_builder.cc diff --git a/src/ge/graph/build/graph_builder.h b/ge/graph/build/graph_builder.h similarity index 100% rename from src/ge/graph/build/graph_builder.h rename to ge/graph/build/graph_builder.h diff --git a/src/ge/graph/build/label_allocator.cc b/ge/graph/build/label_allocator.cc similarity index 100% rename from src/ge/graph/build/label_allocator.cc rename to ge/graph/build/label_allocator.cc diff --git a/src/ge/graph/build/label_allocator.h b/ge/graph/build/label_allocator.h similarity index 100% rename from src/ge/graph/build/label_allocator.h rename to ge/graph/build/label_allocator.h diff --git a/src/ge/graph/build/logical_stream_allocator.cc b/ge/graph/build/logical_stream_allocator.cc similarity index 100% rename from src/ge/graph/build/logical_stream_allocator.cc rename to ge/graph/build/logical_stream_allocator.cc diff --git a/src/ge/graph/build/logical_stream_allocator.h b/ge/graph/build/logical_stream_allocator.h similarity index 100% rename from src/ge/graph/build/logical_stream_allocator.h rename to ge/graph/build/logical_stream_allocator.h diff --git a/ge/graph/build/memory/CMakeLists.txt b/ge/graph/build/memory/CMakeLists.txt new file mode 100644 index 00000000..c568f2fe --- /dev/null +++ b/ge/graph/build/memory/CMakeLists.txt @@ -0,0 +1,38 @@ +set(SRC_LIST + "memory_assigner.cc" + "graph_mem_assigner.cc" + "binary_block_mem_assigner.cc" + "block_mem_assigner.cc" + "hybrid_mem_assigner.cc" + "max_block_mem_assigner.cc" + "var_mem_assign_util.cc" +) + +############ libge_memory.a ############ +add_library(ge_memory STATIC ${SRC_LIST}) + +target_compile_options(ge_memory PRIVATE + -Werror + -O2 +) + +target_link_libraries(ge_memory PRIVATE + $ + protobuf + c_sec +) + +target_include_directories(ge_memory PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${GE_CODE_DIR}/inc/framework + #### yellow zone #### + ${GE_CODE_DIR}/../inc + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc +) diff --git a/src/ge/graph/build/memory/binary_block_mem_assigner.cc b/ge/graph/build/memory/binary_block_mem_assigner.cc similarity index 100% rename from src/ge/graph/build/memory/binary_block_mem_assigner.cc rename to ge/graph/build/memory/binary_block_mem_assigner.cc diff --git a/src/ge/graph/build/memory/binary_block_mem_assigner.h b/ge/graph/build/memory/binary_block_mem_assigner.h similarity index 100% rename from src/ge/graph/build/memory/binary_block_mem_assigner.h rename to ge/graph/build/memory/binary_block_mem_assigner.h diff --git a/src/ge/graph/build/memory/block_mem_assigner.cc b/ge/graph/build/memory/block_mem_assigner.cc similarity index 100% rename from src/ge/graph/build/memory/block_mem_assigner.cc rename to ge/graph/build/memory/block_mem_assigner.cc diff --git a/src/ge/graph/build/memory/block_mem_assigner.h b/ge/graph/build/memory/block_mem_assigner.h similarity index 100% rename from src/ge/graph/build/memory/block_mem_assigner.h rename to ge/graph/build/memory/block_mem_assigner.h diff --git a/src/ge/graph/build/memory/graph_mem_assigner.cc b/ge/graph/build/memory/graph_mem_assigner.cc similarity index 100% rename from src/ge/graph/build/memory/graph_mem_assigner.cc rename to ge/graph/build/memory/graph_mem_assigner.cc diff --git a/src/ge/graph/build/memory/graph_mem_assigner.h b/ge/graph/build/memory/graph_mem_assigner.h similarity index 100% rename from src/ge/graph/build/memory/graph_mem_assigner.h rename to ge/graph/build/memory/graph_mem_assigner.h diff --git a/src/ge/graph/build/memory/hybrid_mem_assigner.cc b/ge/graph/build/memory/hybrid_mem_assigner.cc similarity index 100% rename from src/ge/graph/build/memory/hybrid_mem_assigner.cc rename to ge/graph/build/memory/hybrid_mem_assigner.cc diff --git a/src/ge/graph/build/memory/hybrid_mem_assigner.h b/ge/graph/build/memory/hybrid_mem_assigner.h similarity index 100% rename from src/ge/graph/build/memory/hybrid_mem_assigner.h rename to ge/graph/build/memory/hybrid_mem_assigner.h diff --git a/src/ge/graph/build/memory/max_block_mem_assigner.cc b/ge/graph/build/memory/max_block_mem_assigner.cc similarity index 100% rename from src/ge/graph/build/memory/max_block_mem_assigner.cc rename to ge/graph/build/memory/max_block_mem_assigner.cc diff --git a/src/ge/graph/build/memory/max_block_mem_assigner.h b/ge/graph/build/memory/max_block_mem_assigner.h similarity index 100% rename from src/ge/graph/build/memory/max_block_mem_assigner.h rename to ge/graph/build/memory/max_block_mem_assigner.h diff --git a/src/ge/graph/build/memory/mem_assigner.h b/ge/graph/build/memory/mem_assigner.h similarity index 100% rename from src/ge/graph/build/memory/mem_assigner.h rename to ge/graph/build/memory/mem_assigner.h diff --git a/src/ge/graph/build/memory/memory_assigner.cc b/ge/graph/build/memory/memory_assigner.cc similarity index 100% rename from src/ge/graph/build/memory/memory_assigner.cc rename to ge/graph/build/memory/memory_assigner.cc diff --git a/src/ge/graph/build/memory/module.mk b/ge/graph/build/memory/module.mk similarity index 100% rename from src/ge/graph/build/memory/module.mk rename to ge/graph/build/memory/module.mk diff --git a/src/ge/graph/build/memory/var_mem_assign_util.cc b/ge/graph/build/memory/var_mem_assign_util.cc similarity index 100% rename from src/ge/graph/build/memory/var_mem_assign_util.cc rename to ge/graph/build/memory/var_mem_assign_util.cc diff --git a/src/ge/graph/build/memory/var_mem_assign_util.h b/ge/graph/build/memory/var_mem_assign_util.h similarity index 100% rename from src/ge/graph/build/memory/var_mem_assign_util.h rename to ge/graph/build/memory/var_mem_assign_util.h diff --git a/src/ge/graph/build/model_builder.cc b/ge/graph/build/model_builder.cc similarity index 100% rename from src/ge/graph/build/model_builder.cc rename to ge/graph/build/model_builder.cc diff --git a/src/ge/graph/build/model_builder.h b/ge/graph/build/model_builder.h similarity index 100% rename from src/ge/graph/build/model_builder.h rename to ge/graph/build/model_builder.h diff --git a/src/ge/graph/build/run_context.cc b/ge/graph/build/run_context.cc similarity index 100% rename from src/ge/graph/build/run_context.cc rename to ge/graph/build/run_context.cc diff --git a/src/ge/graph/build/run_context.h b/ge/graph/build/run_context.h similarity index 100% rename from src/ge/graph/build/run_context.h rename to ge/graph/build/run_context.h diff --git a/src/ge/graph/build/stream_allocator.cc b/ge/graph/build/stream_allocator.cc similarity index 100% rename from src/ge/graph/build/stream_allocator.cc rename to ge/graph/build/stream_allocator.cc diff --git a/src/ge/graph/build/stream_allocator.h b/ge/graph/build/stream_allocator.h similarity index 100% rename from src/ge/graph/build/stream_allocator.h rename to ge/graph/build/stream_allocator.h diff --git a/src/ge/graph/build/stream_graph_optimizer.cc b/ge/graph/build/stream_graph_optimizer.cc similarity index 100% rename from src/ge/graph/build/stream_graph_optimizer.cc rename to ge/graph/build/stream_graph_optimizer.cc diff --git a/src/ge/graph/build/stream_graph_optimizer.h b/ge/graph/build/stream_graph_optimizer.h similarity index 100% rename from src/ge/graph/build/stream_graph_optimizer.h rename to ge/graph/build/stream_graph_optimizer.h diff --git a/src/ge/graph/build/task_generator.cc b/ge/graph/build/task_generator.cc similarity index 100% rename from src/ge/graph/build/task_generator.cc rename to ge/graph/build/task_generator.cc diff --git a/src/ge/graph/build/task_generator.h b/ge/graph/build/task_generator.h similarity index 100% rename from src/ge/graph/build/task_generator.h rename to ge/graph/build/task_generator.h diff --git a/src/ge/graph/common/bcast.cc b/ge/graph/common/bcast.cc similarity index 100% rename from src/ge/graph/common/bcast.cc rename to ge/graph/common/bcast.cc diff --git a/src/ge/graph/common/bcast.h b/ge/graph/common/bcast.h similarity index 100% rename from src/ge/graph/common/bcast.h rename to ge/graph/common/bcast.h diff --git a/src/ge/graph/common/ge_call_wrapper.h b/ge/graph/common/ge_call_wrapper.h similarity index 100% rename from src/ge/graph/common/ge_call_wrapper.h rename to ge/graph/common/ge_call_wrapper.h diff --git a/src/ge/graph/common/local_context.cc b/ge/graph/common/local_context.cc similarity index 100% rename from src/ge/graph/common/local_context.cc rename to ge/graph/common/local_context.cc diff --git a/src/ge/graph/common/local_context.h b/ge/graph/common/local_context.h similarity index 100% rename from src/ge/graph/common/local_context.h rename to ge/graph/common/local_context.h diff --git a/src/ge/graph/common/omg_util.cc b/ge/graph/common/omg_util.cc similarity index 100% rename from src/ge/graph/common/omg_util.cc rename to ge/graph/common/omg_util.cc diff --git a/src/ge/graph/common/omg_util.h b/ge/graph/common/omg_util.h similarity index 100% rename from src/ge/graph/common/omg_util.h rename to ge/graph/common/omg_util.h diff --git a/src/ge/graph/common/transop_util.cc b/ge/graph/common/transop_util.cc similarity index 100% rename from src/ge/graph/common/transop_util.cc rename to ge/graph/common/transop_util.cc diff --git a/src/ge/graph/common/transop_util.h b/ge/graph/common/transop_util.h similarity index 100% rename from src/ge/graph/common/transop_util.h rename to ge/graph/common/transop_util.h diff --git a/src/ge/graph/execute/graph_execute.cc b/ge/graph/execute/graph_execute.cc similarity index 100% rename from src/ge/graph/execute/graph_execute.cc rename to ge/graph/execute/graph_execute.cc diff --git a/src/ge/graph/execute/graph_execute.h b/ge/graph/execute/graph_execute.h similarity index 100% rename from src/ge/graph/execute/graph_execute.h rename to ge/graph/execute/graph_execute.h diff --git a/src/ge/graph/label/case_label_maker.cc b/ge/graph/label/case_label_maker.cc similarity index 100% rename from src/ge/graph/label/case_label_maker.cc rename to ge/graph/label/case_label_maker.cc diff --git a/src/ge/graph/label/case_label_maker.h b/ge/graph/label/case_label_maker.h similarity index 100% rename from src/ge/graph/label/case_label_maker.h rename to ge/graph/label/case_label_maker.h diff --git a/src/ge/graph/label/if_label_maker.cc b/ge/graph/label/if_label_maker.cc similarity index 100% rename from src/ge/graph/label/if_label_maker.cc rename to ge/graph/label/if_label_maker.cc diff --git a/src/ge/graph/label/if_label_maker.h b/ge/graph/label/if_label_maker.h similarity index 100% rename from src/ge/graph/label/if_label_maker.h rename to ge/graph/label/if_label_maker.h diff --git a/src/ge/graph/label/label_maker.cc b/ge/graph/label/label_maker.cc similarity index 100% rename from src/ge/graph/label/label_maker.cc rename to ge/graph/label/label_maker.cc diff --git a/src/ge/graph/label/label_maker.h b/ge/graph/label/label_maker.h similarity index 100% rename from src/ge/graph/label/label_maker.h rename to ge/graph/label/label_maker.h diff --git a/src/ge/graph/label/label_maker_factory.h b/ge/graph/label/label_maker_factory.h similarity index 100% rename from src/ge/graph/label/label_maker_factory.h rename to ge/graph/label/label_maker_factory.h diff --git a/src/ge/graph/label/partitioned_call_label_maker.cc b/ge/graph/label/partitioned_call_label_maker.cc similarity index 100% rename from src/ge/graph/label/partitioned_call_label_maker.cc rename to ge/graph/label/partitioned_call_label_maker.cc diff --git a/src/ge/graph/label/partitioned_call_label_maker.h b/ge/graph/label/partitioned_call_label_maker.h similarity index 100% rename from src/ge/graph/label/partitioned_call_label_maker.h rename to ge/graph/label/partitioned_call_label_maker.h diff --git a/src/ge/graph/label/while_label_maker.cc b/ge/graph/label/while_label_maker.cc similarity index 100% rename from src/ge/graph/label/while_label_maker.cc rename to ge/graph/label/while_label_maker.cc diff --git a/src/ge/graph/label/while_label_maker.h b/ge/graph/label/while_label_maker.h similarity index 100% rename from src/ge/graph/label/while_label_maker.h rename to ge/graph/label/while_label_maker.h diff --git a/src/ge/graph/load/graph_loader.cc b/ge/graph/load/graph_loader.cc similarity index 100% rename from src/ge/graph/load/graph_loader.cc rename to ge/graph/load/graph_loader.cc diff --git a/src/ge/graph/load/graph_loader.h b/ge/graph/load/graph_loader.h similarity index 100% rename from src/ge/graph/load/graph_loader.h rename to ge/graph/load/graph_loader.h diff --git a/src/ge/graph/load/new_model_manager/aipp_utils.cc b/ge/graph/load/new_model_manager/aipp_utils.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/aipp_utils.cc rename to ge/graph/load/new_model_manager/aipp_utils.cc diff --git a/src/ge/graph/load/new_model_manager/aipp_utils.h b/ge/graph/load/new_model_manager/aipp_utils.h similarity index 100% rename from src/ge/graph/load/new_model_manager/aipp_utils.h rename to ge/graph/load/new_model_manager/aipp_utils.h diff --git a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc b/ge/graph/load/new_model_manager/cpu_queue_schedule.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc rename to ge/graph/load/new_model_manager/cpu_queue_schedule.cc diff --git a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.h b/ge/graph/load/new_model_manager/cpu_queue_schedule.h similarity index 100% rename from src/ge/graph/load/new_model_manager/cpu_queue_schedule.h rename to ge/graph/load/new_model_manager/cpu_queue_schedule.h diff --git a/src/ge/graph/load/new_model_manager/data_dumper.cc b/ge/graph/load/new_model_manager/data_dumper.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/data_dumper.cc rename to ge/graph/load/new_model_manager/data_dumper.cc diff --git a/src/ge/graph/load/new_model_manager/data_dumper.h b/ge/graph/load/new_model_manager/data_dumper.h similarity index 100% rename from src/ge/graph/load/new_model_manager/data_dumper.h rename to ge/graph/load/new_model_manager/data_dumper.h diff --git a/src/ge/graph/load/new_model_manager/data_inputer.cc b/ge/graph/load/new_model_manager/data_inputer.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/data_inputer.cc rename to ge/graph/load/new_model_manager/data_inputer.cc diff --git a/src/ge/graph/load/new_model_manager/data_inputer.h b/ge/graph/load/new_model_manager/data_inputer.h similarity index 100% rename from src/ge/graph/load/new_model_manager/data_inputer.h rename to ge/graph/load/new_model_manager/data_inputer.h diff --git a/src/ge/graph/load/new_model_manager/davinci_model.cc b/ge/graph/load/new_model_manager/davinci_model.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/davinci_model.cc rename to ge/graph/load/new_model_manager/davinci_model.cc diff --git a/src/ge/graph/load/new_model_manager/davinci_model.h b/ge/graph/load/new_model_manager/davinci_model.h similarity index 100% rename from src/ge/graph/load/new_model_manager/davinci_model.h rename to ge/graph/load/new_model_manager/davinci_model.h diff --git a/src/ge/graph/load/new_model_manager/davinci_model_parser.cc b/ge/graph/load/new_model_manager/davinci_model_parser.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/davinci_model_parser.cc rename to ge/graph/load/new_model_manager/davinci_model_parser.cc diff --git a/src/ge/graph/load/new_model_manager/davinci_model_parser.h b/ge/graph/load/new_model_manager/davinci_model_parser.h similarity index 100% rename from src/ge/graph/load/new_model_manager/davinci_model_parser.h rename to ge/graph/load/new_model_manager/davinci_model_parser.h diff --git a/src/ge/graph/load/new_model_manager/model_manager.cc b/ge/graph/load/new_model_manager/model_manager.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/model_manager.cc rename to ge/graph/load/new_model_manager/model_manager.cc diff --git a/src/ge/graph/load/new_model_manager/model_manager.h b/ge/graph/load/new_model_manager/model_manager.h similarity index 100% rename from src/ge/graph/load/new_model_manager/model_manager.h rename to ge/graph/load/new_model_manager/model_manager.h diff --git a/src/ge/graph/load/new_model_manager/model_utils.cc b/ge/graph/load/new_model_manager/model_utils.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/model_utils.cc rename to ge/graph/load/new_model_manager/model_utils.cc diff --git a/src/ge/graph/load/new_model_manager/model_utils.h b/ge/graph/load/new_model_manager/model_utils.h similarity index 100% rename from src/ge/graph/load/new_model_manager/model_utils.h rename to ge/graph/load/new_model_manager/model_utils.h diff --git a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc b/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc rename to ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h b/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h rename to ge/graph/load/new_model_manager/task_info/end_graph_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/event_record_task_info.cc b/ge/graph/load/new_model_manager/task_info/event_record_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/event_record_task_info.cc rename to ge/graph/load/new_model_manager/task_info/event_record_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/event_record_task_info.h b/ge/graph/load/new_model_manager/task_info/event_record_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/event_record_task_info.h rename to ge/graph/load/new_model_manager/task_info/event_record_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/event_wait_task_info.cc b/ge/graph/load/new_model_manager/task_info/event_wait_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/event_wait_task_info.cc rename to ge/graph/load/new_model_manager/task_info/event_wait_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/event_wait_task_info.h b/ge/graph/load/new_model_manager/task_info/event_wait_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/event_wait_task_info.h rename to ge/graph/load/new_model_manager/task_info/event_wait_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.cc b/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.cc rename to ge/graph/load/new_model_manager/task_info/fusion_start_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.h b/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.h rename to ge/graph/load/new_model_manager/task_info/fusion_start_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc b/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc rename to ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.h b/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.h rename to ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc b/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc rename to ge/graph/load/new_model_manager/task_info/hccl_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h b/ge/graph/load/new_model_manager/task_info/hccl_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h rename to ge/graph/load/new_model_manager/task_info/hccl_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc b/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc rename to ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h b/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h rename to ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc b/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc rename to ge/graph/load/new_model_manager/task_info/kernel_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h b/ge/graph/load/new_model_manager/task_info/kernel_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h rename to ge/graph/load/new_model_manager/task_info/kernel_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc b/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc rename to ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h b/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h rename to ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/label_set_task_info.cc b/ge/graph/load/new_model_manager/task_info/label_set_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/label_set_task_info.cc rename to ge/graph/load/new_model_manager/task_info/label_set_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/label_set_task_info.h b/ge/graph/load/new_model_manager/task_info/label_set_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/label_set_task_info.h rename to ge/graph/load/new_model_manager/task_info/label_set_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc b/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc rename to ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h b/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h rename to ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc b/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc rename to ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h b/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h rename to ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc b/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc rename to ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h b/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h rename to ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc b/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc rename to ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.h b/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.h rename to ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc b/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc rename to ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.h b/ge/graph/load/new_model_manager/task_info/stream_active_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.h rename to ge/graph/load/new_model_manager/task_info/stream_active_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc b/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc rename to ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h b/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h rename to ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc b/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc rename to ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h b/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h rename to ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc b/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc rename to ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h b/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h rename to ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc b/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc rename to ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h b/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h rename to ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h diff --git a/src/ge/graph/load/new_model_manager/task_info/task_info.cc b/ge/graph/load/new_model_manager/task_info/task_info.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/task_info.cc rename to ge/graph/load/new_model_manager/task_info/task_info.cc diff --git a/src/ge/graph/load/new_model_manager/task_info/task_info.h b/ge/graph/load/new_model_manager/task_info/task_info.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/task_info.h rename to ge/graph/load/new_model_manager/task_info/task_info.h diff --git a/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h b/ge/graph/load/new_model_manager/task_info/task_info_factory.h similarity index 100% rename from src/ge/graph/load/new_model_manager/task_info/task_info_factory.h rename to ge/graph/load/new_model_manager/task_info/task_info_factory.h diff --git a/src/ge/graph/load/new_model_manager/tbe_handle_store.cc b/ge/graph/load/new_model_manager/tbe_handle_store.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/tbe_handle_store.cc rename to ge/graph/load/new_model_manager/tbe_handle_store.cc diff --git a/src/ge/graph/load/new_model_manager/tbe_handle_store.h b/ge/graph/load/new_model_manager/tbe_handle_store.h similarity index 100% rename from src/ge/graph/load/new_model_manager/tbe_handle_store.h rename to ge/graph/load/new_model_manager/tbe_handle_store.h diff --git a/src/ge/graph/load/new_model_manager/zero_copy_offset.cc b/ge/graph/load/new_model_manager/zero_copy_offset.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/zero_copy_offset.cc rename to ge/graph/load/new_model_manager/zero_copy_offset.cc diff --git a/src/ge/graph/load/new_model_manager/zero_copy_offset.h b/ge/graph/load/new_model_manager/zero_copy_offset.h similarity index 100% rename from src/ge/graph/load/new_model_manager/zero_copy_offset.h rename to ge/graph/load/new_model_manager/zero_copy_offset.h diff --git a/src/ge/graph/load/new_model_manager/zero_copy_task.cc b/ge/graph/load/new_model_manager/zero_copy_task.cc similarity index 100% rename from src/ge/graph/load/new_model_manager/zero_copy_task.cc rename to ge/graph/load/new_model_manager/zero_copy_task.cc diff --git a/src/ge/graph/load/new_model_manager/zero_copy_task.h b/ge/graph/load/new_model_manager/zero_copy_task.h similarity index 100% rename from src/ge/graph/load/new_model_manager/zero_copy_task.h rename to ge/graph/load/new_model_manager/zero_copy_task.h diff --git a/src/ge/graph/manager/block_memory.h b/ge/graph/manager/block_memory.h similarity index 100% rename from src/ge/graph/manager/block_memory.h rename to ge/graph/manager/block_memory.h diff --git a/src/ge/graph/manager/graph_caching_allocator.cc b/ge/graph/manager/graph_caching_allocator.cc similarity index 100% rename from src/ge/graph/manager/graph_caching_allocator.cc rename to ge/graph/manager/graph_caching_allocator.cc diff --git a/src/ge/graph/manager/graph_caching_allocator.h b/ge/graph/manager/graph_caching_allocator.h similarity index 100% rename from src/ge/graph/manager/graph_caching_allocator.h rename to ge/graph/manager/graph_caching_allocator.h diff --git a/src/ge/graph/manager/graph_context.cc b/ge/graph/manager/graph_context.cc similarity index 100% rename from src/ge/graph/manager/graph_context.cc rename to ge/graph/manager/graph_context.cc diff --git a/src/ge/graph/manager/graph_context.h b/ge/graph/manager/graph_context.h similarity index 100% rename from src/ge/graph/manager/graph_context.h rename to ge/graph/manager/graph_context.h diff --git a/src/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc similarity index 100% rename from src/ge/graph/manager/graph_manager.cc rename to ge/graph/manager/graph_manager.cc diff --git a/src/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h similarity index 100% rename from src/ge/graph/manager/graph_manager.h rename to ge/graph/manager/graph_manager.h diff --git a/src/ge/graph/manager/graph_manager_utils.cc b/ge/graph/manager/graph_manager_utils.cc similarity index 100% rename from src/ge/graph/manager/graph_manager_utils.cc rename to ge/graph/manager/graph_manager_utils.cc diff --git a/src/ge/graph/manager/graph_manager_utils.h b/ge/graph/manager/graph_manager_utils.h similarity index 100% rename from src/ge/graph/manager/graph_manager_utils.h rename to ge/graph/manager/graph_manager_utils.h diff --git a/src/ge/graph/manager/graph_mem_allocator.cc b/ge/graph/manager/graph_mem_allocator.cc similarity index 100% rename from src/ge/graph/manager/graph_mem_allocator.cc rename to ge/graph/manager/graph_mem_allocator.cc diff --git a/src/ge/graph/manager/graph_mem_allocator.h b/ge/graph/manager/graph_mem_allocator.h similarity index 100% rename from src/ge/graph/manager/graph_mem_allocator.h rename to ge/graph/manager/graph_mem_allocator.h diff --git a/src/ge/graph/manager/graph_var_manager.cc b/ge/graph/manager/graph_var_manager.cc similarity index 100% rename from src/ge/graph/manager/graph_var_manager.cc rename to ge/graph/manager/graph_var_manager.cc diff --git a/src/ge/graph/manager/graph_var_manager.h b/ge/graph/manager/graph_var_manager.h similarity index 100% rename from src/ge/graph/manager/graph_var_manager.h rename to ge/graph/manager/graph_var_manager.h diff --git a/src/ge/graph/manager/host_mem_manager.cc b/ge/graph/manager/host_mem_manager.cc similarity index 100% rename from src/ge/graph/manager/host_mem_manager.cc rename to ge/graph/manager/host_mem_manager.cc diff --git a/src/ge/graph/manager/host_mem_manager.h b/ge/graph/manager/host_mem_manager.h similarity index 100% rename from src/ge/graph/manager/host_mem_manager.h rename to ge/graph/manager/host_mem_manager.h diff --git a/src/ge/graph/manager/memory_api.cc b/ge/graph/manager/memory_api.cc similarity index 100% rename from src/ge/graph/manager/memory_api.cc rename to ge/graph/manager/memory_api.cc diff --git a/src/ge/graph/manager/model_manager/event_manager.cc b/ge/graph/manager/model_manager/event_manager.cc similarity index 100% rename from src/ge/graph/manager/model_manager/event_manager.cc rename to ge/graph/manager/model_manager/event_manager.cc diff --git a/src/ge/graph/manager/model_manager/event_manager.h b/ge/graph/manager/model_manager/event_manager.h similarity index 100% rename from src/ge/graph/manager/model_manager/event_manager.h rename to ge/graph/manager/model_manager/event_manager.h diff --git a/src/ge/graph/manager/rdma_pool_allocator.cc b/ge/graph/manager/rdma_pool_allocator.cc similarity index 100% rename from src/ge/graph/manager/rdma_pool_allocator.cc rename to ge/graph/manager/rdma_pool_allocator.cc diff --git a/src/ge/graph/manager/rdma_pool_allocator.h b/ge/graph/manager/rdma_pool_allocator.h similarity index 100% rename from src/ge/graph/manager/rdma_pool_allocator.h rename to ge/graph/manager/rdma_pool_allocator.h diff --git a/src/ge/graph/manager/trans_var_data_utils.cc b/ge/graph/manager/trans_var_data_utils.cc similarity index 100% rename from src/ge/graph/manager/trans_var_data_utils.cc rename to ge/graph/manager/trans_var_data_utils.cc diff --git a/src/ge/graph/manager/trans_var_data_utils.h b/ge/graph/manager/trans_var_data_utils.h similarity index 100% rename from src/ge/graph/manager/trans_var_data_utils.h rename to ge/graph/manager/trans_var_data_utils.h diff --git a/src/ge/graph/manager/util/debug.cc b/ge/graph/manager/util/debug.cc similarity index 100% rename from src/ge/graph/manager/util/debug.cc rename to ge/graph/manager/util/debug.cc diff --git a/src/ge/graph/manager/util/debug.h b/ge/graph/manager/util/debug.h similarity index 100% rename from src/ge/graph/manager/util/debug.h rename to ge/graph/manager/util/debug.h diff --git a/src/ge/graph/manager/util/hcom_util.cc b/ge/graph/manager/util/hcom_util.cc similarity index 100% rename from src/ge/graph/manager/util/hcom_util.cc rename to ge/graph/manager/util/hcom_util.cc diff --git a/src/ge/graph/manager/util/hcom_util.h b/ge/graph/manager/util/hcom_util.h similarity index 100% rename from src/ge/graph/manager/util/hcom_util.h rename to ge/graph/manager/util/hcom_util.h diff --git a/src/ge/graph/manager/util/rt_context_util.cc b/ge/graph/manager/util/rt_context_util.cc similarity index 100% rename from src/ge/graph/manager/util/rt_context_util.cc rename to ge/graph/manager/util/rt_context_util.cc diff --git a/src/ge/graph/manager/util/rt_context_util.h b/ge/graph/manager/util/rt_context_util.h similarity index 100% rename from src/ge/graph/manager/util/rt_context_util.h rename to ge/graph/manager/util/rt_context_util.h diff --git a/src/ge/graph/manager/util/variable_accelerate_ctrl.cc b/ge/graph/manager/util/variable_accelerate_ctrl.cc similarity index 100% rename from src/ge/graph/manager/util/variable_accelerate_ctrl.cc rename to ge/graph/manager/util/variable_accelerate_ctrl.cc diff --git a/src/ge/graph/manager/util/variable_accelerate_ctrl.h b/ge/graph/manager/util/variable_accelerate_ctrl.h similarity index 100% rename from src/ge/graph/manager/util/variable_accelerate_ctrl.h rename to ge/graph/manager/util/variable_accelerate_ctrl.h diff --git a/src/ge/graph/optimize/common/params.h b/ge/graph/optimize/common/params.h similarity index 100% rename from src/ge/graph/optimize/common/params.h rename to ge/graph/optimize/common/params.h diff --git a/src/ge/graph/optimize/graph_optimize.cc b/ge/graph/optimize/graph_optimize.cc similarity index 100% rename from src/ge/graph/optimize/graph_optimize.cc rename to ge/graph/optimize/graph_optimize.cc diff --git a/src/ge/graph/optimize/graph_optimize.h b/ge/graph/optimize/graph_optimize.h similarity index 100% rename from src/ge/graph/optimize/graph_optimize.h rename to ge/graph/optimize/graph_optimize.h diff --git a/src/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc similarity index 100% rename from src/ge/graph/optimize/mem_rw_conflict_optimize.cc rename to ge/graph/optimize/mem_rw_conflict_optimize.cc diff --git a/src/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc b/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc similarity index 100% rename from src/ge/graph/optimize/optimizer/allreduce_fusion_pass.cc rename to ge/graph/optimize/optimizer/allreduce_fusion_pass.cc diff --git a/src/ge/graph/optimize/optimizer/allreduce_fusion_pass.h b/ge/graph/optimize/optimizer/allreduce_fusion_pass.h similarity index 100% rename from src/ge/graph/optimize/optimizer/allreduce_fusion_pass.h rename to ge/graph/optimize/optimizer/allreduce_fusion_pass.h diff --git a/src/ge/graph/optimize/summary_optimize.cc b/ge/graph/optimize/summary_optimize.cc similarity index 100% rename from src/ge/graph/optimize/summary_optimize.cc rename to ge/graph/optimize/summary_optimize.cc diff --git a/src/ge/graph/partition/dynamic_shape_partition.cc b/ge/graph/partition/dynamic_shape_partition.cc similarity index 100% rename from src/ge/graph/partition/dynamic_shape_partition.cc rename to ge/graph/partition/dynamic_shape_partition.cc diff --git a/src/ge/graph/partition/dynamic_shape_partition.h b/ge/graph/partition/dynamic_shape_partition.h similarity index 100% rename from src/ge/graph/partition/dynamic_shape_partition.h rename to ge/graph/partition/dynamic_shape_partition.h diff --git a/src/ge/graph/partition/engine_place.cc b/ge/graph/partition/engine_place.cc similarity index 100% rename from src/ge/graph/partition/engine_place.cc rename to ge/graph/partition/engine_place.cc diff --git a/src/ge/graph/partition/engine_place.h b/ge/graph/partition/engine_place.h similarity index 100% rename from src/ge/graph/partition/engine_place.h rename to ge/graph/partition/engine_place.h diff --git a/src/ge/graph/partition/graph_partition.cc b/ge/graph/partition/graph_partition.cc similarity index 100% rename from src/ge/graph/partition/graph_partition.cc rename to ge/graph/partition/graph_partition.cc diff --git a/src/ge/graph/partition/graph_partition.h b/ge/graph/partition/graph_partition.h similarity index 100% rename from src/ge/graph/partition/graph_partition.h rename to ge/graph/partition/graph_partition.h diff --git a/src/ge/graph/passes/addn_pass.cc b/ge/graph/passes/addn_pass.cc similarity index 100% rename from src/ge/graph/passes/addn_pass.cc rename to ge/graph/passes/addn_pass.cc diff --git a/src/ge/graph/passes/addn_pass.h b/ge/graph/passes/addn_pass.h similarity index 100% rename from src/ge/graph/passes/addn_pass.h rename to ge/graph/passes/addn_pass.h diff --git a/src/ge/graph/passes/aicpu_constant_folding_pass.cc b/ge/graph/passes/aicpu_constant_folding_pass.cc similarity index 100% rename from src/ge/graph/passes/aicpu_constant_folding_pass.cc rename to ge/graph/passes/aicpu_constant_folding_pass.cc diff --git a/src/ge/graph/passes/aicpu_constant_folding_pass.h b/ge/graph/passes/aicpu_constant_folding_pass.h similarity index 100% rename from src/ge/graph/passes/aicpu_constant_folding_pass.h rename to ge/graph/passes/aicpu_constant_folding_pass.h diff --git a/src/ge/graph/passes/assert_pass.cc b/ge/graph/passes/assert_pass.cc similarity index 100% rename from src/ge/graph/passes/assert_pass.cc rename to ge/graph/passes/assert_pass.cc diff --git a/src/ge/graph/passes/assert_pass.h b/ge/graph/passes/assert_pass.h similarity index 100% rename from src/ge/graph/passes/assert_pass.h rename to ge/graph/passes/assert_pass.h diff --git a/src/ge/graph/passes/assign_pass.cc b/ge/graph/passes/assign_pass.cc similarity index 100% rename from src/ge/graph/passes/assign_pass.cc rename to ge/graph/passes/assign_pass.cc diff --git a/src/ge/graph/passes/assign_pass.h b/ge/graph/passes/assign_pass.h similarity index 100% rename from src/ge/graph/passes/assign_pass.h rename to ge/graph/passes/assign_pass.h diff --git a/src/ge/graph/passes/atomic_addr_clean_pass.cc b/ge/graph/passes/atomic_addr_clean_pass.cc similarity index 100% rename from src/ge/graph/passes/atomic_addr_clean_pass.cc rename to ge/graph/passes/atomic_addr_clean_pass.cc diff --git a/src/ge/graph/passes/atomic_addr_clean_pass.h b/ge/graph/passes/atomic_addr_clean_pass.h similarity index 100% rename from src/ge/graph/passes/atomic_addr_clean_pass.h rename to ge/graph/passes/atomic_addr_clean_pass.h diff --git a/src/ge/graph/passes/attach_stream_label_pass.cc b/ge/graph/passes/attach_stream_label_pass.cc similarity index 100% rename from src/ge/graph/passes/attach_stream_label_pass.cc rename to ge/graph/passes/attach_stream_label_pass.cc diff --git a/src/ge/graph/passes/attach_stream_label_pass.h b/ge/graph/passes/attach_stream_label_pass.h similarity index 100% rename from src/ge/graph/passes/attach_stream_label_pass.h rename to ge/graph/passes/attach_stream_label_pass.h diff --git a/src/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc similarity index 100% rename from src/ge/graph/passes/base_pass.cc rename to ge/graph/passes/base_pass.cc diff --git a/src/ge/graph/passes/base_pass.h b/ge/graph/passes/base_pass.h similarity index 100% rename from src/ge/graph/passes/base_pass.h rename to ge/graph/passes/base_pass.h diff --git a/src/ge/graph/passes/bitcast_pass.cc b/ge/graph/passes/bitcast_pass.cc similarity index 100% rename from src/ge/graph/passes/bitcast_pass.cc rename to ge/graph/passes/bitcast_pass.cc diff --git a/src/ge/graph/passes/bitcast_pass.h b/ge/graph/passes/bitcast_pass.h similarity index 100% rename from src/ge/graph/passes/bitcast_pass.h rename to ge/graph/passes/bitcast_pass.h diff --git a/src/ge/graph/passes/cast_remove_pass.cc b/ge/graph/passes/cast_remove_pass.cc similarity index 100% rename from src/ge/graph/passes/cast_remove_pass.cc rename to ge/graph/passes/cast_remove_pass.cc diff --git a/src/ge/graph/passes/cast_remove_pass.h b/ge/graph/passes/cast_remove_pass.h similarity index 100% rename from src/ge/graph/passes/cast_remove_pass.h rename to ge/graph/passes/cast_remove_pass.h diff --git a/src/ge/graph/passes/cast_translate_pass.cc b/ge/graph/passes/cast_translate_pass.cc similarity index 100% rename from src/ge/graph/passes/cast_translate_pass.cc rename to ge/graph/passes/cast_translate_pass.cc diff --git a/src/ge/graph/passes/cast_translate_pass.h b/ge/graph/passes/cast_translate_pass.h similarity index 100% rename from src/ge/graph/passes/cast_translate_pass.h rename to ge/graph/passes/cast_translate_pass.h diff --git a/src/ge/graph/passes/common_subexpression_elimination_pass.cc b/ge/graph/passes/common_subexpression_elimination_pass.cc similarity index 100% rename from src/ge/graph/passes/common_subexpression_elimination_pass.cc rename to ge/graph/passes/common_subexpression_elimination_pass.cc diff --git a/src/ge/graph/passes/common_subexpression_elimination_pass.h b/ge/graph/passes/common_subexpression_elimination_pass.h similarity index 100% rename from src/ge/graph/passes/common_subexpression_elimination_pass.h rename to ge/graph/passes/common_subexpression_elimination_pass.h diff --git a/src/ge/graph/passes/compile_nodes_pass.cc b/ge/graph/passes/compile_nodes_pass.cc similarity index 100% rename from src/ge/graph/passes/compile_nodes_pass.cc rename to ge/graph/passes/compile_nodes_pass.cc diff --git a/src/ge/graph/passes/compile_nodes_pass.h b/ge/graph/passes/compile_nodes_pass.h similarity index 100% rename from src/ge/graph/passes/compile_nodes_pass.h rename to ge/graph/passes/compile_nodes_pass.h diff --git a/src/ge/graph/passes/cond_pass.cc b/ge/graph/passes/cond_pass.cc similarity index 100% rename from src/ge/graph/passes/cond_pass.cc rename to ge/graph/passes/cond_pass.cc diff --git a/src/ge/graph/passes/cond_pass.h b/ge/graph/passes/cond_pass.h similarity index 100% rename from src/ge/graph/passes/cond_pass.h rename to ge/graph/passes/cond_pass.h diff --git a/src/ge/graph/passes/cond_remove_pass.cc b/ge/graph/passes/cond_remove_pass.cc similarity index 100% rename from src/ge/graph/passes/cond_remove_pass.cc rename to ge/graph/passes/cond_remove_pass.cc diff --git a/src/ge/graph/passes/cond_remove_pass.h b/ge/graph/passes/cond_remove_pass.h similarity index 100% rename from src/ge/graph/passes/cond_remove_pass.h rename to ge/graph/passes/cond_remove_pass.h diff --git a/src/ge/graph/passes/constant_folding_pass.cc b/ge/graph/passes/constant_folding_pass.cc similarity index 100% rename from src/ge/graph/passes/constant_folding_pass.cc rename to ge/graph/passes/constant_folding_pass.cc diff --git a/src/ge/graph/passes/constant_folding_pass.h b/ge/graph/passes/constant_folding_pass.h similarity index 100% rename from src/ge/graph/passes/constant_folding_pass.h rename to ge/graph/passes/constant_folding_pass.h diff --git a/src/ge/graph/passes/constant_fuse_same_pass.cc b/ge/graph/passes/constant_fuse_same_pass.cc similarity index 100% rename from src/ge/graph/passes/constant_fuse_same_pass.cc rename to ge/graph/passes/constant_fuse_same_pass.cc diff --git a/src/ge/graph/passes/constant_fuse_same_pass.h b/ge/graph/passes/constant_fuse_same_pass.h similarity index 100% rename from src/ge/graph/passes/constant_fuse_same_pass.h rename to ge/graph/passes/constant_fuse_same_pass.h diff --git a/src/ge/graph/passes/control_trigger_pass.cc b/ge/graph/passes/control_trigger_pass.cc similarity index 100% rename from src/ge/graph/passes/control_trigger_pass.cc rename to ge/graph/passes/control_trigger_pass.cc diff --git a/src/ge/graph/passes/control_trigger_pass.h b/ge/graph/passes/control_trigger_pass.h similarity index 100% rename from src/ge/graph/passes/control_trigger_pass.h rename to ge/graph/passes/control_trigger_pass.h diff --git a/src/ge/graph/passes/ctrl_edge_transfer_pass.cc b/ge/graph/passes/ctrl_edge_transfer_pass.cc similarity index 100% rename from src/ge/graph/passes/ctrl_edge_transfer_pass.cc rename to ge/graph/passes/ctrl_edge_transfer_pass.cc diff --git a/src/ge/graph/passes/ctrl_edge_transfer_pass.h b/ge/graph/passes/ctrl_edge_transfer_pass.h similarity index 100% rename from src/ge/graph/passes/ctrl_edge_transfer_pass.h rename to ge/graph/passes/ctrl_edge_transfer_pass.h diff --git a/src/ge/graph/passes/data_pass.cc b/ge/graph/passes/data_pass.cc similarity index 100% rename from src/ge/graph/passes/data_pass.cc rename to ge/graph/passes/data_pass.cc diff --git a/src/ge/graph/passes/data_pass.h b/ge/graph/passes/data_pass.h similarity index 100% rename from src/ge/graph/passes/data_pass.h rename to ge/graph/passes/data_pass.h diff --git a/src/ge/graph/passes/dimension_adjust_pass.cc b/ge/graph/passes/dimension_adjust_pass.cc similarity index 100% rename from src/ge/graph/passes/dimension_adjust_pass.cc rename to ge/graph/passes/dimension_adjust_pass.cc diff --git a/src/ge/graph/passes/dimension_adjust_pass.h b/ge/graph/passes/dimension_adjust_pass.h similarity index 100% rename from src/ge/graph/passes/dimension_adjust_pass.h rename to ge/graph/passes/dimension_adjust_pass.h diff --git a/src/ge/graph/passes/dimension_compute_pass.cc b/ge/graph/passes/dimension_compute_pass.cc similarity index 100% rename from src/ge/graph/passes/dimension_compute_pass.cc rename to ge/graph/passes/dimension_compute_pass.cc diff --git a/src/ge/graph/passes/dimension_compute_pass.h b/ge/graph/passes/dimension_compute_pass.h similarity index 100% rename from src/ge/graph/passes/dimension_compute_pass.h rename to ge/graph/passes/dimension_compute_pass.h diff --git a/src/ge/graph/passes/dropout_pass.cc b/ge/graph/passes/dropout_pass.cc similarity index 100% rename from src/ge/graph/passes/dropout_pass.cc rename to ge/graph/passes/dropout_pass.cc diff --git a/src/ge/graph/passes/dropout_pass.h b/ge/graph/passes/dropout_pass.h similarity index 100% rename from src/ge/graph/passes/dropout_pass.h rename to ge/graph/passes/dropout_pass.h diff --git a/src/ge/graph/passes/end_of_sequence_add_control_pass.cc b/ge/graph/passes/end_of_sequence_add_control_pass.cc similarity index 100% rename from src/ge/graph/passes/end_of_sequence_add_control_pass.cc rename to ge/graph/passes/end_of_sequence_add_control_pass.cc diff --git a/src/ge/graph/passes/end_of_sequence_add_control_pass.h b/ge/graph/passes/end_of_sequence_add_control_pass.h similarity index 100% rename from src/ge/graph/passes/end_of_sequence_add_control_pass.h rename to ge/graph/passes/end_of_sequence_add_control_pass.h diff --git a/src/ge/graph/passes/enter_pass.cc b/ge/graph/passes/enter_pass.cc similarity index 100% rename from src/ge/graph/passes/enter_pass.cc rename to ge/graph/passes/enter_pass.cc diff --git a/src/ge/graph/passes/enter_pass.h b/ge/graph/passes/enter_pass.h similarity index 100% rename from src/ge/graph/passes/enter_pass.h rename to ge/graph/passes/enter_pass.h diff --git a/src/ge/graph/passes/flow_ctrl_pass.cc b/ge/graph/passes/flow_ctrl_pass.cc similarity index 100% rename from src/ge/graph/passes/flow_ctrl_pass.cc rename to ge/graph/passes/flow_ctrl_pass.cc diff --git a/src/ge/graph/passes/flow_ctrl_pass.h b/ge/graph/passes/flow_ctrl_pass.h similarity index 100% rename from src/ge/graph/passes/flow_ctrl_pass.h rename to ge/graph/passes/flow_ctrl_pass.h diff --git a/src/ge/graph/passes/folding_pass.cc b/ge/graph/passes/folding_pass.cc similarity index 100% rename from src/ge/graph/passes/folding_pass.cc rename to ge/graph/passes/folding_pass.cc diff --git a/src/ge/graph/passes/folding_pass.h b/ge/graph/passes/folding_pass.h similarity index 100% rename from src/ge/graph/passes/folding_pass.h rename to ge/graph/passes/folding_pass.h diff --git a/src/ge/graph/passes/for_pass.cc b/ge/graph/passes/for_pass.cc similarity index 100% rename from src/ge/graph/passes/for_pass.cc rename to ge/graph/passes/for_pass.cc diff --git a/src/ge/graph/passes/for_pass.h b/ge/graph/passes/for_pass.h similarity index 100% rename from src/ge/graph/passes/for_pass.h rename to ge/graph/passes/for_pass.h diff --git a/src/ge/graph/passes/get_original_format_pass.cc b/ge/graph/passes/get_original_format_pass.cc similarity index 100% rename from src/ge/graph/passes/get_original_format_pass.cc rename to ge/graph/passes/get_original_format_pass.cc diff --git a/src/ge/graph/passes/get_original_format_pass.h b/ge/graph/passes/get_original_format_pass.h similarity index 100% rename from src/ge/graph/passes/get_original_format_pass.h rename to ge/graph/passes/get_original_format_pass.h diff --git a/src/ge/graph/passes/global_step_insert_pass.cc b/ge/graph/passes/global_step_insert_pass.cc similarity index 100% rename from src/ge/graph/passes/global_step_insert_pass.cc rename to ge/graph/passes/global_step_insert_pass.cc diff --git a/src/ge/graph/passes/global_step_insert_pass.h b/ge/graph/passes/global_step_insert_pass.h similarity index 100% rename from src/ge/graph/passes/global_step_insert_pass.h rename to ge/graph/passes/global_step_insert_pass.h diff --git a/src/ge/graph/passes/guarantee_const_pass.cc b/ge/graph/passes/guarantee_const_pass.cc similarity index 100% rename from src/ge/graph/passes/guarantee_const_pass.cc rename to ge/graph/passes/guarantee_const_pass.cc diff --git a/src/ge/graph/passes/guarantee_const_pass.h b/ge/graph/passes/guarantee_const_pass.h similarity index 100% rename from src/ge/graph/passes/guarantee_const_pass.h rename to ge/graph/passes/guarantee_const_pass.h diff --git a/src/ge/graph/passes/hccl_group_pass.cc b/ge/graph/passes/hccl_group_pass.cc similarity index 100% rename from src/ge/graph/passes/hccl_group_pass.cc rename to ge/graph/passes/hccl_group_pass.cc diff --git a/src/ge/graph/passes/hccl_group_pass.h b/ge/graph/passes/hccl_group_pass.h similarity index 100% rename from src/ge/graph/passes/hccl_group_pass.h rename to ge/graph/passes/hccl_group_pass.h diff --git a/src/ge/graph/passes/hccl_memcpy_pass.cc b/ge/graph/passes/hccl_memcpy_pass.cc similarity index 100% rename from src/ge/graph/passes/hccl_memcpy_pass.cc rename to ge/graph/passes/hccl_memcpy_pass.cc diff --git a/src/ge/graph/passes/hccl_memcpy_pass.h b/ge/graph/passes/hccl_memcpy_pass.h similarity index 100% rename from src/ge/graph/passes/hccl_memcpy_pass.h rename to ge/graph/passes/hccl_memcpy_pass.h diff --git a/src/ge/graph/passes/identity_pass.cc b/ge/graph/passes/identity_pass.cc similarity index 100% rename from src/ge/graph/passes/identity_pass.cc rename to ge/graph/passes/identity_pass.cc diff --git a/src/ge/graph/passes/identity_pass.h b/ge/graph/passes/identity_pass.h similarity index 99% rename from src/ge/graph/passes/identity_pass.h rename to ge/graph/passes/identity_pass.h index 21d990d5..a4a80efc 100644 --- a/src/ge/graph/passes/identity_pass.h +++ b/ge/graph/passes/identity_pass.h @@ -25,7 +25,6 @@ class IdentityPass : public BaseNodePass { explicit IdentityPass(bool force) : force_(force) {} ~IdentityPass() override = default; Status Run(NodePtr &node) override; - private: bool force_ = false; }; diff --git a/src/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc similarity index 100% rename from src/ge/graph/passes/infershape_pass.cc rename to ge/graph/passes/infershape_pass.cc diff --git a/src/ge/graph/passes/infershape_pass.h b/ge/graph/passes/infershape_pass.h similarity index 100% rename from src/ge/graph/passes/infershape_pass.h rename to ge/graph/passes/infershape_pass.h diff --git a/src/ge/graph/passes/input_output_connection_identify_pass.cc b/ge/graph/passes/input_output_connection_identify_pass.cc similarity index 100% rename from src/ge/graph/passes/input_output_connection_identify_pass.cc rename to ge/graph/passes/input_output_connection_identify_pass.cc diff --git a/src/ge/graph/passes/input_output_connection_identify_pass.h b/ge/graph/passes/input_output_connection_identify_pass.h similarity index 100% rename from src/ge/graph/passes/input_output_connection_identify_pass.h rename to ge/graph/passes/input_output_connection_identify_pass.h diff --git a/src/ge/graph/passes/isolated_op_remove_pass.cc b/ge/graph/passes/isolated_op_remove_pass.cc similarity index 100% rename from src/ge/graph/passes/isolated_op_remove_pass.cc rename to ge/graph/passes/isolated_op_remove_pass.cc diff --git a/src/ge/graph/passes/isolated_op_remove_pass.h b/ge/graph/passes/isolated_op_remove_pass.h similarity index 100% rename from src/ge/graph/passes/isolated_op_remove_pass.h rename to ge/graph/passes/isolated_op_remove_pass.h diff --git a/src/ge/graph/passes/iterator_op_pass.cc b/ge/graph/passes/iterator_op_pass.cc similarity index 100% rename from src/ge/graph/passes/iterator_op_pass.cc rename to ge/graph/passes/iterator_op_pass.cc diff --git a/src/ge/graph/passes/iterator_op_pass.h b/ge/graph/passes/iterator_op_pass.h similarity index 100% rename from src/ge/graph/passes/iterator_op_pass.h rename to ge/graph/passes/iterator_op_pass.h diff --git a/src/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc similarity index 100% rename from src/ge/graph/passes/link_gen_mask_nodes_pass.cc rename to ge/graph/passes/link_gen_mask_nodes_pass.cc diff --git a/src/ge/graph/passes/link_gen_mask_nodes_pass.h b/ge/graph/passes/link_gen_mask_nodes_pass.h similarity index 100% rename from src/ge/graph/passes/link_gen_mask_nodes_pass.h rename to ge/graph/passes/link_gen_mask_nodes_pass.h diff --git a/src/ge/graph/passes/mark_agnostic_pass.cc b/ge/graph/passes/mark_agnostic_pass.cc similarity index 100% rename from src/ge/graph/passes/mark_agnostic_pass.cc rename to ge/graph/passes/mark_agnostic_pass.cc diff --git a/src/ge/graph/passes/mark_agnostic_pass.h b/ge/graph/passes/mark_agnostic_pass.h similarity index 100% rename from src/ge/graph/passes/mark_agnostic_pass.h rename to ge/graph/passes/mark_agnostic_pass.h diff --git a/src/ge/graph/passes/mark_graph_unknown_status_pass.cc b/ge/graph/passes/mark_graph_unknown_status_pass.cc similarity index 100% rename from src/ge/graph/passes/mark_graph_unknown_status_pass.cc rename to ge/graph/passes/mark_graph_unknown_status_pass.cc diff --git a/src/ge/graph/passes/mark_graph_unknown_status_pass.h b/ge/graph/passes/mark_graph_unknown_status_pass.h similarity index 100% rename from src/ge/graph/passes/mark_graph_unknown_status_pass.h rename to ge/graph/passes/mark_graph_unknown_status_pass.h diff --git a/src/ge/graph/passes/mark_same_addr_pass.cc b/ge/graph/passes/mark_same_addr_pass.cc similarity index 100% rename from src/ge/graph/passes/mark_same_addr_pass.cc rename to ge/graph/passes/mark_same_addr_pass.cc diff --git a/src/ge/graph/passes/mark_same_addr_pass.h b/ge/graph/passes/mark_same_addr_pass.h similarity index 100% rename from src/ge/graph/passes/mark_same_addr_pass.h rename to ge/graph/passes/mark_same_addr_pass.h diff --git a/src/ge/graph/passes/memcpy_addr_async_pass.cc b/ge/graph/passes/memcpy_addr_async_pass.cc similarity index 100% rename from src/ge/graph/passes/memcpy_addr_async_pass.cc rename to ge/graph/passes/memcpy_addr_async_pass.cc diff --git a/src/ge/graph/passes/memcpy_addr_async_pass.h b/ge/graph/passes/memcpy_addr_async_pass.h similarity index 100% rename from src/ge/graph/passes/memcpy_addr_async_pass.h rename to ge/graph/passes/memcpy_addr_async_pass.h diff --git a/src/ge/graph/passes/merge_pass.cc b/ge/graph/passes/merge_pass.cc similarity index 100% rename from src/ge/graph/passes/merge_pass.cc rename to ge/graph/passes/merge_pass.cc diff --git a/src/ge/graph/passes/merge_pass.h b/ge/graph/passes/merge_pass.h similarity index 100% rename from src/ge/graph/passes/merge_pass.h rename to ge/graph/passes/merge_pass.h diff --git a/src/ge/graph/passes/merge_to_stream_merge_pass.cc b/ge/graph/passes/merge_to_stream_merge_pass.cc similarity index 100% rename from src/ge/graph/passes/merge_to_stream_merge_pass.cc rename to ge/graph/passes/merge_to_stream_merge_pass.cc diff --git a/src/ge/graph/passes/merge_to_stream_merge_pass.h b/ge/graph/passes/merge_to_stream_merge_pass.h similarity index 100% rename from src/ge/graph/passes/merge_to_stream_merge_pass.h rename to ge/graph/passes/merge_to_stream_merge_pass.h diff --git a/src/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc similarity index 100% rename from src/ge/graph/passes/multi_batch_clone_pass.cc rename to ge/graph/passes/multi_batch_clone_pass.cc diff --git a/src/ge/graph/passes/multi_batch_clone_pass.h b/ge/graph/passes/multi_batch_clone_pass.h similarity index 100% rename from src/ge/graph/passes/multi_batch_clone_pass.h rename to ge/graph/passes/multi_batch_clone_pass.h diff --git a/src/ge/graph/passes/multi_batch_pass.cc b/ge/graph/passes/multi_batch_pass.cc similarity index 100% rename from src/ge/graph/passes/multi_batch_pass.cc rename to ge/graph/passes/multi_batch_pass.cc diff --git a/src/ge/graph/passes/multi_batch_pass.h b/ge/graph/passes/multi_batch_pass.h similarity index 100% rename from src/ge/graph/passes/multi_batch_pass.h rename to ge/graph/passes/multi_batch_pass.h diff --git a/src/ge/graph/passes/net_output_pass.cc b/ge/graph/passes/net_output_pass.cc similarity index 100% rename from src/ge/graph/passes/net_output_pass.cc rename to ge/graph/passes/net_output_pass.cc diff --git a/src/ge/graph/passes/net_output_pass.h b/ge/graph/passes/net_output_pass.h similarity index 100% rename from src/ge/graph/passes/net_output_pass.h rename to ge/graph/passes/net_output_pass.h diff --git a/src/ge/graph/passes/next_iteration_pass.cc b/ge/graph/passes/next_iteration_pass.cc similarity index 100% rename from src/ge/graph/passes/next_iteration_pass.cc rename to ge/graph/passes/next_iteration_pass.cc diff --git a/src/ge/graph/passes/next_iteration_pass.h b/ge/graph/passes/next_iteration_pass.h similarity index 100% rename from src/ge/graph/passes/next_iteration_pass.h rename to ge/graph/passes/next_iteration_pass.h diff --git a/src/ge/graph/passes/no_use_reshape_remove_pass.cc b/ge/graph/passes/no_use_reshape_remove_pass.cc similarity index 100% rename from src/ge/graph/passes/no_use_reshape_remove_pass.cc rename to ge/graph/passes/no_use_reshape_remove_pass.cc diff --git a/src/ge/graph/passes/no_use_reshape_remove_pass.h b/ge/graph/passes/no_use_reshape_remove_pass.h similarity index 100% rename from src/ge/graph/passes/no_use_reshape_remove_pass.h rename to ge/graph/passes/no_use_reshape_remove_pass.h diff --git a/src/ge/graph/passes/parallel_concat_start_op_pass.cc b/ge/graph/passes/parallel_concat_start_op_pass.cc similarity index 100% rename from src/ge/graph/passes/parallel_concat_start_op_pass.cc rename to ge/graph/passes/parallel_concat_start_op_pass.cc diff --git a/src/ge/graph/passes/parallel_concat_start_op_pass.h b/ge/graph/passes/parallel_concat_start_op_pass.h similarity index 100% rename from src/ge/graph/passes/parallel_concat_start_op_pass.h rename to ge/graph/passes/parallel_concat_start_op_pass.h diff --git a/src/ge/graph/passes/pass_manager.cc b/ge/graph/passes/pass_manager.cc similarity index 100% rename from src/ge/graph/passes/pass_manager.cc rename to ge/graph/passes/pass_manager.cc diff --git a/src/ge/graph/passes/pass_utils.cc b/ge/graph/passes/pass_utils.cc similarity index 100% rename from src/ge/graph/passes/pass_utils.cc rename to ge/graph/passes/pass_utils.cc diff --git a/src/ge/graph/passes/pass_utils.h b/ge/graph/passes/pass_utils.h similarity index 100% rename from src/ge/graph/passes/pass_utils.h rename to ge/graph/passes/pass_utils.h diff --git a/src/ge/graph/passes/permute_pass.cc b/ge/graph/passes/permute_pass.cc similarity index 100% rename from src/ge/graph/passes/permute_pass.cc rename to ge/graph/passes/permute_pass.cc diff --git a/src/ge/graph/passes/permute_pass.h b/ge/graph/passes/permute_pass.h similarity index 100% rename from src/ge/graph/passes/permute_pass.h rename to ge/graph/passes/permute_pass.h diff --git a/src/ge/graph/passes/placeholder_with_default_pass.cc b/ge/graph/passes/placeholder_with_default_pass.cc similarity index 100% rename from src/ge/graph/passes/placeholder_with_default_pass.cc rename to ge/graph/passes/placeholder_with_default_pass.cc diff --git a/src/ge/graph/passes/placeholder_with_default_pass.h b/ge/graph/passes/placeholder_with_default_pass.h similarity index 100% rename from src/ge/graph/passes/placeholder_with_default_pass.h rename to ge/graph/passes/placeholder_with_default_pass.h diff --git a/src/ge/graph/passes/prevent_gradient_pass.cc b/ge/graph/passes/prevent_gradient_pass.cc similarity index 100% rename from src/ge/graph/passes/prevent_gradient_pass.cc rename to ge/graph/passes/prevent_gradient_pass.cc diff --git a/src/ge/graph/passes/prevent_gradient_pass.h b/ge/graph/passes/prevent_gradient_pass.h similarity index 100% rename from src/ge/graph/passes/prevent_gradient_pass.h rename to ge/graph/passes/prevent_gradient_pass.h diff --git a/src/ge/graph/passes/print_op_pass.cc b/ge/graph/passes/print_op_pass.cc similarity index 100% rename from src/ge/graph/passes/print_op_pass.cc rename to ge/graph/passes/print_op_pass.cc diff --git a/src/ge/graph/passes/print_op_pass.h b/ge/graph/passes/print_op_pass.h similarity index 100% rename from src/ge/graph/passes/print_op_pass.h rename to ge/graph/passes/print_op_pass.h diff --git a/src/ge/graph/passes/prune_pass.cc b/ge/graph/passes/prune_pass.cc similarity index 100% rename from src/ge/graph/passes/prune_pass.cc rename to ge/graph/passes/prune_pass.cc diff --git a/src/ge/graph/passes/prune_pass.h b/ge/graph/passes/prune_pass.h similarity index 100% rename from src/ge/graph/passes/prune_pass.h rename to ge/graph/passes/prune_pass.h diff --git a/src/ge/graph/passes/ref_identity_delete_op_pass.cc b/ge/graph/passes/ref_identity_delete_op_pass.cc similarity index 100% rename from src/ge/graph/passes/ref_identity_delete_op_pass.cc rename to ge/graph/passes/ref_identity_delete_op_pass.cc diff --git a/src/ge/graph/passes/ref_identity_delete_op_pass.h b/ge/graph/passes/ref_identity_delete_op_pass.h similarity index 100% rename from src/ge/graph/passes/ref_identity_delete_op_pass.h rename to ge/graph/passes/ref_identity_delete_op_pass.h diff --git a/src/ge/graph/passes/remove_nodes_pass.cc b/ge/graph/passes/remove_nodes_pass.cc similarity index 100% rename from src/ge/graph/passes/remove_nodes_pass.cc rename to ge/graph/passes/remove_nodes_pass.cc diff --git a/src/ge/graph/passes/remove_nodes_pass.h b/ge/graph/passes/remove_nodes_pass.h similarity index 100% rename from src/ge/graph/passes/remove_nodes_pass.h rename to ge/graph/passes/remove_nodes_pass.h diff --git a/src/ge/graph/passes/replace_transshape_pass.cc b/ge/graph/passes/replace_transshape_pass.cc similarity index 100% rename from src/ge/graph/passes/replace_transshape_pass.cc rename to ge/graph/passes/replace_transshape_pass.cc diff --git a/src/ge/graph/passes/replace_transshape_pass.h b/ge/graph/passes/replace_transshape_pass.h similarity index 100% rename from src/ge/graph/passes/replace_transshape_pass.h rename to ge/graph/passes/replace_transshape_pass.h diff --git a/src/ge/graph/passes/replace_with_empty_const_pass.cc b/ge/graph/passes/replace_with_empty_const_pass.cc similarity index 100% rename from src/ge/graph/passes/replace_with_empty_const_pass.cc rename to ge/graph/passes/replace_with_empty_const_pass.cc diff --git a/src/ge/graph/passes/replace_with_empty_const_pass.h b/ge/graph/passes/replace_with_empty_const_pass.h similarity index 100% rename from src/ge/graph/passes/replace_with_empty_const_pass.h rename to ge/graph/passes/replace_with_empty_const_pass.h diff --git a/src/ge/graph/passes/reshape_recovery_pass.cc b/ge/graph/passes/reshape_recovery_pass.cc similarity index 100% rename from src/ge/graph/passes/reshape_recovery_pass.cc rename to ge/graph/passes/reshape_recovery_pass.cc diff --git a/src/ge/graph/passes/reshape_recovery_pass.h b/ge/graph/passes/reshape_recovery_pass.h similarity index 100% rename from src/ge/graph/passes/reshape_recovery_pass.h rename to ge/graph/passes/reshape_recovery_pass.h diff --git a/src/ge/graph/passes/reshape_remove_pass.cc b/ge/graph/passes/reshape_remove_pass.cc similarity index 100% rename from src/ge/graph/passes/reshape_remove_pass.cc rename to ge/graph/passes/reshape_remove_pass.cc diff --git a/src/ge/graph/passes/reshape_remove_pass.h b/ge/graph/passes/reshape_remove_pass.h similarity index 100% rename from src/ge/graph/passes/reshape_remove_pass.h rename to ge/graph/passes/reshape_remove_pass.h diff --git a/src/ge/graph/passes/resource_pair_add_control_pass.cc b/ge/graph/passes/resource_pair_add_control_pass.cc similarity index 100% rename from src/ge/graph/passes/resource_pair_add_control_pass.cc rename to ge/graph/passes/resource_pair_add_control_pass.cc diff --git a/src/ge/graph/passes/resource_pair_add_control_pass.h b/ge/graph/passes/resource_pair_add_control_pass.h similarity index 100% rename from src/ge/graph/passes/resource_pair_add_control_pass.h rename to ge/graph/passes/resource_pair_add_control_pass.h diff --git a/src/ge/graph/passes/resource_pair_remove_control_pass.cc b/ge/graph/passes/resource_pair_remove_control_pass.cc similarity index 100% rename from src/ge/graph/passes/resource_pair_remove_control_pass.cc rename to ge/graph/passes/resource_pair_remove_control_pass.cc diff --git a/src/ge/graph/passes/resource_pair_remove_control_pass.h b/ge/graph/passes/resource_pair_remove_control_pass.h similarity index 100% rename from src/ge/graph/passes/resource_pair_remove_control_pass.h rename to ge/graph/passes/resource_pair_remove_control_pass.h diff --git a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc b/ge/graph/passes/same_transdata_breadth_fusion_pass.cc similarity index 100% rename from src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc rename to ge/graph/passes/same_transdata_breadth_fusion_pass.cc diff --git a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.h b/ge/graph/passes/same_transdata_breadth_fusion_pass.h similarity index 100% rename from src/ge/graph/passes/same_transdata_breadth_fusion_pass.h rename to ge/graph/passes/same_transdata_breadth_fusion_pass.h diff --git a/src/ge/graph/passes/save_pass.cc b/ge/graph/passes/save_pass.cc similarity index 100% rename from src/ge/graph/passes/save_pass.cc rename to ge/graph/passes/save_pass.cc diff --git a/src/ge/graph/passes/save_pass.h b/ge/graph/passes/save_pass.h similarity index 100% rename from src/ge/graph/passes/save_pass.h rename to ge/graph/passes/save_pass.h diff --git a/src/ge/graph/passes/set_input_output_offset_pass.cc b/ge/graph/passes/set_input_output_offset_pass.cc similarity index 100% rename from src/ge/graph/passes/set_input_output_offset_pass.cc rename to ge/graph/passes/set_input_output_offset_pass.cc diff --git a/src/ge/graph/passes/set_input_output_offset_pass.h b/ge/graph/passes/set_input_output_offset_pass.h similarity index 100% rename from src/ge/graph/passes/set_input_output_offset_pass.h rename to ge/graph/passes/set_input_output_offset_pass.h diff --git a/src/ge/graph/passes/shape_operate_op_remove_pass.cc b/ge/graph/passes/shape_operate_op_remove_pass.cc similarity index 100% rename from src/ge/graph/passes/shape_operate_op_remove_pass.cc rename to ge/graph/passes/shape_operate_op_remove_pass.cc diff --git a/src/ge/graph/passes/shape_operate_op_remove_pass.h b/ge/graph/passes/shape_operate_op_remove_pass.h similarity index 100% rename from src/ge/graph/passes/shape_operate_op_remove_pass.h rename to ge/graph/passes/shape_operate_op_remove_pass.h diff --git a/src/ge/graph/passes/snapshot_pass.cc b/ge/graph/passes/snapshot_pass.cc similarity index 100% rename from src/ge/graph/passes/snapshot_pass.cc rename to ge/graph/passes/snapshot_pass.cc diff --git a/src/ge/graph/passes/snapshot_pass.h b/ge/graph/passes/snapshot_pass.h similarity index 100% rename from src/ge/graph/passes/snapshot_pass.h rename to ge/graph/passes/snapshot_pass.h diff --git a/src/ge/graph/passes/stop_gradient_pass.cc b/ge/graph/passes/stop_gradient_pass.cc similarity index 100% rename from src/ge/graph/passes/stop_gradient_pass.cc rename to ge/graph/passes/stop_gradient_pass.cc diff --git a/src/ge/graph/passes/stop_gradient_pass.h b/ge/graph/passes/stop_gradient_pass.h similarity index 100% rename from src/ge/graph/passes/stop_gradient_pass.h rename to ge/graph/passes/stop_gradient_pass.h diff --git a/src/ge/graph/passes/subexpression_migration_pass.cc b/ge/graph/passes/subexpression_migration_pass.cc similarity index 100% rename from src/ge/graph/passes/subexpression_migration_pass.cc rename to ge/graph/passes/subexpression_migration_pass.cc diff --git a/src/ge/graph/passes/subexpression_migration_pass.h b/ge/graph/passes/subexpression_migration_pass.h similarity index 100% rename from src/ge/graph/passes/subexpression_migration_pass.h rename to ge/graph/passes/subexpression_migration_pass.h diff --git a/src/ge/graph/passes/subgraph_pass.cc b/ge/graph/passes/subgraph_pass.cc similarity index 100% rename from src/ge/graph/passes/subgraph_pass.cc rename to ge/graph/passes/subgraph_pass.cc diff --git a/src/ge/graph/passes/subgraph_pass.h b/ge/graph/passes/subgraph_pass.h similarity index 100% rename from src/ge/graph/passes/subgraph_pass.h rename to ge/graph/passes/subgraph_pass.h diff --git a/src/ge/graph/passes/switch_data_edges_bypass.cc b/ge/graph/passes/switch_data_edges_bypass.cc similarity index 100% rename from src/ge/graph/passes/switch_data_edges_bypass.cc rename to ge/graph/passes/switch_data_edges_bypass.cc diff --git a/src/ge/graph/passes/switch_data_edges_bypass.h b/ge/graph/passes/switch_data_edges_bypass.h similarity index 100% rename from src/ge/graph/passes/switch_data_edges_bypass.h rename to ge/graph/passes/switch_data_edges_bypass.h diff --git a/src/ge/graph/passes/switch_dead_branch_elimination.cc b/ge/graph/passes/switch_dead_branch_elimination.cc similarity index 100% rename from src/ge/graph/passes/switch_dead_branch_elimination.cc rename to ge/graph/passes/switch_dead_branch_elimination.cc diff --git a/src/ge/graph/passes/switch_dead_branch_elimination.h b/ge/graph/passes/switch_dead_branch_elimination.h similarity index 100% rename from src/ge/graph/passes/switch_dead_branch_elimination.h rename to ge/graph/passes/switch_dead_branch_elimination.h diff --git a/src/ge/graph/passes/switch_logic_remove_pass.cc b/ge/graph/passes/switch_logic_remove_pass.cc similarity index 100% rename from src/ge/graph/passes/switch_logic_remove_pass.cc rename to ge/graph/passes/switch_logic_remove_pass.cc diff --git a/src/ge/graph/passes/switch_logic_remove_pass.h b/ge/graph/passes/switch_logic_remove_pass.h similarity index 100% rename from src/ge/graph/passes/switch_logic_remove_pass.h rename to ge/graph/passes/switch_logic_remove_pass.h diff --git a/src/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc similarity index 100% rename from src/ge/graph/passes/switch_to_stream_switch_pass.cc rename to ge/graph/passes/switch_to_stream_switch_pass.cc diff --git a/src/ge/graph/passes/switch_to_stream_switch_pass.h b/ge/graph/passes/switch_to_stream_switch_pass.h similarity index 100% rename from src/ge/graph/passes/switch_to_stream_switch_pass.h rename to ge/graph/passes/switch_to_stream_switch_pass.h diff --git a/src/ge/graph/passes/transop_breadth_fusion_pass.cc b/ge/graph/passes/transop_breadth_fusion_pass.cc similarity index 100% rename from src/ge/graph/passes/transop_breadth_fusion_pass.cc rename to ge/graph/passes/transop_breadth_fusion_pass.cc diff --git a/src/ge/graph/passes/transop_breadth_fusion_pass.h b/ge/graph/passes/transop_breadth_fusion_pass.h similarity index 100% rename from src/ge/graph/passes/transop_breadth_fusion_pass.h rename to ge/graph/passes/transop_breadth_fusion_pass.h diff --git a/src/ge/graph/passes/transop_depth_fusion_pass.cc b/ge/graph/passes/transop_depth_fusion_pass.cc similarity index 100% rename from src/ge/graph/passes/transop_depth_fusion_pass.cc rename to ge/graph/passes/transop_depth_fusion_pass.cc diff --git a/src/ge/graph/passes/transop_depth_fusion_pass.h b/ge/graph/passes/transop_depth_fusion_pass.h similarity index 100% rename from src/ge/graph/passes/transop_depth_fusion_pass.h rename to ge/graph/passes/transop_depth_fusion_pass.h diff --git a/src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc b/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc similarity index 100% rename from src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc rename to ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc diff --git a/src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.h b/ge/graph/passes/transop_nearby_allreduce_fusion_pass.h similarity index 100% rename from src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.h rename to ge/graph/passes/transop_nearby_allreduce_fusion_pass.h diff --git a/src/ge/graph/passes/transop_symmetry_elimination_pass.cc b/ge/graph/passes/transop_symmetry_elimination_pass.cc similarity index 100% rename from src/ge/graph/passes/transop_symmetry_elimination_pass.cc rename to ge/graph/passes/transop_symmetry_elimination_pass.cc diff --git a/src/ge/graph/passes/transop_symmetry_elimination_pass.h b/ge/graph/passes/transop_symmetry_elimination_pass.h similarity index 100% rename from src/ge/graph/passes/transop_symmetry_elimination_pass.h rename to ge/graph/passes/transop_symmetry_elimination_pass.h diff --git a/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc b/ge/graph/passes/transop_without_reshape_fusion_pass.cc similarity index 100% rename from src/ge/graph/passes/transop_without_reshape_fusion_pass.cc rename to ge/graph/passes/transop_without_reshape_fusion_pass.cc diff --git a/src/ge/graph/passes/transop_without_reshape_fusion_pass.h b/ge/graph/passes/transop_without_reshape_fusion_pass.h similarity index 100% rename from src/ge/graph/passes/transop_without_reshape_fusion_pass.h rename to ge/graph/passes/transop_without_reshape_fusion_pass.h diff --git a/src/ge/graph/passes/transpose_transdata_pass.cc b/ge/graph/passes/transpose_transdata_pass.cc similarity index 100% rename from src/ge/graph/passes/transpose_transdata_pass.cc rename to ge/graph/passes/transpose_transdata_pass.cc diff --git a/src/ge/graph/passes/transpose_transdata_pass.h b/ge/graph/passes/transpose_transdata_pass.h similarity index 100% rename from src/ge/graph/passes/transpose_transdata_pass.h rename to ge/graph/passes/transpose_transdata_pass.h index 6b65e960..bf42f5de 100644 --- a/src/ge/graph/passes/transpose_transdata_pass.h +++ b/ge/graph/passes/transpose_transdata_pass.h @@ -23,7 +23,6 @@ namespace ge { class TransposeTransDataPass : public BaseNodePass { public: Status Run(NodePtr &node) override; - private: Status CheckOneInAndOneOutDataAnchor(NodePtr &node) const; Status RemoveTranspose(NodePtr &node); @@ -33,3 +32,4 @@ class TransposeTransDataPass : public BaseNodePass { }; } // namespace ge #endif // GE_GRAPH_PASSES_TRANSPOSE_TRANSDATA_PASS_H_ + diff --git a/src/ge/graph/passes/unused_args_clean_pass.cc b/ge/graph/passes/unused_args_clean_pass.cc similarity index 100% rename from src/ge/graph/passes/unused_args_clean_pass.cc rename to ge/graph/passes/unused_args_clean_pass.cc diff --git a/src/ge/graph/passes/unused_args_clean_pass.h b/ge/graph/passes/unused_args_clean_pass.h similarity index 100% rename from src/ge/graph/passes/unused_args_clean_pass.h rename to ge/graph/passes/unused_args_clean_pass.h diff --git a/src/ge/graph/passes/unused_const_pass.cc b/ge/graph/passes/unused_const_pass.cc similarity index 100% rename from src/ge/graph/passes/unused_const_pass.cc rename to ge/graph/passes/unused_const_pass.cc diff --git a/src/ge/graph/passes/unused_const_pass.h b/ge/graph/passes/unused_const_pass.h similarity index 100% rename from src/ge/graph/passes/unused_const_pass.h rename to ge/graph/passes/unused_const_pass.h diff --git a/src/ge/graph/passes/unused_op_remove_pass.cc b/ge/graph/passes/unused_op_remove_pass.cc similarity index 100% rename from src/ge/graph/passes/unused_op_remove_pass.cc rename to ge/graph/passes/unused_op_remove_pass.cc diff --git a/src/ge/graph/passes/unused_op_remove_pass.h b/ge/graph/passes/unused_op_remove_pass.h similarity index 100% rename from src/ge/graph/passes/unused_op_remove_pass.h rename to ge/graph/passes/unused_op_remove_pass.h diff --git a/src/ge/graph/passes/var_is_initialized_op_pass.cc b/ge/graph/passes/var_is_initialized_op_pass.cc similarity index 100% rename from src/ge/graph/passes/var_is_initialized_op_pass.cc rename to ge/graph/passes/var_is_initialized_op_pass.cc diff --git a/src/ge/graph/passes/var_is_initialized_op_pass.h b/ge/graph/passes/var_is_initialized_op_pass.h similarity index 100% rename from src/ge/graph/passes/var_is_initialized_op_pass.h rename to ge/graph/passes/var_is_initialized_op_pass.h diff --git a/src/ge/graph/passes/variable_format_pass.cc b/ge/graph/passes/variable_format_pass.cc similarity index 100% rename from src/ge/graph/passes/variable_format_pass.cc rename to ge/graph/passes/variable_format_pass.cc diff --git a/src/ge/graph/passes/variable_format_pass.h b/ge/graph/passes/variable_format_pass.h similarity index 100% rename from src/ge/graph/passes/variable_format_pass.h rename to ge/graph/passes/variable_format_pass.h diff --git a/src/ge/graph/passes/variable_op_pass.cc b/ge/graph/passes/variable_op_pass.cc similarity index 100% rename from src/ge/graph/passes/variable_op_pass.cc rename to ge/graph/passes/variable_op_pass.cc diff --git a/src/ge/graph/passes/variable_op_pass.h b/ge/graph/passes/variable_op_pass.h similarity index 100% rename from src/ge/graph/passes/variable_op_pass.h rename to ge/graph/passes/variable_op_pass.h diff --git a/src/ge/graph/passes/variable_prepare_op_pass.cc b/ge/graph/passes/variable_prepare_op_pass.cc similarity index 100% rename from src/ge/graph/passes/variable_prepare_op_pass.cc rename to ge/graph/passes/variable_prepare_op_pass.cc diff --git a/src/ge/graph/passes/variable_prepare_op_pass.h b/ge/graph/passes/variable_prepare_op_pass.h similarity index 100% rename from src/ge/graph/passes/variable_prepare_op_pass.h rename to ge/graph/passes/variable_prepare_op_pass.h diff --git a/src/ge/graph/passes/variable_ref_delete_op_pass.cc b/ge/graph/passes/variable_ref_delete_op_pass.cc similarity index 100% rename from src/ge/graph/passes/variable_ref_delete_op_pass.cc rename to ge/graph/passes/variable_ref_delete_op_pass.cc diff --git a/src/ge/graph/passes/variable_ref_delete_op_pass.h b/ge/graph/passes/variable_ref_delete_op_pass.h similarity index 100% rename from src/ge/graph/passes/variable_ref_delete_op_pass.h rename to ge/graph/passes/variable_ref_delete_op_pass.h diff --git a/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc b/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc similarity index 100% rename from src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc rename to ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc diff --git a/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.h b/ge/graph/passes/variable_ref_useless_control_out_delete_pass.h similarity index 100% rename from src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.h rename to ge/graph/passes/variable_ref_useless_control_out_delete_pass.h diff --git a/src/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc similarity index 100% rename from src/ge/graph/preprocess/graph_preprocess.cc rename to ge/graph/preprocess/graph_preprocess.cc diff --git a/src/ge/graph/preprocess/graph_preprocess.h b/ge/graph/preprocess/graph_preprocess.h similarity index 100% rename from src/ge/graph/preprocess/graph_preprocess.h rename to ge/graph/preprocess/graph_preprocess.h diff --git a/src/ge/graph/preprocess/insert_op/base_insert_op.h b/ge/graph/preprocess/insert_op/base_insert_op.h similarity index 100% rename from src/ge/graph/preprocess/insert_op/base_insert_op.h rename to ge/graph/preprocess/insert_op/base_insert_op.h diff --git a/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc b/ge/graph/preprocess/insert_op/ge_aipp_op.cc similarity index 100% rename from src/ge/graph/preprocess/insert_op/ge_aipp_op.cc rename to ge/graph/preprocess/insert_op/ge_aipp_op.cc diff --git a/src/ge/graph/preprocess/insert_op/ge_aipp_op.h b/ge/graph/preprocess/insert_op/ge_aipp_op.h similarity index 100% rename from src/ge/graph/preprocess/insert_op/ge_aipp_op.h rename to ge/graph/preprocess/insert_op/ge_aipp_op.h diff --git a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc b/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc similarity index 100% rename from src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc rename to ge/graph/preprocess/insert_op/util_insert_aipp_op.cc diff --git a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.h b/ge/graph/preprocess/insert_op/util_insert_aipp_op.h similarity index 100% rename from src/ge/graph/preprocess/insert_op/util_insert_aipp_op.h rename to ge/graph/preprocess/insert_op/util_insert_aipp_op.h diff --git a/src/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc similarity index 100% rename from src/ge/graph/preprocess/multi_batch_copy_graph.cc rename to ge/graph/preprocess/multi_batch_copy_graph.cc diff --git a/src/ge/graph/preprocess/multi_batch_copy_graph.h b/ge/graph/preprocess/multi_batch_copy_graph.h similarity index 100% rename from src/ge/graph/preprocess/multi_batch_copy_graph.h rename to ge/graph/preprocess/multi_batch_copy_graph.h diff --git a/src/ge/graph/preprocess/multi_batch_options.cc b/ge/graph/preprocess/multi_batch_options.cc similarity index 100% rename from src/ge/graph/preprocess/multi_batch_options.cc rename to ge/graph/preprocess/multi_batch_options.cc diff --git a/src/ge/graph/preprocess/multi_batch_options.h b/ge/graph/preprocess/multi_batch_options.h similarity index 100% rename from src/ge/graph/preprocess/multi_batch_options.h rename to ge/graph/preprocess/multi_batch_options.h diff --git a/ge/host_cpu_engine/CMakeLists.txt b/ge/host_cpu_engine/CMakeLists.txt new file mode 100644 index 00000000..b9a23009 --- /dev/null +++ b/ge/host_cpu_engine/CMakeLists.txt @@ -0,0 +1,109 @@ +set(PROTO_LIST + "${METADEF_DIR}/proto/task.proto" +) + +protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) + +set(SRC_LIST + "engine/host_cpu_engine.cc" + "ops_kernel_store/host_cpu_ops_kernel_info.cc" + "ops_kernel_store/op/op_factory.cc" + "ops_kernel_store/op/host_op.cc" +) + +set(CPU_OPS_KERNEL_LIST + "ops_kernel_store/host_cpu_ops_kernel_builder.cc" +) + +############ libhost_cpu_engine.so ############ +add_library(host_cpu_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) + +target_compile_options(host_cpu_engine PRIVATE + -Werror +) + +target_include_directories(host_cpu_engine PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc +) + +target_link_libraries(host_cpu_engine PRIVATE + $ + -Wl,--no-as-needed + protobuf + c_sec + graph + register + slog + runtime + -Wl,--as-needed +) + +############ atcstub/libhost_cpu_engine.so ############ +add_library(atc_host_cpu_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) + +target_compile_options(atc_host_cpu_engine PRIVATE + -Werror +) + +target_compile_definitions(atc_host_cpu_engine PRIVATE + COMPILE_OMG_PACKAGE +) + +target_include_directories(atc_host_cpu_engine PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc +) + +target_link_libraries(atc_host_cpu_engine PRIVATE + $ + -Wl,--no-as-needed + protobuf + c_sec + graph + register + slog + runtime_compile + -Wl,--as-needed +) + +set_target_properties(atc_host_cpu_engine PROPERTIES + OUTPUT_NAME host_cpu_engine + LIBRARY_OUTPUT_DIRECTORY atclib +) + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS host_cpu_engine OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} +) + +install(TARGETS atc_host_cpu_engine OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/atclib +) diff --git a/src/ge/host_cpu_engine/common/constant/constant.h b/ge/host_cpu_engine/common/constant/constant.h similarity index 100% rename from src/ge/host_cpu_engine/common/constant/constant.h rename to ge/host_cpu_engine/common/constant/constant.h diff --git a/src/ge/host_cpu_engine/engine/host_cpu_engine.cc b/ge/host_cpu_engine/engine/host_cpu_engine.cc similarity index 100% rename from src/ge/host_cpu_engine/engine/host_cpu_engine.cc rename to ge/host_cpu_engine/engine/host_cpu_engine.cc diff --git a/src/ge/host_cpu_engine/engine/host_cpu_engine.h b/ge/host_cpu_engine/engine/host_cpu_engine.h similarity index 100% rename from src/ge/host_cpu_engine/engine/host_cpu_engine.h rename to ge/host_cpu_engine/engine/host_cpu_engine.h diff --git a/src/ge/host_cpu_engine/module.mk b/ge/host_cpu_engine/module.mk similarity index 100% rename from src/ge/host_cpu_engine/module.mk rename to ge/host_cpu_engine/module.mk diff --git a/src/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc b/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc similarity index 100% rename from src/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc rename to ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc diff --git a/src/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h b/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h similarity index 100% rename from src/ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h rename to ge/host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h diff --git a/src/ge/host_cpu_engine/ops_kernel_store/op/host_op.cc b/ge/host_cpu_engine/ops_kernel_store/op/host_op.cc similarity index 100% rename from src/ge/host_cpu_engine/ops_kernel_store/op/host_op.cc rename to ge/host_cpu_engine/ops_kernel_store/op/host_op.cc diff --git a/src/ge/host_cpu_engine/ops_kernel_store/op/host_op.h b/ge/host_cpu_engine/ops_kernel_store/op/host_op.h similarity index 100% rename from src/ge/host_cpu_engine/ops_kernel_store/op/host_op.h rename to ge/host_cpu_engine/ops_kernel_store/op/host_op.h diff --git a/src/ge/host_cpu_engine/ops_kernel_store/op/op.h b/ge/host_cpu_engine/ops_kernel_store/op/op.h similarity index 100% rename from src/ge/host_cpu_engine/ops_kernel_store/op/op.h rename to ge/host_cpu_engine/ops_kernel_store/op/op.h diff --git a/src/ge/host_cpu_engine/ops_kernel_store/op/op_factory.cc b/ge/host_cpu_engine/ops_kernel_store/op/op_factory.cc similarity index 100% rename from src/ge/host_cpu_engine/ops_kernel_store/op/op_factory.cc rename to ge/host_cpu_engine/ops_kernel_store/op/op_factory.cc diff --git a/src/ge/host_cpu_engine/ops_kernel_store/op/op_factory.h b/ge/host_cpu_engine/ops_kernel_store/op/op_factory.h similarity index 100% rename from src/ge/host_cpu_engine/ops_kernel_store/op/op_factory.h rename to ge/host_cpu_engine/ops_kernel_store/op/op_factory.h diff --git a/ge/host_cpu_engine/proto/task.proto b/ge/host_cpu_engine/proto/task.proto new file mode 100644 index 00000000..36ae4847 --- /dev/null +++ b/ge/host_cpu_engine/proto/task.proto @@ -0,0 +1 @@ +../../proto/task.proto \ No newline at end of file diff --git a/src/ge/host_kernels/add_kernel.cc b/ge/host_kernels/add_kernel.cc similarity index 100% rename from src/ge/host_kernels/add_kernel.cc rename to ge/host_kernels/add_kernel.cc diff --git a/src/ge/host_kernels/add_kernel.h b/ge/host_kernels/add_kernel.h similarity index 100% rename from src/ge/host_kernels/add_kernel.h rename to ge/host_kernels/add_kernel.h diff --git a/src/ge/host_kernels/broadcast_args_kernel.cc b/ge/host_kernels/broadcast_args_kernel.cc similarity index 100% rename from src/ge/host_kernels/broadcast_args_kernel.cc rename to ge/host_kernels/broadcast_args_kernel.cc diff --git a/src/ge/host_kernels/broadcast_args_kernel.h b/ge/host_kernels/broadcast_args_kernel.h similarity index 100% rename from src/ge/host_kernels/broadcast_args_kernel.h rename to ge/host_kernels/broadcast_args_kernel.h diff --git a/src/ge/host_kernels/broadcast_gradient_args_kernel.cc b/ge/host_kernels/broadcast_gradient_args_kernel.cc similarity index 100% rename from src/ge/host_kernels/broadcast_gradient_args_kernel.cc rename to ge/host_kernels/broadcast_gradient_args_kernel.cc diff --git a/src/ge/host_kernels/broadcast_gradient_args_kernel.h b/ge/host_kernels/broadcast_gradient_args_kernel.h similarity index 100% rename from src/ge/host_kernels/broadcast_gradient_args_kernel.h rename to ge/host_kernels/broadcast_gradient_args_kernel.h diff --git a/src/ge/host_kernels/cast_kernel.cc b/ge/host_kernels/cast_kernel.cc similarity index 100% rename from src/ge/host_kernels/cast_kernel.cc rename to ge/host_kernels/cast_kernel.cc diff --git a/src/ge/host_kernels/cast_kernel.h b/ge/host_kernels/cast_kernel.h similarity index 100% rename from src/ge/host_kernels/cast_kernel.h rename to ge/host_kernels/cast_kernel.h diff --git a/src/ge/host_kernels/concat_offset_kernel.cc b/ge/host_kernels/concat_offset_kernel.cc similarity index 100% rename from src/ge/host_kernels/concat_offset_kernel.cc rename to ge/host_kernels/concat_offset_kernel.cc diff --git a/src/ge/host_kernels/concat_offset_kernel.h b/ge/host_kernels/concat_offset_kernel.h similarity index 100% rename from src/ge/host_kernels/concat_offset_kernel.h rename to ge/host_kernels/concat_offset_kernel.h diff --git a/src/ge/host_kernels/concat_v2_kernel.cc b/ge/host_kernels/concat_v2_kernel.cc similarity index 100% rename from src/ge/host_kernels/concat_v2_kernel.cc rename to ge/host_kernels/concat_v2_kernel.cc diff --git a/src/ge/host_kernels/concat_v2_kernel.h b/ge/host_kernels/concat_v2_kernel.h similarity index 100% rename from src/ge/host_kernels/concat_v2_kernel.h rename to ge/host_kernels/concat_v2_kernel.h diff --git a/src/ge/host_kernels/dynamic_stitch_kernel.cc b/ge/host_kernels/dynamic_stitch_kernel.cc similarity index 100% rename from src/ge/host_kernels/dynamic_stitch_kernel.cc rename to ge/host_kernels/dynamic_stitch_kernel.cc diff --git a/src/ge/host_kernels/dynamic_stitch_kernel.h b/ge/host_kernels/dynamic_stitch_kernel.h similarity index 100% rename from src/ge/host_kernels/dynamic_stitch_kernel.h rename to ge/host_kernels/dynamic_stitch_kernel.h diff --git a/src/ge/host_kernels/empty_kernel.cc b/ge/host_kernels/empty_kernel.cc similarity index 100% rename from src/ge/host_kernels/empty_kernel.cc rename to ge/host_kernels/empty_kernel.cc diff --git a/src/ge/host_kernels/empty_kernel.h b/ge/host_kernels/empty_kernel.h similarity index 100% rename from src/ge/host_kernels/empty_kernel.h rename to ge/host_kernels/empty_kernel.h diff --git a/src/ge/host_kernels/expanddims_kernel.cc b/ge/host_kernels/expanddims_kernel.cc similarity index 100% rename from src/ge/host_kernels/expanddims_kernel.cc rename to ge/host_kernels/expanddims_kernel.cc diff --git a/src/ge/host_kernels/expanddims_kernel.h b/ge/host_kernels/expanddims_kernel.h similarity index 100% rename from src/ge/host_kernels/expanddims_kernel.h rename to ge/host_kernels/expanddims_kernel.h diff --git a/src/ge/host_kernels/fill_kernel.cc b/ge/host_kernels/fill_kernel.cc similarity index 100% rename from src/ge/host_kernels/fill_kernel.cc rename to ge/host_kernels/fill_kernel.cc diff --git a/src/ge/host_kernels/fill_kernel.h b/ge/host_kernels/fill_kernel.h similarity index 100% rename from src/ge/host_kernels/fill_kernel.h rename to ge/host_kernels/fill_kernel.h diff --git a/src/ge/host_kernels/floordiv_kernel.cc b/ge/host_kernels/floordiv_kernel.cc similarity index 100% rename from src/ge/host_kernels/floordiv_kernel.cc rename to ge/host_kernels/floordiv_kernel.cc diff --git a/src/ge/host_kernels/floordiv_kernel.h b/ge/host_kernels/floordiv_kernel.h similarity index 100% rename from src/ge/host_kernels/floordiv_kernel.h rename to ge/host_kernels/floordiv_kernel.h diff --git a/src/ge/host_kernels/floormod_kernel.cc b/ge/host_kernels/floormod_kernel.cc similarity index 100% rename from src/ge/host_kernels/floormod_kernel.cc rename to ge/host_kernels/floormod_kernel.cc diff --git a/src/ge/host_kernels/floormod_kernel.h b/ge/host_kernels/floormod_kernel.h similarity index 100% rename from src/ge/host_kernels/floormod_kernel.h rename to ge/host_kernels/floormod_kernel.h diff --git a/src/ge/host_kernels/gather_v2_kernel.cc b/ge/host_kernels/gather_v2_kernel.cc similarity index 100% rename from src/ge/host_kernels/gather_v2_kernel.cc rename to ge/host_kernels/gather_v2_kernel.cc diff --git a/src/ge/host_kernels/gather_v2_kernel.h b/ge/host_kernels/gather_v2_kernel.h similarity index 100% rename from src/ge/host_kernels/gather_v2_kernel.h rename to ge/host_kernels/gather_v2_kernel.h diff --git a/src/ge/host_kernels/greater_kernel.cc b/ge/host_kernels/greater_kernel.cc similarity index 100% rename from src/ge/host_kernels/greater_kernel.cc rename to ge/host_kernels/greater_kernel.cc diff --git a/src/ge/host_kernels/greater_kernel.h b/ge/host_kernels/greater_kernel.h similarity index 100% rename from src/ge/host_kernels/greater_kernel.h rename to ge/host_kernels/greater_kernel.h diff --git a/src/ge/host_kernels/identity_kernel.cc b/ge/host_kernels/identity_kernel.cc similarity index 100% rename from src/ge/host_kernels/identity_kernel.cc rename to ge/host_kernels/identity_kernel.cc diff --git a/src/ge/host_kernels/identity_kernel.h b/ge/host_kernels/identity_kernel.h similarity index 100% rename from src/ge/host_kernels/identity_kernel.h rename to ge/host_kernels/identity_kernel.h diff --git a/src/ge/host_kernels/kernel_utils.cc b/ge/host_kernels/kernel_utils.cc similarity index 100% rename from src/ge/host_kernels/kernel_utils.cc rename to ge/host_kernels/kernel_utils.cc diff --git a/src/ge/host_kernels/kernel_utils.h b/ge/host_kernels/kernel_utils.h similarity index 100% rename from src/ge/host_kernels/kernel_utils.h rename to ge/host_kernels/kernel_utils.h diff --git a/src/ge/host_kernels/maximum_kernel.cc b/ge/host_kernels/maximum_kernel.cc similarity index 100% rename from src/ge/host_kernels/maximum_kernel.cc rename to ge/host_kernels/maximum_kernel.cc diff --git a/src/ge/host_kernels/maximum_kernel.h b/ge/host_kernels/maximum_kernel.h similarity index 100% rename from src/ge/host_kernels/maximum_kernel.h rename to ge/host_kernels/maximum_kernel.h diff --git a/src/ge/host_kernels/mul_kernel.cc b/ge/host_kernels/mul_kernel.cc similarity index 100% rename from src/ge/host_kernels/mul_kernel.cc rename to ge/host_kernels/mul_kernel.cc diff --git a/src/ge/host_kernels/mul_kernel.h b/ge/host_kernels/mul_kernel.h similarity index 100% rename from src/ge/host_kernels/mul_kernel.h rename to ge/host_kernels/mul_kernel.h diff --git a/src/ge/host_kernels/pack_kernel.cc b/ge/host_kernels/pack_kernel.cc similarity index 100% rename from src/ge/host_kernels/pack_kernel.cc rename to ge/host_kernels/pack_kernel.cc diff --git a/src/ge/host_kernels/pack_kernel.h b/ge/host_kernels/pack_kernel.h similarity index 100% rename from src/ge/host_kernels/pack_kernel.h rename to ge/host_kernels/pack_kernel.h diff --git a/src/ge/host_kernels/permute_kernel.cc b/ge/host_kernels/permute_kernel.cc similarity index 100% rename from src/ge/host_kernels/permute_kernel.cc rename to ge/host_kernels/permute_kernel.cc diff --git a/src/ge/host_kernels/permute_kernel.h b/ge/host_kernels/permute_kernel.h similarity index 100% rename from src/ge/host_kernels/permute_kernel.h rename to ge/host_kernels/permute_kernel.h diff --git a/src/ge/host_kernels/range_kernel.cc b/ge/host_kernels/range_kernel.cc similarity index 100% rename from src/ge/host_kernels/range_kernel.cc rename to ge/host_kernels/range_kernel.cc diff --git a/src/ge/host_kernels/range_kernel.h b/ge/host_kernels/range_kernel.h similarity index 100% rename from src/ge/host_kernels/range_kernel.h rename to ge/host_kernels/range_kernel.h diff --git a/src/ge/host_kernels/rank_kernel.cc b/ge/host_kernels/rank_kernel.cc similarity index 100% rename from src/ge/host_kernels/rank_kernel.cc rename to ge/host_kernels/rank_kernel.cc diff --git a/src/ge/host_kernels/rank_kernel.h b/ge/host_kernels/rank_kernel.h similarity index 100% rename from src/ge/host_kernels/rank_kernel.h rename to ge/host_kernels/rank_kernel.h diff --git a/src/ge/host_kernels/reduce_prod_kernel.cc b/ge/host_kernels/reduce_prod_kernel.cc similarity index 100% rename from src/ge/host_kernels/reduce_prod_kernel.cc rename to ge/host_kernels/reduce_prod_kernel.cc diff --git a/src/ge/host_kernels/reduce_prod_kernel.h b/ge/host_kernels/reduce_prod_kernel.h similarity index 100% rename from src/ge/host_kernels/reduce_prod_kernel.h rename to ge/host_kernels/reduce_prod_kernel.h diff --git a/src/ge/host_kernels/reformat_kernel.cc b/ge/host_kernels/reformat_kernel.cc similarity index 100% rename from src/ge/host_kernels/reformat_kernel.cc rename to ge/host_kernels/reformat_kernel.cc diff --git a/src/ge/host_kernels/reformat_kernel.h b/ge/host_kernels/reformat_kernel.h similarity index 100% rename from src/ge/host_kernels/reformat_kernel.h rename to ge/host_kernels/reformat_kernel.h diff --git a/src/ge/host_kernels/reshape_kernel.cc b/ge/host_kernels/reshape_kernel.cc similarity index 100% rename from src/ge/host_kernels/reshape_kernel.cc rename to ge/host_kernels/reshape_kernel.cc diff --git a/src/ge/host_kernels/reshape_kernel.h b/ge/host_kernels/reshape_kernel.h similarity index 100% rename from src/ge/host_kernels/reshape_kernel.h rename to ge/host_kernels/reshape_kernel.h diff --git a/src/ge/host_kernels/rsqrt_kernel.cc b/ge/host_kernels/rsqrt_kernel.cc similarity index 100% rename from src/ge/host_kernels/rsqrt_kernel.cc rename to ge/host_kernels/rsqrt_kernel.cc diff --git a/src/ge/host_kernels/rsqrt_kernel.h b/ge/host_kernels/rsqrt_kernel.h similarity index 100% rename from src/ge/host_kernels/rsqrt_kernel.h rename to ge/host_kernels/rsqrt_kernel.h diff --git a/src/ge/host_kernels/shape_kernel.cc b/ge/host_kernels/shape_kernel.cc similarity index 100% rename from src/ge/host_kernels/shape_kernel.cc rename to ge/host_kernels/shape_kernel.cc diff --git a/src/ge/host_kernels/shape_kernel.h b/ge/host_kernels/shape_kernel.h similarity index 100% rename from src/ge/host_kernels/shape_kernel.h rename to ge/host_kernels/shape_kernel.h diff --git a/src/ge/host_kernels/shape_n_kernel.cc b/ge/host_kernels/shape_n_kernel.cc similarity index 100% rename from src/ge/host_kernels/shape_n_kernel.cc rename to ge/host_kernels/shape_n_kernel.cc diff --git a/src/ge/host_kernels/shape_n_kernel.h b/ge/host_kernels/shape_n_kernel.h similarity index 100% rename from src/ge/host_kernels/shape_n_kernel.h rename to ge/host_kernels/shape_n_kernel.h diff --git a/src/ge/host_kernels/size_kernel.cc b/ge/host_kernels/size_kernel.cc similarity index 100% rename from src/ge/host_kernels/size_kernel.cc rename to ge/host_kernels/size_kernel.cc diff --git a/src/ge/host_kernels/size_kernel.h b/ge/host_kernels/size_kernel.h similarity index 100% rename from src/ge/host_kernels/size_kernel.h rename to ge/host_kernels/size_kernel.h diff --git a/src/ge/host_kernels/slice_d_kernel.cc b/ge/host_kernels/slice_d_kernel.cc similarity index 100% rename from src/ge/host_kernels/slice_d_kernel.cc rename to ge/host_kernels/slice_d_kernel.cc diff --git a/src/ge/host_kernels/slice_d_kernel.h b/ge/host_kernels/slice_d_kernel.h similarity index 100% rename from src/ge/host_kernels/slice_d_kernel.h rename to ge/host_kernels/slice_d_kernel.h diff --git a/src/ge/host_kernels/slice_kernel.cc b/ge/host_kernels/slice_kernel.cc similarity index 100% rename from src/ge/host_kernels/slice_kernel.cc rename to ge/host_kernels/slice_kernel.cc diff --git a/src/ge/host_kernels/slice_kernel.h b/ge/host_kernels/slice_kernel.h similarity index 100% rename from src/ge/host_kernels/slice_kernel.h rename to ge/host_kernels/slice_kernel.h diff --git a/src/ge/host_kernels/squeeze_kernel.cc b/ge/host_kernels/squeeze_kernel.cc similarity index 100% rename from src/ge/host_kernels/squeeze_kernel.cc rename to ge/host_kernels/squeeze_kernel.cc diff --git a/src/ge/host_kernels/squeeze_kernel.h b/ge/host_kernels/squeeze_kernel.h similarity index 100% rename from src/ge/host_kernels/squeeze_kernel.h rename to ge/host_kernels/squeeze_kernel.h diff --git a/src/ge/host_kernels/ssd_prior_box_kernel.cc b/ge/host_kernels/ssd_prior_box_kernel.cc similarity index 100% rename from src/ge/host_kernels/ssd_prior_box_kernel.cc rename to ge/host_kernels/ssd_prior_box_kernel.cc diff --git a/src/ge/host_kernels/ssd_prior_box_kernel.h b/ge/host_kernels/ssd_prior_box_kernel.h similarity index 100% rename from src/ge/host_kernels/ssd_prior_box_kernel.h rename to ge/host_kernels/ssd_prior_box_kernel.h diff --git a/src/ge/host_kernels/strided_slice_kernel.cc b/ge/host_kernels/strided_slice_kernel.cc similarity index 100% rename from src/ge/host_kernels/strided_slice_kernel.cc rename to ge/host_kernels/strided_slice_kernel.cc diff --git a/src/ge/host_kernels/strided_slice_kernel.h b/ge/host_kernels/strided_slice_kernel.h similarity index 100% rename from src/ge/host_kernels/strided_slice_kernel.h rename to ge/host_kernels/strided_slice_kernel.h diff --git a/src/ge/host_kernels/sub_kernel.cc b/ge/host_kernels/sub_kernel.cc similarity index 100% rename from src/ge/host_kernels/sub_kernel.cc rename to ge/host_kernels/sub_kernel.cc diff --git a/src/ge/host_kernels/sub_kernel.h b/ge/host_kernels/sub_kernel.h similarity index 100% rename from src/ge/host_kernels/sub_kernel.h rename to ge/host_kernels/sub_kernel.h diff --git a/src/ge/host_kernels/transdata_kernel.cc b/ge/host_kernels/transdata_kernel.cc similarity index 100% rename from src/ge/host_kernels/transdata_kernel.cc rename to ge/host_kernels/transdata_kernel.cc diff --git a/src/ge/host_kernels/transdata_kernel.h b/ge/host_kernels/transdata_kernel.h similarity index 100% rename from src/ge/host_kernels/transdata_kernel.h rename to ge/host_kernels/transdata_kernel.h diff --git a/src/ge/host_kernels/transpose_kernel.cc b/ge/host_kernels/transpose_kernel.cc similarity index 100% rename from src/ge/host_kernels/transpose_kernel.cc rename to ge/host_kernels/transpose_kernel.cc diff --git a/src/ge/host_kernels/transpose_kernel.h b/ge/host_kernels/transpose_kernel.h similarity index 100% rename from src/ge/host_kernels/transpose_kernel.h rename to ge/host_kernels/transpose_kernel.h diff --git a/src/ge/host_kernels/unpack_kernel.cc b/ge/host_kernels/unpack_kernel.cc similarity index 100% rename from src/ge/host_kernels/unpack_kernel.cc rename to ge/host_kernels/unpack_kernel.cc diff --git a/src/ge/host_kernels/unpack_kernel.h b/ge/host_kernels/unpack_kernel.h similarity index 100% rename from src/ge/host_kernels/unpack_kernel.h rename to ge/host_kernels/unpack_kernel.h diff --git a/src/ge/host_kernels/unsqueeze_kernel.cc b/ge/host_kernels/unsqueeze_kernel.cc similarity index 100% rename from src/ge/host_kernels/unsqueeze_kernel.cc rename to ge/host_kernels/unsqueeze_kernel.cc diff --git a/src/ge/host_kernels/unsqueeze_kernel.h b/ge/host_kernels/unsqueeze_kernel.h similarity index 100% rename from src/ge/host_kernels/unsqueeze_kernel.h rename to ge/host_kernels/unsqueeze_kernel.h diff --git a/src/ge/hybrid/common/npu_memory_allocator.cc b/ge/hybrid/common/npu_memory_allocator.cc similarity index 100% rename from src/ge/hybrid/common/npu_memory_allocator.cc rename to ge/hybrid/common/npu_memory_allocator.cc diff --git a/src/ge/hybrid/common/npu_memory_allocator.h b/ge/hybrid/common/npu_memory_allocator.h similarity index 100% rename from src/ge/hybrid/common/npu_memory_allocator.h rename to ge/hybrid/common/npu_memory_allocator.h diff --git a/src/ge/hybrid/common/tensor_value.cc b/ge/hybrid/common/tensor_value.cc similarity index 100% rename from src/ge/hybrid/common/tensor_value.cc rename to ge/hybrid/common/tensor_value.cc diff --git a/src/ge/hybrid/common/tensor_value.h b/ge/hybrid/common/tensor_value.h similarity index 100% rename from src/ge/hybrid/common/tensor_value.h rename to ge/hybrid/common/tensor_value.h diff --git a/src/ge/hybrid/executor/hybrid_execution_context.cc b/ge/hybrid/executor/hybrid_execution_context.cc similarity index 100% rename from src/ge/hybrid/executor/hybrid_execution_context.cc rename to ge/hybrid/executor/hybrid_execution_context.cc diff --git a/src/ge/hybrid/executor/hybrid_execution_context.h b/ge/hybrid/executor/hybrid_execution_context.h similarity index 100% rename from src/ge/hybrid/executor/hybrid_execution_context.h rename to ge/hybrid/executor/hybrid_execution_context.h diff --git a/src/ge/hybrid/executor/hybrid_model_async_executor.cc b/ge/hybrid/executor/hybrid_model_async_executor.cc similarity index 100% rename from src/ge/hybrid/executor/hybrid_model_async_executor.cc rename to ge/hybrid/executor/hybrid_model_async_executor.cc diff --git a/src/ge/hybrid/executor/hybrid_model_async_executor.h b/ge/hybrid/executor/hybrid_model_async_executor.h similarity index 100% rename from src/ge/hybrid/executor/hybrid_model_async_executor.h rename to ge/hybrid/executor/hybrid_model_async_executor.h diff --git a/src/ge/hybrid/executor/hybrid_model_executor.cc b/ge/hybrid/executor/hybrid_model_executor.cc similarity index 100% rename from src/ge/hybrid/executor/hybrid_model_executor.cc rename to ge/hybrid/executor/hybrid_model_executor.cc diff --git a/src/ge/hybrid/executor/hybrid_model_executor.h b/ge/hybrid/executor/hybrid_model_executor.h similarity index 100% rename from src/ge/hybrid/executor/hybrid_model_executor.h rename to ge/hybrid/executor/hybrid_model_executor.h diff --git a/src/ge/hybrid/executor/hybrid_profiler.cc b/ge/hybrid/executor/hybrid_profiler.cc similarity index 100% rename from src/ge/hybrid/executor/hybrid_profiler.cc rename to ge/hybrid/executor/hybrid_profiler.cc diff --git a/src/ge/hybrid/executor/hybrid_profiler.h b/ge/hybrid/executor/hybrid_profiler.h similarity index 100% rename from src/ge/hybrid/executor/hybrid_profiler.h rename to ge/hybrid/executor/hybrid_profiler.h diff --git a/src/ge/hybrid/executor/node_done_manager.cc b/ge/hybrid/executor/node_done_manager.cc similarity index 100% rename from src/ge/hybrid/executor/node_done_manager.cc rename to ge/hybrid/executor/node_done_manager.cc diff --git a/src/ge/hybrid/executor/node_done_manager.h b/ge/hybrid/executor/node_done_manager.h similarity index 100% rename from src/ge/hybrid/executor/node_done_manager.h rename to ge/hybrid/executor/node_done_manager.h diff --git a/src/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc similarity index 100% rename from src/ge/hybrid/executor/node_state.cc rename to ge/hybrid/executor/node_state.cc diff --git a/src/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h similarity index 100% rename from src/ge/hybrid/executor/node_state.h rename to ge/hybrid/executor/node_state.h diff --git a/src/ge/hybrid/executor/rt_callback_manager.cc b/ge/hybrid/executor/rt_callback_manager.cc similarity index 100% rename from src/ge/hybrid/executor/rt_callback_manager.cc rename to ge/hybrid/executor/rt_callback_manager.cc diff --git a/src/ge/hybrid/executor/rt_callback_manager.h b/ge/hybrid/executor/rt_callback_manager.h similarity index 100% rename from src/ge/hybrid/executor/rt_callback_manager.h rename to ge/hybrid/executor/rt_callback_manager.h diff --git a/src/ge/hybrid/executor/subgraph_context.cc b/ge/hybrid/executor/subgraph_context.cc similarity index 100% rename from src/ge/hybrid/executor/subgraph_context.cc rename to ge/hybrid/executor/subgraph_context.cc diff --git a/src/ge/hybrid/executor/subgraph_context.h b/ge/hybrid/executor/subgraph_context.h similarity index 100% rename from src/ge/hybrid/executor/subgraph_context.h rename to ge/hybrid/executor/subgraph_context.h diff --git a/src/ge/hybrid/executor/subgraph_executor.cc b/ge/hybrid/executor/subgraph_executor.cc similarity index 100% rename from src/ge/hybrid/executor/subgraph_executor.cc rename to ge/hybrid/executor/subgraph_executor.cc diff --git a/src/ge/hybrid/executor/subgraph_executor.h b/ge/hybrid/executor/subgraph_executor.h similarity index 100% rename from src/ge/hybrid/executor/subgraph_executor.h rename to ge/hybrid/executor/subgraph_executor.h diff --git a/src/ge/hybrid/executor/worker/execution_engine.cc b/ge/hybrid/executor/worker/execution_engine.cc similarity index 100% rename from src/ge/hybrid/executor/worker/execution_engine.cc rename to ge/hybrid/executor/worker/execution_engine.cc diff --git a/src/ge/hybrid/executor/worker/execution_engine.h b/ge/hybrid/executor/worker/execution_engine.h similarity index 100% rename from src/ge/hybrid/executor/worker/execution_engine.h rename to ge/hybrid/executor/worker/execution_engine.h diff --git a/src/ge/hybrid/executor/worker/shape_inference_engine.cc b/ge/hybrid/executor/worker/shape_inference_engine.cc similarity index 100% rename from src/ge/hybrid/executor/worker/shape_inference_engine.cc rename to ge/hybrid/executor/worker/shape_inference_engine.cc diff --git a/src/ge/hybrid/executor/worker/shape_inference_engine.h b/ge/hybrid/executor/worker/shape_inference_engine.h similarity index 100% rename from src/ge/hybrid/executor/worker/shape_inference_engine.h rename to ge/hybrid/executor/worker/shape_inference_engine.h diff --git a/src/ge/hybrid/executor/worker/task_compile_engine.cc b/ge/hybrid/executor/worker/task_compile_engine.cc similarity index 100% rename from src/ge/hybrid/executor/worker/task_compile_engine.cc rename to ge/hybrid/executor/worker/task_compile_engine.cc diff --git a/src/ge/hybrid/executor/worker/task_compile_engine.h b/ge/hybrid/executor/worker/task_compile_engine.h similarity index 100% rename from src/ge/hybrid/executor/worker/task_compile_engine.h rename to ge/hybrid/executor/worker/task_compile_engine.h diff --git a/src/ge/hybrid/hybrid_davinci_model.cc b/ge/hybrid/hybrid_davinci_model.cc similarity index 100% rename from src/ge/hybrid/hybrid_davinci_model.cc rename to ge/hybrid/hybrid_davinci_model.cc diff --git a/src/ge/hybrid/hybrid_davinci_model.h b/ge/hybrid/hybrid_davinci_model.h similarity index 100% rename from src/ge/hybrid/hybrid_davinci_model.h rename to ge/hybrid/hybrid_davinci_model.h diff --git a/src/ge/hybrid/hybrid_davinci_model_stub.cc b/ge/hybrid/hybrid_davinci_model_stub.cc similarity index 100% rename from src/ge/hybrid/hybrid_davinci_model_stub.cc rename to ge/hybrid/hybrid_davinci_model_stub.cc diff --git a/src/ge/hybrid/model/graph_item.cc b/ge/hybrid/model/graph_item.cc similarity index 100% rename from src/ge/hybrid/model/graph_item.cc rename to ge/hybrid/model/graph_item.cc diff --git a/src/ge/hybrid/model/graph_item.h b/ge/hybrid/model/graph_item.h similarity index 100% rename from src/ge/hybrid/model/graph_item.h rename to ge/hybrid/model/graph_item.h diff --git a/src/ge/hybrid/model/hybrid_model.cc b/ge/hybrid/model/hybrid_model.cc similarity index 100% rename from src/ge/hybrid/model/hybrid_model.cc rename to ge/hybrid/model/hybrid_model.cc diff --git a/src/ge/hybrid/model/hybrid_model.h b/ge/hybrid/model/hybrid_model.h similarity index 100% rename from src/ge/hybrid/model/hybrid_model.h rename to ge/hybrid/model/hybrid_model.h diff --git a/src/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc similarity index 100% rename from src/ge/hybrid/model/hybrid_model_builder.cc rename to ge/hybrid/model/hybrid_model_builder.cc diff --git a/src/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h similarity index 100% rename from src/ge/hybrid/model/hybrid_model_builder.h rename to ge/hybrid/model/hybrid_model_builder.h diff --git a/src/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc similarity index 100% rename from src/ge/hybrid/model/node_item.cc rename to ge/hybrid/model/node_item.cc diff --git a/src/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h similarity index 100% rename from src/ge/hybrid/model/node_item.h rename to ge/hybrid/model/node_item.h diff --git a/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc similarity index 100% rename from src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc rename to ge/hybrid/node_executor/aicore/aicore_node_executor.cc diff --git a/src/ge/hybrid/node_executor/aicore/aicore_node_executor.h b/ge/hybrid/node_executor/aicore/aicore_node_executor.h similarity index 100% rename from src/ge/hybrid/node_executor/aicore/aicore_node_executor.h rename to ge/hybrid/node_executor/aicore/aicore_node_executor.h diff --git a/src/ge/hybrid/node_executor/aicore/aicore_op_task.cc b/ge/hybrid/node_executor/aicore/aicore_op_task.cc similarity index 100% rename from src/ge/hybrid/node_executor/aicore/aicore_op_task.cc rename to ge/hybrid/node_executor/aicore/aicore_op_task.cc diff --git a/src/ge/hybrid/node_executor/aicore/aicore_op_task.h b/ge/hybrid/node_executor/aicore/aicore_op_task.h similarity index 100% rename from src/ge/hybrid/node_executor/aicore/aicore_op_task.h rename to ge/hybrid/node_executor/aicore/aicore_op_task.h diff --git a/src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc b/ge/hybrid/node_executor/aicore/aicore_task_builder.cc similarity index 100% rename from src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc rename to ge/hybrid/node_executor/aicore/aicore_task_builder.cc diff --git a/src/ge/hybrid/node_executor/aicore/aicore_task_builder.h b/ge/hybrid/node_executor/aicore/aicore_task_builder.h similarity index 100% rename from src/ge/hybrid/node_executor/aicore/aicore_task_builder.h rename to ge/hybrid/node_executor/aicore/aicore_task_builder.h diff --git a/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc b/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc similarity index 100% rename from src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc rename to ge/hybrid/node_executor/aicore/aicore_task_compiler.cc diff --git a/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.h b/ge/hybrid/node_executor/aicore/aicore_task_compiler.h similarity index 100% rename from src/ge/hybrid/node_executor/aicore/aicore_task_compiler.h rename to ge/hybrid/node_executor/aicore/aicore_task_compiler.h diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc similarity index 100% rename from src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc rename to ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h similarity index 100% rename from src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h rename to ge/hybrid/node_executor/aicpu/aicpu_ext_info.h diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc similarity index 100% rename from src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc rename to ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h similarity index 100% rename from src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h rename to ge/hybrid/node_executor/aicpu/aicpu_node_executor.h diff --git a/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc similarity index 100% rename from src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc rename to ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc diff --git a/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h similarity index 100% rename from src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h rename to ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h diff --git a/src/ge/hybrid/node_executor/controlop/control_op_executor.cc b/ge/hybrid/node_executor/controlop/control_op_executor.cc similarity index 100% rename from src/ge/hybrid/node_executor/controlop/control_op_executor.cc rename to ge/hybrid/node_executor/controlop/control_op_executor.cc diff --git a/src/ge/hybrid/node_executor/controlop/control_op_executor.h b/ge/hybrid/node_executor/controlop/control_op_executor.h similarity index 100% rename from src/ge/hybrid/node_executor/controlop/control_op_executor.h rename to ge/hybrid/node_executor/controlop/control_op_executor.h diff --git a/src/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc b/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc similarity index 100% rename from src/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc rename to ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc diff --git a/src/ge/hybrid/node_executor/ge_local/ge_local_node_executor.h b/ge/hybrid/node_executor/ge_local/ge_local_node_executor.h similarity index 100% rename from src/ge/hybrid/node_executor/ge_local/ge_local_node_executor.h rename to ge/hybrid/node_executor/ge_local/ge_local_node_executor.h diff --git a/src/ge/hybrid/node_executor/hccl/hccl_node_executor.cc b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc similarity index 100% rename from src/ge/hybrid/node_executor/hccl/hccl_node_executor.cc rename to ge/hybrid/node_executor/hccl/hccl_node_executor.cc diff --git a/src/ge/hybrid/node_executor/hccl/hccl_node_executor.h b/ge/hybrid/node_executor/hccl/hccl_node_executor.h similarity index 100% rename from src/ge/hybrid/node_executor/hccl/hccl_node_executor.h rename to ge/hybrid/node_executor/hccl/hccl_node_executor.h diff --git a/src/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc b/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc similarity index 100% rename from src/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc rename to ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc diff --git a/src/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.h b/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.h similarity index 100% rename from src/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.h rename to ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.h diff --git a/src/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc b/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc similarity index 100% rename from src/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc rename to ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc diff --git a/src/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.h b/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.h similarity index 100% rename from src/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.h rename to ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.h diff --git a/src/ge/hybrid/node_executor/host_cpu/kernel/kernel.h b/ge/hybrid/node_executor/host_cpu/kernel/kernel.h similarity index 100% rename from src/ge/hybrid/node_executor/host_cpu/kernel/kernel.h rename to ge/hybrid/node_executor/host_cpu/kernel/kernel.h diff --git a/src/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc b/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc similarity index 100% rename from src/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc rename to ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc diff --git a/src/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.h b/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.h similarity index 100% rename from src/ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.h rename to ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.h diff --git a/src/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc b/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc similarity index 100% rename from src/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc rename to ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc diff --git a/src/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h b/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h similarity index 100% rename from src/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h rename to ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h diff --git a/src/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc b/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc similarity index 100% rename from src/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc rename to ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc diff --git a/src/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.h b/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.h similarity index 100% rename from src/ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.h rename to ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.h diff --git a/src/ge/hybrid/node_executor/host_cpu/kernel_factory.cc b/ge/hybrid/node_executor/host_cpu/kernel_factory.cc similarity index 100% rename from src/ge/hybrid/node_executor/host_cpu/kernel_factory.cc rename to ge/hybrid/node_executor/host_cpu/kernel_factory.cc diff --git a/src/ge/hybrid/node_executor/host_cpu/kernel_factory.h b/ge/hybrid/node_executor/host_cpu/kernel_factory.h similarity index 100% rename from src/ge/hybrid/node_executor/host_cpu/kernel_factory.h rename to ge/hybrid/node_executor/host_cpu/kernel_factory.h diff --git a/src/ge/hybrid/node_executor/node_executor.cc b/ge/hybrid/node_executor/node_executor.cc similarity index 100% rename from src/ge/hybrid/node_executor/node_executor.cc rename to ge/hybrid/node_executor/node_executor.cc diff --git a/src/ge/hybrid/node_executor/node_executor.h b/ge/hybrid/node_executor/node_executor.h similarity index 100% rename from src/ge/hybrid/node_executor/node_executor.h rename to ge/hybrid/node_executor/node_executor.h diff --git a/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc b/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc similarity index 100% rename from src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc rename to ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc diff --git a/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h b/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h similarity index 100% rename from src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h rename to ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h diff --git a/src/ge/hybrid/node_executor/rts/rts_node_executor.cc b/ge/hybrid/node_executor/rts/rts_node_executor.cc similarity index 100% rename from src/ge/hybrid/node_executor/rts/rts_node_executor.cc rename to ge/hybrid/node_executor/rts/rts_node_executor.cc diff --git a/src/ge/hybrid/node_executor/rts/rts_node_executor.h b/ge/hybrid/node_executor/rts/rts_node_executor.h similarity index 100% rename from src/ge/hybrid/node_executor/rts/rts_node_executor.h rename to ge/hybrid/node_executor/rts/rts_node_executor.h diff --git a/src/ge/hybrid/node_executor/task_context.cc b/ge/hybrid/node_executor/task_context.cc similarity index 100% rename from src/ge/hybrid/node_executor/task_context.cc rename to ge/hybrid/node_executor/task_context.cc diff --git a/src/ge/hybrid/node_executor/task_context.h b/ge/hybrid/node_executor/task_context.h similarity index 100% rename from src/ge/hybrid/node_executor/task_context.h rename to ge/hybrid/node_executor/task_context.h diff --git a/src/ge/inc/graph_pass.h b/ge/inc/graph_pass.h similarity index 100% rename from src/ge/inc/graph_pass.h rename to ge/inc/graph_pass.h diff --git a/src/ge/inc/kernel.h b/ge/inc/kernel.h similarity index 100% rename from src/ge/inc/kernel.h rename to ge/inc/kernel.h diff --git a/src/ge/inc/kernel_factory.h b/ge/inc/kernel_factory.h similarity index 100% rename from src/ge/inc/kernel_factory.h rename to ge/inc/kernel_factory.h diff --git a/src/ge/inc/pass.h b/ge/inc/pass.h similarity index 100% rename from src/ge/inc/pass.h rename to ge/inc/pass.h diff --git a/src/ge/inc/pass_manager.h b/ge/inc/pass_manager.h similarity index 100% rename from src/ge/inc/pass_manager.h rename to ge/inc/pass_manager.h diff --git a/src/ge/init/gelib.cc b/ge/init/gelib.cc similarity index 100% rename from src/ge/init/gelib.cc rename to ge/init/gelib.cc diff --git a/src/ge/init/gelib.h b/ge/init/gelib.h similarity index 100% rename from src/ge/init/gelib.h rename to ge/init/gelib.h diff --git a/src/ge/ir_build/atc_ir_common.cc b/ge/ir_build/atc_ir_common.cc similarity index 100% rename from src/ge/ir_build/atc_ir_common.cc rename to ge/ir_build/atc_ir_common.cc diff --git a/src/ge/ir_build/atc_ir_common.h b/ge/ir_build/atc_ir_common.h similarity index 100% rename from src/ge/ir_build/atc_ir_common.h rename to ge/ir_build/atc_ir_common.h diff --git a/src/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc similarity index 100% rename from src/ge/ir_build/ge_ir_build.cc rename to ge/ir_build/ge_ir_build.cc diff --git a/src/ge/model/ge_model.cc b/ge/model/ge_model.cc similarity index 100% rename from src/ge/model/ge_model.cc rename to ge/model/ge_model.cc diff --git a/src/ge/model/ge_model.h b/ge/model/ge_model.h similarity index 100% rename from src/ge/model/ge_model.h rename to ge/model/ge_model.h diff --git a/src/ge/model/ge_root_model.cc b/ge/model/ge_root_model.cc similarity index 100% rename from src/ge/model/ge_root_model.cc rename to ge/model/ge_root_model.cc diff --git a/src/ge/model/ge_root_model.h b/ge/model/ge_root_model.h similarity index 100% rename from src/ge/model/ge_root_model.h rename to ge/model/ge_root_model.h diff --git a/src/ge/module.mk b/ge/module.mk similarity index 100% rename from src/ge/module.mk rename to ge/module.mk diff --git a/ge/offline/CMakeLists.txt b/ge/offline/CMakeLists.txt new file mode 100644 index 00000000..a5a334bd --- /dev/null +++ b/ge/offline/CMakeLists.txt @@ -0,0 +1,81 @@ +set(PROTO_LIST + "${METADEF_DIR}/proto/om.proto" + "${METADEF_DIR}/proto/ge_ir.proto" + "${METADEF_DIR}/proto/insert_op.proto" + "${METADEF_DIR}/proto/task.proto" +) + +protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) + +set(SRC_LIST + "main.cc" + "single_op_parser.cc" + "../session/omg.cc" + "../ir_build/atc_ir_common.cc" +) + +############ atc ############ +add_executable(atc ${SRC_LIST} ${PROTO_HDRS}) + +target_compile_options(atc PRIVATE + -Werror + -O2 +) + +target_compile_definitions(atc PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + COMPILE_OMG_PACKAGE +) + +target_include_directories(atc PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${GE_CODE_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/common/inc/external + ${GE_CODE_DIR}/common/inc/external/graph + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/inc/register + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/external/register + ${PARSER_DIR} + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + ${GE_CODE_DIR}/../inc/common + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc + ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain +) + +target_link_libraries(atc PRIVATE + $ + protobuf + ge_common + register + c_sec + graph + error_manager + ge_compiler + parser_common + gflags + json + runtime_compile + slog + mmpa + -lrt + -ldl +) + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS atc OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} +) diff --git a/ge/offline/main.cc b/ge/offline/main.cc new file mode 100644 index 00000000..9fa2cfba --- /dev/null +++ b/ge/offline/main.cc @@ -0,0 +1,1334 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/gflags_util.h" +#include "common/util.h" +#include "common/util/error_manager/error_manager.h" +#include "framework/common/debug/ge_log.h" +#include "ge/ge_api.h" +#include "generator/ge_generator.h" +#include "graph/anchor.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/graph.h" +#include "graph/op_desc.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/type_utils.h" +#include "init/gelib.h" +#include "ir_build/atc_ir_common.h" +#include "omg/omg.h" +#include "omg/parser/parser_factory.h" +#include "omg/parser/parser_inner_ctx.h" +#include "parser/common/register_tbe.h" +#include "register/op_registry.h" +#include "single_op_parser.h" + +using domi::BuildMode; +using domi::OpRegistrationData; +using domi::OpRegistry; +using domi::Status; +using domi::SUCCESS; +using ge::GEN_OM_MODEL; +using ge::GflagsUtils; +using ge::MODEL_TO_JSON; +using ge::ONLY_PRE_CHECK; +using ge::ParseInputShape; +using ge::PBTXT_TO_JSON; +using std::map; +using std::pair; +using std::shared_ptr; +using std::string; +using std::vector; + +static bool is_dynamic_input = false; + +// 310 limited 8G size +const char *const kGraphMemoryManagerMallocMaxSize = "8*1024*1024*1024"; +const char *const kModeSupport = "only support 0(model to framework model), " + "1(framework model to json), 3(only pre-check), 5(pbtxt to json)"; +const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow)"; + +// limit available mem size 2G +const long kMinAvailableMem = 2 * 1024 * 1024; + +DEFINE_string(model, "", "The model file."); +DEFINE_string(output, "", "The output file path&name."); +DEFINE_int32(framework, -1, "Framework type(0:Caffe; 1:MindSpore; 3:Tensorflow)."); +DEFINE_string(weight, "", "Optional; weight file. Required when framework is Caffe."); + +DEFINE_string(input_shape, "", + "Optional; shape of input data. Required when framework is caffe " + "or TensorFLow or MindSpore." + "Format: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\""); +DEFINE_bool(h, false, "show this help message"); +DEFINE_string(cal_conf, "", "Optional; the calibration config file."); + +DEFINE_string(insert_op_conf, "", "Optional; the config file to insert new op, for example AIPP op."); +DEFINE_string(op_name_map, "", "Optional; custom op name mapping file."); + +DEFINE_string(target, "", "Optional; mini."); + +DEFINE_string(om, "", "The model file to be converted to json."); +DEFINE_string(json, "", "The output json file path&name which is converted from a model."); +DEFINE_int32(mode, 0, + "Optional; run mode, 0(default): model => framework model; 1: " + "framework model => json; 3: only pre-check; 5: pbtxt => json."); + +#if !defined(__ANDROID__) && !defined(ANDROID) +DEFINE_int32(encrypt_mode, -1, "Optional; the encrypt flag. 0: encrypt; -1(default): not encrypt"); +DEFINE_string(encrypt_key, "", "Optional; the encrypt_key file."); +DEFINE_string(certificate, "", "Optional; the certificate file."); +DEFINE_string(hardware_key, "", "Optional; the ISV key file."); +DEFINE_string(private_key, "", "Optional; the private key file."); +#endif + +DEFINE_string(out_nodes, "", + "Optional; output nodes designated by users." + "Format: \"node_name1:0;node_name1:1;node_name2:0\""); + +DEFINE_string(precision_mode, "force_fp16", + "Optional; precision mode." + "Support force_fp16, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); + +DEFINE_string(input_format, "", + "Optional; input_format, format of input data, NCHW;NHWC." + "Format:\"NHWC\""); + +DEFINE_string(check_report, "check_result.json", "Optional; the pre-checking report file."); + +DEFINE_string(input_fp16_nodes, "", + "Optional; input node datatype is fp16 and format is NC1HWC0." + "Format:\"node_name1;node_name2\""); + +DEFINE_string(is_output_adjust_hw_layout, "", + "Optional; Net output node's datatype is fp16 and format is " + "NC1HWC0, or not." + "Format:\"false,true,false,true\""); + +DEFINE_string(is_input_adjust_hw_layout, "", + "Optional; Intput node's datatype is fp16 and format is " + "NC1HWC0, or not." + "Format:\"false,true,false,true\""); + +DEFINE_string(output_type, "", + "Optional; output type! " + "Support FP32,FP16,INT8,INT16,UINT16,UINT8,INT32,INT64,UINT32,UINT64,DOUBLE."); + +DEFINE_string(op_select_implmode, "", + "Optional; op select implmode! " + "Support high_precision, high_performance."); + +DEFINE_string(optypelist_for_implmode, "", + "Optional; Nodes need use implmode selected in op_select_implmode " + "Format:\"node_name1,node_name2\""); + +DEFINE_string(singleop, "", "Optional; If set, generate single op model with the given json file."); + +DEFINE_int32(disable_reuse_memory, 0, "Optional; If set to 1, disable reuse memory when generating if."); + +DEFINE_string(auto_tune_mode, "", "Optional; Set tune mode."); + +DEFINE_string(soc_version, "", "The soc version."); + +DEFINE_string(core_type, "AiCore", "Optional; If set to VectorCore, only use vector core."); + +DEFINE_string(aicore_num, "", "Optional; Set aicore num"); + +DEFINE_string(buffer_optimize, "l2_optimize", "Optional; buffer optimize"); + +DEFINE_string(fusion_switch_file, "", "Optional; Set fusion switch file path"); + +DEFINE_string(save_original_model, "", "Optional; enable output original offline model. false(default)"); + +DEFINE_string(dynamic_batch_size, "", + "Optional; If set, generate dynamic multi batch model. " + "Different batch sizes are split by ','." + "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); + +DEFINE_string(dynamic_image_size, "", + "Optional; If set, generate dynamic multi image size model." + "Different groups of image size are split by ';'," + "while different dimensions of each group are split by ','." + "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); + +DEFINE_string(dynamic_dims, "", + "Optional; If set, generate dynamic input size model. " + "Different groups of size are split by ';', while different dimensions of each group are split by ','." + "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); + +DEFINE_string(enable_small_channel, "0", "Optional; If set to 1, small channel is enabled."); + +DEFINE_string(enable_compress_weight, "false", + "Optional; enable compress weight. true: enable; false(default): disable"); + +DEFINE_string(compress_weight_conf, "", "Optional; the config file to compress weight"); + +DEFINE_string(enable_single_stream, "", "Optional; enable single stream. true: enable; false(default): disable"); + +DEFINE_string(log, "null", "Optional; generate atc log. Support debug, info, warning, error, null"); + +DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0."); + +DEFINE_int32(op_debug_level, 0, "Optional; configure debug level of compiler. 0(default): close debug;" + "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler"); +DEFINE_string(enable_scope_fusion_passes, "", "Optional; validate the non-general scope fusion pass," + "multiple names can be set and separated by ','."); + +class GFlagUtils { + public: + /** + * @name InitGFlag + * @brief initialize gflag + * @return void + */ + static void InitGFlag(int argc, char *argv[]) { + // -help + gflags::SetUsageMessage( + "usage: ./atc \n" + "generate offline model example:\n" + "./atc --model=./alexnet.prototxt --weight=./alexnet.caffemodel \n" + "--framework=0 --output=./domi \n" + "generate offline model for single op example:\n" + "./atc --singleop=./op_list.json --output=./op_model \n" + "===== Basic Functionality =====\n" + "[General]\n" + " --h/help Show this help message\n" + " --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format " + "3: only pre-check; 5: convert pbtxt file to JSON format\n" + "\n[Input]\n" + " --model Model file\n" + " --weight Weight file. Required when framework is Caffe\n" + " --om The model file to be converted to json\n" + " --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow\n" + " --input_format Format of input data. E.g.: \"NCHW\"\n" + " --input_shape Shape of input data. Separate multiple nodes with semicolons (;)." + "Use double quotation marks (\") to enclose each argument.\n" + " E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n" + " --dynamic_batch_size Set dynamic batch size. E.g: \"batchsize1,batchsize2,batchsize3\"\n" + " --dynamic_image_size Set dynamic image size. Separate multiple nodes with semicolons (;)." + "Use double quotation marks (\") to enclose each argument.\n" + " E.g: \"imagesize1_height,imagesize1_width;imagesize2_height,imagesize2_width\"\n" + " --dynamic_dims Set dynamic dims. Separate multiple nodes with semicolons (;)." + "Use double quotation marks (\") to enclose each argument. E.g: \"dims1_n1,dims1_n2;dims2_n1,dims2_n2\"\n" + " --singleop Single op definition file. atc will generate offline " + "model(s) for single op if --singleop is set.\n" + "\n[Output]\n" + " --output Output file path&name(needn't suffix, will add " + ".om automatically). \n" + " If --singleop is set, this arg specifies the directory to " + "which the single op offline model will be generated\n" + " --output_type Set net output type. Support FP32, FP16, UINT8." + "E.g.: FP16, indicates that all out nodes are set to FP16.\n" + " \"node1:0:FP16;node2:1:FP32\", indicates setting the datatype of multiple out nodes.\n" + " --check_report The pre-checking report file. Default value is: " + "\"check_result.json\"\n" + " --json The output json file path&name which is " + "converted from a model\n" + "\n[Target]\n" + " --soc_version The soc version.\n" + " --core_type Set core type AiCore or VectorCore. VectorCore: use vector core. " + "Default value is: AiCore\n" + " --aicore_num Set aicore num\n" + "===== Advanced Functionality =====\n" + "[Feature]\n" + " --out_nodes Output nodes designated by users. Separate multiple nodes with semicolons (;)." + "Use double quotation marks (\") to enclose each argument.\n" + " E.g.: \"node_name1:0;node_name1:1;node_name2:0\"\n" + " --input_fp16_nodes Input node datatype is fp16. Separate multiple nodes with semicolons " + "(;)." + "Use double quotation marks (\") to enclose each argument." + "E.g.: \"node_name1;node_name2\"\n" + " --insert_op_conf Config file to insert new op\n" + " --op_name_map Custom op name mapping file\n" + " Note: A semicolon(;) cannot be included in each " + "path, otherwise the resolved path will not match the expected one.\n" + " --is_input_adjust_hw_layout Intput node datatype is fp16 and format is " + "NC1HWC0, used with input_fp16_nodes E.g.: \"true,true,false,true\"\n" + " --is_output_adjust_hw_layout Net output node datatype is fp16 and format is " + "NC1HWC0, used with out_nodes. E.g.: \"true,true,false,true\"\n" + "\n[Model Tuning]\n" + " --disable_reuse_memory The switch of reuse memory. Default value is : 0." + "0 means reuse memory, 1 means do not reuse memory.\n" + " --fusion_switch_file Set fusion switch file path\n" + " --enable_scope_fusion_passes validate the non-general scope fusion passes," + "multiple names can be set and separated by ','. E.g.: ScopePass1,ScopePass2,...\n" + " --enable_single_stream Enable single stream. true: enable; false(default): disable\n" + " --enable_small_channel Set enable small channel. 0(default): disable; 1: enable\n" + " --enable_compress_weight Enable compress weight. true: enable; false(default): disable\n" + " --compress_weight_conf Config file to compress weight\n" + " --buffer_optimize Set buffer optimize. \"l2_optimize\" (default). Set \"off_optimize\" to close\n" + "\n[Operator Tuning]\n" + " --precision_mode precision mode, support force_fp16(default), allow_mix_precision, " + "allow_fp32_to_fp16, must_keep_origin_dtype.\n" + " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n" + " --op_select_implmode Set op select implmode. Support high_precision, high_performance." + "default: high_performance\n" + " --optypelist_for_implmode Appoint which op to select implmode, cooperated with op_select_implmode.\n" + " Separate multiple nodes with commas (,). Use double quotation marks (\") " + " to enclose each argument. E.g.: \"node_name1,node_name2\"\n" + " --op_debug_level Debug enable for TBE operator building.\n" + " 0 (default): Disable debug; 1: Enable TBE pipe_all, " + "and generate the operator CCE file and Python-CCE mapping file (.json);\n" + " 2: Enable TBE pipe_all, generate the operator CCE file and Python-CCE mapping file " + "(.json), and enable the CCE compiler -O0-g.\n" + "\n[Debug]\n" + " --save_original_model Control whether to output original model. E.g.: true: output original model\n" + " --log Generate log with level. Support debug, info, warning, error, null\n" + " --dump_mode The switch of dump json with shape, to be used with mode 1." + "0(default): disable; 1: enable."); + + gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true); + // Using gflags to analyze input parameters + GflagsUtils::ChangeHelpFlags(FLAGS_h); + gflags::HandleCommandLineHelpFlags(); + } + + static Status CheckDumpInfershapeJsonFlags() { + Status ret = CheckFrameWorkValid(FLAGS_framework, FLAGS_weight); + GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, + "check custom aicpu run so failed!"); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"), + return domi::FAILED, "Input parameter[--weight]'s value[%s] is invalid!", + FLAGS_weight.c_str()); + return domi::SUCCESS; + } + + static Status CheckFlags() { + Status ret = ge::SUCCESS; + // No model file information passed in + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_model == "", + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"model"}); + ret = ge::FAILED, "Input parameter[--model]'s value is empty!"); + + // check param disable_reuse_memory + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + ge::CheckDisableReuseMemoryParamValid(to_string(FLAGS_disable_reuse_memory)) != ge::SUCCESS, + ret = ge::FAILED, "check disable_reuse_memory failed!"); + + // check optypelist_for_implmode and op_select_implmode + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, + FLAGS_op_select_implmode) != ge::SUCCESS, + ret = ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!"); + // No output file information passed in + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_mode == GEN_OM_MODEL && FLAGS_output == "", + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"output"}); + ret = ge::FAILED, "Input parameter[--output]'s value is empty!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + CheckFrameWorkValid(FLAGS_framework, FLAGS_weight) != ge::SUCCESS, + ret = ge::FAILED, + "CheckFrameWorkValid failed"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + ge::CheckDynamicInputParamValid(FLAGS_dynamic_batch_size, FLAGS_dynamic_image_size, + FLAGS_dynamic_dims, FLAGS_input_shape, + FLAGS_input_format, is_dynamic_input) != ge::SUCCESS, + ret = ge::FAILED, "check dynamic size(batch size, image size or dims) failed!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + !FLAGS_insert_op_conf.empty() && !FLAGS_dynamic_dims.empty(), + ErrorManager::GetInstance().ATCReportErrMessage("E10001", + {"parameter", "value", "reason"}, + {"--insert_op_conf", FLAGS_insert_op_conf, + "dynamic dims function does not support aipp"}); + ret = ge::FAILED, "dynamic dims function does not support aipp"); + +#if !defined(__ANDROID__) && !defined(ANDROID) + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!CheckEncryptModeValid(FLAGS_encrypt_mode), ret = ge::FAILED, + "encrypt_mode %d not valid!!", FLAGS_encrypt_mode); + + if (FLAGS_encrypt_mode == 0) { // Encryption mode + GELOGI("ge will run with encrypt!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_encrypt_key), ret = ge::FAILED, + "encrypt_key file not found!!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_certificate), ret = ge::FAILED, + "certificate file not found!!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_hardware_key), ret = ge::FAILED, + "hardware_key file not found!!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_private_key), ret = ge::FAILED, + "private_key file not found!!"); + } else { // No encryption + GELOGI("ge will run without encrypt!"); + } +#endif + + /** + * Check the validity of the I / O file path + */ + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_model != "" && !ge::CheckInputPathValid(FLAGS_model, "--model"), ret = ge::FAILED, + "model file %s not found!!", FLAGS_model.c_str()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"), + ret = ge::FAILED, "weight file %s not found!!", + FLAGS_weight.c_str()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_cal_conf != "" && !ge::CheckInputPathValid(FLAGS_cal_conf, "--cal_conf"), + ret = ge::FAILED, "calibration config file %s not found!!", + FLAGS_cal_conf.c_str()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_op_name_map != "" && !ge::CheckInputPathValid(FLAGS_op_name_map, "--op_name_map"), + ret = ge::FAILED, "op config file %s not found!!", + FLAGS_op_name_map.c_str()); + + GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(FLAGS_insert_op_conf)) == ge::SUCCESS, + ret = ge::FAILED, "check insert op conf failed!"); + + GE_CHK_BOOL_EXEC(ge::CheckCompressWeightParamValid( + FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS, + ret = ge::FAILED, "check compress weight failed!"); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + !ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), ret = ge::FAILED, + "check_report file %s not found!!", FLAGS_check_report.c_str()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_mode == GEN_OM_MODEL && FLAGS_output != "" && + (!ge::CheckOutputPathValid(FLAGS_output, "--output") || !CheckPathWithName(FLAGS_output)), + ret = ge::FAILED, "output path %s is not valid!!", FLAGS_output.c_str()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_save_original_model != "" && + FLAGS_save_original_model != "true" && + FLAGS_save_original_model != "false", + ErrorManager::GetInstance().ATCReportErrMessage( + "E10005", {"parameter", "value"}, {"save_original_model", FLAGS_save_original_model}); + ret = ge::FAILED, + "Input parameter[--save_original_model]'s value[%s] must be true or false.", + FLAGS_save_original_model.c_str()); + GE_CHK_BOOL_EXEC(ge::CheckBufferOptimizeParamValid(FLAGS_buffer_optimize) == ge::SUCCESS, + ret = ge::FAILED, "check output type failed!"); + + GE_CHK_BOOL_EXEC( + ge::CheckEnableSingleStreamParamValid(std::string(FLAGS_enable_single_stream)) == ge::SUCCESS, + ret = ge::FAILED, "check enable single stream failed!"); + + return ret; + } + + /** + * Verifying the parameters of converting model to JSON + * 1. Fmk_model + * 2. out_json + **/ + static Status CheckConverJsonParamFlags() { + Status ret = ge::SUCCESS; + + // No model path passed in + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "", + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"}); + ret = ge::FAILED, + "Input parameter[--om]'s value is empty!!"); + + // JSON path not passed in + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_json == "", + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"json"}); + ret = ge::FAILED, + "Input parameter[--json]'s value is empty!!"); + + // Check if the model path is valid + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_om != "" && !ge::CheckInputPathValid(FLAGS_om, "--om"), + ret = ge::FAILED, + "model file path is invalid: %s.", FLAGS_om.c_str()); + + // Check whether the JSON path is valid + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_json != "" && !ge::CheckOutputPathValid(FLAGS_json, "--json"), + ret = ge::FAILED, + "json file path is invalid: %s.", FLAGS_json.c_str()); + + return ret; + } + + /** + * Check command line parameters for explicit settings + * true: Explicit setup + * false: Not set up + * */ + static bool CheckFlagSet(string flag) { + gflags::CommandLineFlagInfo info; + return !(gflags::GetCommandLineFlagInfo(flag.c_str(), &info) && info.is_default); + } + + private: + static bool CheckEncryptModeValid(const int encrypt_mode) { +#if !defined(__ANDROID__) && !defined(ANDROID) + if (encrypt_mode != 0 && encrypt_mode != -1) { + DOMI_LOGE("encrypt mode must be 0 or -1"); + return false; + } +#else + if (encrypt_mode != -1) { + DOMI_LOGE("encrypt mode must be -1"); + return false; + } +#endif + + return true; + } + + static Status CheckFrameWorkValid(int framework, const std::string weight_file) { + if (framework != (int32_t)domi::CAFFE && framework != (int32_t)domi::TENSORFLOW && + framework != (int32_t)domi::MINDSPORE && framework != (int32_t)domi::ONNX) { + // No framework information was passed in or the entered framework is illegal + ErrorManager::GetInstance().ATCReportErrMessage( + "E10007", {"parameter", "support"}, + {"framework", "0(Caffe) or 1(MindSpore) or 3(TensorFlow)"}); + DOMI_LOGE("Input parameter[--framework] is mandatory and it's value must be: " + "0(Caffe) or 1(MindSpore) or 3(TensorFlow)."); + return domi::PARAM_INVALID; + } + + if ((framework == (int32_t)domi::CAFFE) && (weight_file == "")) { + ErrorManager::GetInstance().ATCReportErrMessage("E10008", {"parameter"}, {"weight"}); + DOMI_LOGE("Input parameter[--weight]'s value is empty when framework is 0(CAFFE)!"); + return domi::PARAM_INVALID; + } + + if ((framework == (int32_t)domi::TENSORFLOW) && (weight_file != "")) { + GELOGW("Parameter weight is ignored for TensorFlow."); + } + + if ((framework == (int32_t)domi::ONNX) && (weight_file != "")) { + GELOGW("Parameter weight is ignored for Onnx."); + } + return domi::SUCCESS; + } + + static bool CheckPathWithName(const std::string &fileName) { + // Determine file path length + if (fileName.size() > static_cast(PATH_MAX)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10021", {"parameter", "size"}, {"output", std::to_string(PATH_MAX)}); + GELOGE(ge::FAILED, "Input parameter[--output]'s path is too long, it must be less than %d", PATH_MAX); + return false; + } + + // Find the last separator + int slashPosition = fileName.size() - 1; + for (; slashPosition >= 0; slashPosition--) { + if (fileName[slashPosition] == '\\' || fileName[slashPosition] == '/') { + break; + } + } + + // Failure if no filename follows the path + if (slashPosition == static_cast(fileName.size() - 1)) { + ErrorManager::GetInstance().ATCReportErrMessage("E10022", {"parameter", "filename"}, {"output", fileName}); + DOMI_LOGE("Input parameter[--output]'s path[%s] not include file name", fileName.c_str()); + return false; + } + + return true; + } +}; + +void SetDynamicInputSizeOptions() { + if (!FLAGS_dynamic_batch_size.empty()) { + domi::GetContext().dynamic_batch_size = FLAGS_dynamic_batch_size; + } + if (!FLAGS_dynamic_image_size.empty()) { + domi::GetContext().dynamic_image_size = FLAGS_dynamic_image_size; + } + if (!FLAGS_dynamic_dims.empty()) { + domi::GetContext().dynamic_dims = FLAGS_dynamic_dims; + } +} + +/// Validate the non-general scope fusion pass. +/// The parameter is set to the name of the fusion rule. +/// Multiple names can be set and separated by ",". +void SetEnableScopeFusionPasses(const std::string pass_names) { + ge::GetParserContext().enable_scope_fusion_passes = pass_names; +} + +static bool CheckInputFormat() { + if (FLAGS_input_format.empty()) { + // Set default format + if (FLAGS_framework == static_cast(domi::TENSORFLOW)) { + FLAGS_input_format = "NHWC"; + } else { + FLAGS_input_format = "NCHW"; + } + return true; + } else if ((FLAGS_framework == static_cast(domi::CAFFE))) { // caffe + if (ge::caffe_support_input_format.find(FLAGS_input_format) != ge::caffe_support_input_format.end()) { + return true; + } + // only support NCHW ND + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, ge::kCaffeFormatSupport}); + GELOGE(ge::FAILED, + "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kCaffeFormatSupport); + return false; + } else if ((FLAGS_framework == static_cast(domi::TENSORFLOW))) { // tf + if (ge::tf_support_input_format.find(FLAGS_input_format) != ge::tf_support_input_format.end()) { + return true; + } + // only support NCHW NHWC ND NCDHW NDHWC + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, ge::kTFFormatSupport}); + GELOGE(ge::FAILED, + "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kTFFormatSupport); + return false; + } else if (FLAGS_framework == static_cast(domi::ONNX)) { + if (ge::onnx_support_input_format.find(FLAGS_input_format) != ge::onnx_support_input_format.end()) { + return true; + } + // only support NCHW ND + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, {"--input_format", FLAGS_input_format, ge::kONNXFormatSupport}); + GELOGE(ge::FAILED, + "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kONNXFormatSupport); + return false; + } + return true; +} + +#if !defined(__ANDROID__) && !defined(ANDROID) +static void GetCustomOpPath(std::string &customop_path) { + GELOGI("Enter get custom op path schedule"); + std::string fmk_type = ge::TypeUtils::FmkTypeToSerialString(static_cast(FLAGS_framework)); + GELOGI("Framework type is %s.", fmk_type.c_str()); + + const char *path_env = std::getenv("ASCEND_OPP_PATH"); + if (path_env != nullptr) { + std::string path = path_env; + customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type); + GELOGI("Get custom so path from env : %s", path_env); + return; + } + std::string path_base = ge::GELib::GetPath(); + GELOGI("path_base is %s", path_base.c_str()); + path_base = path_base.substr(0, path_base.rfind('/')); + path_base = path_base.substr(0, path_base.rfind('/') + 1); + customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type); + return; +} + +void GetPluginSoFileList(const string &path, vector &fileList, string &caffe_parser_path) { + // Support to split multiple so directories by ":" + GELOGI("path is %s", path.c_str()); + vector v_path = ge::StringUtils::Split(path, ':'); + for (size_t i = 0; i < v_path.size(); ++i) { + ge::FindParserSo(v_path[i], fileList, caffe_parser_path); + GELOGI("CustomOpLib full name = %s", v_path[i].c_str()); + } +} + +void LoadModelParserLib(std::string caffe_parser_path) { + if (FLAGS_framework == static_cast(domi::TENSORFLOW)) { + void *tf_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL); + if (tf_handle == nullptr) { + GELOGW("dlopen fmk library [libfmk_parser.so] failed."); + return; + } + GELOGI("plugin load libfmk_parser.so success."); + } else if (FLAGS_framework == static_cast(domi::CAFFE)) { + // What we are dealing with here is that the user modifies the caffe.proto scenario. + // If no lib_Caffe_Parser.so is found under the plugin path, use the default lib_Caffe_Parser.so path. + caffe_parser_path = caffe_parser_path.empty() ? "lib_caffe_parser.so" : caffe_parser_path; + + void *handle = dlopen(caffe_parser_path.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle == nullptr) { + GELOGW("dlopen failed, plugin name:%s. Message(%s).", caffe_parser_path.c_str(), dlerror()); + return; + } + GELOGI("plugin load %s success.", caffe_parser_path.c_str()); + // According to the dependency, the Caffe parsing module of the framework is loaded here( libfmk_parser.so). + // (depend on the lib_caffe_parser.so) + void *fmk_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL); + if (fmk_handle == nullptr) { + GELOGW("dlopen fmk library [libfmk_parser.so] failed."); + if (dlclose(handle) != 0) { + GELOGW("dlclose lib_caffe_parser.so failed."); + } + return; + } + GELOGI("plugin load libfmk_parser.so success."); + } else if (FLAGS_framework == static_cast(domi::ONNX)) { + void *handle = dlopen("libfmk_onnx_parser.so", RTLD_NOW | RTLD_GLOBAL); + if (handle == nullptr) { + GELOGW("dlopen fmk library [libfmk_onnx_parser.so] failed."); + return; + } + GELOGI("plugin load libfmk_onnx_parser.so success."); + } else { + GELOGW("Framework:%s is not support.", + ge::TypeUtils::FmkTypeToSerialString(static_cast(FLAGS_framework)).c_str()); + return; + } + return; +} + +void LoadCustomOpLib(bool need_load_ops_plugin) { + std::string plugin_path; + GetCustomOpPath(plugin_path); + + vector fileList; + string caffe_parser_path = ""; + + // whether there are files in the plugin so path + GetPluginSoFileList(plugin_path, fileList, caffe_parser_path); + + // no file + if (fileList.empty() && caffe_parser_path.empty()) { + GELOGW("can not find any plugin file in plugin_path: %s", plugin_path.c_str()); + } + + LoadModelParserLib(caffe_parser_path); + if (!need_load_ops_plugin) { + GELOGI("No need to load ops plugin so."); + return; + } + OpRegistry::Instance()->registrationDatas.clear(); + // load other so files except lib_caffe_parser.so in the plugin so path + for (auto elem : fileList) { + ge::StringUtils::Trim(elem); + + void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle == nullptr) { + GELOGW("dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); + } else { + GELOGI("plugin load %s success.", elem.c_str()); + } + } + + std::vector registrationDatas = OpRegistry::Instance()->registrationDatas; + for (OpRegistrationData reg_data : registrationDatas) { + if (reg_data.GetFrameworkType() == static_cast(FLAGS_framework)) { + (void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data); + (void)OpRegistry::Instance()->Register(reg_data); + } + } +} + +void SaveCustomCaffeProtoPath() { + GELOGI("Enter save custom caffe proto path."); + + std::string path_base = ge::GELib::GetPath(); + GELOGI("path_base is %s", path_base.c_str()); + path_base = path_base.substr(0, path_base.rfind('/')); + path_base = path_base.substr(0, path_base.rfind('/') + 1); + ge::GetParserContext().caffe_proto_path = path_base + "include/proto/"; + + string customop_path; + const char *path_env = std::getenv("ASCEND_OPP_PATH"); + if (path_env != nullptr) { + std::string path = path_env; + customop_path = path + "/framework/custom/caffe/"; + GELOGI("Get custom proto path from env : %s", path_env); + ge::GetParserContext().custom_proto_path = customop_path; + return; + } + customop_path = path_base + "ops/framework/custom/caffe/"; + ge::GetParserContext().custom_proto_path = customop_path; + return; +} + +#endif + +Status CreateInputsForInference(const ge::Graph &graph, vector &inputs) { + auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + for (ge::NodePtr &input_node : compute_graph->GetAllNodes()) { + GE_CHECK_NOTNULL(input_node); + ge::OpDescPtr op = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op); + if (op->GetType() == ge::DATA) { + GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size()); + ge::GeTensorDesc tensor = op->GetInputDesc(0); + string data_op_name = op->GetName(); + GELOGI("Data op name is: %s", data_op_name.c_str()); + ge::GeShape data_shape; + auto iter = domi::GetContext().input_dims.find(data_op_name); + if (iter != domi::GetContext().input_dims.end()) { + data_shape = ge::GeShape(iter->second); + GELOGI("Data op get shape from Context."); + } else { + data_shape = tensor.GetShape(); + GELOGI("Data op get shape from InputDesc in geir graph."); + } + + ge::DataType data_type = tensor.GetDataType(); + string data_type_str = ge::TypeUtils::DataTypeToSerialString(data_type); + GELOGI("Data op get data type:%s from InputDesc in geir graph.", data_type_str.c_str()); + + ge::GeTensor input_tensor; + ge::GeTensorDesc desc(data_shape, ge::Format(domi::GetContext().format), data_type); + input_tensor.SetTensorDesc(desc); + inputs.push_back(input_tensor); + } + } + GELOGI("Build ME model, inputs size is: %zu", inputs.size()); + return ge::SUCCESS; +} + +domi::Status GenerateInfershapeJson() { + if (!CheckInputFormat()) { + GELOGE(ge::FAILED, "Check input_format failed"); + return domi::FAILED; + } + Status ret = GFlagUtils::CheckDumpInfershapeJsonFlags(); + GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check flags failed!"); + + ge::GeGenerator ge_generator; + std::map options; + ge::Status geRet = ge_generator.Initialize(options, domi::GetContext()); + if (geRet != ge::SUCCESS) { + DOMI_LOGE("GeGenerator initialize failed!"); + return domi::FAILED; + } + + ge::Graph graph; + std::map atc_params; + atc_params.insert(std::pair("input_format", FLAGS_input_format)); + ret = ParseGraph(graph, atc_params, FLAGS_om.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType) FLAGS_framework, + "", FLAGS_target.c_str(), (ge::RunMode) FLAGS_mode, false); + if (ret != ge::SUCCESS) { + DOMI_LOGE("ATC Parse graph domi::FAILED"); + (void)ge_generator.Finalize(); + return domi::FAILED; + } + + geRet = ge_generator.GenerateInfershapeGraph(graph); + if (geRet != ge::SUCCESS) { + DOMI_LOGE("ATC GenerateInfershapeJson failed"); + (void)ge_generator.Finalize(); + return domi::FAILED; + } + if (DumpInfershapeJson(graph, FLAGS_json.c_str()) != SUCCESS) { + DOMI_LOGE("ATC DumpInfershapeJson failed"); + (void)ge_generator.Finalize(); + return domi::FAILED; + } + (void)ge_generator.Finalize(); + return ge::SUCCESS; +} + +static Status ConvertModelToJson(int fwk_type, const string &model_file, const string &json_file) { + Status ret = ge::SUCCESS; + if (fwk_type == -1) { + ret = ge::ConvertOmModelToJson(model_file.c_str(), json_file.c_str()); + return ret; + } + + if ((fwk_type != domi::TENSORFLOW) && (fwk_type != domi::CAFFE) && (fwk_type != domi::ONNX)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--framework", std::to_string(fwk_type), kModelToJsonSupport}); + GELOGE(ge::FAILED, "Invalid value for --framework[%d], %s.", fwk_type, kModelToJsonSupport); + ret = ge::FAILED; + } + + if (FLAGS_dump_mode != "0" && FLAGS_dump_mode != "1") { + ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"dump_mode"}); + GELOGE(ge::FAILED, "Input parameter[--dump_mode]'s value must be 1 or 0."); + ret = ge::FAILED; + } + + if (ret != ge::SUCCESS) return ret; + + // Need to save caffe.proto path + SaveCustomCaffeProtoPath(); + + if (FLAGS_dump_mode == "0") { + // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so. + LoadCustomOpLib(false); + ret = ge::ConvertFwkModelToJson((domi::FrameworkType)fwk_type, model_file.c_str(), json_file.c_str()); + } else if (FLAGS_dump_mode == "1") { + // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so and ops plugin so. + LoadCustomOpLib(true); + ret = GenerateInfershapeJson(); + } + + return ret; +} + +domi::Status GenerateModel(std::map &options, std::string output) { + ge::GeGenerator ge_generator; + ge::Status geRet = ge::SUCCESS; + std::shared_ptr instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { + geRet = ge::GELib::Initialize(options); + if (geRet != ge::SUCCESS) { + DOMI_LOGE("GE initialize failed!"); + return domi::FAILED; + } + } + geRet = ge_generator.Initialize(options, domi::GetContext()); + if (geRet != ge::SUCCESS) { + DOMI_LOGE("GeGenerator initialize failed!"); + (void)ge::GELib::GetInstance()->Finalize(); + return domi::FAILED; + } + + ge::Graph graph; + std::vector inputs; + if (FLAGS_framework == domi::MINDSPORE) { + // load model from file + ge::Model load_model = ge::Model("loadmodel", "version2"); + auto ret1 = load_model.LoadFromFile(FLAGS_model); + if (ret1 != ge::GRAPH_SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"parameter"}, {FLAGS_model}); + DOMI_LOGE("Load model from %s failed, please check model file or " + "input parameter[--framework] is correct", FLAGS_model.c_str()); + (void)ge_generator.Finalize(); + (void)ge::GELib::GetInstance()->Finalize(); + return domi::FAILED; + } + + graph = load_model.GetGraph(); + + GE_CHK_STATUS_EXEC(ge::InitDomiOmgContext(FLAGS_input_shape, FLAGS_input_format, "", is_dynamic_input), + GELOGE(ge::FAILED, "ATC Generate call InitDomiOmgContext ret fail"); + (void)ge_generator.Finalize(); (void)ge::GELib::GetInstance()->Finalize(); return domi::FAILED); + + Status ret = CreateInputsForInference(graph, inputs); + if (ret != ge::SUCCESS) { + GELOGE(ge::FAILED, "create inputs for inference failed."); + (void)ge_generator.Finalize(); + (void)ge::GELib::GetInstance()->Finalize(); + return domi::FAILED; + } + + } else { + std::map atc_params; + atc_params.insert(std::pair("input_shape", FLAGS_input_shape)); + atc_params.insert(std::pair("out_nodes", FLAGS_out_nodes)); + atc_params.insert(std::pair("input_format", FLAGS_input_format)); + atc_params.insert(std::pair("check_report", FLAGS_check_report)); + atc_params.insert(std::pair("input_fp16_nodes", FLAGS_input_fp16_nodes)); + atc_params.insert(std::pair("is_input_adjust_hw_layout", FLAGS_is_input_adjust_hw_layout)); + atc_params.insert(std::pair("is_output_adjust_hw_layout", FLAGS_is_output_adjust_hw_layout)); + atc_params.insert(std::pair("compress_weight_conf", FLAGS_compress_weight_conf)); + atc_params.insert(std::pair(string(ge::OUTPUT_DATATYPE), FLAGS_output_type)); + atc_params.insert(std::pair("output", output)); + + Status ret = + ParseGraph(graph, atc_params, FLAGS_model.c_str(), FLAGS_weight.c_str(), (domi::FrameworkType)FLAGS_framework, + FLAGS_op_name_map.c_str(), FLAGS_target.c_str(), (ge::RunMode)FLAGS_mode, is_dynamic_input); + + // in ONLY_PRE_CHECK mode, pre-checking report has already saved in ParseGraph + if (FLAGS_mode == ge::ONLY_PRE_CHECK) { + (void)ge_generator.Finalize(); + (void)ge::GELib::GetInstance()->Finalize(); + if (ret != ge::SUCCESS) { + DOMI_LOGE("ATC precheck fail."); + return domi::FAILED; + } + return domi::SUCCESS; + } + + if (ret != ge::SUCCESS) { + DOMI_LOGE("ATC Parse graph domi::FAILED"); + DOMI_LOGE("ATC Generate execute failed"); // Duplicate log. (for test case + (void)ge_generator.Finalize(); + (void)ge::GELib::GetInstance()->Finalize(); + return domi::FAILED; + } + if (ge::SetOutputNodeInfo(graph, FLAGS_output_type, "") != domi::SUCCESS) { + DOMI_LOGE("Set output node info fail."); + (void)ge_generator.Finalize(); + (void)ge::GELib::GetInstance()->Finalize(); + return domi::FAILED; + } + } + + geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); + if (geRet != ge::SUCCESS) { + DOMI_LOGE("GE GenerateOfflineModel execute failed"); + DOMI_LOGE("ATC Generate execute failed"); // Duplicate log. (for test case + // checking error log) + (void)ge_generator.Finalize(); + (void)ge::GELib::GetInstance()->Finalize(); + return domi::FAILED; + } + (void)ge_generator.Finalize(); + (void)ge::GELib::GetInstance()->Finalize(); + return ge::SUCCESS; +} + +static void SetEnvForSingleOp(std::map &options) { + string flag_on = "1"; + string flag_off = "0"; + options.emplace(ge::GE_FE_FLAG, flag_on); + options.emplace(ge::STREAM_NUM, "1"); // single op only use one stream + options.emplace(ge::RUN_FLAG, flag_off); + options.emplace(ge::OPTION_GRAPH_RUN_MODE, flag_off); + options.emplace(ge::SINGLE_OP_FLAG, flag_on); + options.emplace(ge::PRECISION_MODE, FLAGS_precision_mode); + options.emplace(ge::SOC_VERSION, FLAGS_soc_version); + options.emplace(ge::CORE_TYPE, FLAGS_core_type); + options.emplace(ge::AICORE_NUM, FLAGS_aicore_num); + options.emplace(ge::OP_SELECT_IMPL_MODE, FLAGS_op_select_implmode); + options.emplace(ge::OPTYPELIST_FOR_IMPLMODE, FLAGS_optypelist_for_implmode); + options.emplace(ge::AUTO_TUNE_MODE, FLAGS_auto_tune_mode); + options.emplace(ge::GRAPH_MEMORY_MAX_SIZE, kGraphMemoryManagerMallocMaxSize); + options.emplace(ge::OP_DEBUG_LEVEL, to_string(FLAGS_op_debug_level)); +} + +domi::Status GenerateSingleOp(const std::string& json_file_path) { + if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output, "--output")) { + DOMI_LOGE("output path %s is not valid!", FLAGS_output.c_str()); + return domi::FAILED; + } + // check optypelist_for_implmode and op_select_implmode + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + ge::CheckImplmodeParamValid(FLAGS_optypelist_for_implmode, FLAGS_op_select_implmode) != ge::SUCCESS, + return ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!"); + + std::map options; + // need to be changed when ge.ini plan is done + SetEnvForSingleOp(options); + + auto ret = ge::GELib::Initialize(options); + if (ret != ge::SUCCESS) { + DOMI_LOGE("GE initialize failed!"); + return domi::FAILED; + } + + ge::GeGenerator generator; + ret = generator.Initialize(options, domi::GetContext()); + if (ret != SUCCESS) { + DOMI_LOGE("GeGenerator initialize failed!"); + (void)ge::GELib::GetInstance()->Finalize(); + return domi::FAILED; + } + + vector build_params; + if (ge::SingleOpParser::ParseSingleOpList(json_file_path, build_params) != ge::SUCCESS) { + DOMI_LOGE("parse single op json file failed"); + (void)generator.Finalize(); + (void)ge::GELib::GetInstance()->Finalize(); + return domi::FAILED; + } + + int index = 0; + for (auto ¶m : build_params) { + string output_path; + if (!FLAGS_output.empty()) { + output_path = FLAGS_output + "/"; + } + output_path += param.file_name; + ret = generator.BuildSingleOpModel(param.op_desc, param.inputs, param.outputs, output_path); + if (ret != SUCCESS) { + DOMI_LOGE("Compile op failed. ge ret = %u, op index = %d", ret, index); + ret = domi::FAILED; + break; + } + GELOGI("Compile op success. op index = %d, output = %s", index, output_path.c_str()); + index += 1; + } + + (void)generator.Finalize(); + (void)ge::GELib::GetInstance()->Finalize(); + return ret; +} + +domi::Status GenerateOmModel() { + if (!CheckInputFormat()) { + GELOGE(ge::FAILED, "Check input_format failed"); + return domi::FAILED; + } + Status ret = GFlagUtils::CheckFlags(); + GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, + "Check flags failed! Please check whether some atc params that include semicolons[;] use double " + "quotation marks (\") to enclose each argument such as out_nodes, input_shape, dynamic_image_size"); +#if !defined(__ANDROID__) && !defined(ANDROID) + // Load custom operator Library + LoadCustomOpLib(true); + + SaveCustomCaffeProtoPath(); + + ret = ge::CheckCustomAiCpuOpLib(); + + GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "check custom aicpu run so failed!"); +#endif + + const int f_stream_num = 1; + std::map options; + options.insert(std::pair(string(ge::FRAMEWORK_TYPE), to_string(FLAGS_framework))); + options.insert(std::pair(string(ge::STREAM_NUM), to_string(f_stream_num))); + options.insert(std::pair(string(ge::CALIBRATION_CONF_FILE), FLAGS_cal_conf)); + options.insert(std::pair(string(ge::ENCRYPT_MODE), to_string(FLAGS_encrypt_mode))); + options.insert(std::pair(string(ge::EK_FILE), FLAGS_encrypt_key)); + options.insert(std::pair(string(ge::CERT_FILE), FLAGS_certificate)); + options.insert(std::pair(string(ge::HW_KEY_FILE), FLAGS_hardware_key)); + options.insert(std::pair(string(ge::PRIVATE_KEY_FILE), FLAGS_private_key)); + options.insert(std::pair(string(ge::OUTPUT_NODE_NAME), FLAGS_out_nodes)); + options.insert(std::pair(string(ge::INSERT_OP_FILE), FLAGS_insert_op_conf)); + options.insert(std::pair(string(ge::PRECISION_MODE), FLAGS_precision_mode)); + + options.insert(std::pair(string(ge::RUN_FLAG), to_string(0))); + options.insert(std::pair(string(ge::TRAIN_FLAG), to_string(0))); + + if (!FLAGS_output_type.empty()) { + options.insert(std::pair(string(ge::OUTPUT_DATATYPE), FLAGS_output_type)); + } + + options.insert(std::pair(string(ge::OP_SELECT_IMPL_MODE), FLAGS_op_select_implmode)); + options.insert(std::pair(string(ge::OPTYPELIST_FOR_IMPLMODE), FLAGS_optypelist_for_implmode)); + + if (!FLAGS_input_fp16_nodes.empty()) { + GELOGI("FLAGS_input_fp16_nodes : %s .", FLAGS_input_fp16_nodes.c_str()); + options.insert(std::pair(ge::INPUT_FP16_NODES, FLAGS_input_fp16_nodes)); + } + + options.insert(std::pair(string(ge::AUTO_TUNE_MODE), FLAGS_auto_tune_mode)); + + options.insert( + std::pair(string(ge::OPTION_EXEC_DISABLE_REUSED_MEMORY), to_string(FLAGS_disable_reuse_memory))); + + options.insert(std::pair(string(ge::SOC_VERSION), FLAGS_soc_version)); + + options.insert(std::pair(string(ge::CORE_TYPE), FLAGS_core_type)); + + options.insert(std::pair(string(ge::AICORE_NUM), FLAGS_aicore_num)); + + options.insert(std::pair(string(ge::BUFFER_OPTIMIZE), FLAGS_buffer_optimize)); + + options.insert(std::pair(string(ge::ENABLE_SMALL_CHANNEL), FLAGS_enable_small_channel)); + + options.insert(std::pair(string(ge::FUSION_SWITCH_FILE), FLAGS_fusion_switch_file)); + + options.insert(std::pair(string(ge::ENABLE_COMPRESS_WEIGHT), + (FLAGS_enable_compress_weight == "true") ? + ge::kEnableCompressWeightTrue : ge::kEnableCompressWeightFalse)); + + options.insert(std::pair(string(ge::GRAPH_MEMORY_MAX_SIZE), kGraphMemoryManagerMallocMaxSize)); + + options.insert(std::pair(string(ge::ENABLE_SINGLE_STREAM), FLAGS_enable_single_stream)); + + SetDynamicInputSizeOptions(); + + if (!FLAGS_save_original_model.empty()) { + options.insert(std::pair(string(ge::SAVE_ORIGINAL_MODEL), FLAGS_save_original_model)); + options.insert(std::pair(string(ge::ORIGINAL_MODEL_FILE), FLAGS_output + "_original.om")); + } + + options.insert(std::pair(string(ge::OP_DEBUG_LEVEL), to_string(FLAGS_op_debug_level))); + // set enable scope fusion passes + SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes); + // print atc option map + ge::PrintOptionMap(options, "atc option"); + + // When the ATC module is transferred to a model, the suffix ".om" is automatically added to the model name + FLAGS_output = FLAGS_output + ".om"; + ret = GenerateModel(options, FLAGS_output); + if (ret != domi::SUCCESS) { + return domi::FAILED; + } + + return domi::SUCCESS; +} + +domi::Status ConvertModelToJson() { + Status ret = GFlagUtils::CheckConverJsonParamFlags(); + GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check convert json params flags failed!"); + + ret = ConvertModelToJson(FLAGS_framework, FLAGS_om, FLAGS_json); + + GE_IF_BOOL_EXEC(ret != domi::SUCCESS, return domi::FAILED); + return domi::SUCCESS; +} + +bool CheckRet(domi::Status ret) { + if (ret != domi::SUCCESS) { + if (FLAGS_mode == ONLY_PRE_CHECK) { + GELOGW("ATC precheck failed."); + } else if (FLAGS_mode == GEN_OM_MODEL) { + GELOGW("ATC generate offline model failed."); + } else if (FLAGS_mode == MODEL_TO_JSON) { + GELOGW("ATC convert model to json file failed."); + } else if (FLAGS_mode == PBTXT_TO_JSON) { + GELOGW("ATC convert pbtxt to json file failed."); + } else { + return false; + } + return false; + } + + if (FLAGS_mode == ONLY_PRE_CHECK) { + GELOGI("ATC precheck success."); + } else if (FLAGS_mode == GEN_OM_MODEL) { + GELOGI("ATC generate offline model success."); + } else if (FLAGS_mode == MODEL_TO_JSON) { + GELOGI("ATC convert model to json file success."); + } else if (FLAGS_mode == PBTXT_TO_JSON) { + GELOGI("ATC convert pbtxt to json file success."); + } + return true; +} + +domi::Status ConvertPbtxtToJson() { + Status ret = GFlagUtils::CheckConverJsonParamFlags(); + if (ret != domi::SUCCESS) { + GELOGE(ge::FAILED, "Check convert json params flags failed!"); + return domi::FAILED; + } + + ret = ge::ConvertPbtxtToJson(FLAGS_om.c_str(), FLAGS_json.c_str()); + if (ret != domi::SUCCESS) { + GELOGE(ge::FAILED, "ConvertPbtxtToJson fail."); + return domi::FAILED; + } + + return domi::SUCCESS; +} + +int init(int argc, char* argv[]) { + GFlagUtils::InitGFlag(argc, argv); + // set log level + int ret = -1; + const std::set log_level = {"null", "debug", "info", "warning", "error"}; + if (log_level.count(FLAGS_log) == 0) { + std::cout << "E10010: invalid value for --log:" << FLAGS_log + <<", only support debug, info, warning, error, null"<< std::endl; + return ret; + } + + ret = ge::CheckLogParamValidAndSetLogLevel(FLAGS_log); + if (ret != 0) { + return ret; + } + + std::string path_base = ge::GELib::GetPath(); + ret = ErrorManager::GetInstance().Init(path_base); + if (ret != 0) { + DOMI_LOGE("ErrorManager init fail !"); + return ret; + } + + return 0; +} + +long GetMemInfo(const std::string &key) { + std::string file_path = "/proc/meminfo"; + std::ifstream fs(file_path, std::ifstream::in); + if (!fs.is_open()) { + GELOGW("Can not open %s .", file_path.c_str()); + return 0; + } + std::string line; + while (getline(fs, line)) { // line not with \n + if (line.find(key) != std::string::npos) { + GELOGI("Find mem [%s] info line [%s]", key.c_str(), line.c_str()); + fs.close(); + size_t pos = line.find(":"); + if (pos == std::string::npos) { + return 0; + } + std::string current_mem_info_str = line.substr(pos + 1); + ge::StringUtils::Trim(current_mem_info_str); + GELOGI("Find mem [%s] info [%s].", key.c_str(), current_mem_info_str.c_str()); + return stol(current_mem_info_str); + } + } + fs.close(); // close the file + return 0; +} + +bool CheckMemInfo() { + if (FLAGS_auto_tune_mode.empty()) { + return true; + } + // only check current available mem when auto_tune_mode is set. + long current_mem_available = GetMemInfo("MemAvailable"); + GELOGI("Get mem available [%lu].", current_mem_available); + std::cout << "Current available mem is " << current_mem_available << "kB." << std::endl; + if ((current_mem_available > 0) && (current_mem_available < kMinAvailableMem)) { + GELOGE(ge::PARAM_INVALID, "Current available mem [%lu] can not be smaller than [%lu] .", + current_mem_available, kMinAvailableMem); + ErrorManager::GetInstance().ATCReportErrMessage("E10044", {"value", "min_value"}, + {to_string(current_mem_available), to_string(kMinAvailableMem)}); + return false; + } + return true; +} + +int main(int argc, char* argv[]) { + Status ret = domi::SUCCESS; + std::cout << "ATC start working now, please wait for a moment." << std::endl; + + // Initialize + if (init(argc, argv) != 0) { + std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl; + return -1; + } + do { + if (!CheckMemInfo()) { + GELOGE(ge::PARAM_INVALID, "Current available mem is too small"); + ret = domi::FAILED; + break; + } + if (!FLAGS_singleop.empty()) { + ret = GenerateSingleOp(FLAGS_singleop); + break; + } + + // default mode(mode:0), Open source model to model + if (GEN_OM_MODEL == FLAGS_mode || ONLY_PRE_CHECK == FLAGS_mode) { + GE_IF_BOOL_EXEC(GenerateOmModel() != domi::SUCCESS, ret = domi::FAILED; break); + } else if (MODEL_TO_JSON == FLAGS_mode) { // Mode 1, transfer model to JSON + GE_CHK_BOOL_EXEC(ConvertModelToJson() == domi::SUCCESS, ret = domi::FAILED; + break, "ATC ConvertJson execute failed!!"); + } else if (FLAGS_mode == ge::RunMode::PBTXT_TO_JSON) { + GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED; + break, "ATC convert pbtxt to json execute failed!!"); + } else { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, {"--mode", std::to_string(FLAGS_mode), kModeSupport}); + GELOGE(ge::PARAM_INVALID, "Invalid value for --mode[%d], %s.", FLAGS_mode, kModeSupport); + ret = domi::FAILED; + break; + } + } while (0); + + if (!CheckRet(ret)) { + std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl; + int result = ErrorManager::GetInstance().OutputErrMessage(STDOUT_FILENO); + if (result != 0) { + DOMI_LOGE("ErrorManager outputErrMessage fail !"); + } + GELOGI("Current mem available mem is [%lu]", GetMemInfo("MemAvailable")); + return ret; + } else { + std::cout << "ATC run success, welcome to the next use." << std::endl; + (void)ErrorManager::GetInstance().OutputMessage(STDOUT_FILENO); + return 0; + } +} /*lint +e530*/ diff --git a/ge/offline/module.mk b/ge/offline/module.mk new file mode 100644 index 00000000..42b217db --- /dev/null +++ b/ge/offline/module.mk @@ -0,0 +1,52 @@ + +LOCAL_PATH := $(call my-dir) + +include $(CLEAR_VARS) + +LOCAL_MODULE := atc + +LOCAL_CFLAGS += -Werror +LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 + +LOCAL_SRC_FILES := \ + main.cc \ + single_op_parser.cc \ + ../session/omg.cc \ + ../ir_build/atc_ir_common.cc \ + +LOCAL_C_INCLUDES := \ + $(LOCAL_PATH)/../ ./ \ + $(TOPDIR)inc \ + $(TOPDIR)inc/external \ + $(TOPDIR)inc/external/graph \ + $(TOPDIR)inc/framework \ + $(TOPDIR)inc/framework/domi \ + $(TOPDIR)libc_sec/include \ + $(TOPDIR)inc/common/util \ + third_party/json/include \ + third_party/gflags/include \ + third_party/protobuf/include \ + proto/om.proto \ + proto/ge_ir.proto \ + proto/task.proto \ + proto/insert_op.proto \ + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libge_common \ + libprotobuf \ + libslog \ + libgraph \ + libregister \ + liberror_manager \ + libge_compiler \ + libruntime_compile \ + libparser_common \ + liberror_manager \ + +LOCAL_STATIC_LIBRARIES := libgflags + +LOCAL_LDFLAGS := -lrt -ldl + +include $(BUILD_HOST_EXECUTABLE) + diff --git a/ge/offline/proto/ge_ir.proto b/ge/offline/proto/ge_ir.proto new file mode 100644 index 00000000..f60a0f89 --- /dev/null +++ b/ge/offline/proto/ge_ir.proto @@ -0,0 +1 @@ +../../../../inc/common/proto/ge_ir.proto \ No newline at end of file diff --git a/ge/offline/proto/insert_op.proto b/ge/offline/proto/insert_op.proto new file mode 100644 index 00000000..27b233e5 --- /dev/null +++ b/ge/offline/proto/insert_op.proto @@ -0,0 +1 @@ +../../../../inc/common/proto/insert_op.proto \ No newline at end of file diff --git a/ge/offline/proto/om.proto b/ge/offline/proto/om.proto new file mode 100644 index 00000000..91c581bb --- /dev/null +++ b/ge/offline/proto/om.proto @@ -0,0 +1 @@ +../../../../inc/common/proto/om.proto \ No newline at end of file diff --git a/ge/offline/proto/task.proto b/ge/offline/proto/task.proto new file mode 100644 index 00000000..36ae4847 --- /dev/null +++ b/ge/offline/proto/task.proto @@ -0,0 +1 @@ +../../proto/task.proto \ No newline at end of file diff --git a/ge/offline/single_op_parser.cc b/ge/offline/single_op_parser.cc new file mode 100644 index 00000000..34ac7d5f --- /dev/null +++ b/ge/offline/single_op_parser.cc @@ -0,0 +1,503 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "single_op_parser.h" + +#include +#include +#include +#include + +#include + +#include "framework/common/debug/ge_log.h" +#include "common/util/error_manager/error_manager.h" +#include "common/ge_inner_error_codes.h" +#include "framework/common/util.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/operator_factory_impl.h" + +using Json = nlohmann::json; +using std::string; +using std::vector; +using std::map; + +namespace ge { +namespace { +constexpr char const *kKeyOp = "op"; +constexpr char const *kKeyInputDesc = "input_desc"; +constexpr char const *kKeyOutputDesc = "output_desc"; +constexpr char const *kKeyAttr = "attr"; +constexpr char const *kKeyName = "name"; +constexpr char const *kKeyType = "type"; +constexpr char const *kKeyShape = "shape"; +constexpr char const *kKeyShapeRange = "shape_range"; +constexpr char const *kKeyValue = "value"; +constexpr char const *kKeyFormat = "format"; +constexpr char const *kFileSuffix = ".om"; +constexpr int kDumpJsonIndent = 2; +constexpr int kShapeRangePairSize = 2; +constexpr int kShapeRangeLow = 0; +constexpr int kShapeRangeHigh = 1; + +map kAttrTypeDict = { + {"bool", GeAttrValue::VT_BOOL}, + {"int", GeAttrValue::VT_INT}, + {"float", GeAttrValue::VT_FLOAT}, + {"string", GeAttrValue::VT_STRING}, + {"list_bool", GeAttrValue::VT_LIST_BOOL}, + {"list_int", GeAttrValue::VT_LIST_INT}, + {"list_float", GeAttrValue::VT_LIST_FLOAT}, + {"list_string", GeAttrValue::VT_LIST_STRING}, + {"list_list_int", GeAttrValue::VT_LIST_LIST_INT}, + {"data_type", GeAttrValue::VT_DATA_TYPE}, +}; + +map kDataTypeDict = { + {"bool", DT_BOOL}, + {"int8", DT_INT8}, + {"uint8", DT_UINT8}, + {"int16", DT_INT16}, + {"uint16", DT_UINT16}, + {"int32", DT_INT32}, + {"uint32", DT_UINT32}, + {"int64", DT_INT64}, + {"uint64", DT_UINT64}, + {"float16", DT_FLOAT16}, + {"half", DT_FLOAT16}, + {"fp16", DT_FLOAT16}, + {"float", DT_FLOAT}, + {"float32", DT_FLOAT}, + {"double", DT_DOUBLE}, +}; + +map kFormatDict = { + {"nchw", FORMAT_NCHW}, + {"nhwc", FORMAT_NHWC}, + {"nd", FORMAT_ND}, + {"fractal_nz", FORMAT_FRACTAL_NZ}, + {"fractal_z", FORMAT_FRACTAL_Z}, + {"nc1hwc0", FORMAT_NC1HWC0}, +}; +} + +template +void SetAttrValue(const Json &j, SingleOpAttr &attr) { + attr.value.SetValue(j.at(kKeyValue).get()); +} + +template +T GetValue(const map &dict, string &key, T default_val) { + transform(key.begin(), key.end(), key.begin(), ::tolower); + auto it = dict.find(key); + if (it == dict.end()) { + return default_val; + } + + return it->second; +} + +void from_json(const Json &j, SingleOpTensorDesc &desc) { + desc.dims = j.at(kKeyShape).get>(); + auto it = j.find(kKeyShapeRange); + if (it != j.end()) { + desc.dim_ranges = j.at(kKeyShapeRange).get>>(); + } + string format_str = j.at(kKeyFormat).get(); + string type_str = j.at(kKeyType).get(); + desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); + desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED); + auto tensor_name = j.find(kKeyName); + if (tensor_name != j.end()) { + desc.name = tensor_name->get(); + } +} + +void from_json(const Json &j, SingleOpAttr &attr) { + attr.name = j.at(kKeyName).get(); + attr.type = j.at(kKeyType).get(); + auto it = kAttrTypeDict.find(attr.type); + if (it == kAttrTypeDict.end()) { + GELOGE(UNSUPPORTED, "Parse attr[%s] failed. Unsupported type: %s", attr.name.c_str(), attr.type.c_str()); + return; + } + + switch (it->second) { + case GeAttrValue::VT_BOOL: + SetAttrValue(j, attr); + break; + case GeAttrValue::VT_INT: + SetAttrValue(j, attr); + break; + case GeAttrValue::VT_FLOAT: + SetAttrValue(j, attr); + break; + case GeAttrValue::VT_STRING: + SetAttrValue(j, attr); + break; + case GeAttrValue::VT_LIST_BOOL: + SetAttrValue>(j, attr); + break; + case GeAttrValue::VT_LIST_INT: + SetAttrValue>(j, attr); + break; + case GeAttrValue::VT_LIST_FLOAT: + SetAttrValue>(j, attr); + break; + case GeAttrValue::VT_LIST_STRING: + SetAttrValue>(j, attr); + break; + case GeAttrValue::VT_LIST_LIST_INT: + SetAttrValue>>(j, attr); + break; + case GeAttrValue::VT_DATA_TYPE: + SetAttrValue(j, attr); + break; + default: + GELOGE(UNSUPPORTED, "Parse attr[%s] failed. Unsupported type: %s", attr.name.c_str(), attr.type.c_str()); + break; + } +} + +void from_json(const Json &j, SingleOpDesc &desc) { + desc.op = j.at(kKeyOp).get(); + + auto input_desc = j.find(kKeyInputDesc); + if (input_desc != j.end()) { + desc.input_desc = input_desc->get>(); + } + + auto output_desc = j.find(kKeyOutputDesc); + if (output_desc != j.end()) { + desc.output_desc = output_desc->get>(); + } + + auto attr_field = j.find(kKeyAttr); + if (attr_field != j.end()) { + desc.attrs = attr_field->get>(); + } +} + +Status SingleOpParser::ReadJsonFile(const std::string &file, Json &json_obj) { + std::string real_path = RealPath(file.c_str()); + if (real_path.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10023", {"value"}, {file}); + GELOGE(FAILED, "Input parameter[--singleop]'s value[%s] is not a valid path.", file.c_str()); + return INTERNAL_ERROR; + } + + std::ifstream ifs(real_path); + if (!ifs.is_open()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10024", {"value"}, {file}); + GELOGE(FAILED, "Open file[%s] provided in input parameter[--singleop] failed.", file.c_str()); + return FAILED; + } + try { + ifs >> json_obj; + } catch (const std::exception &e) { + ErrorManager::GetInstance().ATCReportErrMessage("E10025", {"realpath", "errmsg"}, {real_path, e.what()}); + GELOGE(PARAM_INVALID, "Parse file[%s] provided in input parameter[--singleop] failed, exception = %s.", + real_path.c_str(), e.what()); + return PARAM_INVALID; + } + + ifs.close(); + return SUCCESS; +} + +bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { + if (op_desc.op.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10026"); + GELOGE(PARAM_INVALID, "Op name is empty"); + return false; + } + + int index = 0; + for (auto &tensor_desc : op_desc.input_desc) { + if (tensor_desc.type == DT_UNDEFINED) { + ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"input", std::to_string(index)}); + GELOGE(false, "Input's dataType is invalid when the index is %d", index); + return false; + } + + if (tensor_desc.format == FORMAT_RESERVED) { + ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"input", std::to_string(index)}); + GELOGE(PARAM_INVALID, "Input's format is invalid when the index is %d", index); + return false; + } + ++index; + } + + index = 0; + for (auto &tensor_desc : op_desc.output_desc) { + if (tensor_desc.type == DT_UNDEFINED) { + ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"output", std::to_string(index)}); + GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index); + return false; + } + + if (tensor_desc.format == FORMAT_RESERVED) { + ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"output", std::to_string(index)}); + GELOGE(PARAM_INVALID, "Output's format is invalid when the index is %d", index); + return false; + } + ++index; + } + + for (auto &attr : op_desc.attrs) { + if (attr.name.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10029"); + GELOGE(PARAM_INVALID, "attr name is empty"); + return false; + } + + if (attr.value.IsEmpty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10030", {"attrname"}, {attr.name}); + GELOGE(PARAM_INVALID, "Parse attr \"%s\" failed. ", attr.name.c_str()); + return false; + } + } + + return true; +} + +std::unique_ptr SingleOpParser::CreateOpDesc(const string &op_type) { + return std::unique_ptr(new(std::nothrow) OpDesc(op_type, op_type)); +} + +Status SingleOpParser::ConvertToBuildParam(int index, + const SingleOpDesc &single_op_desc, + SingleOpBuildParam &build_param) { + auto op_desc = CreateOpDesc(single_op_desc.op); + if (op_desc == nullptr) { + GELOGE(MEMALLOC_FAILED, "Failed to create instance of opDesc"); + return MEMALLOC_FAILED; + } + + std::stringstream file_name; + file_name << index; + file_name << "_" << single_op_desc.op; + for (auto &desc : single_op_desc.input_desc) { + file_name << "_" << desc.type << "_" << desc.format; + for (auto dim : desc.dims) { + file_name << "_" << dim; + } + GeTensorDesc ge_tensor_desc(GeShape(desc.dims), + desc.format, + desc.type); + ge_tensor_desc.SetOriginFormat(desc.format); + GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc)); + TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size()); + TensorUtils::SetInputTensor(ge_tensor_desc, true); + TensorUtils::SetOutputTensor(ge_tensor_desc, false); + if (desc.name.empty()) { + op_desc->AddInputDesc(ge_tensor_desc); + } else { + op_desc->AddInputDesc(desc.name, ge_tensor_desc); + } + build_param.inputs.emplace_back(ge_tensor_desc); + } + + for (auto &desc : single_op_desc.output_desc) { + file_name << "_" << desc.type << "_" << desc.format; + for (auto dim : desc.dims) { + file_name << "_" << dim; + } + + GeTensorDesc ge_tensor_desc(GeShape(desc.dims), + desc.format, + desc.type); + ge_tensor_desc.SetOriginFormat(desc.format); + GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc)); + TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size()); + TensorUtils::SetInputTensor(ge_tensor_desc, false); + TensorUtils::SetOutputTensor(ge_tensor_desc, true); + if (desc.name.empty()) { + op_desc->AddOutputDesc(ge_tensor_desc); + } else { + op_desc->AddOutputDesc(desc.name, ge_tensor_desc); + } + build_param.outputs.emplace_back(ge_tensor_desc); + } + + for (const auto &attr : single_op_desc.attrs) { + op_desc->SetAttr(attr.name, attr.value); + } + + if (VerifyOpInputOutputSizeByIr(*op_desc) != SUCCESS) { + GELOGE(PARAM_INVALID, "Verify op [%s] input or output size failed.", op_desc->GetType().c_str()); + return PARAM_INVALID; + } + + file_name << kFileSuffix; + build_param.file_name = file_name.str(); + build_param.op_desc.reset(op_desc.release()); + return SUCCESS; +} + +Status SingleOpParser::VerifyOpInputOutputSizeByIr(const OpDesc ¤t_op_desc) { + ge::Operator operator_ir = ge::OperatorFactory::CreateOperator("tmp_operator", current_op_desc.GetType()); + if (!operator_ir.IsEmpty()) { + auto opdesc_ir = ge::OpDescUtils::GetOpDescFromOperator(operator_ir); + GE_CHECK_NOTNULL(opdesc_ir); + size_t current_opdesc_inputs_num = current_op_desc.GetInputsSize(); + size_t ir_opdesc_inputs_num = opdesc_ir->GetInputsSize(); + if (current_opdesc_inputs_num < ir_opdesc_inputs_num) { + string reason = "is smaller than the ir needed input size " + std::to_string(ir_opdesc_inputs_num); + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {current_op_desc.GetName(), "input size " + std::to_string(current_opdesc_inputs_num), reason}); + GELOGE(PARAM_INVALID, "This op [%s] input size %zu is smaller than the ir needed input size %zu", + current_op_desc.GetName().c_str(), current_opdesc_inputs_num, ir_opdesc_inputs_num); + return PARAM_INVALID; + } + size_t current_opdesc_outputs_num = current_op_desc.GetOutputsSize(); + size_t ir_opdesc_outputs_num = opdesc_ir->GetOutputsSize(); + if (current_opdesc_outputs_num < ir_opdesc_outputs_num) { + string reason = "is smaller than the ir needed output size " + std::to_string(ir_opdesc_outputs_num); + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {current_op_desc.GetName(), "output size " + std::to_string(current_opdesc_outputs_num), reason}); + GELOGE(PARAM_INVALID, "This op [%s] output size %zu is smaller than the ir needed output size %zu", + current_op_desc.GetName().c_str(), current_opdesc_outputs_num, ir_opdesc_outputs_num); + return PARAM_INVALID; + } + } + return SUCCESS; +} + +Status SingleOpParser::SetShapeRange(const std::string &op_name, + const SingleOpTensorDesc &tensor_desc, + GeTensorDesc &ge_tensor_desc) { + auto num_shape_ranges = tensor_desc.dim_ranges.size(); + GELOGD("Number of shape ranges = %zu", num_shape_ranges); + auto it = std::find(tensor_desc.dims.begin(), tensor_desc.dims.end(), ge::UNKNOWN_DIM_NUM); + if (it != tensor_desc.dims.end()) { + if (tensor_desc.dims != ge::UNKNOWN_RANK) { + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {op_name, + "shape", + "has unknown rank but dim size is not one"}); + GELOGE(PARAM_INVALID, "Invalid tensor shape: [%s]", ge_tensor_desc.MutableShape().ToString().c_str()); + return PARAM_INVALID; + } + if (!tensor_desc.dim_ranges.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {op_name, + "shape range", + "is not needed while the rank the shape is unknown"}); + GELOGE(PARAM_INVALID, "Shape range is not needed while the rank the shape is unknown"); + return PARAM_INVALID; + } + + GELOGD("Shape is unknown rank, do not set shape range"); + return SUCCESS; + } + + std::vector> shape_range; + size_t range_index = 0; + for (auto dim : tensor_desc.dims) { + if (dim >= 0) { + shape_range.emplace_back(dim, dim); + GELOGD("Adding shape range: [%ld, %ld]", dim, dim); + } else { + GELOGD("To get shape range by index = %zu", range_index); + if (range_index >= num_shape_ranges) { + string reason = "is smaller than the unknown dim size " + std::to_string(++range_index); + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {op_name, + "shape range size " + std::to_string(num_shape_ranges), + reason}); + GELOGE(PARAM_INVALID, "The number of shape_range mismatches that of unknown dims."); + return PARAM_INVALID; + } + + auto &range = tensor_desc.dim_ranges[range_index]; + if (range.size() != kShapeRangePairSize) { + string reason = "has " + std::to_string(range.size()) + " item(s)"; + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {op_name, + "shape range " + std::to_string(range_index), + reason}); + GELOGE(PARAM_INVALID, "Invalid shape range entry. index = %zu, size = %zu", range_index, range.size()); + return PARAM_INVALID; + } + + shape_range.emplace_back(range[kShapeRangeLow], range[kShapeRangeHigh]); + GELOGD("Adding shape range: [%ld, %ld]", range[kShapeRangeLow], range[kShapeRangeHigh]); + ++range_index; + } + } + + if (num_shape_ranges != range_index) { + string reason = "is greater than the unknown dim size " + std::to_string(range_index); + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {op_name, + "shape range size " + std::to_string(num_shape_ranges), + reason}); + GELOGE(PARAM_INVALID, + "The number of shape_range(%zu) mismatches that of unknown dims(%zu).", + num_shape_ranges, + range_index); + return PARAM_INVALID; + } + + if (range_index > 0) { + ge_tensor_desc.SetShapeRange(shape_range); + } + + return SUCCESS; +} + +Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector &op_list) { + Json single_op_list_json; + auto ret = ReadJsonFile(file, single_op_list_json); + if (ret != SUCCESS) { + return ret; + } + + int index = 0; + for (const Json &single_op_json : single_op_list_json) { + SingleOpDesc single_op_desc; + try { + GELOGI("Parsing op[%d], jsonStr = %s", index, single_op_json.dump(kDumpJsonIndent).c_str()); + single_op_desc = single_op_json; + } catch (const nlohmann::json::exception &e) { + ErrorManager::GetInstance().ATCReportErrMessage("E10032", {"index", "jsonfile", "exception"}, + {std::to_string(index), file, e.what()}); + GELOGE(PARAM_INVALID, "Parse the index[%d] of op failed when read json file[%s], exception %s", + index, file.c_str(), e.what()); + return PARAM_INVALID; + } + + if (!Validate(single_op_desc)) { + GELOGE(PARAM_INVALID, "Validate the index[%d] of op failed when read json file[%s].", index, file.c_str()); + return PARAM_INVALID; + } + + SingleOpBuildParam param; + ret = ConvertToBuildParam(index, single_op_desc, param); + if (ret != SUCCESS) { + return ret; + } + + op_list.emplace_back(param); + GELOGI("Parse the index[%d] of op success", index); + index += 1; + } + + return SUCCESS; +} +} // namespace ge + diff --git a/ge/offline/single_op_parser.h b/ge/offline/single_op_parser.h new file mode 100644 index 00000000..9a1bd962 --- /dev/null +++ b/ge/offline/single_op_parser.h @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ACL_TOOLS_COMPILE_PARSER_H +#define ACL_TOOLS_COMPILE_PARSER_H + +#include +#include + +#include + +#include "ge/ge_api_error_codes.h" +#include "graph/types.h" +#include "graph/ge_attr_value.h" +#include "graph/op_desc.h" + +namespace ge { +struct SingleOpTensorDesc { + std::string name; + std::vector dims; + std::vector> dim_ranges; + ge::Format format = ge::FORMAT_RESERVED; + ge::DataType type = ge::DT_UNDEFINED; +}; + +struct SingleOpAttr { + std::string name; + std::string type; + ge::GeAttrValue value; +}; + +struct SingleOpDesc { + std::string op; + std::vector input_desc; + std::vector output_desc; + std::vector attrs; +}; + +struct SingleOpBuildParam { + ge::OpDescPtr op_desc; + std::vector inputs; + std::vector outputs; + std::string file_name; +}; + +void from_json(const nlohmann::json &json, SingleOpTensorDesc &desc); + +void from_json(const nlohmann::json &json, SingleOpAttr &desc); + +void from_json(const nlohmann::json &json, SingleOpDesc &desc); + +class SingleOpParser { + public: + static Status ParseSingleOpList(const std::string &file, std::vector &op_list); + + private: + static Status ReadJsonFile(const std::string &file, nlohmann::json &json_obj); + static bool Validate(const SingleOpDesc &op_desc); + static std::unique_ptr CreateOpDesc(const std::string &op_type); + static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param); + static Status VerifyOpInputOutputSizeByIr(const OpDesc ¤t_op_desc); + static Status SetShapeRange(const std::string &op_name, const SingleOpTensorDesc &tensor_desc, GeTensorDesc &ge_tensor_desc); +}; +} // namespace ge + +#endif // ACL_TOOLS_COMPILE_PARSER_H diff --git a/src/ge/omm/csa_interact.cc b/ge/omm/csa_interact.cc similarity index 100% rename from src/ge/omm/csa_interact.cc rename to ge/omm/csa_interact.cc diff --git a/src/ge/omm/csa_interact.h b/ge/omm/csa_interact.h similarity index 85% rename from src/ge/omm/csa_interact.h rename to ge/omm/csa_interact.h index b135d8e6..0a609e09 100644 --- a/src/ge/omm/csa_interact.h +++ b/ge/omm/csa_interact.h @@ -56,7 +56,10 @@ enum ErrorModule { }; struct CsaErrorCode { - CsaErrorCode() : module_ret_errcode(0), error_module(ERROR_MODULE_FMK), job_sub_state(JOBSUBSTATE_OTHER) {} + CsaErrorCode() + : module_ret_errcode(0), + error_module(ERROR_MODULE_FMK), + job_sub_state(JOBSUBSTATE_OTHER) {} ~CsaErrorCode() {} uint32_t module_ret_errcode; ErrorModule error_module; @@ -86,8 +89,10 @@ class CsaInteract { /// @param [in] error_module error module identified by FMK /// @return Status /// - Status WriteJobState(JobState job_state, JobSubState job_sub_state = JOBSUBSTATE_OTHER, - uint32_t module_ret_errcode = SUCCESS, ErrorModule error_module = ERROR_MODULE_FMK); + Status WriteJobState(JobState job_state, + JobSubState job_sub_state = JOBSUBSTATE_OTHER, + uint32_t module_ret_errcode = SUCCESS, + ErrorModule error_module = ERROR_MODULE_FMK); /// /// @brief Update error code in the job state file @@ -96,7 +101,8 @@ class CsaInteract { /// @param [in] job_sub_state detailed job state /// @return void /// - void WriteErrorCode(uint32_t module_ret_errcode, ErrorModule error_module, JobSubState job_sub_state); + void WriteErrorCode(uint32_t module_ret_errcode, ErrorModule error_module, + JobSubState job_sub_state); /// /// @brief Record errors that occurred durning the training @@ -105,7 +111,9 @@ class CsaInteract { /// @param [in] job_sub_state detailed job state /// @return void /// - void StoreInternalErrorCode(uint32_t module_ret_errcode, ErrorModule error_module, JobSubState job_sub_state); + void StoreInternalErrorCode(uint32_t module_ret_errcode, + ErrorModule error_module, + JobSubState job_sub_state); /// /// @brief Update training error code in the job state file @@ -122,7 +130,11 @@ class CsaInteract { private: CsaInteract() - : dev_index_(0), job_id_(0), is_init_(false), curr_state_(JOBSTATE_UNKOWN), is_have_internal_error_(false) {} + : dev_index_(0), + job_id_(0), + is_init_(false), + curr_state_(JOBSTATE_UNKOWN), + is_have_internal_error_(false) {} ~CsaInteract() {} @@ -168,3 +180,4 @@ class CsaInteract { } // namespace ge #endif // GE_OMM_CSA_INTERACT_H_ + diff --git a/src/ge/opskernel_manager/ops_kernel_manager.cc b/ge/opskernel_manager/ops_kernel_manager.cc similarity index 100% rename from src/ge/opskernel_manager/ops_kernel_manager.cc rename to ge/opskernel_manager/ops_kernel_manager.cc diff --git a/src/ge/opskernel_manager/ops_kernel_manager.h b/ge/opskernel_manager/ops_kernel_manager.h similarity index 100% rename from src/ge/opskernel_manager/ops_kernel_manager.h rename to ge/opskernel_manager/ops_kernel_manager.h diff --git a/src/ge/opskernel_manager/optimizer_priority.pbtxt b/ge/opskernel_manager/optimizer_priority.pbtxt old mode 100755 new mode 100644 similarity index 100% rename from src/ge/opskernel_manager/optimizer_priority.pbtxt rename to ge/opskernel_manager/optimizer_priority.pbtxt diff --git a/ge/plugin/engine/CMakeLists.txt b/ge/plugin/engine/CMakeLists.txt new file mode 100644 index 00000000..87a6d682 --- /dev/null +++ b/ge/plugin/engine/CMakeLists.txt @@ -0,0 +1,49 @@ +set(SRC_LIST + "dnnengines.cc" + "engine_manage.cc" +) + +############ libengine.so ############ +add_library(engine SHARED ${SRC_LIST}) + +target_compile_options(engine PRIVATE + -Werror +) + +target_compile_definitions(engine PRIVATE + REUSE_MEMORY=1 + PROTOBUF_INLINE_NOT_IN_HEADERS=0 +) + +target_include_directories(engine PRIVATE + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc/ + ${GE_CODE_DIR}/inc/framework + ${GE_CODE_DIR}/inc/framework/common + ${GE_CODE_DIR}/inc/external + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc +) + +target_link_libraries(engine PRIVATE + $ + -Wl,--no-as-needed + slog + -Wl,--as-needed + -lrt + -ldl +) + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS engine OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} +) diff --git a/src/ge/plugin/engine/dnnengines.cc b/ge/plugin/engine/dnnengines.cc similarity index 100% rename from src/ge/plugin/engine/dnnengines.cc rename to ge/plugin/engine/dnnengines.cc diff --git a/src/ge/plugin/engine/dnnengines.h b/ge/plugin/engine/dnnengines.h similarity index 100% rename from src/ge/plugin/engine/dnnengines.h rename to ge/plugin/engine/dnnengines.h diff --git a/src/ge/plugin/engine/engine_manage.cc b/ge/plugin/engine/engine_manage.cc similarity index 100% rename from src/ge/plugin/engine/engine_manage.cc rename to ge/plugin/engine/engine_manage.cc diff --git a/src/ge/plugin/engine/engine_manage.h b/ge/plugin/engine/engine_manage.h similarity index 100% rename from src/ge/plugin/engine/engine_manage.h rename to ge/plugin/engine/engine_manage.h diff --git a/src/ge/plugin/engine/module.mk b/ge/plugin/engine/module.mk similarity index 100% rename from src/ge/plugin/engine/module.mk rename to ge/plugin/engine/module.mk diff --git a/ge/proto/caffe/caffe.proto b/ge/proto/caffe/caffe.proto new file mode 100644 index 00000000..3f45aae2 --- /dev/null +++ b/ge/proto/caffe/caffe.proto @@ -0,0 +1,1821 @@ +syntax = "proto2"; + +package domi.caffe; + +// Specifies the shape (dimensions) of a Blob. +message BlobShape { + repeated int64 dim = 1 [packed = true]; +} + +message BlobProto { + optional BlobShape shape = 7; + repeated float data = 5 [packed = true]; + repeated float diff = 6 [packed = true]; + repeated double double_data = 8 [packed = true]; + repeated double double_diff = 9 [packed = true]; + optional bytes int8_data = 10; + repeated int32 int32_data = 11 [packed = true]; + repeated uint64 uint64_data = 12 [packed = true]; + // 4D dimensions -- deprecated. Use "shape" instead. + optional int32 num = 1 [default = 0]; + optional int32 channels = 2 [default = 0]; + optional int32 height = 3 [default = 0]; + optional int32 width = 4 [default = 0]; +} + +// The BlobProtoVector is simply a way to pass multiple blobproto instances +// around. +message BlobProtoVector { + repeated BlobProto blobs = 1; +} + +message Datum { + optional int32 channels = 1; + optional int32 height = 2; + optional int32 width = 3; + // the actual image data, in bytes + optional bytes data = 4; + optional int32 label = 5; + // Optionally, the datum could also hold float data. + repeated float float_data = 6; + // If true data contains an encoded image that need to be decoded + optional bool encoded = 7 [default = false]; +} + +message FillerParameter { + // The filler type. + optional string type = 1 [default = 'constant']; + optional float value = 2 [default = 0]; // the value in constant filler + optional float min = 3 [default = 0]; // the min value in uniform filler + optional float max = 4 [default = 1]; // the max value in uniform filler + optional float mean = 5 [default = 0]; // the mean value in Gaussian filler + optional float std = 6 [default = 1]; // the std value in Gaussian filler + // The expected number of non-zero output weights for a given input in + // Gaussian filler -- the default -1 means don't perform sparsification. + optional int32 sparse = 7 [default = -1]; + // Normalize the filler variance by fan_in, fan_out, or their average. + // Applies to 'xavier' and 'msra' fillers. + enum VarianceNorm { + FAN_IN = 0; + FAN_OUT = 1; + AVERAGE = 2; + } + optional VarianceNorm variance_norm = 8 [default = FAN_IN]; +} + +message NetParameter { + optional string name = 1; // consider giving the network a name + // DEPRECATED. See InputParameter. The input blobs to the network. + repeated string input = 3; + // DEPRECATED. See InputParameter. The shape of the input blobs. + repeated BlobShape input_shape = 8; + + // 4D input dimensions -- deprecated. Use "input_shape" instead. + // If specified, for each input blob there should be four + // values specifying the num, channels, height and width of the input blob. + // Thus, there should be a total of (4 * #input) numbers. + repeated int32 input_dim = 4; + + // Whether the network will force every layer to carry out backward operation. + // If set False, then whether to carry out backward is determined + // automatically according to the net structure and learning rates. + optional bool force_backward = 5 [default = false]; + // The current "state" of the network, including the phase, level, and stage. + // Some layers may be included/excluded depending on this state and the states + // specified in the layers' include and exclude fields. + optional NetState state = 6; + + // Print debugging information about results while running Net::Forward, + // Net::Backward, and Net::Update. + optional bool debug_info = 7 [default = false]; + + // The layers that make up the net. Each of their configurations, including + // connectivity and behavior, is specified as a LayerParameter. + repeated LayerParameter layer = 100; // ID 100 so layers are printed last. + + // DEPRECATED: use 'layer' instead. + repeated V1LayerParameter layers = 2; +} + +// NOTE +// Update the next available ID when you add a new SolverParameter field. +// +// SolverParameter next available ID: 42 (last added: layer_wise_reduce) +message SolverParameter { + ////////////////////////////////////////////////////////////////////////////// + // Specifying the train and test networks + // + // Exactly one train net must be specified using one of the following fields: + // train_net_param, train_net, net_param, net + // One or more test nets may be specified using any of the following fields: + // test_net_param, test_net, net_param, net + // If more than one test net field is specified (e.g., both net and + // test_net are specified), they will be evaluated in the field order given + // above: (1) test_net_param, (2) test_net, (3) net_param/net. + // A test_iter must be specified for each test_net. + // A test_level and/or a test_stage may also be specified for each test_net. + ////////////////////////////////////////////////////////////////////////////// + + // Proto filename for the train net, possibly combined with one or more + // test nets. + optional string net = 24; + // Inline train net param, possibly combined with one or more test nets. + optional NetParameter net_param = 25; + + optional string train_net = 1; // Proto filename for the train net. + repeated string test_net = 2; // Proto filenames for the test nets. + optional NetParameter train_net_param = 21; // Inline train net params. + repeated NetParameter test_net_param = 22; // Inline test net params. + + // The states for the train/test nets. Must be unspecified or + // specified once per net. + // + // By default, all states will have solver = true; + // train_state will have phase = TRAIN, + // and all test_state's will have phase = TEST. + // Other defaults are set according to the NetState defaults. + optional NetState train_state = 26; + repeated NetState test_state = 27; + + // The number of iterations for each test net. + repeated int32 test_iter = 3; + + // The number of iterations between two testing phases. + optional int32 test_interval = 4 [default = 0]; + optional bool test_compute_loss = 19 [default = false]; + // If true, run an initial test pass before the first iteration, + // ensuring memory availability and printing the starting value of the loss. + optional bool test_initialization = 32 [default = true]; + optional float base_lr = 5; // The base learning rate + // the number of iterations between displaying info. If display = 0, no info + // will be displayed. + optional int32 display = 6; + // Display the loss averaged over the last average_loss iterations + optional int32 average_loss = 33 [default = 1]; + optional int32 max_iter = 7; // the maximum number of iterations + // accumulate gradients over `iter_size` x `batch_size` instances + optional int32 iter_size = 36 [default = 1]; + + // The learning rate decay policy. The currently implemented learning rate + // policies are as follows: + // - fixed: always return base_lr. + // - step: return base_lr * gamma ^ (floor(iter / step)) + // - exp: return base_lr * gamma ^ iter + // - inv: return base_lr * (1 + gamma * iter) ^ (- power) + // - multistep: similar to step but it allows non uniform steps defined by + // stepvalue + // - poly: the effective learning rate follows a polynomial decay, to be + // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) + // - sigmoid: the effective learning rate follows a sigmod decay + // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) + // + // where base_lr, max_iter, gamma, step, stepvalue and power are defined + // in the solver parameter protocol buffer, and iter is the current iteration. + optional string lr_policy = 8; + optional float gamma = 9; // The parameter to compute the learning rate. + optional float power = 10; // The parameter to compute the learning rate. + optional float momentum = 11; // The momentum value. + optional float weight_decay = 12; // The weight decay. + // regularization types supported: L1 and L2 + // controlled by weight_decay + optional string regularization_type = 29 [default = "L2"]; + // the stepsize for learning rate policy "step" + optional int32 stepsize = 13; + // the stepsize for learning rate policy "multistep" + repeated int32 stepvalue = 34; + + // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, + // whenever their actual L2 norm is larger. + optional float clip_gradients = 35 [default = -1]; + + optional int32 snapshot = 14 [default = 0]; // The snapshot interval + optional string snapshot_prefix = 15; // The prefix for the snapshot. + // whether to snapshot diff in the results or not. Snapshotting diff will help + // debugging but the final protocol buffer size will be much larger. + optional bool snapshot_diff = 16 [default = false]; + enum SnapshotFormat { + HDF5 = 0; + BINARYPROTO = 1; + } + optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO]; + // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. + enum SolverMode { + CPU = 0; + GPU = 1; + } + optional SolverMode solver_mode = 17 [default = GPU]; + // the device_id will that be used in GPU mode. Use device_id = 0 in default. + optional int32 device_id = 18 [default = 0]; + // If non-negative, the seed with which the Solver will initialize the Caffe + // random number generator -- useful for reproducible results. Otherwise, + // (and by default) initialize using a seed derived from the system clock. + optional int64 random_seed = 20 [default = -1]; + + // type of the solver + optional string type = 40 [default = "SGD"]; + + // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam + optional float delta = 31 [default = 1e-8]; + // parameters for the Adam solver + optional float momentum2 = 39 [default = 0.999]; + + // RMSProp decay value + // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) + optional float rms_decay = 38 [default = 0.99]; + + // If true, print information about the state of the net that may help with + // debugging learning problems. + optional bool debug_info = 23 [default = false]; + + // If false, don't save a snapshot after training finishes. + optional bool snapshot_after_train = 28 [default = true]; + + // DEPRECATED: old solver enum types, use string instead + enum SolverType { + SGD = 0; + NESTEROV = 1; + ADAGRAD = 2; + RMSPROP = 3; + ADADELTA = 4; + ADAM = 5; + } + // DEPRECATED: use type instead of solver_type + optional SolverType solver_type = 30 [default = SGD]; + + // Overlap compute and communication for data parallel training + optional bool layer_wise_reduce = 41 [default = true]; +} + +// A message that stores the solver snapshots +message SolverState { + optional int32 iter = 1; // The current iteration + optional string learned_net = 2; // The file that stores the learned net. + repeated BlobProto history = 3; // The history for sgd solvers + optional int32 current_step = 4 [default = 0]; // The current step for learning rate +} + +enum Phase { + TRAIN = 0; + TEST = 1; +} + +message NetState { + optional Phase phase = 1 [default = TEST]; + optional int32 level = 2 [default = 0]; + repeated string stage = 3; +} + +message NetStateRule { + // Set phase to require the NetState have a particular phase (TRAIN or TEST) + // to meet this rule. + optional Phase phase = 1; + + // Set the minimum and/or maximum levels in which the layer should be used. + // Leave undefined to meet the rule regardless of level. + optional int32 min_level = 2; + optional int32 max_level = 3; + + // Customizable sets of stages to include or exclude. + // The net must have ALL of the specified stages and NONE of the specified + // "not_stage"s to meet the rule. + // (Use multiple NetStateRules to specify conjunctions of stages.) + repeated string stage = 4; + repeated string not_stage = 5; +} + +// Specifies training parameters (multipliers on global learning constants, +// and the name and other settings used for weight sharing). +message ParamSpec { + // The names of the parameter blobs -- useful for sharing parameters among + // layers, but never required otherwise. To share a parameter between two + // layers, give it a (non-empty) name. + optional string name = 1; + + // Whether to require shared weights to have the same shape, or just the same + // count -- defaults to STRICT if unspecified. + optional DimCheckMode share_mode = 2; + enum DimCheckMode { + // STRICT (default) requires that num, channels, height, width each match. + STRICT = 0; + // PERMISSIVE requires only the count (num*channels*height*width) to match. + PERMISSIVE = 1; + } + + // The multiplier on the global learning rate for this parameter. + optional float lr_mult = 3 [default = 1.0]; + + // The multiplier on the global weight decay for this parameter. + optional float decay_mult = 4 [default = 1.0]; +} + +// NOTE +// Update the next available ID when you add a new LayerParameter field. +// +// LayerParameter next available layer-specific ID: 151 (last added: smooth_l1_loss_param) +message LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the layer type + repeated string bottom = 3; // the name of each bottom blob + repeated string top = 4; // the name of each top blob + + // The train / test phase for computation. + optional Phase phase = 10; + + // The amount of weight to assign each top blob in the objective. + // Each layer assigns a default value, usually of either 0 or 1, + // to each top blob. + repeated float loss_weight = 5; + + // Specifies training parameters (multipliers on global learning constants, + // and the name and other settings used for weight sharing). + repeated ParamSpec param = 6; + + // The blobs containing the numeric parameters of the layer. + repeated BlobProto blobs = 7; + + // Specifies whether to backpropagate to each bottom. If unspecified, + // Caffe will automatically infer whether each input needs backpropagation + // to compute parameter gradients. If set to true for some inputs, + // backpropagation to those inputs is forced; if set false for some inputs, + // backpropagation to those inputs is skipped. + // + // The size must be either 0 or equal to the number of bottoms. + repeated bool propagate_down = 11; + + // Rules controlling whether and when a layer is included in the network, + // based on the current NetState. You may specify a non-zero number of rules + // to include OR exclude, but not both. If no include or exclude rules are + // specified, the layer is always included. If the current NetState meets + // ANY (i.e., one or more) of the specified rules, the layer is + // included/excluded. + repeated NetStateRule include = 8; + repeated NetStateRule exclude = 9; + + // Parameters for data pre-processing. + optional TransformationParameter transform_param = 100; + + // Parameters shared by loss layers. + optional LossParameter loss_param = 101; + + // Layer type-specific parameters. + // + // Note: certain layers may have more than one computational engine + // for their implementation. These layers include an Engine type and + // engine parameter for selecting the implementation. + // The default for the engine is set by the ENGINE switch at compile-time. + optional AccuracyParameter accuracy_param = 102; + optional ArgMaxParameter argmax_param = 103; + optional BatchNormParameter batch_norm_param = 139; + optional BiasParameter bias_param = 141; + optional ConcatParameter concat_param = 104; + optional ContrastiveLossParameter contrastive_loss_param = 105; + optional ConvolutionParameter convolution_param = 106; + optional CropParameter crop_param = 144; + optional DataParameter data_param = 107; + optional DetectionOutputParameter detection_output_param = 150; + optional DropoutParameter dropout_param = 108; + optional DummyDataParameter dummy_data_param = 109; + optional EltwiseParameter eltwise_param = 110; + optional ELUParameter elu_param = 140; + optional EmbedParameter embed_param = 137; + optional ExpParameter exp_param = 111; + optional FlattenParameter flatten_param = 135; + optional HDF5DataParameter hdf5_data_param = 112; + optional HDF5OutputParameter hdf5_output_param = 113; + optional HingeLossParameter hinge_loss_param = 114; + optional ImageDataParameter image_data_param = 115; + optional InfogainLossParameter infogain_loss_param = 116; + optional InnerProductParameter inner_product_param = 117; + optional InputParameter input_param = 143; + optional LogParameter log_param = 134; + optional LRNParameter lrn_param = 118; + optional MemoryDataParameter memory_data_param = 119; + optional MVNParameter mvn_param = 120; + optional ParameterParameter parameter_param = 145; + optional PoolingParameter pooling_param = 121; + optional PowerParameter power_param = 122; + optional PReLUParameter prelu_param = 131; + optional PythonParameter python_param = 130; + optional RecurrentParameter recurrent_param = 146; + optional ReductionParameter reduction_param = 136; + optional ReLUParameter relu_param = 123; + optional ReshapeParameter reshape_param = 133; + optional ScaleParameter scale_param = 142; + optional SigmoidParameter sigmoid_param = 124; + optional SmoothL1LossParameter smooth_l1_loss_param = 148; + optional SoftmaxParameter softmax_param = 125; + optional SPPParameter spp_param = 132; + optional SliceParameter slice_param = 126; + optional TanHParameter tanh_param = 127; + optional ThresholdParameter threshold_param = 128; + optional TileParameter tile_param = 138; + optional WindowDataParameter window_data_param = 129; + optional PermuteParameter permute_param = 202; + optional PriorBoxParameter prior_box_param = 203; + optional NormalizeParameter norm_param = 206; + optional PSROIPoolingParameter psroi_pooling_param = 207; + optional FreespaceExtractParameter freespace_extract_param = 151; + optional PostprocessParameter postprocess_param = 152; + optional SpatialTransformParameter spatial_transform_param = 153; + optional ROIAlignParameter roi_align_param = 154; + optional ReorgParameter reorg_param = 155; + optional RegionParameter region_param = 156; + optional ReverseParameter reverse_param = 157; + optional InterpParameter interp_param = 158; + optional ShuffleChannelParameter shuffle_channel_param = 159; + optional UpsampleParameter upsample_param = 160; + optional ROIPoolingParameter roi_pooling_param = 161; + optional YoloParameter yolo_param = 199; + optional YoloV3DetectionOutputParameter yolov3_detection_output_param = 200; + optional ProposalParameter proposal_param = 201; + optional FSRDetectionOutputParameter fsrdetectionoutput_param = 222; + optional SSDDetectionOutputParameter ssddetectionoutput_param = 232; + optional YoloV2DetectionOutputParameter yolov2_detection_output_param = 204; + optional QuantParameter quant_param = 208; + optional CondTakeParameter condtake_param = 233; + optional MatrixInverseParameter matrix_inverse_param = 210; + optional WarpPerspectiveParameter warp_perspective_param = 234; + optional BatchMatMulParameter batch_matmul_param = 235; + optional SpatialTransformerParameter st_param = 5000; + optional YoloV3DetectionOutputV2Parameter yolov3_detection_output_v2_param = 5001; +} + +// Message that stores parameters used to apply transformation +// to the data layer's data +message TransformationParameter { + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 1 [default = 1]; + // Specify if we want to randomly mirror data. + optional bool mirror = 2 [default = false]; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 3 [default = 0]; + // mean_file and mean_value cannot be specified at the same time + optional string mean_file = 4; + // if specified can be repeated once (would substract it from all the channels) + // or can be repeated the same number of times as channels + // (would subtract them from the corresponding channel) + repeated float mean_value = 5; + // Force the decoded image to have 3 color channels. + optional bool force_color = 6 [default = false]; + // Force the decoded image to have 1 color channels. + optional bool force_gray = 7 [default = false]; +} + +// Message that stores parameters shared by loss layers +message LossParameter { + // If specified, ignore instances with the given label. + optional int32 ignore_label = 1; + // How to normalize the loss for loss layers that aggregate across batches, + // spatial dimensions, or other dimensions. Currently only implemented in + // SoftmaxWithLoss and SigmoidCrossEntropyLoss layers. + enum NormalizationMode { + // Divide by the number of examples in the batch times spatial dimensions. + // Outputs that receive the ignore label will NOT be ignored in computing + // the normalization factor. + FULL = 0; + // Divide by the total number of output locations that do not take the + // ignore_label. If ignore_label is not set, this behaves like FULL. + VALID = 1; + // Divide by the batch size. + BATCH_SIZE = 2; + // Do not normalize the loss. + NONE = 3; + } + // For historical reasons, the default normalization for + // SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID. + optional NormalizationMode normalization = 3 [default = VALID]; + // Deprecated. Ignored if normalization is specified. If normalization + // is not specified, then setting this to false will be equivalent to + // normalization = BATCH_SIZE to be consistent with previous behavior. + optional bool normalize = 2; +} + +// Messages that store parameters used by individual layer types follow, in +// alphabetical order. + +message AccuracyParameter { + // When computing accuracy, count as correct by comparing the true label to + // the top k scoring classes. By default, only compare to the top scoring + // class (i.e. argmax). + optional uint32 top_k = 1 [default = 1]; + + // The "label" axis of the prediction blob, whose argmax corresponds to the + // predicted label -- may be negative to index from the end (e.g., -1 for the + // last axis). For example, if axis == 1 and the predictions are + // (N x C x H x W), the label blob is expected to contain N*H*W ground truth + // labels with integer values in {0, 1, ..., C-1}. + optional int32 axis = 2 [default = 1]; + + // If specified, ignore instances with the given label. + optional int32 ignore_label = 3; +} + +message ArgMaxParameter { + // If true produce pairs (argmax, maxval) + optional bool out_max_val = 1 [default = false]; + optional uint32 top_k = 2 [default = 1]; + // The axis along which to maximise -- may be negative to index from the + // end (e.g., -1 for the last axis). + // By default ArgMaxLayer maximizes over the flattened trailing dimensions + // for each index of the first / num dimension. + optional int32 axis = 3; +} + +message ConcatParameter { + // The axis along which to concatenate -- may be negative to index from the + // end (e.g., -1 for the last axis). Other axes must have the + // same dimension for all the bottom blobs. + // By default, ConcatLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 2 [default = 1]; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 concat_dim = 1 [default = 1]; +} + +message BatchNormParameter { + // If false, normalization is performed over the current mini-batch + // and global statistics are accumulated (but not yet used) by a moving + // average. + // If true, those accumulated mean and variance values are used for the + // normalization. + // By default, it is set to false when the network is in the training + // phase and true when the network is in the testing phase. + optional bool use_global_stats = 1; + // What fraction of the moving average remains each iteration? + // Smaller values make the moving average decay faster, giving more + // weight to the recent values. + // Each iteration updates the moving average @f$S_{t-1}@f$ with the + // current mean @f$ Y_t @f$ by + // @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$ + // is the moving_average_fraction parameter. + optional float moving_average_fraction = 2 [default = .999]; + // Small value to add to the variance estimate so that we don't divide by + // zero. + optional float eps = 3 [default = 1e-5]; +} + +message BiasParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a scalar bias. + optional int32 axis = 1 [default = 1]; + + // (num_axes is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the bias + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // Set num_axes := 0, to add a zero-axis Blob: a scalar. + optional int32 num_axes = 2 [default = 1]; + + // (filler is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer.) + // The initialization for the learned bias parameter. + // Default is the zero (0) initialization, resulting in the BiasLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; + optional bool bias_from_blob = 4 [default = true]; +} + +message ContrastiveLossParameter { + // margin for dissimilar pair + optional float margin = 1 [default = 1.0]; + // The first implementation of this cost did not exactly match the cost of + // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2. + // legacy_version = false (the default) uses (margin - d)^2 as proposed in the + // Hadsell paper. New models should probably use this version. + // legacy_version = true uses (margin - d^2). This is kept to support / + // reproduce existing models and results + optional bool legacy_version = 2 [default = false]; +} + +message ConvolutionParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in all spatial dimensions, or once per spatial dimension. + repeated uint32 pad = 3; // The padding size; defaults to 0 + repeated uint32 kernel_size = 4; // The kernel size + repeated uint32 stride = 6; // The stride; defaults to 1 + // Factor used to dilate the kernel, (implicitly) zero-filling the resulting + // holes. (Kernel dilation is sometimes referred to by its use in the + // algorithme à trous from Holschneider et al. 1987.) + repeated uint32 dilation = 18; // The dilation; defaults to 1 + + // For 2D convolution only, the *_h and *_w versions may also be used to + // specify both spatial dimensions. + optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only) + optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only) + optional uint32 kernel_h = 11; // The kernel height (2D only) + optional uint32 kernel_w = 12; // The kernel width (2D only) + optional uint32 stride_h = 13; // The stride height (2D only) + optional uint32 stride_w = 14; // The stride width (2D only) + + optional uint32 group = 5 [default = 1]; // The group size for group conv + + optional FillerParameter weight_filler = 7; // The filler for the weight + optional FillerParameter bias_filler = 8; // The filler for the bias + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; + + // The axis to interpret as "channels" when performing convolution. + // Preceding dimensions are treated as independent inputs; + // succeeding dimensions are treated as "spatial". + // With (N, C, H, W) inputs, and axis == 1 (the default), we perform + // N independent 2D convolutions, sliding C-channel (or (C/g)-channels, for + // groups g>1) filters across the spatial axes (H, W) of the input. + // With (N, C, D, H, W) inputs, and axis == 1, we perform + // N independent 3D convolutions, sliding (C/g)-channels + // filters across the spatial axes (D, H, W) of the input. + optional int32 axis = 16 [default = 1]; + + // Whether to force use of the general ND convolution, even if a specific + // implementation for blobs of the appropriate number of spatial dimensions + // is available. (Currently, there is only a 2D-specific convolution + // implementation; for input blobs with num_axes != 2, this option is + // ignored and the ND implementation will be used.) + optional bool force_nd_im2col = 17 [default = false]; +} + +message CropParameter { + // To crop, elements of the first bottom are selected to fit the dimensions + // of the second, reference bottom. The crop is configured by + // - the crop `axis` to pick the dimensions for cropping + // - the crop `offset` to set the shift for all/each dimension + // to align the cropped bottom with the reference bottom. + // All dimensions up to but excluding `axis` are preserved, while + // the dimensions including and trailing `axis` are cropped. + // If only one `offset` is set, then all dimensions are offset by this amount. + // Otherwise, the number of offsets must equal the number of cropped axes to + // shift the crop in each dimension accordingly. + // Note: standard dimensions are N,C,H,W so the default is a spatial crop, + // and `axis` may be negative to index from the end (e.g., -1 for the last + // axis). + optional int32 axis = 1 [default = 2]; + repeated uint32 offset = 2; +} + +message DataParameter { + enum DB { + LEVELDB = 0; + LMDB = 1; + } + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + // DEPRECATED. Each solver accesses a different subset of the database. + optional uint32 rand_skip = 7 [default = 0]; + optional DB backend = 8 [default = LEVELDB]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + // Force the encoded image to have 3 color channels + optional bool force_encoded_color = 9 [default = false]; + // Prefetch queue (Increase if data feeding bandwidth varies, within the + // limit of device memory for GPU training) + optional uint32 prefetch = 10 [default = 4]; +} + +message DropoutParameter { + optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio + optional bool scale_train = 2 [default = true]; // scale train or test phase +} + +// DummyDataLayer fills any number of arbitrarily shaped blobs with random +// (or constant) data generated by "Fillers" (see "message FillerParameter"). +message DummyDataParameter { + // This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or N + // shape fields, and 0, 1 or N data_fillers. + // + // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used. + // If 1 data_filler is specified, it is applied to all top blobs. If N are + // specified, the ith is applied to the ith top blob. + repeated FillerParameter data_filler = 1; + repeated BlobShape shape = 6; + + // 4D dimensions -- deprecated. Use "shape" instead. + repeated uint32 num = 2; + repeated uint32 channels = 3; + repeated uint32 height = 4; + repeated uint32 width = 5; +} + +message EltwiseParameter { + enum EltwiseOp { + PROD = 0; + SUM = 1; + MAX = 2; + } + optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation + repeated float coeff = 2; // blob-wise coefficient for SUM operation + + // Whether to use an asymptotically slower (for >2 inputs) but stabler method + // of computing the gradient for the PROD operation. (No effect for SUM op.) + optional bool stable_prod_grad = 3 [default = true]; +} + +// Message that stores parameters used by ELULayer +message ELUParameter { + // Described in: + // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate + // Deep Network Learning by Exponential Linear Units (ELUs). arXiv + optional float alpha = 1 [default = 1]; +} + +// Message that stores parameters used by EmbedLayer +message EmbedParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + // The input is given as integers to be interpreted as one-hot + // vector indices with dimension num_input. Hence num_input should be + // 1 greater than the maximum possible input value. + optional uint32 input_dim = 2; + + optional bool bias_term = 3 [default = true]; // Whether to use a bias term + optional FillerParameter weight_filler = 4; // The filler for the weight + optional FillerParameter bias_filler = 5; // The filler for the bias + +} + +// Message that stores parameters used by ExpLayer +message ExpParameter { + // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = exp(shift + scale * x). + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +/// Message that stores parameters used by FlattenLayer +message FlattenParameter { + // The first axis to flatten: all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 1 [default = 1]; + + // The last axis to flatten: all following axes are retained in the output. + // May be negative to index from the end (e.g., the default -1 for the last + // axis). + optional int32 end_axis = 2 [default = -1]; +} + +// Message that stores parameters used by HDF5DataLayer +message HDF5DataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 2; + + // Specify whether to shuffle the data. + // If shuffle == true, the ordering of the HDF5 files is shuffled, + // and the ordering of data within any given HDF5 file is shuffled, + // but data between different files are not interleaved; all of a file's + // data are output (in a random order) before moving onto another file. + optional bool shuffle = 3 [default = false]; +} + +message HDF5OutputParameter { + optional string file_name = 1; +} + +message HingeLossParameter { + enum Norm { + L1 = 1; + L2 = 2; + } + // Specify the Norm to use L1 or L2 + optional Norm norm = 1 [default = L1]; +} + +message ImageDataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4 [default = 1]; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 7 [default = 0]; + // Whether or not ImageLayer should shuffle the list of files at every epoch. + optional bool shuffle = 8 [default = false]; + // It will also resize images if new_height or new_width are not zero. + optional uint32 new_height = 9 [default = 0]; + optional uint32 new_width = 10 [default = 0]; + // Specify if the images are color or gray + optional bool is_color = 11 [default = true]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + optional string root_folder = 12 [default = ""]; +} + +message InfogainLossParameter { + // Specify the infogain matrix source. + optional string source = 1; + optional int32 axis = 2 [default = 1]; // axis of prob +} + +message InnerProductParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 3; // The filler for the weight + optional FillerParameter bias_filler = 4; // The filler for the bias + + // The first axis to be lumped into a single inner product computation; + // all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 5 [default = 1]; + // Specify whether to transpose the weight matrix or not. + // If transpose == true, any operations will be performed on the transpose + // of the weight matrix. The weight matrix itself is not going to be transposed + // but rather the transfer flag of operations will be toggled accordingly. + optional bool transpose = 6 [default = false]; +} + +message InputParameter { + // This layer produces N >= 1 top blob(s) to be assigned manually. + // Define N shapes to set a shape for each top. + // Define 1 shape to set the same shape for every top. + // Define no shape to defer to reshaping manually. + repeated BlobShape shape = 1; +} + +// Message that stores parameters used by LogLayer +message LogParameter { + // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = ln(shift + scale * x) = log_e(shift + scale * x) + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +// Message that stores parameters used by LRNLayer +message LRNParameter { + optional uint32 local_size = 1 [default = 5]; + optional float alpha = 2 [default = 1.]; + optional float beta = 3 [default = 0.75]; + enum NormRegion { + ACROSS_CHANNELS = 0; + WITHIN_CHANNEL = 1; + } + optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; + optional float k = 5 [default = 1.]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; +} + +message MemoryDataParameter { + optional uint32 batch_size = 1; + optional uint32 channels = 2; + optional uint32 height = 3; + optional uint32 width = 4; +} + +message MVNParameter { + // This parameter can be set to false to normalize mean only + optional bool normalize_variance = 1 [default = true]; + + // This parameter can be set to true to perform DNN-like MVN + optional bool across_channels = 2 [default = false]; + + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 3 [default = 1e-9]; +} + +message ParameterParameter { + optional BlobShape shape = 1; +} + +message PoolingParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 1 [default = MAX]; // The pooling method + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X) + optional uint32 pad_h = 9 [default = 0]; // The padding height + optional uint32 pad_w = 10 [default = 0]; // The padding width + optional uint32 kernel_size = 2; // The kernel size (square) + optional uint32 kernel_h = 5; // The kernel height + optional uint32 kernel_w = 6; // The kernel width + optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X) + optional uint32 stride_h = 7; // The stride height + optional uint32 stride_w = 8; // The stride width + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 11 [default = DEFAULT]; + // If global_pooling then it will pool over the size of the bottom by doing + // kernel_h = bottom->height and kernel_w = bottom->width + optional bool global_pooling = 12 [default = false]; + optional bool ceil_mode = 13 [default = true]; + // How to calculate the output size - using ceil (default) or floor rounding. + enum RoundMode { + CEIL = 0; + FLOOR = 1; + } + optional RoundMode round_mode = 14 [default = CEIL]; +} + +message PowerParameter { + // PowerLayer computes outputs y = (shift + scale * x) ^ power. + optional float power = 1 [default = 1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +message PythonParameter { + optional string module = 1; + optional string layer = 2; + // This value is set to the attribute `param_str` of the `PythonLayer` object + // in Python before calling the `setup()` method. This could be a number, + // string, dictionary in Python dict format, JSON, etc. You may parse this + // string in `setup` method and use it in `forward` and `backward`. + optional string param_str = 3 [default = '']; + // Whether this PythonLayer is shared among worker solvers during data parallelism. + // If true, each worker solver sequentially run forward from this layer. + // This value should be set true if you are using it as a data layer. + optional bool share_in_parallel = 4 [default = false]; +} + +// Message that stores parameters used by RecurrentLayer +message RecurrentParameter { + // The dimension of the output (and usually hidden state) representation -- + // must be explicitly set to non-zero. + optional uint32 num_output = 1 [default = 0]; + + optional FillerParameter weight_filler = 2; // The filler for the weight + optional FillerParameter bias_filler = 3; // The filler for the bias + + // Whether to enable displaying debug_info in the unrolled recurrent net. + optional bool debug_info = 4 [default = false]; + + // Whether to add as additional inputs (bottoms) the initial hidden state + // blobs, and add as additional outputs (tops) the final timestep hidden state + // blobs. The number of additional bottom/top blobs required depends on the + // recurrent architecture -- e.g., 1 for RNNs, 2 for LSTMs. + optional bool expose_hidden = 5 [default = false]; +} + +// Message that stores parameters used by ReductionLayer +message ReductionParameter { + enum ReductionOp { + SUM = 1; + ASUM = 2; + SUMSQ = 3; + MEAN = 4; + } + + optional ReductionOp operation = 1 [default = SUM]; // reduction operation + + // The first axis to reduce to a scalar -- may be negative to index from the + // end (e.g., -1 for the last axis). + // (Currently, only reduction along ALL "tail" axes is supported; reduction + // of axis M through N, where N < num_axes - 1, is unsupported.) + // Suppose we have an n-axis bottom Blob with shape: + // (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)). + // If axis == m, the output Blob will have shape + // (d0, d1, d2, ..., d(m-1)), + // and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1)) + // times, each including (dm * d(m+1) * ... * d(n-1)) individual data. + // If axis == 0 (the default), the output Blob always has the empty shape + // (count 1), performing reduction across the entire input -- + // often useful for creating new loss functions. + optional int32 axis = 2 [default = 0]; + + optional float coeff = 3 [default = 1.0]; // coefficient for output +} + +// Message that stores parameters used by ReLULayer +message ReLUParameter { + // Allow non-zero slope for negative inputs to speed up optimization + // Described in: + // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities + // improve neural network acoustic models. In ICML Workshop on Deep Learning + // for Audio, Speech, and Language Processing. + optional float negative_slope = 1 [default = 0]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 2 [default = DEFAULT]; +} + +message ReshapeParameter { + // Specify the output dimensions. If some of the dimensions are set to 0, + // the corresponding dimension from the bottom layer is used (unchanged). + // Exactly one dimension may be set to -1, in which case its value is + // inferred from the count of the bottom blob and the remaining dimensions. + // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8: + // + // layer { + // type: "Reshape" bottom: "input" top: "output" + // reshape_param { ... } + // } + // + // If "input" is 2D with shape 2 x 8, then the following reshape_param + // specifications are all equivalent, producing a 3D blob "output" with shape + // 2 x 2 x 4: + // + // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 0 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 0 dim: 2 dim: -1 } } + // reshape_param { shape { dim: 0 dim:-1 dim: 4 } } + // + optional BlobShape shape = 1; + + // axis and num_axes control the portion of the bottom blob's shape that are + // replaced by (included in) the reshape. By default (axis == 0 and + // num_axes == -1), the entire bottom blob shape is included in the reshape, + // and hence the shape field must specify the entire output shape. + // + // axis may be non-zero to retain some portion of the beginning of the input + // shape (and may be negative to index from the end; e.g., -1 to begin the + // reshape after the last axis, including nothing in the reshape, + // -2 to include only the last axis, etc.). + // + // For example, suppose "input" is a 2D blob with shape 2 x 8. + // Then the following ReshapeLayer specifications are all equivalent, + // producing a blob "output" with shape 2 x 2 x 4: + // + // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 2 dim: 4 } axis: 1 } + // reshape_param { shape { dim: 2 dim: 4 } axis: -3 } + // + // num_axes specifies the extent of the reshape. + // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on + // input axes in the range [axis, axis+num_axes]. + // num_axes may also be -1, the default, to include all remaining axes + // (starting from axis). + // + // For example, suppose "input" is a 2D blob with shape 2 x 8. + // Then the following ReshapeLayer specifications are equivalent, + // producing a blob "output" with shape 1 x 2 x 8. + // + // reshape_param { shape { dim: 1 dim: 2 dim: 8 } } + // reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 } + // reshape_param { shape { dim: 1 } num_axes: 0 } + // + // On the other hand, these would produce output blob shape 2 x 1 x 8: + // + // reshape_param { shape { dim: 2 dim: 1 dim: 8 } } + // reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 } + // + optional int32 axis = 2 [default = 0]; + optional int32 num_axes = 3 [default = -1]; +} + + +message ScaleParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a scalar multiplier. + optional int32 axis = 1 [default = 1]; + + // (num_axes is ignored unless just one bottom is given and the scale is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the scale + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // Set num_axes := 0, to multiply with a zero-axis Blob: a scalar. + optional int32 num_axes = 2 [default = 1]; + + // (filler is ignored unless just one bottom is given and the scale is + // a learned parameter of the layer.) + // The initialization for the learned scale parameter. + // Default is the unit (1) initialization, resulting in the ScaleLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; + + // Whether to also learn a bias (equivalent to a ScaleLayer+BiasLayer, but + // may be more efficient). Initialized with bias_filler (defaults to 0). + optional bool bias_term = 4 [default = false]; + optional FillerParameter bias_filler = 5; + optional bool scale_from_blob = 6 [default = true]; +} + +message SigmoidParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +message SliceParameter { + // The axis along which to slice -- may be negative to index from the end + // (e.g., -1 for the last axis). + // By default, SliceLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 3 [default = 1]; + repeated uint32 slice_point = 2; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 slice_dim = 1 [default = 1]; +} + +message SmoothL1LossParameter { + // SmoothL1Loss(x) = + // 0.5 * (sigma * x) ** 2 -- if x < 1.0 / sigma / sigma + // |x| - 0.5 / sigma / sigma -- otherwise + optional float sigma = 1 [default = 1]; +} + +// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer +message SoftmaxParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; + + // The axis along which to perform the softmax -- may be negative to index + // from the end (e.g., -1 for the last axis). + // Any other axes will be evaluated as independent softmaxes. + optional int32 axis = 2 [default = 1]; +} + +message TanHParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +// Message that stores parameters used by TileLayer +message TileParameter { + // The index of the axis to tile. + optional int32 axis = 1 [default = 1]; + + // The number of copies (tiles) of the blob to output. + optional int32 tiles = 2; +} + +// Message that stores parameters used by ThresholdLayer +message ThresholdParameter { + optional float threshold = 1 [default = 0]; // Strictly positive values +} + +message WindowDataParameter { + // Specify the data source. + optional string source = 1; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // Specify the batch size. + optional uint32 batch_size = 4; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 5 [default = 0]; + // Specify if we want to randomly mirror data. + optional bool mirror = 6 [default = false]; + // Foreground (object) overlap threshold + optional float fg_threshold = 7 [default = 0.5]; + // Background (non-object) overlap threshold + optional float bg_threshold = 8 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float fg_fraction = 9 [default = 0.25]; + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 context_pad = 10 [default = 0]; + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string crop_mode = 11 [default = "warp"]; + // cache_images: will load all images in memory for faster access + optional bool cache_images = 12 [default = false]; + // append root_folder to locate images + optional string root_folder = 13 [default = ""]; +} + +message SPPParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional uint32 pyramid_height = 1; + optional PoolMethod pool = 2 [default = MAX]; // The pooling method + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; +} + +// DEPRECATED: use LayerParameter. +message V1LayerParameter { + repeated string bottom = 2; + repeated string top = 3; + optional string name = 4; + repeated NetStateRule include = 32; + repeated NetStateRule exclude = 33; + enum LayerType { + NONE = 0; + ABSVAL = 35; + ACCURACY = 1; + ARGMAX = 30; + BNLL = 2; + CONCAT = 3; + CONTRASTIVE_LOSS = 37; + CONVOLUTION = 4; + DATA = 5; + DECONVOLUTION = 39; + DROPOUT = 6; + DUMMY_DATA = 32; + EUCLIDEAN_LOSS = 7; + ELTWISE = 25; + EXP = 38; + FLATTEN = 8; + HDF5_DATA = 9; + HDF5_OUTPUT = 10; + HINGE_LOSS = 28; + IM2COL = 11; + IMAGE_DATA = 12; + INFOGAIN_LOSS = 13; + INNER_PRODUCT = 14; + LRN = 15; + MEMORY_DATA = 29; + MULTINOMIAL_LOGISTIC_LOSS = 16; + MVN = 34; + POOLING = 17; + POWER = 26; + RELU = 18; + SIGMOID = 19; + SIGMOID_CROSS_ENTROPY_LOSS = 27; + SILENCE = 36; + SOFTMAX = 20; + SOFTMAX_LOSS = 21; + SPLIT = 22; + SLICE = 33; + TANH = 23; + WINDOW_DATA = 24; + THRESHOLD = 31; + QUANT = 208; + DEQUANT = 209; + } + optional LayerType type = 5; + repeated BlobProto blobs = 6; + repeated string param = 1001; + repeated DimCheckMode blob_share_mode = 1002; + enum DimCheckMode { + STRICT = 0; + PERMISSIVE = 1; + } + repeated float blobs_lr = 7; + repeated float weight_decay = 8; + repeated float loss_weight = 35; + optional AccuracyParameter accuracy_param = 27; + optional ArgMaxParameter argmax_param = 23; + optional ConcatParameter concat_param = 9; + optional ContrastiveLossParameter contrastive_loss_param = 40; + optional ConvolutionParameter convolution_param = 10; + optional DataParameter data_param = 11; + optional DropoutParameter dropout_param = 12; + optional DummyDataParameter dummy_data_param = 26; + optional EltwiseParameter eltwise_param = 24; + optional ExpParameter exp_param = 41; + optional HDF5DataParameter hdf5_data_param = 13; + optional HDF5OutputParameter hdf5_output_param = 14; + optional HingeLossParameter hinge_loss_param = 29; + optional ImageDataParameter image_data_param = 15; + optional InfogainLossParameter infogain_loss_param = 16; + optional InnerProductParameter inner_product_param = 17; + optional LRNParameter lrn_param = 18; + optional MemoryDataParameter memory_data_param = 22; + optional MVNParameter mvn_param = 34; + optional PoolingParameter pooling_param = 19; + optional PowerParameter power_param = 21; + optional ReLUParameter relu_param = 30; + optional SigmoidParameter sigmoid_param = 38; + optional SoftmaxParameter softmax_param = 39; + optional SliceParameter slice_param = 31; + optional TanHParameter tanh_param = 37; + optional ThresholdParameter threshold_param = 25; + optional WindowDataParameter window_data_param = 20; + optional TransformationParameter transform_param = 36; + optional LossParameter loss_param = 42; + optional V0LayerParameter layer = 1; +} + +// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters +// in Caffe. We keep this message type around for legacy support. +message V0LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the string to specify the layer type + + // Parameters to specify layers with inner products. + optional uint32 num_output = 3; // The number of outputs for the layer + optional bool biasterm = 4 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 5; // The filler for the weight + optional FillerParameter bias_filler = 6; // The filler for the bias + + optional uint32 pad = 7 [default = 0]; // The padding size + optional uint32 kernelsize = 8; // The kernel size + optional uint32 group = 9 [default = 1]; // The group size for group conv + optional uint32 stride = 10 [default = 1]; // The stride + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 11 [default = MAX]; // The pooling method + optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio + + optional uint32 local_size = 13 [default = 5]; // for local response norm + optional float alpha = 14 [default = 1.]; // for local response norm + optional float beta = 15 [default = 0.75]; // for local response norm + optional float k = 22 [default = 1.]; + + // For data layers, specify the data source + optional string source = 16; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 17 [default = 1]; + optional string meanfile = 18; + // For data layers, specify the batch size. + optional uint32 batchsize = 19; + // For data layers, specify if we would like to randomly crop an image. + optional uint32 cropsize = 20 [default = 0]; + // For data layers, specify if we want to randomly mirror data. + optional bool mirror = 21 [default = false]; + + // The blobs containing the numeric parameters of the layer + repeated BlobProto blobs = 50; + // The ratio that is multiplied on the global learning rate. If you want to + // set the learning ratio for one blob, you need to set it for all blobs. + repeated float blobs_lr = 51; + // The weight decay that is multiplied on the global weight decay. + repeated float weight_decay = 52; + + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 53 [default = 0]; + + // Fields related to detection (det_*) + // foreground (object) overlap threshold + optional float det_fg_threshold = 54 [default = 0.5]; + // background (non-object) overlap threshold + optional float det_bg_threshold = 55 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float det_fg_fraction = 56 [default = 0.25]; + + // optional bool OBSOLETE_can_clobber = 57 [default = true]; + + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 det_context_pad = 58 [default = 0]; + + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string det_crop_mode = 59 [default = "warp"]; + + // For ReshapeLayer, one needs to specify the new dimensions. + optional int32 new_num = 60 [default = 0]; + optional int32 new_channels = 61 [default = 0]; + optional int32 new_height = 62 [default = 0]; + optional int32 new_width = 63 [default = 0]; + + // Whether or not ImageLayer should shuffle the list of files at every epoch. + // It will also resize images if new_height or new_width are not zero. + optional bool shuffle_images = 64 [default = false]; + + // For ConcatLayer, one needs to specify the dimension for concatenation, and + // the other dimensions must be the same for all the bottom blobs. + // By default it will concatenate blobs along the channels dimension. + optional uint32 concat_dim = 65 [default = 1]; + + optional HDF5OutputParameter hdf5_output_param = 1001; +} + +message PReLUParameter { + // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers: + // Surpassing Human-Level Performance on ImageNet Classification, 2015. + + // Initial value of a_i. Default is a_i=0.25 for all i. + optional FillerParameter filler = 1; + // Whether or not slope parameters are shared across channels. + optional bool channel_shared = 2 [default = false]; +} + +// Message that stores parameters used by DetectionOutputLayer +//message DetectionOutputParameter { +// optional int32 num_classes = 1 [default = 21]; +// optional float nms_threshold = 2 [default = 0.3]; +// optional int32 top_k = 3; +// optional float confidence_threshold = 4 [default = 0.8]; +//} + +// Message that store parameters used by PriorBoxLayer +message PriorBoxParameter { + // Encode/decode type. + enum CodeType { + CORNER = 1; + CENTER_SIZE = 2; + CORNER_SIZE = 3; + } + // Minimum box size (in pixels). Required! + repeated float min_size = 1; + // Maximum box size (in pixels). Required! + repeated float max_size = 2; + // Various of aspect ratios. Duplicate ratios will be ignored. + // If none is provided, we use default ratio 1. + repeated float aspect_ratio = 3; + // If true, will flip each aspect ratio. + // For example, if there is aspect ratio "r", + // we will generate aspect ratio "1.0/r" as well. + optional bool flip = 4 [default = true]; + // If true, will clip the prior so that it is within [0, 1] + optional bool clip = 5 [default = false]; + // Variance for adjusting the prior bboxes. + repeated float variance = 6; + // By default, we calculate img_height, img_width, step_x, step_y based on + // bottom[0] (feat) and bottom[1] (img). Unless these values are explicitely + // provided. + // Explicitly provide the img_size. + optional uint32 img_size = 7; + // Either img_size or img_h/img_w should be specified; not both. + optional uint32 img_h = 8; + optional uint32 img_w = 9; + + // Explicitly provide the step size. + optional float step = 10; + // Either step or step_h/step_w should be specified; not both. + optional float step_h = 11; + optional float step_w = 12; + + // Offset to the top left corner of each cell. + optional float offset = 13 [default = 0.5]; +} + +// Message that stores parameters used by PermutetLayer +message PermuteParameter { + // The new orders of the axes of data. Notice it should be with + // in the same range as the input data, and it starts from 0. + // Do not provide repeated order. + repeated uint32 order = 1; +} + +message NormalizeParameter { + optional bool across_spatial = 1 [default = true]; + // Initial value of scale. Default is 1.0 for all + optional FillerParameter scale_filler = 2; + // Whether or not scale parameters are shared across channels. + optional bool channel_shared = 3 [default = true]; + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 4 [default = 1e-10]; +} + +// needed by ssd +message SaveOutputParameter { + // Output directory. If not empty, we will save the results. + optional string output_directory = 1; + // Output name prefix. + optional string output_name_prefix = 2; + // Output format. + // VOC - PASCAL VOC output format. + // COCO - MS COCO output format. + optional string output_format = 3; + // If you want to output results, must also provide the following two files. + // Otherwise, we will ignore saving results. + // label map file. + optional string label_map_file = 4; + // A file which contains a list of names and sizes with same order + // of the input DB. The file is in the following format: + // name height width + // ... + optional string name_size_file = 5; + // Number of test images. It can be less than the lines specified in + // name_size_file. For example, when we only want to evaluate on part + // of the test images. + optional uint32 num_test_image = 6; + // The resize parameter used in saving the data. + // optional ResizeParameter resize_param = 7; +} + +message NonMaximumSuppressionParameter { + // Threshold to be used in nms. + optional float nms_threshold = 1 [default = 0.3]; + // Maximum number of results to be kept. + optional int32 top_k = 2; + // Parameter for adaptive nms. + optional float eta = 3 [default = 1.0]; +} + +message GeneralNmsParameter { + optional int32 post_top_k = 1 ; + optional float nms_threshold = 2 [default = 0]; + optional float iou_threshold_decay = 3 [default = 1.0]; + optional float coor_scale_factor = 4 [default = 1.0]; +} + +// Message that store parameters used by DetectionOutputLayer, ssd/fasterRcnn +message DetectionOutputParameter { + optional int32 num_classes = 1; + optional bool share_location = 2 [default = true]; + optional int32 background_label_id = 3 [default = 0]; + optional NonMaximumSuppressionParameter nms_param = 4; + optional SaveOutputParameter save_output_param = 5; + optional PriorBoxParameter.CodeType code_type = 6 [default = CENTER_SIZE]; + optional bool variance_encoded_in_target = 8 [default = true]; + optional int32 keep_top_k = 7; + optional float confidence_threshold = 9; + optional float nms_threshold = 13; + optional int32 top_k = 14; + optional int32 boxes = 15 [default = 1]; + optional bool relative = 17 [default = true]; + optional float objectness_threshold = 18 [default = 0.5]; + optional float class_threshold = 19 [default = 0.5]; + repeated float biases = 20; + optional GeneralNmsParameter general_nms_param = 21; + optional float objectness_score = 22; +} +message PSROIPoolingParameter { + required float spatial_scale = 1; + required int32 output_dim = 2; // output channel number + required int32 group_size = 3; // number of groups to encode position-sensitive score maps +} +// Message that stores parameters used by FreespaceExtractLayer +message FreespaceExtractParameter { + optional float org_height = 1; +} + +// Message that stores parameters used by DetectpostprocessLayer +message PostprocessParameter { + optional float nms_thresh = 1 [default = 0.3]; + optional float conf_thresh = 2 [default = 0.5]; + optional uint32 post_nms_topn = 3 [default = 100]; + optional uint32 cls_num = 4 [default = 12]; + repeated float bbox_reg_weights = 5; +} + +// Message that stores parameters used by SpatialTransformLayer +message SpatialTransformParameter { + optional uint32 output_h = 1 [default = 0]; + optional uint32 output_w = 2 [default = 0]; + optional float border_value = 3 [default = 0]; + repeated float affine_transform = 4; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; +} +message ROIAlignParameter { + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pooled_h = 1 [default = 0]; // The pooled output height + optional uint32 pooled_w = 2 [default = 0]; // The pooled output width + // Multiplicative spatial scale factor to translate ROI coords from their + // input scale to the scale used when pooling + optional float spatial_scale = 3 [default = 1]; + optional int32 sampling_ratio = 4 [default = -1]; + optional int32 roi_end_mode = 5 [default = 0]; +} + +message RegionParameter { + optional uint32 classes = 1 [default = 20]; // Category of classification + optional uint32 coords = 2 [default = 4]; // Coordinates of box + optional uint32 boxes = 3 [default = 1]; // Number of boxes predicted per grid + optional uint32 softmax = 4 [default = 0]; + optional string softmax_tree = 5 [default = ""]; + optional uint32 background = 6 [default = 0]; +} +message ReorgParameter{ + optional uint32 stride = 2 [default = 2]; + optional bool reverse = 1 [default = false]; +} +message ReverseParameter{ + repeated int32 axis = 1; +} +message InterpParameter{ + optional int32 height = 1 [default = 0];//Height of output + optional int32 width = 2 [default = 0];//Width of output + optional int32 zoom_factor = 3 [default = 1];//zoom factor + optional int32 shrink_factor = 4 [default = 1];//shrink factor + optional int32 pad_beg = 5 [default = 0];//padding at begin of input + optional int32 pad_end = 6 [default = 0];//padding at end of input +} +message ShuffleChannelParameter{ + optional uint32 group = 1[default = 1]; // The number of group +} +message UpsampleParameter{ + optional float scale = 1[default = 1]; + optional int32 stride = 2[default = 2]; + optional int32 stride_h = 3[default = 2]; + optional int32 stride_w = 4[default=2]; +} +message ROIPoolingParameter { + required int32 pooled_h = 1; + required int32 pooled_w = 2; + optional float spatial_scale = 3 [default=0.0625]; + optional float spatial_scale_h = 4; + optional float spatial_scale_w = 5; +} + +message YoloParameter { + optional int32 boxes = 1 [default = 3]; + optional int32 coords = 2 [default = 4]; + optional int32 classes = 3 [default = 80]; + optional string yolo_version = 4 [default = "V3"]; + optional bool softmax = 5 [default = false]; + optional bool background = 6 [default = false]; + optional bool softmaxtree = 7 [default = false]; +} + +message YoloV3DetectionOutputParameter { + optional int32 boxes = 1 [default = 3]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases_high = 9; + repeated float biases_mid = 10; + repeated float biases_low = 11; + optional int32 coords = 12 [default = 4]; + repeated float biases = 13; + optional bool resize_origin_img_to_net = 14 [default = false]; +} + +message YoloV3DetectionOutputV2Parameter { + optional int32 boxes = 1 [default = 3]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases_high = 9; + repeated float biases_mid = 10; + repeated float biases_low = 11; + optional int32 coords = 12 [default = 4]; + repeated float biases = 13; + optional bool resize_origin_img_to_net = 14 [default = false]; + optional int32 out_box_dim = 15 [default = 3]; +} + +message ProposalParameter { + optional float feat_stride = 1 [default = 16]; + optional float base_size = 2 [default = 16]; + optional float min_size = 3 [default = 16]; + repeated float ratio = 4; + repeated float scale = 5; + optional int32 pre_nms_topn = 6 [default = 3000]; + optional int32 post_nms_topn = 7 [default = 304]; + optional float iou_threshold = 8 [default = 0.7]; + optional bool output_actual_rois_num = 9 [default = false]; +} + +message FSRDetectionOutputParameter { + required int32 num_classes = 1; + required float score_threshold = 2; + required float iou_threshold = 3; + optional int32 batch_rois = 4 [default = 1]; +} + +message SSDDetectionOutputParameter { + required int32 num_classes= 1 [default = 2]; + optional bool share_location = 2 [default = true]; + optional int32 background_label_id = 3 [default = 0]; + optional float iou_threshold = 4 [default = 0.3]; + optional int32 top_k = 5 [default = 200]; + optional float eta = 6 [default = 1.0]; + optional bool variance_encoded_in_target = 7 [default = false]; + optional int32 code_type = 8 [default = 1]; + optional int32 keep_top_k = 9 [default = -1]; + optional float confidence_threshold = 10 [default = 0.0]; +} +message YoloV2DetectionOutputParameter { + optional int32 boxes = 1 [default = 5]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases = 9; + optional int32 coords = 10 [default = 4]; + optional bool resize_origin_img_to_net = 11 [default = false]; +} + +message QuantParameter { + optional float scale = 2; + optional bytes offset = 3; +} + +message BatchMatMulParameter{ + optional bool adj_x1 = 1 [default = false]; + optional bool adj_x2 = 2 [default = false]; +} + +message CondTakeParameter { + required string mode = 1; + required float val = 2; + optional float eps = 3 [default = 1e-06]; +} + +message MatrixInverseParameter { + optional bool adjoint = 1 [default = false]; +} + +message WarpPerspectiveParameter { + required int32 out_height = 1; + required int32 out_width = 2; + optional float constant = 3; + optional string border_type = 4 [default = 'BORDER_CONSTANT']; +} + +message SpatialTransformerParameter { + // How to use the parameter passed by localisation network + optional string transform_type = 1 [default = "affine"]; + // What is the sampling technique + optional string sampler_type = 2 [default = "bilinear"]; + + // If not set,stay same with the input dimension H and W + optional int32 output_H = 3; + optional int32 output_W = 4; + // If false, only compute dTheta, DO NOT compute dU + optional bool to_compute_dU = 5 [default = true]; + + // The default value for some parameters + optional double theta_1_1 = 6; + optional double theta_1_2 = 7; + optional double theta_1_3 = 8; + optional double theta_2_1 = 9; + optional double theta_2_2 = 10; + optional double theta_2_3 = 11; +} diff --git a/ge/proto/dump_task.proto b/ge/proto/dump_task.proto new file mode 100644 index 00000000..ecdf4792 --- /dev/null +++ b/ge/proto/dump_task.proto @@ -0,0 +1,127 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; +package toolkit.dumpdata; + +enum OutputDataType { + DT_UNDEFINED = 0; + DT_FLOAT = 1; + DT_FLOAT16 = 2; + DT_INT8 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_UINT16 = 6; + DT_INT32 = 7; + DT_INT64 = 8; + DT_UINT32 = 9; + DT_UINT64 = 10; + DT_BOOL = 11; + DT_DOUBLE = 12; + DT_STRING = 13; + DT_DUAL_SUB_INT8 = 14; + DT_DUAL_SUB_UINT8 = 15; + DT_COMPLEX64 = 16; + DT_COMPLEX128 = 17; + DT_QINT8 = 18; + DT_QINT16 = 19; + DT_QINT32 = 20; + DT_QUINT8 = 21; + DT_QUINT16 = 22; + DT_RESOURCE = 23; + DT_STRING_REF = 24; + DT_DUAL = 25; +} + +enum OutputFormat { + FORMAT_NCHW = 0; + FORMAT_NHWC = 1; + FORMAT_ND = 2; + FORMAT_NC1HWC0 = 3; + FORMAT_FRACTAL_Z = 4; + FORMAT_NC1C0HWPAD = 5; + FORMAT_NHWC1C0 = 6; + FORMAT_FSR_NCHW = 7; + FORMAT_FRACTAL_DECONV = 8; + FORMAT_C1HWNC0 = 9; + FORMAT_FRACTAL_DECONV_TRANSPOSE = 10; + FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11; + FORMAT_NC1HWC0_C04 = 12; + FORMAT_FRACTAL_Z_C04 = 13; + FORMAT_CHWN = 14; + FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15; + FORMAT_HWCN = 16; + FORMAT_NC1KHKWHWC0 = 17; + FORMAT_BN_WEIGHT = 18; + FORMAT_FILTER_HWCK = 19; + FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20; + FORMAT_HASHTABLE_LOOKUP_KEYS = 21; + FORMAT_HASHTABLE_LOOKUP_VALUE = 22; + FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23; + FORMAT_HASHTABLE_LOOKUP_HITS=24; + FORMAT_C1HWNCoC0 = 25; + FORMAT_MD = 26; + FORMAT_NDHWC = 27; + FORMAT_FRACTAL_ZZ = 28; + FORMAT_FRACTAL_NZ = 29; + FORMAT_RESERVED = 30; +} + +message OriginalOp { + string name = 1; + uint32 output_index = 2; + OutputDataType data_type = 3; + OutputFormat format = 4; +} + +message Shape { + repeated uint64 dim = 1; +} + +message OpOutput { + OutputDataType data_type = 1; + OutputFormat format = 2; + Shape shape = 3; + OriginalOp original_op = 4; // the original op corresponding to the output + bytes data = 5; + uint64 size = 6; +} + +message OpInput { + OutputDataType data_type = 1; + OutputFormat format = 2; + Shape shape = 3; + bytes data = 4; + uint64 size = 5; +} + +enum BufferType { + L1 = 0; +} + +message OpBuffer { + BufferType buffer_type = 1; + bytes data = 2; + uint64 size = 3; +} + +message DumpData{ + string version = 1; + uint64 dump_time = 2; + repeated OpOutput output = 3; + repeated OpInput input = 4; + repeated OpBuffer buffer = 5; +} diff --git a/src/proto/fusion_model.proto b/ge/proto/fusion_model.proto similarity index 100% rename from src/proto/fusion_model.proto rename to ge/proto/fusion_model.proto diff --git a/src/proto/fwk_adapter.proto b/ge/proto/fwk_adapter.proto similarity index 100% rename from src/proto/fwk_adapter.proto rename to ge/proto/fwk_adapter.proto diff --git a/ge/proto/ge_api.proto b/ge/proto/ge_api.proto new file mode 100644 index 00000000..ac5b3b3a --- /dev/null +++ b/ge/proto/ge_api.proto @@ -0,0 +1,104 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; +package ge.api_pb; + +import "ge_ir.proto"; + +// GE initialize +message GEInitialize { + map options = 1; +}; + +// initialize response +message GEInitializeResponse { + uint32 status = 1; + uint32 clientId = 2; +}; + +// GE finalize +message GEFinalize { + bool final = 1; + uint32 clientId = 2; +}; + +message GEFinalizeResponse { + uint32 status = 1; +}; + +// GE Session +message CreateSession{ + map options = 1; +}; + +message CreateSessionResponse { + uint32 status = 1; + uint64 sessionId = 2; +}; + +//GE AddGraph +//model serialize :: serializegraph +message SessionAddGraph{ + uint32 graphId = 1; + uint64 sessionId = 2; + ge.proto.GraphDef graph = 3; +}; + +message SessionAddGraphResponse { + uint32 status = 1; +}; + +//GE SessionRemoveGraph +message SessionRemoveGraph{ + uint32 graphId = 1; + uint64 sessionId = 2; +}; + +message SessionRemoveGraphResponse { + uint32 status = 1; +}; + +message SessionRunGraph{ + uint32 graphId = 1; + uint64 sessionId = 2; + repeated ge.proto.TensorDef tensor = 3; +}; + +message SessionBuildGraph{ + uint32 graphId = 1; + uint64 sessionId = 2; + repeated ge.proto.TensorDef tensor = 3; + string savePath = 4; +}; + +message SessionRunGraphResponse { + uint32 status = 1; + repeated ge.proto.TensorDef tensor = 2; +}; + +message SessionBuildGraphResponse { + uint32 status = 1; +}; + +message DestroySession{ + bool final = 1; + uint64 sessionId = 2; +}; + +message DestroySessionResponse { + uint32 status = 1; +}; diff --git a/ge/proto/ge_ir.proto b/ge/proto/ge_ir.proto new file mode 100644 index 00000000..87886c84 --- /dev/null +++ b/ge/proto/ge_ir.proto @@ -0,0 +1,206 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package ge.proto; + +enum DataType +{ + DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. + DT_FLOAT = 1; // float type + DT_FLOAT16 = 2; // fp16 type + DT_INT8 = 3; // int8 type + DT_UINT8 = 4; // uint8 type + DT_INT16 = 5; // int16 type + DT_UINT16 = 6; // uint16 type + DT_INT32 = 7; // + DT_INT64 = 8; // int64 type + DT_UINT32 = 9; // unsigned int32 + DT_UINT64 = 10; // unsigned int64 + DT_BOOL = 11; // bool type + DT_DOUBLE = 12; // double type + DT_STRING = 13; // string type + DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ + DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ + DT_COMPLEX64 = 16; // complex64 type + DT_COMPLEX128 = 17; // complex128 type + DT_QINT8 = 18; // qint8 type + DT_QINT16 = 19; // qint16 type + DT_QINT32 = 20; // qint32 type + DT_QUINT8 = 21; // quint8 type + DT_QUINT16 = 22; // quint16 type + DT_RESOURCE = 23; // resource type + DT_STRING_REF = 24; // string_ref type + DT_DUAL = 25; /**< dual output type */ +} + +message AttrDef +{ + message ListValue + { + enum ListValueType{ + VT_LIST_NONE = 0; + VT_LIST_STRING = 1; + VT_LIST_INT = 2; + VT_LIST_FLOAT = 3; + VT_LIST_BOOL = 4; + VT_LIST_BYTES = 5; + VT_LIST_TENSOR_DESC = 6; + VT_LIST_TENSOR = 7; + VT_LIST_GRAPH = 8; + VT_LIST_NAMED_ATTRS = 9; + VT_LIST_DATA_TYPE = 10; + } + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3; // "list(int)" + repeated float f = 4; // "list(float)" + repeated bool b = 5; // "list(bool)" + repeated bytes bt = 7; + repeated TensorDescriptor td = 8; + repeated TensorDef t = 9; + repeated GraphDef g = 10; + repeated NamedAttrs na = 11; + repeated int64 dt = 12; // list ge::DataType + + ListValueType val_type = 20; + } + + message ListListInt{ + message ListInt{ + repeated int64 list_i = 1; // list int + } + repeated ListInt list_list_i = 1; // list list int + } + + oneof value + { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; // Used to support attr nesting + TensorDescriptor td = 11; // GeTensorDesc type + TensorDef t = 12; // GeTensor type + GraphDef g = 13; // Graph type + ListListInt list_list_int = 14; // List List Int type + int64 dt = 15; // ge::DataType + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs +{ + string name = 1; + map attr = 2; +} + +// Shape / dimension description, using row-major order +message ShapeDef +{ + repeated int64 dim = 1; // Size of each dimension +} + +// Multidimensional data description +message TensorDescriptor +{ + string name = 1; // Optional parameter, tensor name + + DataType dtype = 2; // tensor datatype + ShapeDef shape = 3; // Shape / dimension + string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" + + bool has_out_attr = 9; + int64 size = 10; + int64 weight_size = 11; + bool reuse_input = 12; + bool output_tensor = 13; + string device_type = 14; + bool input_tensor =15; + int64 real_dim_cnt = 16; + int64 reuse_input_index = 17; + int64 data_offset = 18; + int64 cmps_size = 19; + string cmps_tab = 20; + int64 cmps_tab_offset = 21; + + map attr = 5; // Set of extra parameter fields +} + +// GeTensor definition +message TensorDef +{ + TensorDescriptor desc = 1; // Tensor description + bytes data = 2; // Tensor data +} + + +// Operator description +message OpDef +{ + string name = 1; // name + string type = 2; // type + + repeated string input = 5; // input original op name + outgoing index. op_name:index + + map attr = 10; // Set of operator parameter fields + + bool has_out_attr = 20; + int64 id = 21; + int64 stream_id =22; + repeated string input_name = 23; + repeated string src_name = 24; + repeated int64 src_index = 25; + repeated string dst_name = 26; + repeated int64 dst_index = 27; + repeated int64 input_i = 28; + repeated int64 output_i = 29; + repeated int64 workspace = 30; + repeated int64 workspace_bytes = 31; + repeated bool is_input_const = 32; + repeated TensorDescriptor input_desc = 33; + repeated TensorDescriptor output_desc = 34; + repeated string subgraph_name = 35; +} + +// Graph definition +message GraphDef +{ + string name = 1; // name + + repeated string input = 4; // Graph input + repeated string output = 5; // Graph output + + repeated OpDef op = 6; // List of operators + + map attr = 11; // Extended field +} + +// model definition +message ModelDef +{ + string name = 1; // name + uint32 version = 2; // IR Proto verion + string custom_version = 3; // User model version number, passed in by user + + repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef + + map attr = 11; // Extended field +} + diff --git a/ge/proto/insert_op.proto b/ge/proto/insert_op.proto new file mode 100644 index 00000000..a059e122 --- /dev/null +++ b/ge/proto/insert_op.proto @@ -0,0 +1,152 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package domi; + +message InsertNewOps { + repeated AippOpParams aipp_op = 1; + repeated MultiShapeOpParams multi_shape_op = 2; +} + +message AippOpParams { + enum InputFormat { + UNDEFINED = 0; + YUV420SP_U8 = 1; + XRGB8888_U8 = 2; + RGB888_U8 = 3; + YUV400_U8 = 4; + NC1HWC0DI_FP16 = 5; + NC1HWC0DI_S8 = 6; + ARGB8888_U8 = 7; + YUYV_U8 = 8; + YUV422SP_U8 = 9; + AYUV444_U8 = 10; + RAW10 = 11; + RAW12 = 12; + RAW16 = 13; + RAW24 = 14; + RGB16 = 15; + RGB20 = 16; + RGB24 = 17; + RGB8_IR = 18; + RGB16_IR = 19; + RGB24_IR = 20; + } + + enum AippMode { + undefined = 0; + static = 1; + dynamic = 2; + } + + // AIPPģʽ£¬Çø·Ö¾²Ì¬AIPPºÍ¶¯Ì¬AIPP + AippMode aipp_mode = 1; + + // related_input_rank²ÎÊýΪ±ØÌÀàÐÍΪÕûÐÍ£¬ÅäÖ÷¶Î§>=0, <=ÊäÈëDataËã×ӵĸöÊý£¬Ä¬ÈÏֵΪ0¡£ + // ±êʶ¶ÔÄ£Ð͵ĵڼ¸¸öÊäÈë×öAIPP´¦Àí£¬ÀýÈçÄ£ÐÍÓÐÁ½¸öÊäÈ룬ÐèÒª¶ÔµÚ2¸öÊäÈë×öAIPP£¬ÔòÅäÖÃrelated_input_rankΪ1¡£ + uint32 related_input_rank = 2; + + // input_edge_idx²ÎÊýΪ¿ÉÑ¡£¬ÀàÐÍΪÕûÐÍ£¬ÅäÖ÷¶Î§Îª>=0¡£ + // ÅäÖøòÎÊýµÄ×÷Óã¬ÔÚÓÚ¶ÔDataËã×Ó²»Í¬µÄÊä³ö×ö²»Í¬µÄAIPP´¦Àí£¬Èç¹û¸Ã²ÎÊýûÓÐÅäÖã¬Ä¬È϶Ôrelated_input_rankÖ¸¶¨µÄÄ£ÐÍÊäÈëµÄËùÓÐÊä³ö±ß×öAIPP¡£ + // ÅäÖÃÖµ <= DataËã×ÓÊä³ö±ßµÄ¸öÊý¡£ + repeated uint32 input_edge_idx = 3; + + // [Begin] ¶¯Ì¬AIPP²ÎÊý£¬ÅäÖþ²Ì¬AIPPʱÎÞЧ + uint32 max_src_image_size = 4; + + // ÊÇ·ñÖ§³ÖÐýת¡£Ä¬Èϲ»Ö§³Ö£¬¿ªÆôÖ§³ÖÐýתʱ£¬»áÓжîÍâµÄ¿Õ¼äºÍÐÔÄÜËðʧ + bool support_rotation = 5; + + // [End] ¶¯Ì¬AIPP²ÎÊý + + + // [Begin] ¾²Ì¬AIPP²ÎÊý£¬ÅäÖö¯Ì¬AIPPʱÎÞЧ + InputFormat input_format = 51; + bool csc_switch = 52; + float cpadding_value = 53; + bool rbuv_swap_switch = 54; + bool ax_swap_switch = 55; + bool single_line_mode = 56; + + int32 src_image_size_w = 57; + int32 src_image_size_h = 58; + + bool crop = 59; + int32 load_start_pos_w = 60; + int32 load_start_pos_h = 61; + int32 crop_size_w = 62; + int32 crop_size_h = 63; + + bool resize = 64; + int32 resize_output_w = 65; + int32 resize_output_h = 66; + + bool padding = 67; + int32 left_padding_size = 68; + int32 right_padding_size = 69; + int32 top_padding_size = 70; + int32 bottom_padding_size = 71; + + int32 mean_chn_0 = 10; + int32 mean_chn_1 = 11; + int32 mean_chn_2 = 12; + int32 mean_chn_3 = 19; + float min_chn_0 = 13; + float min_chn_1 = 14; + float min_chn_2 = 15; + float min_chn_3 = 20; + repeated float var_reci_chn_0 = 16; + repeated float var_reci_chn_1 = 17; + repeated float var_reci_chn_2 = 18; + repeated float var_reci_chn_3 = 21; + + repeated int32 matrix_r0c0 = 30; + repeated int32 matrix_r0c1 = 31; + repeated int32 matrix_r0c2 = 32; + repeated int32 matrix_r1c0 = 33; + repeated int32 matrix_r1c1 = 34; + repeated int32 matrix_r1c2 = 35; + repeated int32 matrix_r2c0 = 36; + repeated int32 matrix_r2c1 = 37; + repeated int32 matrix_r2c2 = 38; + repeated int32 output_bias_0 = 39; + repeated int32 output_bias_1 = 40; + repeated int32 output_bias_2 = 41; + repeated int32 input_bias_0 = 42; + repeated int32 input_bias_1 = 43; + repeated int32 input_bias_2 = 44; + + // [End] ¾²Ì¬AIPP²ÎÊý + + // The n number that is used for raw/rgbir data into f16 transformation. + // The transformation equation is x/(2^n). If set to 0, no transform is performed. + uint32 raw_rgbir_to_f16_n = 45; +} + +message MultiShapeOpParams { + enum MultiShapeMode { + batch = 0; //¶¯Ì¬batch + resolution = 1; //¶¯Ì¬·Ö±æÂÊ£¬À©Õ¹Óà + } + + MultiShapeMode mode = 1; //Ëã×Óģʽ + uint32 related_input_rank = 2; //ÐÂÔöËã×Ó²åÈëµ½ÄĸöÊäÈë + + + repeated uint32 batch_list = 11; //batch_listÖµ£¬batch_listµÄ¸öÊýÊÇ2µ½8Ö®¼ä +} diff --git a/ge/proto/om.proto b/ge/proto/om.proto new file mode 100644 index 00000000..dd992191 --- /dev/null +++ b/ge/proto/om.proto @@ -0,0 +1,401 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package domi; + +enum TargetType +{ + MINI = 0; + TINY = 1; + LITE = 2; +} + +// offline model +message ModelDef { + string name = 1; + uint32 version = 2; + + uint64 memory_size = 10; + uint32 stream_num = 11; + uint32 event_num = 12; + uint64 weight_size = 13; + uint32 label_num = 15; + repeated OpDef op = 20; + TargetType target_type = 23; + + map attr = 30; +}; + +// operator define +message OpDef { + string name = 1; + string type = 2; + + uint32 id = 3; + uint32 stream_id = 4; + + repeated string input_name = 5; + + repeated string src_name = 8; + repeated int32 src_index = 9; + repeated int64 input = 10; + repeated int64 output = 11; + repeated TensorDescriptor input_desc = 12; + repeated TensorDescriptor output_desc = 13; + repeated WeightDef weights = 14; + repeated string dst_name = 15; + repeated int32 dst_index = 16; + + repeated int64 workspace = 20; + repeated uint32 workspace_bytes = 21; + + repeated string weight_name = 22; + repeated bool is_input_const = 23; + + map attr = 30; + + QuantizeFactorParams quantize_factor = 31; + + oneof op_params { + // start at 100 here + SendOpParams sender_param = 100; + RecvOpParams receiver_param = 200; + ConvolutionOpParams convolution_param = 300; + PoolingOpParams pooling_param = 400; + EltwiseOpParams eltwise_param = 500; + BatchNormOpParams batchnorm_param = 600; + ScaleOpParams scale_param = 700; + FullConnectionOpParams full_connection_param = 800; + SoftmaxOpParams softmax_param = 900; + ActivationOpParams activation_param = 1000; + ReshapeOpParams reshape_param = 1100; + } +}; + +message SendOpParams { + uint32 event_id = 1; +}; + +message RecvOpParams { + uint32 event_id = 1; +}; + +enum QuantizeScaleType +{ + VECTOR_SCALE = 0; + SCALAR_SCALE = 1; +} + +enum QuantizeScaleMode +{ + NORMAL_MODE = 0; + SQRT_MODE = 1; +} + +enum QuantizeAlgorithm +{ + NON_OFFSET_ALGO = 0; + HALF_OFFSET_ALGO = 1; + ALL_OFFSET_ALGO = 2; +} +message QuantizeFactor +{ + QuantizeScaleMode scale_mode = 1; + bytes scale_value = 2; + int64 scale_offset = 3; + bytes offset_data_value = 4; + int64 offset_data_offset = 5; + bytes offset_weight_value = 6; + int64 offset_weight_offset = 7; + bytes offset_pad_value = 8; + int64 offset_pad_offset = 9; +}; + +message QuantizeCalcFactor +{ + bytes offsetw = 1; + int64 offsetw_offset = 2; + bytes offsetd = 3; + int64 offsetd_offset = 4; + bytes scalereq = 5; + int64 scaledreq_offset = 6; + bytes offsetdnext = 7; + int64 offsetdnext_offset = 8; +} + +message QuantizeFactorParams +{ + QuantizeAlgorithm quantize_algo = 1; + QuantizeScaleType scale_type = 2; + QuantizeFactor quantize_param = 3; + QuantizeFactor dequantize_param = 4; + QuantizeFactor requantize_param = 5; + QuantizeCalcFactor quantizecalc_param = 6; +}; + +message ConvolutionOpParams { + int32 mode = 1; + int32 algo = 2; + int32 pad_mode = 3; + uint32 group = 4; + uint32 num_output = 5; + + repeated uint32 pad = 10; + repeated uint32 stride = 11; + repeated uint32 dilation = 12; + repeated uint32 kernel = 13; + + float alpha = 20; + float beta = 21; + + WeightDef filter = 40; + WeightDef bias = 41; + + bool relu_flag = 62; + repeated uint32 adj = 70; + repeated uint32 target_shape = 71; + repeated uint32 before_pad = 72; +}; + +message PoolingOpParams { + int32 mode = 1; + int32 nan_opt = 2; + int32 pad_mode = 3; + bool global_pooling = 4; + + repeated uint32 window = 10; + repeated uint32 pad = 11; + repeated uint32 stride = 12; + bool ceil_mode = 13; + int32 data_mode = 14; + + float alpha = 20; + float beta = 21; + repeated uint32 before_pad = 22; +}; + +message EltwiseOpParams { + int32 mode = 1; + repeated float coeff = 2; + float alpha = 3; + float beta = 4; + repeated WeightDef weight = 5; + bool relu_flag = 6; +}; + +message ActivationOpParams { + int32 mode = 1; + float coef = 2; + float alpha = 3; + float beta = 4; +}; + +message BatchNormOpParams { + int32 mode = 1; + + float alpha = 2; + float beta = 3; + double epsilon = 4;//optinal,[default = 1e-5] + bool use_global_stats = 5; //optinal,by default true,testing mode + float moving_average_fraction = 6; //optinal,[default = .999]; + + WeightDef estimated_mean = 7; + WeightDef estimated_variance = 8; + + WeightDef scale = 9; + WeightDef bias = 10; +}; + +message ScaleOpParams { + WeightDef scale = 1; + WeightDef bias = 2; +}; + +message ReshapeOpParams { + float alpha = 1; + float beta = 2; + ShapeDef shape = 3; + int32 axis = 4; + int32 num_axes = 5; + int32 format = 6; +}; + +message SoftmaxOpParams { + int32 algo = 1; + int32 mode = 2; + float alpha = 3; + float beta = 4; +}; + +message FullConnectionOpParams { + WeightDef filter = 1; + WeightDef bias = 2; + uint32 num_output = 3; + bool relu_flag = 12; +}; + +message FlattenOpParams { + float alpha = 1; + float beta = 2; + int32 start_axis = 3; + int32 end_axis = 4; +} + +message AddLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message MulLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message AddOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message MulOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message SubOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message BiasAddOpParams { + float alpha = 1; + float beta = 2; + + WeightDef bias = 10; +}; + +message MatMulOpParams { + float alpha = 1; + float beta = 2; + bool transposeX = 3; + bool transposeW = 4; + + WeightDef filter = 10; + WeightDef bias = 12; +}; + +message RsqrtOpParams { + float alpha = 1; + float beta = 2; +}; + + +message WeightDef { + int32 format = 1; + int32 data_type = 2; + ShapeDef shape = 3; + bytes data = 4; + int64 data_offset = 5; + uint32 cmps_size = 6; + bytes cmps_tab = 7; + int64 cmps_tab_offset = 10; + CompressInfo cmps_info = 8; + AllOffsetQuantizeInfo alloffset_quantize_info = 11; +} + +message ShapeDef { + repeated int64 dim = 1; +} + +enum DeviceType { + NPU = 0; // In default, we will use NPU. + CPU = 1; // CPU +} + +message AllOffsetQuantizeInfo { + float scale = 1; + int32 offset = 2; +} + +message TensorDescriptor { + int32 format = 1; + int32 data_type = 2; + repeated int64 dim = 3; + uint32 size = 4; + bool reuse_input = 5; + bool output_tensor = 7; + DeviceType device_type = 8; + bool input_tensor = 9; + uint32 real_dim_cnt = 10; + uint32 reuse_input_index = 11; + AllOffsetQuantizeInfo alloffset_quantize_info = 12; +} + +message CompressInfo { + int32 blockRow = 1; // block row + int32 blockCol = 2; // block col + int32 fractalK = 3; // fractal K + int32 fractalN = 4; // fractal N + int32 lastFractalK = 5; // K of last fractal + int32 lastFractalN = 6; // N of last fractal + int32 cubeSize = 7; // cube's length + int32 loadDir = 8; // data load directtiono 0:col load 1:row load +} + +message AttrDef { + message ListValue { + repeated string s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated uint32 u = 6 [packed = true]; // "list(uint)" + repeated bytes bt = 7; + } + + oneof value { + string s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + uint32 u = 6; // "uint32" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs { + string name = 1; + map attr = 2; +} + diff --git a/ge/proto/op_mapping_info.proto b/ge/proto/op_mapping_info.proto new file mode 100644 index 00000000..7b84a115 --- /dev/null +++ b/ge/proto/op_mapping_info.proto @@ -0,0 +1,89 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; +package aicpu.dump; + +message Shape { + repeated uint64 dim = 1; +} + +message Output { + int32 data_type = 1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + string original_name = 5; + int32 original_output_index = 6; + int32 original_output_data_type = 7; + int32 original_output_format = 8; + uint64 size = 9; +} + +message Input { + int32 data_type =1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + uint64 size = 5; +} + +enum BufferType { + L1 = 0; +} + +message OpBuffer { + BufferType buffer_type = 1; + uint64 address = 2; + uint64 size = 3; +} + +message Op { + string op_name = 1; + string op_type = 2; +} + +message Task { + uint32 task_id = 1; + uint32 stream_id = 2; + Op op = 3; + repeated Output output = 4; + bool end_graph = 5; + repeated Input input = 6; + repeated OpBuffer buffer = 7; +} + +message OpMappingInfo { + string dump_path = 1; + oneof model_name_param { + string model_name = 2; + } + oneof model_id_param { + uint32 model_id = 3; + } + oneof step_id { + uint64 step_id_addr = 4; + } + oneof iterations_per_loop { + uint64 iterations_per_loop_addr = 5; + } + oneof loop_cond { + uint64 loop_cond_addr = 6; + } + uint32 flag = 7; // 0x01 load, 0x00 unload + repeated Task task = 8; + string dump_step = 9; +} \ No newline at end of file diff --git a/src/proto/optimizer_priority.proto b/ge/proto/optimizer_priority.proto similarity index 100% rename from src/proto/optimizer_priority.proto rename to ge/proto/optimizer_priority.proto diff --git a/ge/proto/task.proto b/ge/proto/task.proto new file mode 100644 index 00000000..50ea061b --- /dev/null +++ b/ge/proto/task.proto @@ -0,0 +1,170 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package domi; + +message ModelTaskDef { + string version = 1; + + map attr = 9; // Extended field + repeated TaskDef task = 10; + + uint64 memory_size = 11; + uint32 stream_num = 12; + uint32 event_num = 13; + uint64 weight_size = 14; + + repeated bytes op = 15; // input/output opdef in bytes + + uint64 base_addr = 16; // base addr + uint64 weight_addr = 17; // weight addr + uint32 batch_num = 18; +} + + +message TaskDef { + uint32 id = 1; + uint32 type = 2; + + uint32 stream_id = 10; + uint32 event_id = 11; + + KernelDef kernel = 20; + KernelExDef kernel_ex = 21; + KernelHcclDef kernel_hccl = 25; + EventExDef event_ex = 26; + LogTimeStampDef log_timestamp = 28; + + uint32 label_id = 30; + + MemcpyAsyncDef memcpy_async = 31; + StreamSwitchDef stream_switch = 32; + StreamActiveDef stream_active = 33; + bytes private_def = 34; + uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future + StreamSwitchNDef stream_switch_n = 36; + + LabelSetDef label_set = 37; + LabelGotoExDef label_goto_ex = 38; + LabelSwitchByIndexDef label_switch_by_index = 39; +} + +message KernelDef { + KernelContext context = 1; + + string stub_func = 10; + uint32 block_dim = 11; + uint32 args_size = 12; + bytes args = 13; + bytes sm_desc = 14; + bytes flowtable = 15; + string so_name = 16; + string kernel_name = 17; + bytes kernel_ext_info = 18; + uint32 kernel_ext_info_size = 19; +} + +message KernelContext { + uint32 kernel_type = 1; + uint32 op_id = 2; // OP type in CCE + uint32 kernel_func_id = 3; + uint32 op_index = 4; // TE/Custom operator + bool is_flowtable = 5; // Identify whether args is a flowtable structure + bytes args_offset = 6; // args offset information + uint32 args_count = 7; // args count + repeated uint32 origin_op_index = 8; +} + + +message KernelExDef { + uint32 flags = 1; + + uint32 op_index = 4; + uint32 args_size = 12; + bytes args = 13; + bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput + uint32 task_info_size = 15; + bytes kernel_ext_info = 16; + uint32 kernel_ext_info_size = 17; +} + + +message KernelHcclDef { + uint32 op_index = 8; + string hccl_type = 9; +} + + +message EventExDef { + uint32 op_index = 1; + uint32 event_type = 2; +} + +message LogTimeStampDef { + uint64 logid = 1; + bool notify = 2; + uint32 flat = 3; +} + +message MemcpyAsyncDef { + uint64 dst = 1; + uint64 dst_max = 2; + uint64 src = 3; + uint64 count = 4; + uint32 kind = 5; + uint32 op_index = 6; +} + +message StreamSwitchDef { + uint32 op_index = 1; + uint32 true_stream_id = 2; + int64 value = 3; + uint64 value_ptr = 4; + uint32 data_type = 5; +} + +message StreamActiveDef { + uint32 op_index = 1; + uint32 active_stream_id = 2; +} + +message StreamSwitchNDef { + uint32 op_index = 1; + uint32 size = 2; + repeated int64 target_value = 3; + repeated uint32 true_stream_id = 4; + uint32 element_size = 5; + uint32 data_type = 6; +} + +message LabelSetDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelGotoExDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelSwitchByIndexDef { + uint32 op_index = 1; + uint32 label_max = 2; +} diff --git a/ge/proto/tensorflow/attr_value.proto b/ge/proto/tensorflow/attr_value.proto new file mode 100644 index 00000000..1cc67d62 --- /dev/null +++ b/ge/proto/tensorflow/attr_value.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AttrValueProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensor.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + // LINT.IfChange + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + repeated NameAttrList func = 9; // "list(attr)" + } + // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/ge/proto/tensorflow/function.proto b/ge/proto/tensorflow/function.proto new file mode 100644 index 00000000..075897c6 --- /dev/null +++ b/ge/proto/tensorflow/function.proto @@ -0,0 +1,100 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "FunctionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "node_def.proto"; +import "op_def.proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; + repeated GradientDef gradient = 2; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // Attributes specific to this function definition. + map attr = 5; + + // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. + reserved 2; + + // In both of the following fields, there is the need to specify an + // output that is used as either the input to another node (in + // `node_def`) or as a return value of the function (in `ret`). + // Unlike the NodeDefs in GraphDef, we need to be able to specify a + // list in some cases (instead of just single outputs). Also, we + // need to be able to deal with lists of unknown length (so the + // output index may not be known at function definition time). So + // we use the following format instead: + // * "fun_in" where "fun_in" is the name of a function input arg in + // the `signature` field above. This represents that input, whether + // it is a single tensor or a list. + // * "fun_in:0" gives the first element of a function input arg (a + // non-list input is considered a list of length 1 for these + // purposes). + // * "node:out" where "node" is the name of a node in `node_def` and + // "out" is the name one of its op's output arguments (the name + // comes from the OpDef of the node's op). This represents that + // node's output, whether it is a single tensor or a list. + // Note: We enforce that an op's output arguments are never + // renamed in the backwards-compatibility test. + // * "node:out:0" gives the first element of a node output arg (a + // non-list output is considered a list of length 1 for these + // purposes). + // + // NOT CURRENTLY SUPPORTED (but may be in the future): + // * "node:out:-1" gives last element in a node output list + // * "node:out:1:" gives a list with all but the first element in a + // node output list + // * "node:out::-1" gives a list with all but the last element in a + // node output list + + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs + // may have values of type `placeholder` and the `input` field uses + // the "output" format above. + + // By convention, "op" in node_def is resolved by consulting with a + // user-defined library first. If not resolved, "func" is assumed to + // be a builtin op. + repeated NodeDef node_def = 3; + + // A mapping from the output arg names from `signature` to the + // outputs from `node_def` that should be returned by the function. + map ret = 4; +} + +// GradientDef defines the gradient function of a function defined in +// a function library. +// +// A gradient function g (specified by gradient_func) for a function f +// (specified by function_name) must follow the following: +// +// The function 'f' must be a numerical function which takes N inputs +// and produces M outputs. Its gradient function 'g', which is a +// function taking N + M inputs and produces N outputs. +// +// I.e. if we have +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// then, g is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the +// loss function). dL/dx_i is the partial derivative of L with respect +// to x_i. +message GradientDef { + string function_name = 1; // The function name. + string gradient_func = 2; // The gradient function's name. +} diff --git a/ge/proto/tensorflow/graph.proto b/ge/proto/tensorflow/graph.proto new file mode 100644 index 00000000..d639a7d6 --- /dev/null +++ b/ge/proto/tensorflow/graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "node_def.proto"; +import "function.proto"; +import "versions.proto"; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + + // Compatibility versions of the graph. See core/public/version.h for version + // history. The GraphDef version is distinct from the TensorFlow version, and + // each release of TensorFlow will support a range of GraphDef versions. + VersionDef versions = 4; + + // Deprecated single version field; use versions above instead. Since all + // GraphDef changes before "versions" was introduced were forward + // compatible, this field is entirely ignored. + int32 version = 3 [deprecated = true]; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", { ... }} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; diff --git a/ge/proto/tensorflow/graph_library.proto b/ge/proto/tensorflow/graph_library.proto new file mode 100644 index 00000000..e393d38d --- /dev/null +++ b/ge/proto/tensorflow/graph_library.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package domi.tensorflow; + +import "graph.proto"; + +message GeGraphDef { + string name = 1; + GraphDef graph = 2; +} + +message GraphDefLibrary { + repeated GeGraphDef graph_def = 1; +}; \ No newline at end of file diff --git a/ge/proto/tensorflow/node_def.proto b/ge/proto/tensorflow/node_def.proto new file mode 100644 index 00000000..b9bc97ee --- /dev/null +++ b/ge/proto/tensorflow/node_def.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "NodeProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= PARTIAL_SPEC + // + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) + // * "/job:worker/device:GPU:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // Add some examples here showing best practices. + map attr = 5; +}; diff --git a/ge/proto/tensorflow/op_def.proto b/ge/proto/tensorflow/op_def.proto new file mode 100644 index 00000000..3485d045 --- /dev/null +++ b/ge/proto/tensorflow/op_def.proto @@ -0,0 +1,164 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "OpDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +// LINT.IfChange +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the "list" field of AttrValue). + // If type == "type" or "list(type)" above, then the "type" field + // of "allowed_values.list" has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the "s" field of + // "allowed_values.list" has the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // Optional deprecation based on GraphDef versions. + OpDeprecation deprecation = 8; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // Ops are marked as stateful if their behavior depends on some state beyond + // their input tensors (e.g. variable reading op) or if they have + // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops + // must always produce the same output for the same input and have + // no side-effects. + // + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) + +// Information about version-dependent deprecation of an op +message OpDeprecation { + // First GraphDef version at which the op is disallowed. + int32 version = 1; + + // Explanation of why it was deprecated and what to use instead. + string explanation = 2; +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/ge/proto/tensorflow/resource_handle.proto b/ge/proto/tensorflow/resource_handle.proto new file mode 100644 index 00000000..a3452351 --- /dev/null +++ b/ge/proto/tensorflow/resource_handle.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ResourceHandle"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message ResourceHandleProto { + // Unique name for the device containing the resource. + string device = 1; + + // Container in which this resource is placed. + string container = 2; + + // Unique name of this resource. + string name = 3; + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code = 4; + + // For debug-only, the name of the type pointed to by this handle, if + // available. + string maybe_type_name = 5; +}; diff --git a/ge/proto/tensorflow/tensor.proto b/ge/proto/tensorflow/tensor.proto new file mode 100644 index 00000000..d0a4d024 --- /dev/null +++ b/ge/proto/tensorflow/tensor.proto @@ -0,0 +1,94 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "resource_handle.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized raw tensor content from either Tensor::AsProtoTensorContent or + // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + // can be used for all tensor types. The purpose of this representation is to + // reduce serialization overhead during RPC call by avoiding serialization of + // many repeated small items. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll + // have some pointless zero padding for each value here. + repeated int32 half_val = 13 [packed = true]; + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; + + // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + // and imaginary parts of i-th double precision complex. + repeated double dcomplex_val = 12 [packed = true]; + + // DT_RESOURCE + repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; + + // DT_UINT32 + repeated uint32 uint32_val = 16 [packed = true]; + + // DT_UINT64 + repeated uint64 uint64_val = 17 [packed = true]; +}; + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/ge/proto/tensorflow/tensor_shape.proto b/ge/proto/tensorflow/tensor_shape.proto new file mode 100644 index 00000000..4225a2e3 --- /dev/null +++ b/ge/proto/tensorflow/tensor_shape.proto @@ -0,0 +1,45 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorShapeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +package domi.tensorflow; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/ge/proto/tensorflow/types.proto b/ge/proto/tensorflow/types.proto new file mode 100644 index 00000000..ba7a72b3 --- /dev/null +++ b/ge/proto/tensorflow/types.proto @@ -0,0 +1,74 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TypesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// LINT.IfChange +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + DT_QINT16 = 15; // Quantized int16 + DT_QUINT16 = 16; // Quantized uint16 + DT_UINT16 = 17; + DT_COMPLEX128 = 18; // Double-precision complex + DT_HALF = 19; + DT_RESOURCE = 20; + DT_VARIANT = 21; // Arbitrary C++ data types + DT_UINT32 = 22; + DT_UINT64 = 23; + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; + DT_QINT16_REF = 115; + DT_QUINT16_REF = 116; + DT_UINT16_REF = 117; + DT_COMPLEX128_REF = 118; + DT_HALF_REF = 119; + DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; + DT_UINT32_REF = 122; + DT_UINT64_REF = 123; +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/c/c_api.h, +// https://www.tensorflow.org/code/tensorflow/go/tensor.go, +// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, +// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, +// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/ge/proto/tensorflow/versions.proto b/ge/proto/tensorflow/versions.proto new file mode 100644 index 00000000..48061218 --- /dev/null +++ b/ge/proto/tensorflow/versions.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VersionsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Version information for a piece of serialized data +// +// There are different types of versions for each type of data +// (GraphDef, etc.), but they all have the same common shape +// described here. +// +// Each consumer has "consumer" and "min_producer" versions (specified +// elsewhere). A consumer is allowed to consume this data if +// +// producer >= min_producer +// consumer >= min_consumer +// consumer not in bad_consumers +// +message VersionDef { + // The version of the code that produced this data. + int32 producer = 1; + + // Any consumer below this version is not allowed to consume this data. + int32 min_consumer = 2; + + // Specific consumer versions which are disallowed (e.g. due to bugs). + repeated int32 bad_consumers = 3; +}; diff --git a/src/ge/session/inner_session.cc b/ge/session/inner_session.cc similarity index 100% rename from src/ge/session/inner_session.cc rename to ge/session/inner_session.cc diff --git a/src/ge/session/inner_session.h b/ge/session/inner_session.h similarity index 100% rename from src/ge/session/inner_session.h rename to ge/session/inner_session.h diff --git a/src/ge/session/omg.cc b/ge/session/omg.cc similarity index 100% rename from src/ge/session/omg.cc rename to ge/session/omg.cc diff --git a/ge/session/readme.txt b/ge/session/readme.txt new file mode 100644 index 00000000..d8d0f393 --- /dev/null +++ b/ge/session/readme.txt @@ -0,0 +1,3 @@ +GE +SessionManager +InnerSession diff --git a/src/ge/session/session_manager.cc b/ge/session/session_manager.cc similarity index 100% rename from src/ge/session/session_manager.cc rename to ge/session/session_manager.cc diff --git a/src/ge/session/session_manager.h b/ge/session/session_manager.h similarity index 100% rename from src/ge/session/session_manager.h rename to ge/session/session_manager.h diff --git a/src/ge/single_op/single_op.cc b/ge/single_op/single_op.cc similarity index 100% rename from src/ge/single_op/single_op.cc rename to ge/single_op/single_op.cc diff --git a/src/ge/single_op/single_op.h b/ge/single_op/single_op.h similarity index 100% rename from src/ge/single_op/single_op.h rename to ge/single_op/single_op.h diff --git a/src/ge/single_op/single_op_manager.cc b/ge/single_op/single_op_manager.cc similarity index 100% rename from src/ge/single_op/single_op_manager.cc rename to ge/single_op/single_op_manager.cc diff --git a/src/ge/single_op/single_op_manager.h b/ge/single_op/single_op_manager.h similarity index 100% rename from src/ge/single_op/single_op_manager.h rename to ge/single_op/single_op_manager.h diff --git a/src/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc similarity index 100% rename from src/ge/single_op/single_op_model.cc rename to ge/single_op/single_op_model.cc diff --git a/src/ge/single_op/single_op_model.h b/ge/single_op/single_op_model.h similarity index 100% rename from src/ge/single_op/single_op_model.h rename to ge/single_op/single_op_model.h diff --git a/src/ge/single_op/stream_resource.cc b/ge/single_op/stream_resource.cc similarity index 100% rename from src/ge/single_op/stream_resource.cc rename to ge/single_op/stream_resource.cc diff --git a/src/ge/single_op/stream_resource.h b/ge/single_op/stream_resource.h similarity index 100% rename from src/ge/single_op/stream_resource.h rename to ge/single_op/stream_resource.h diff --git a/src/ge/single_op/task/aicpu_kernel_task_builder.cc b/ge/single_op/task/aicpu_kernel_task_builder.cc similarity index 100% rename from src/ge/single_op/task/aicpu_kernel_task_builder.cc rename to ge/single_op/task/aicpu_kernel_task_builder.cc diff --git a/src/ge/single_op/task/aicpu_kernel_task_builder.h b/ge/single_op/task/aicpu_kernel_task_builder.h similarity index 100% rename from src/ge/single_op/task/aicpu_kernel_task_builder.h rename to ge/single_op/task/aicpu_kernel_task_builder.h diff --git a/src/ge/single_op/task/aicpu_task_builder.cc b/ge/single_op/task/aicpu_task_builder.cc similarity index 100% rename from src/ge/single_op/task/aicpu_task_builder.cc rename to ge/single_op/task/aicpu_task_builder.cc diff --git a/src/ge/single_op/task/aicpu_task_builder.h b/ge/single_op/task/aicpu_task_builder.h similarity index 100% rename from src/ge/single_op/task/aicpu_task_builder.h rename to ge/single_op/task/aicpu_task_builder.h diff --git a/src/ge/single_op/task/build_task_utils.cc b/ge/single_op/task/build_task_utils.cc similarity index 100% rename from src/ge/single_op/task/build_task_utils.cc rename to ge/single_op/task/build_task_utils.cc diff --git a/src/ge/single_op/task/build_task_utils.h b/ge/single_op/task/build_task_utils.h similarity index 100% rename from src/ge/single_op/task/build_task_utils.h rename to ge/single_op/task/build_task_utils.h diff --git a/src/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc similarity index 100% rename from src/ge/single_op/task/op_task.cc rename to ge/single_op/task/op_task.cc diff --git a/src/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h similarity index 100% rename from src/ge/single_op/task/op_task.h rename to ge/single_op/task/op_task.h diff --git a/src/ge/single_op/task/tbe_task_builder.cc b/ge/single_op/task/tbe_task_builder.cc similarity index 100% rename from src/ge/single_op/task/tbe_task_builder.cc rename to ge/single_op/task/tbe_task_builder.cc diff --git a/src/ge/single_op/task/tbe_task_builder.h b/ge/single_op/task/tbe_task_builder.h similarity index 100% rename from src/ge/single_op/task/tbe_task_builder.h rename to ge/single_op/task/tbe_task_builder.h diff --git a/src/ge/stub/Makefile b/ge/stub/Makefile similarity index 100% rename from src/ge/stub/Makefile rename to ge/stub/Makefile diff --git a/src/ge/stub/README b/ge/stub/README similarity index 100% rename from src/ge/stub/README rename to ge/stub/README diff --git a/src/ge/stub/README.md b/ge/stub/README.md old mode 100755 new mode 100644 similarity index 100% rename from src/ge/stub/README.md rename to ge/stub/README.md diff --git a/src/ge/stub/gen_stubapi.py b/ge/stub/gen_stubapi.py similarity index 99% rename from src/ge/stub/gen_stubapi.py rename to ge/stub/gen_stubapi.py index b6e1e70c..0c5e712b 100644 --- a/src/ge/stub/gen_stubapi.py +++ b/ge/stub/gen_stubapi.py @@ -64,7 +64,7 @@ max_code_len_per_line = 100 when DEBUG on """ white_list_for_debug = ["attr_value.h", "operator.h", "tensor.h", "graph.h", "operator_factory.h", - "ge_ir_build.h", "ge_api.h", "tensorflow_parser.h", "caffe_parser.h"] + "ge_ir_build.h", "ge_api.h", "ge_prof.h", "tensorflow_parser.h", "caffe_parser.h"] include_dir_key_words = ["ge", "graph", "parser"] DEBUG = True diff --git a/inc/common/blocking_queue.h b/inc/common/blocking_queue.h deleted file mode 100644 index 12b02773..00000000 --- a/inc/common/blocking_queue.h +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_COMMON_BLOCKING_QUEUE_H_ -#define INC_COMMON_BLOCKING_QUEUE_H_ - -#include -#include -#include -#include - -static const int kDefaultMaxQueueSize = 2048; - -template -class BlockingQueue { - public: - explicit BlockingQueue(uint32_t max_size = kDefaultMaxQueueSize) : max_size_(max_size), is_stoped_(false) {} - - ~BlockingQueue() {} - - bool Pop(T &item) { - std::unique_lock lock(mutex_); - - while (queue_.empty() && !is_stoped_) { - empty_cond_.wait(lock); - } - - if (is_stoped_) { - return false; - } - - item = std::move(queue_.front()); - queue_.pop_front(); - - full_cond_.notify_one(); - - return true; - } - - bool Push(const T &item, bool is_wait = true) { - std::unique_lock lock(mutex_); - - while (queue_.size() >= max_size_ && !is_stoped_) { - if (!is_wait) { - return false; - } - full_cond_.wait(lock); - } - - if (is_stoped_) { - return false; - } - - queue_.push_back(item); - - empty_cond_.notify_one(); - - return true; - } - - bool Push(T &&item, bool is_wait = true) { - std::unique_lock lock(mutex_); - - while (queue_.size() >= max_size_ && !is_stoped_) { - if (!is_wait) { - return false; - } - full_cond_.wait(lock); - } - - if (is_stoped_) { - return false; - } - - queue_.emplace_back(std::move(item)); - - empty_cond_.notify_one(); - - return true; - } - - void Stop() { - { - std::unique_lock lock(mutex_); - is_stoped_ = true; - } - - full_cond_.notify_all(); - empty_cond_.notify_all(); - } - - void Restart() { - std::unique_lock lock(mutex_); - is_stoped_ = false; - } - - // if the queue is stoped ,need call this function to release the unprocessed items - std::list GetRemainItems() { - std::unique_lock lock(mutex_); - - if (!is_stoped_) { - return std::list(); - } - - return queue_; - } - - bool IsFull() { - std::unique_lock lock(mutex_); - return queue_.size() >= max_size_; - } - - void Clear() { - std::unique_lock lock(mutex_); - queue_.clear(); - } - - private: - std::list queue_; - std::mutex mutex_; - std::condition_variable empty_cond_; - std::condition_variable full_cond_; - uint32_t max_size_; - - bool is_stoped_; -}; - -#endif // INC_COMMON_BLOCKING_QUEUE_H_ diff --git a/inc/common/dynamic_aipp.h b/inc/common/dynamic_aipp.h deleted file mode 100644 index a687853f..00000000 --- a/inc/common/dynamic_aipp.h +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_COMMON_DYNAMIC_AIPP_H_ -#define INC_COMMON_DYNAMIC_AIPP_H_ - -#include - -/** - * @ingroup dnn - * @brief struct define of dynamic aipp batch parameter. - */ -typedef struct tagAippDynamicBatchPara { - int8_t cropSwitch; // crop switch - int8_t scfSwitch; // resize switch - int8_t paddingSwitch; // 0: unable padding - // 1: padding config value,sfr_filling_hblank_ch0 ~ sfr_filling_hblank_ch2 - // 2: padding source picture data, single row/collumn copy - // 3: padding source picture data, block copy - // 4: padding source picture data, mirror copy - int8_t rotateSwitch; // rotate switch,0: non-ratate, - // 1: ratate 90° clockwise,2: ratate 180° clockwise,3: ratate 270° clockwise - int8_t reserve[4]; - int32_t cropStartPosW; // the start horizontal position of cropping - int32_t cropStartPosH; // the start vertical position of cropping - int32_t cropSizeW; // crop width - int32_t cropSizeH; // crop height - - int32_t scfInputSizeW; // input width of scf - int32_t scfInputSizeH; // input height of scf - int32_t scfOutputSizeW; // output width of scf - int32_t scfOutputSizeH; // output height of scf - - int32_t paddingSizeTop; // top padding size - int32_t paddingSizeBottom; // bottom padding size - int32_t paddingSizeLeft; // left padding size - int32_t paddingSizeRight; // right padding size - - int16_t dtcPixelMeanChn0; // mean value of channel 0 - int16_t dtcPixelMeanChn1; // mean value of channel 1 - int16_t dtcPixelMeanChn2; // mean value of channel 2 - int16_t dtcPixelMeanChn3; // mean value of channel 3 - - uint16_t dtcPixelMinChn0; // min value of channel 0 - uint16_t dtcPixelMinChn1; // min value of channel 1 - uint16_t dtcPixelMinChn2; // min value of channel 2 - uint16_t dtcPixelMinChn3; // min value of channel 3 - uint16_t dtcPixelVarReciChn0; // sfr_dtc_pixel_variance_reci_ch0 - uint16_t dtcPixelVarReciChn1; // sfr_dtc_pixel_variance_reci_ch1 - uint16_t dtcPixelVarReciChn2; // sfr_dtc_pixel_variance_reci_ch2 - uint16_t dtcPixelVarReciChn3; // sfr_dtc_pixel_variance_reci_ch3 - - int8_t reserve1[16]; // 32B assign, for ub copy -} kAippDynamicBatchPara; - -/** - * @ingroup dnn - * @brief struct define of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte - */ -typedef struct tagAippDynamicPara { - uint8_t inputFormat; // input format:YUV420SP_U8/XRGB8888_U8/RGB888_U8 - int8_t cscSwitch; // csc switch - int8_t rbuvSwapSwitch; // rb/ub swap switch - int8_t axSwapSwitch; // RGBA->ARGB, YUVA->AYUV swap switch - int8_t batchNum; // batch parameter number - int8_t reserve1[3]; - int32_t srcImageSizeW; // source image width - int32_t srcImageSizeH; // source image height - int16_t cscMatrixR0C0; // csc_matrix_r0_c0 - int16_t cscMatrixR0C1; // csc_matrix_r0_c1 - int16_t cscMatrixR0C2; // csc_matrix_r0_c2 - int16_t cscMatrixR1C0; // csc_matrix_r1_c0 - int16_t cscMatrixR1C1; // csc_matrix_r1_c1 - int16_t cscMatrixR1C2; // csc_matrix_r1_c2 - int16_t cscMatrixR2C0; // csc_matrix_r2_c0 - int16_t cscMatrixR2C1; // csc_matrix_r2_c1 - int16_t cscMatrixR2C2; // csc_matrix_r2_c2 - int16_t reserve2[3]; - uint8_t cscOutputBiasR0; // output Bias for RGB to YUV, element of row 0, unsigned number - uint8_t cscOutputBiasR1; // output Bias for RGB to YUV, element of row 1, unsigned number - uint8_t cscOutputBiasR2; // output Bias for RGB to YUV, element of row 2, unsigned number - uint8_t cscInputBiasR0; // input Bias for YUV to RGB, element of row 0, unsigned number - uint8_t cscInputBiasR1; // input Bias for YUV to RGB, element of row 1, unsigned number - uint8_t cscInputBiasR2; // input Bias for YUV to RGB, element of row 2, unsigned number - uint8_t reserve3[2]; - int8_t reserve4[16]; // 32B assign, for ub copy - - kAippDynamicBatchPara aippBatchPara; // allow transfer several batch para. -} kAippDynamicPara; - -#endif // INC_COMMON_DYNAMIC_AIPP_H_ diff --git a/inc/common/npu_error_define.h b/inc/common/npu_error_define.h deleted file mode 100644 index a4515cf6..00000000 --- a/inc/common/npu_error_define.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_COMMON_NPU_ERROR_DEFINE_H_ -#define INC_COMMON_NPU_ERROR_DEFINE_H_ - -typedef enum tagHiAiNpuLocal { - HIAI_HOST = 1, - HIAI_DEVICE = 2, -} HiAiNpuLocal; - -typedef enum tagHiAiNpuCodeType { - ERROR_CODE = 1, - EXCEPTION_CODE = 2, -} HiAiNpuCodeType; - -typedef enum tagHiAiNpuErrLevel { - NONE_LEVEL = 0, - SUGGESTION_LEVEL = 1, - NORMAL_LEVEL = 2, - SERIOUS_LEVEL = 3, - CRITICAL_ERROR = 4, -} HiAiNpuErrLevel; - -typedef enum tagHiAiNpuModuleId { - HIAI_DRIVER = 1, - HIAI_CTRLCPU = 2, - HIAI_TS = 3, - HIAI_RUNTIME = 4, - HIAI_AICPU = 5, - HIAI_CCE = 6, - HIAI_TVM = 7, - HIAI_FRAMEWORK = 8, - HiAI_ENGINE = 9, - HIAI_DVPP = 10, - HIAI_AIPP = 11, - HIAI_LOWPOWER = 12, - HIAI_MDC = 13, - HIAI_COMPILE = 14, - HIAI_TOOLCHIAN = 15, - HIAI_ALG = 16, - HIAI_PROFILING = 17, - HIAI_HCCL = 18, - HIAI_SIMULATION = 19, - HIAI_BIOS = 20, - HIAI_SEC = 21, - HIAI_TINY = 22, - HIAI_DP = 23, -} HiAiNpuModuleId; - -/* bit 31-bit30 to be hiai local */ -#define HIAI_NPULOCAL_MASK 0xC0000000 -#define SHIFT_LOCAL_MASK 30 -#define HIAI_NPULOCAL_VAL_MASK 0x3 -/* bit 29 -bit28 to be hiai aicpu code type */ -#define HIAI_CODE_TYPE_MASK 0x30000000 -#define SHIFT_CODE_MASK 28 -#define HIAI_CODE_TYPE_VAL_MASK 0x3 -/* bit 27 -bit25 to be hiai error level */ -#define HIAI_ERROR_LEVEL_MASK 0x0E000000 -#define SHIFT_ERROR_LVL_MASK 25 -#define HIAI_ERROR_LEVEL_VAL_MASK 0x7 -/* bit 24 -bit17 to be hiai mod */ -#define HIAI_MODE_ID_MASK 0x01FE0000 -#define SHIFT_MODE_MASK 17 -#define HIAI_MODE_ID_VAL_MASK 0xFF - -#define HIAI_NPU_LOC_BIT(a) \ - (HIAI_NPULOCAL_MASK & ((unsigned int)((HiAiNpuLocal)(a)) & HIAI_NPULOCAL_VAL_MASK) << SHIFT_LOCAL_MASK) -#define HIAI_NPU_CODE_TYPE_BIT(a) \ - (HIAI_CODE_TYPE_MASK & ((unsigned int)((HiAiNpuCodeType)(a)) & HIAI_CODE_TYPE_VAL_MASK) << SHIFT_CODE_MASK) -#define HIAI_NPU_ERR_LEV_BIT(a) \ - (HIAI_ERROR_LEVEL_MASK & ((unsigned int)((HiAiNpuErrLevel)(a)) & HIAI_ERROR_LEVEL_VAL_MASK) << SHIFT_ERROR_LVL_MASK) -#define HIAI_NPU_MOD_ID_BIT(a) \ - (HIAI_MODE_ID_MASK & ((unsigned int)((HiAiNpuModuleId)(a)) & HIAI_MODE_ID_VAL_MASK) << SHIFT_MODE_MASK) - -#define HIAI_NPU_ERR_CODE_HEAD(npuLocal, codeType, errLevel, moduleId) \ - (HIAI_NPU_LOC_BIT(npuLocal) + HIAI_NPU_CODE_TYPE_BIT(codeType) + HIAI_NPU_ERR_LEV_BIT(errLevel) + \ - HIAI_NPU_MOD_ID_BIT(moduleId)) - -#endif // INC_COMMON_NPU_ERROR_DEFINE_H_ diff --git a/inc/common/opskernel/ge_task_info.h b/inc/common/opskernel/ge_task_info.h deleted file mode 100644 index 9f3c409d..00000000 --- a/inc/common/opskernel/ge_task_info.h +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ -#define INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ - -#include -#include -#include -#include - -using std::string; -namespace ge { -// when need to eliminate GETaskKernelHcclInfo, so not need DAVINCI_TRAIN/DAVINCI_CLOUD -struct GETaskKernelHcclInfo { - string input_name; - string hccl_type; - void *inputDataAddr; - void *outputDataAddr; - void *workSpaceAddr; - int32_t count; - int32_t dataType; - int32_t opType; - int64_t rootId; - uint64_t workSpaceMemSize; - std::vector dims; - std::vector hcclStreamList; -}; - -struct GETaskInfo { - uint32_t id; - uint16_t type; - uint32_t streamID; - void *stream; // rtKernelLaunch input argument - void *event; - void *privateDef; - uint32_t privateDefLen; - void *opsKernelStorePtr; - - std::vector kernelHcclInfo; -}; - -struct HcomOpertion { - std::string hcclType; - void *inputPtr; - void *outputPtr; - uint64_t count; - int32_t dataType; - int32_t opType; - int32_t root; -}; - -struct HcomRemoteAccessAddrInfo { - uint32_t remotetRankID; - uint64_t remoteAddr; // host embedding table address - uint64_t localAddr; // device HBM address - uint64_t length; // memory Length in Bytes -}; - -} // namespace ge -#endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ diff --git a/inc/common/opskernel/ops_kernel_info_store.h b/inc/common/opskernel/ops_kernel_info_store.h deleted file mode 100644 index ce1464d4..00000000 --- a/inc/common/opskernel/ops_kernel_info_store.h +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ -#define INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ - -#include -#include -#include -#include -#include "./ge_task_info.h" -#include "./ops_kernel_info_types.h" -#include "cce/aicpu_engine_struct.h" -#include "cce/fwk_adpt_struct.h" -#include "common/ge_inner_error_codes.h" -#include "graph/node.h" -#include "proto/task.pb.h" -using std::map; -using std::string; -using std::to_string; -using std::vector; - -namespace ge { -class OpDesc; - -class OpsKernelInfoStore { - public: - OpsKernelInfoStore() {} - - virtual ~OpsKernelInfoStore() {} - - // initialize opsKernelInfoStore - virtual Status Initialize(const map &options) = 0; /*lint -e148*/ - - // close opsKernelInfoStore - virtual Status Finalize() = 0; /*lint -e148*/ - - virtual Status CreateSession(const std::map &session_options) { return SUCCESS; } - - virtual Status DestroySession(const std::map &session_options) { return SUCCESS; } - - // get all opsKernelInfo - virtual void GetAllOpsKernelInfo(map &infos) const = 0; - - // whether the opsKernelInfoStore is supported based on the operator attribute - virtual bool CheckSupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason) const = 0; - - virtual bool CheckAccuracySupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason, - bool realQuery = false) const { - return CheckSupported(opDescPtr, un_supported_reason); - } - // opsFlag opsFlag[0] indicates constant folding is supported or not - virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag){}; - - // memory allocation requirement - virtual Status CalcOpRunningParam(Node &node) = 0; /*lint -e148*/ - - // generate task for op。 - virtual Status GenerateTask(const Node &node, RunContext &context, - std::vector &tasks) = 0; /*lint -e148*/ - - // only call fe engine interface to compile single op - virtual Status CompileOp(vector &node_vec) { return SUCCESS; } - virtual Status CompileOpRun(vector &node_vec) { return SUCCESS; } - // load task for op - virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; } - - // only call aicpu interface to generate task struct - virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } - - // only call aicpu interface to generate task struct - virtual Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } -}; -} // namespace ge -#endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ diff --git a/inc/common/opskernel/ops_kernel_info_types.h b/inc/common/opskernel/ops_kernel_info_types.h deleted file mode 100644 index 684c1abc..00000000 --- a/inc/common/opskernel/ops_kernel_info_types.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ -#define INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ - -#include -#include -#include -#include "graph/buffer.h" -#include "runtime/rt_model.h" - -using std::string; - -namespace ge { -/*lint -e148*/ -struct RunContext { - rtModel_t model; - rtStream_t stream; - uint64_t sessionId; - uint64_t dataMemSize; - uint8_t *dataMemBase; - uint64_t weightMemSize; - uint8_t *weightMemBase; - ge::Buffer weightsBuffer; - std::vector graphStreamList; // all streams of graph, order by ge stream id(0,1,...) - std::vector graphEventList; // all events of graph, order by ge event id(0,1,...) - std::vector graphLabelList; // all labels of graph, order by ge label id(0,1,...) -}; - -/*lint +e148*/ - -struct Task { - uint32_t id; - uint16_t type; - void *stream; - void *event; -}; - -struct OpInfo { - string engine; // which engin - /*lint -e148*/ - string opKernelLib; // which opsKernelStore - int computeCost; // compute cost - bool flagPartial; // whether to support is related to shape - bool flagAsync; // Whether to support asynchronous - bool isAtomic; // whether to support atomic addr clean - string opFileName; // op file name - string opFuncName; // op function name -}; -} // namespace ge - -#endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ diff --git a/inc/common/optimizer/graph_optimizer.h b/inc/common/optimizer/graph_optimizer.h deleted file mode 100644 index 253aaae1..00000000 --- a/inc/common/optimizer/graph_optimizer.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ -#define INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ - -#include -#include -#include "./graph_optimizer_types.h" -#include "common/ge_inner_error_codes.h" -#include "common/opskernel/ops_kernel_info_types.h" -#include "graph/compute_graph.h" - -using std::map; -using std::string; - -/*lint -e148*/ -namespace ge { -class GraphOptimizer { - public: - virtual ~GraphOptimizer() {} - - // initialize graphOptimizer - virtual Status Initialize(const map &options) = 0; - - // close graphOptimizer - virtual Status Finalize() = 0; - - // optimize original graph for FE quant optimize - virtual Status OptimizeGraphPrepare(ComputeGraph &graph) { return SUCCESS; } - - // optimize graph before build for RTS - virtual Status OptimizeGraphBeforeBuild(ComputeGraph &graph) { return SUCCESS; } - - // optimize original graph, using in graph preparation stage - virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; - - // optimize original graph, using for conversion operator insert in graph preparation stage - virtual Status OptimizeOriginalGraphJudgeInsert(ComputeGraph &graph) { return SUCCESS; } - - // optimize fused graph - virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; - - // optimize whole graph, using after graph merged stage - virtual Status OptimizeWholeGraph(ComputeGraph &graph) = 0; - - // get attribute of graph optimizer - virtual Status GetAttributes(GraphOptimizerAttribute &attrs) const = 0; - - // optimize streamed Graph - virtual Status OptimizeStreamGraph(ComputeGraph &graph, const RunContext &context) { return SUCCESS; } - - // op compile - virtual Status OptimizeFusedGraphAfterGraphSlice(ComputeGraph &graph) { return SUCCESS; } -}; -} // namespace ge -/*lint +e148*/ -#endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ diff --git a/inc/common/util/ai_core/common/aicore_util_attr_define.h b/inc/common/util/ai_core/common/aicore_util_attr_define.h deleted file mode 100644 index ba28d7b3..00000000 --- a/inc/common/util/ai_core/common/aicore_util_attr_define.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_ATTR_DEFINE_H_ -#define INC_COMMON_UTILS_AI_CORE_COMMON_ATTR_DEFINE_H_ - -#include - -namespace fe { -static const std::string SCOPE_ID_ATTR = "fusion_scope"; - -static const std::string FE_IMPLY_TYPE = "_fe_imply_type"; - -static const std::string PARENT_OP_TYPE = "parentOpType"; - -static const std::string ATTR_NAME_TASK_L2_FUSION_INFO_EXTEND_PTR = "task_l2_fusion_info_extend_content"; - -static const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; - -static const std::string ATTR_NAME_L2_FUSION_EXTEND_PTR = "l2_fusion_extend_content"; - -static const std::string L1_OPTIMIZED = "l1_optimized"; - -static const std::string L2_OPTIMIZED = "l2_optimized"; - -static const std::string OP_SLICE_INFO = "_op_slice_info"; -} // namespace fe -#endif diff --git a/inc/common/util/ai_core/common/aicore_util_types.h b/inc/common/util/ai_core/common/aicore_util_types.h deleted file mode 100644 index b2615dc9..00000000 --- a/inc/common/util/ai_core/common/aicore_util_types.h +++ /dev/null @@ -1,118 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_TYPES_H_ -#define INC_COMMON_UTILS_AI_CORE_COMMON_TYPES_H_ - -#include "graph/anchor.h" -#include "graph/types.h" -#include "runtime/kernel.h" -#include -#include -#include - -namespace fe { -struct FusionOpSrc { - uint32_t src_op_id; - ge::AnchorPtr src_anchor; - int32_t fusion_src_index; - int32_t fusion_dst_index; -}; - -struct FusionOpDst { - uint32_t dst_op_id; - ge::AnchorPtr dst_anchor; -}; - -struct FusionDataFlow { - std::pair edge; - std::pair node_dataindex_pair; -}; - -typedef struct tagL2FusionData { - uint32_t l2Index; - uint64_t l2Addr; - uint64_t l2PageNum; -} L2FusionData_t; -typedef std::map L2FusionDataMap_t; - -typedef struct tagFeSmDesc { - rtL2Ctrl_t l2ctrl; - std::string nodeName[8]; - uint8_t outputIndex[8]; -} feSmDesc_t; - -typedef struct TagTaskL2FusionInfo { - std::string nodeName; - feSmDesc_t l2Info; - L2FusionDataMap_t input; - L2FusionDataMap_t output; - uint32_t isUsed; -} TaskL2FusionInfo_t; - -using L2FusionInfoPtr = std::shared_ptr; - -typedef struct ToOpStruct { - int64_t opL1Space = 0; - std::vector opL1FusionType; - int64_t opL1WorkspaceFlag = 0; // for workspace flag - int64_t opL1WorkspaceSize = 0; - std::vector> validInputShape; - std::vector> validOutputShape; - std::vector> sliceInputOffset; // conv & pooling & ReadSelect - std::vector> sliceOutputOffset; // WriteSelect - std::vector totalShape; - uint32_t splitIndex = 0; - ToOpStruct() { - // set invalid value for essential variable - opL1Space = -1; - opL1WorkspaceSize = -1; - } -} ToOpStruct_t; - -enum OpImplType { - EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op - EN_IMPL_CUSTOM_TIK, // custom tik op - EN_IMPL_CUSTOM_TBE, // custom tbe op - EN_IMPL_HW_CONSTANT_CCE, // Huawei built-in constant op - EN_IMPL_HW_GENERAL_CCE, // Huawei built-in cce op - EN_IMPL_HW_TIK, // Huawei built-in tik op - EN_IMPL_HW_TBE, // Huawei built-in tbe op - EN_IMPL_RL, // RL op - EN_IMPL_PLUGIN_TBE, // Huawei built-in tbe plugin op - EN_IMPL_VECTOR_CORE_HW_TBE, // Huawei built-in tbe op - EN_IMPL_VECTOR_CORE_CUSTOM_TBE, // custom tbe op - EN_IMPL_NON_PERSISTENT_CUSTOM_TBE, // custom tbe op - EN_RESERVED // reserved value -}; - -static const std::map DATATYPE_SIZE_MAP{{ge::DT_FLOAT, sizeof(float)}, - {ge::DT_FLOAT16, sizeof(int16_t)}, - {ge::DT_INT8, sizeof(int8_t)}, - {ge::DT_INT32, sizeof(int32_t)}, - {ge::DT_UINT8, sizeof(uint8_t)}, - {ge::DT_UINT32, sizeof(uint32_t)}, - {ge::DT_INT16, sizeof(int16_t)}, - {ge::DT_UINT16, sizeof(uint16_t)}, - {ge::DT_INT64, sizeof(int64_t)}, - {ge::DT_UINT64, sizeof(uint64_t)}, - {ge::DT_DOUBLE, sizeof(double)}, - {ge::DT_BOOL, sizeof(bool)}, - {ge::DT_DUAL, sizeof(float) + sizeof(int8_t)}, - {ge::DT_DUAL_SUB_UINT8, sizeof(int8_t)}, - {ge::DT_DUAL_SUB_INT8, sizeof(int8_t)}}; -} // namespace fe -#endif diff --git a/inc/common/util/ai_core/common/graph_comm.h b/inc/common/util/ai_core/common/graph_comm.h deleted file mode 100644 index d672e056..00000000 --- a/inc/common/util/ai_core/common/graph_comm.h +++ /dev/null @@ -1,107 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_GRAPH_COMMON_H_ -#define INC_COMMON_UTILS_AI_CORE_COMMON_GRAPH_COMMON_H_ - -#include "graph/compute_graph.h" -#include "common/aicore_util_types.h" -#include "register/graph_optimizer/graph_optimize_register_error_codes.h" - -#include -#include -#include -#include - -namespace fe { - -using kScopeNodeMap_t = std::map>; -using kScopeNodePair_t = std::pair>; - -class GraphCommImpl; -using GraphCommImplPtr = std::unique_ptr; - -class GraphComm { - public: - GraphComm(const string &engineName); - virtual ~GraphComm(); - GraphComm(const GraphComm &in) = delete; - GraphComm &operator=(const GraphComm &in) = delete; - - Status GetscopeNodeMap(ge::ComputeGraph &graph, kScopeNodeMap_t &fusionMap); - - Status CopyFusionOpNodes(vector &fusInputEdgeList, vector &fusOutputEdgeList, - vector &fusNodelist, ge::OpDescPtr fusionOpDesc, - ge::ComputeGraphPtr fusionGraph); - - Status CopyFusionOpEdges(ge::OpDescPtr fusionOpDesc, ge::ComputeGraph &origGraph, ge::ComputeGraphPtr fusionGraph); - - Status GetNodeDataFlowMap(const ge::NodePtr &fusNode, - std::map> &fusionOpAnchorsMap, - ge::kFusionDataFlowVec_t &fusDataflowList, const int &mapType); - - Status GetFusionNodeEdgeList(std::vector &fusNodelist, std::vector &fusInputEdgeList, - std::vector &fusOutputEdgeList); - void ClearFusionSrc(); - - void ClearFusionDst(); - - void AddFusionOutputSrc(const uint32_t &src_op_id, const ge::AnchorPtr &src_anchor, const int32_t &fusion_src_index, - std::pair &node_dataindex_pair); - - void AddFusionInputSrc(const uint32_t &src_op_id, const ge::AnchorPtr &src_anchor, const int32_t &fusion_dst_index, - std::pair &node_dataindex_pair); - - void SaveFusionDst(const uint32_t &dst_op_id, ge::AnchorPtr dst_anchor); - - bool IsFusionDstExist(const uint32_t &dst_op_id, const ge::AnchorPtr &dst_anchor); - - bool GetFusionSrc(const uint32_t &src_op_id, const ge::AnchorPtr &src_anchor, int32_t &fusion_src_index, - int32_t &fusion_dst_index); - - Status GetFusionNodeCtrlEdgeList(vector &fusNodelist, vector &fusInputCtrlEdgeList, - vector &fusOutputCtrlEdgeList); - - Status MergeFusionNodeEdgeList(ge::NodePtr &fusNode, vector &fusNodelist, - vector &fusInputEdgeList, vector &fusOutputEdgeList); - - Status MergeFusionNodeCtrlEdgeList(ge::NodePtr &fusNode, vector &fusNodelist, - vector &fusInputEdgeList, - vector &fusOutputEdgeList); - - string GetEngineName(); - - private: - Status MergeFusionNodeInputEdgeList(ge::NodePtr fusNode, std::vector &fusNodelist, - std::vector &fusInputEdgeList); - Status MergeFusionNodeOutputEdgeList(ge::NodePtr fusNode, std::vector &fusNodelist, - std::vector &fusOutputEdgeList); - - string engineName_; - - std::vector exist_fusion_src_list_; - std::vector exist_fusion_dst_list_; - - // std::vector> - ge::kFusionDataFlowVec_t fusion_input_dataflow_list_; - - // std::vector> - ge::kFusionDataFlowVec_t fusion_output_dataflow_list_; - - GraphCommImplPtr graphCommImplPtr_; -}; -} // namespace fe -#endif diff --git a/inc/common/util/ai_core/common/scope_allocator.h b/inc/common/util/ai_core/common/scope_allocator.h deleted file mode 100644 index 6cebb286..00000000 --- a/inc/common/util/ai_core/common/scope_allocator.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_SCOPE_ALLOCATOR_H_ -#define INC_COMMON_UTILS_AI_CORE_COMMON_SCOPE_ALLOCATOR_H_ - -#include "graph/op_desc.h" - -namespace fe { -class ScopeAllocator { - public: - ScopeAllocator(); - virtual ~ScopeAllocator(); - ScopeAllocator(const ScopeAllocator& in) = delete; - ScopeAllocator& operator=(const ScopeAllocator& in) = delete; - - public: - void Init(); - int64_t GetCurrentScopeId(); - int64_t AllocateScopeId(void); - bool HasScopeAttr(ge::ConstOpDescPtr opdef); - bool GetScopeAttr(ge::ConstOpDescPtr opdef, int64_t& scopeId); - bool SetScopeAttr(ge::OpDescPtr opdef, int64_t scopeId); - bool ResetScopeId(int64_t scopeId); - - private: - int64_t scopeId; -}; -} // namespace fe -#endif diff --git a/inc/common/util/ai_core/param_calculate/aicore_param_calculator.h b/inc/common/util/ai_core/param_calculate/aicore_param_calculator.h deleted file mode 100644 index c0c378fd..00000000 --- a/inc/common/util/ai_core/param_calculate/aicore_param_calculator.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef AICORE_PARAM_CALCULATOR -#define AICORE_PARAM_CALCULATOR - -#include "graph/node.h" -#include "graph_optimizer/graph_optimize_register_error_codes.h" - -namespace fe { -class AICoreParamCalculator { - public: - AICoreParamCalculator(); - - ~AICoreParamCalculator(); - - Status CalcOpRunningParam(ge::Node &node); -}; -} // namespace fe -#endif // AICORE_PARAM_CALCULATOR diff --git a/inc/common/util/ai_core/param_calculate/tensorsize_calculator.h b/inc/common/util/ai_core/param_calculate/tensorsize_calculator.h deleted file mode 100644 index c82cca4b..00000000 --- a/inc/common/util/ai_core/param_calculate/tensorsize_calculator.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TENSORSIZE_CALCULATOR_H -#define TENSORSIZE_CALCULATOR_H - -#include "graph_optimizer/graph_optimize_register_error_codes.h" - -#include -#include -#include "graph/compute_graph.h" -#include "graph/op_desc.h" - -namespace fe { -class TensorSizeCalculator { - public: - /** - * Calculate the tensor size of input and output of each opdesc - * @param opDesc opdesc object - * @param opImplType op impl type - * @return status SUCCESS or FAILED - */ - static Status CalculateOpTensorSize(ge::OpDesc &opDesc); - - private: - static Status CalcInputOpTensorSize(ge::OpDesc &opDesc, int32_t &outputRealCalcFlag); - - static Status CalcOutputOpTensorSize(ge::OpDesc &opDesc, int32_t &outputRealCalcFlag); -}; -} // namespace fe - -#endif // TENSORSIZE_CALCULATOR_H diff --git a/inc/common/util/compress/compress.h b/inc/common/util/compress/compress.h deleted file mode 100644 index e350f9e5..00000000 --- a/inc/common/util/compress/compress.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMPRESS_H -#define COMPRESS_H - -#include - -enum CmpStatus { RET_SUCCESS = 0, RET_ERROR = -1 }; - -struct CompressConfig { - size_t inputSize; // length of data to compress - size_t engineNum; // how many decompress engines - size_t maxRatio; // how much size of a basic compression block, only 64 supported now (8x: 64 4x: 32) - size_t channel; // channels of L2 or DDR. For load balance - size_t fractalSize; // size of compressing block - bool isTight; // whether compose compressed data tightly - size_t init_offset; -}; - -CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, - size_t& compressedLength); - -#endif // COMPRESS_H diff --git a/inc/common/util/compress/compress_weight.h b/inc/common/util/compress/compress_weight.h deleted file mode 100644 index 34ea47d1..00000000 --- a/inc/common/util/compress/compress_weight.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMPRESS_WEIGHT_H -#define COMPRESS_WEIGHT_H - -#include "compress.h" - -const int SHAPE_SIZE_WEIGHT = 4; - -struct CompressOpConfig { - int64_t wShape[SHAPE_SIZE_WEIGHT]; - size_t compressTilingK; - size_t compressTilingN; - struct CompressConfig compressConfig; -}; - -extern "C" CmpStatus CompressWeightsConv2D(const char *const input, char *const zipBuffer, char *const infoBuffer, - CompressOpConfig *const param); -#endif // COMPRESS_WEIGHT_H diff --git a/inc/common/util/error_manager/error_manager.h b/inc/common/util/error_manager/error_manager.h deleted file mode 100644 index 438e68a7..00000000 --- a/inc/common/util/error_manager/error_manager.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef ERROR_MANAGER_H_ -#define ERROR_MANAGER_H_ - -#include -#include -#include - -class ErrorManager { - public: - /// - /// @brief Obtain ErrorManager instance - /// @return ErrorManager instance - /// - static ErrorManager &GetInstance(); - - /// - /// @brief init - /// @param [in] path: current so path - /// @return int 0(success) -1(fail) - /// - int Init(std::string path); - - /// - /// @brief Report error message - /// @param [in] error_code: error code - /// @param [in] args_map: parameter map - /// @return int 0(success) -1(fail) - /// - int ReportErrMessage(std::string error_code, const std::map &args_map); - - /// - /// @brief output error message - /// @param [in] handle: print handle - /// @return int 0(success) -1(fail) - /// - int OutputErrMessage(int handle); - - /// - /// @brief output message - /// @param [in] handle: print handle - /// @return int 0(success) -1(fail) - /// - int OutputMessage(int handle); - - /// - /// @brief Report error message - /// @param [in] key: vector parameter key - /// @param [in] value: vector parameter value - /// - void ATCReportErrMessage(std::string error_code, const std::vector &key = {}, - const std::vector &value = {}); - - private: - struct ErrorInfo { - std::string error_id; - std::string error_message; - std::vector arg_list; - }; - - ErrorManager() {} - ~ErrorManager() {} - - ErrorManager(const ErrorManager &) = delete; - ErrorManager(ErrorManager &&) = delete; - ErrorManager &operator=(const ErrorManager &) = delete; - ErrorManager &operator=(ErrorManager &&) = delete; - - int ParseJsonFile(std::string path); - - int ReadJsonFile(const std::string &file_path, void *handle); - - bool is_init_ = false; - std::map error_map_; - std::vector error_messages_; - std::vector warning_messages_; -}; - -#endif // ERROR_MANAGER_H_ diff --git a/inc/common/util/platform_info.h b/inc/common/util/platform_info.h deleted file mode 100644 index 8d2a0579..00000000 --- a/inc/common/util/platform_info.h +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PLATFORM_INFO_H -#define PLATFORM_INFO_H - -#include -#include -#include -#include "platform_info_def.h" - -using std::map; -using std::string; -using std::vector; - -namespace fe { -class PlatformInfoManager { - public: - PlatformInfoManager(const PlatformInfoManager &) = delete; - PlatformInfoManager &operator=(const PlatformInfoManager &) = delete; - - static PlatformInfoManager &Instance(); - uint32_t InitializePlatformInfo(); - uint32_t Finalize(); - - uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); - - uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); - - void SetOptionalCompilationInfo(OptionalInfo &optiCompilationInfo); - - private: - PlatformInfoManager(); - ~PlatformInfoManager(); - - uint32_t LoadIniFile(string iniFileRealPath); - - void Trim(string &str); - - uint32_t LoadConfigFile(string realPath); - - string RealPath(const std::string &path); - - string GetSoFilePath(); - - void ParseVersion(map &versionMap, string &socVersion, PlatformInfo &platformInfoTemp); - - void ParseSocInfo(map &socInfoMap, PlatformInfo &platformInfoTemp); - - void ParseCubeOfAICoreSpec(map &aiCoreSpecMap, PlatformInfo &platformInfoTemp); - - void ParseBufferOfAICoreSpec(map &aiCoreSpecMap, PlatformInfo &platformInfoTemp); - - void ParseUBOfAICoreSpec(map &aiCoreSpecMap, PlatformInfo &platformInfoTemp); - - void ParseUnzipOfAICoreSpec(map &aiCoreSpecMap, PlatformInfo &platformInfoTemp); - - void ParseAICoreSpec(map &aiCoreSpecMap, PlatformInfo &platformInfoTemp); - - void ParseBufferOfAICoreMemoryRates(map &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); - - void ParseAICoreMemoryRates(map &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); - - void ParseUBOfAICoreMemoryRates(map &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); - - void ParseAICoreintrinsicDtypeMap(map &aiCoreintrinsicDtypeMap, PlatformInfo &platformInfoTemp); - - void ParseVectorCoreSpec(map &vectorCoreSpecMap, PlatformInfo &platformInfoTemp); - - void ParseVectorCoreMemoryRates(map &vectorCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); - - void ParseCPUCache(map &CPUCacheMap, PlatformInfo &platformInfoTemp); - - void ParseVectorCoreintrinsicDtypeMap(map &vectorCoreintrinsicDtypeMap, - PlatformInfo &platformInfoTemp); - - uint32_t ParsePlatformInfoFromStrToStruct(map> &contentInfoMap, string &socVersion, - PlatformInfo &platformInfoTemp); - - uint32_t AssemblePlatformInfoVector(map> &contentInfoMap); - - private: - bool initFlag_; - map platformInfoMap_; - OptionalInfo optiCompilationInfo_; -}; -} // namespace fe -#endif diff --git a/inc/common/util/platform_info_def.h b/inc/common/util/platform_info_def.h deleted file mode 100644 index c660e8f1..00000000 --- a/inc/common/util/platform_info_def.h +++ /dev/null @@ -1,140 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PLATFORM_INFO_DEF_H -#define PLATFORM_INFO_DEF_H - -#include -#include -#include - -using std::map; -using std::string; -using std::vector; - -namespace fe { -enum MemoryType { DDR = 0, HBM }; - -enum L2Type { Cache = 0, Buff }; - -typedef struct tagStrInfo { - string aicVersion; - string ccecAICVersion; - string ccecAIVVersion; - string isSupportAIcpuCompiler; -} StrInfo; - -typedef struct tagSoCInfo { - uint32_t aiCoreCnt; - uint32_t vectorCoreCnt; - uint32_t aiCpuCnt; - MemoryType memoryType; - uint64_t memorySize; - L2Type l2Type; - uint64_t l2Size; - uint32_t l2PageNum; -} SoCInfo; - -typedef struct tagAiCoreSpec { - double cubeFreq; - uint64_t cubeMSize; - uint64_t cubeNSize; - uint64_t cubeKSize; - uint64_t vecCalcSize; - uint64_t l0ASize; - uint64_t l0BSize; - uint64_t l0CSize; - uint64_t l1Size; - uint64_t smaskBuffer; - uint64_t ubSize; - uint64_t ubblockSize; - uint64_t ubbankSize; - uint64_t ubbankNum; - uint64_t ubburstInOneBlock; - uint64_t ubbankGroupNum; - uint32_t unzipEngines; - uint32_t unzipMaxRatios; - uint32_t unzipChannels; - uint8_t unzipIsTight; -} AiCoreSpec; - -typedef struct tagAiCoreMemoryRates { - double ddrRate; - double ddrReadRate; - double ddrWriteRate; - double l2Rate; - double l2ReadRate; - double l2WriteRate; - double l1ToL0ARate; - double l1ToL0BRate; - double l1ToUBRate; - double l0CToUBRate; - double ubToL2Rate; - double ubToDdrRate; - double ubToL1Rate; -} AiCoreMemoryRates; - -typedef struct tagVectorCoreSpec { - double vecFreq; - uint64_t vecCalcSize; - uint64_t smaskBuffer; - uint64_t ubSize; - uint64_t ubblockSize; - uint64_t ubbankSize; - uint64_t ubbankNum; - uint64_t ubburstInOneBlock; - uint64_t ubbankGroupNum; - uint64_t vectorRegSize; - uint64_t predicateRegSize; - uint64_t addressRegSize; -} VectorCoreSpec; - -typedef struct tagVectorCoreMemoryRates { - double ddrRate; - double ddrReadRate; - double ddrWriteRate; - double l2Rate; - double l2ReadRate; - double l2WriteRate; - double ubToL2Rate; - double ubToDdrRate; -} VectorCoreMemoryRates; - -typedef struct tagCPUCache { - uint32_t AICPUSyncBySW; - uint32_t TSCPUSyncBySW; -} CPUCache; - -typedef struct tagPlatformInfo { - StrInfo strInfo; - SoCInfo socInfo; - AiCoreSpec aiCoreSpec; - AiCoreMemoryRates aiCoreMemoryRates; - map> aiCoreIntrinsicDtypeMap; - VectorCoreSpec vectorCoreSpec; - VectorCoreMemoryRates vectorCoreMemoryRates; - CPUCache cpucache; - map> vectorCoreIntrinsicDtypeMap; -} PlatformInfo; - -typedef struct tagOptionalInfo { - string socVersion; - string coreType; - uint32_t aiCoreNum; - string l1FusionFlag; -} OptionalInfo; -} // namespace fe -#endif diff --git a/inc/external/graph/attr_value.h b/inc/external/graph/attr_value.h deleted file mode 100644 index af430f9b..00000000 --- a/inc/external/graph/attr_value.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ -#define INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ - -#include -#include -#include -#include - -#include "./ge_error_codes.h" - -using std::make_shared; -using std::map; -using std::pair; -using std::string; -using std::to_string; -using std::unique_ptr; -using std::vector; - -namespace ge { -class AttrValueImpl; -/*lint -e148*/ -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue { - public: - using INT = int64_t; - using FLOAT = float; - using STR = std::string; - - AttrValue(); - ~AttrValue() = default; - - // GetValue, not list type - template - graphStatus GetValue(DT &val) const { - T valGet; - auto status = GetValue(valGet); - if (status != GRAPH_SUCCESS) { - return status; - } - val = DT(valGet); - return GRAPH_SUCCESS; - } - - template - static T CreateFrom(DT &&val) { - return val; - } - - std::shared_ptr impl; - - private: -#define VALUE_SET_GET_DEC(DT) graphStatus GetValue(DT &val) const; - VALUE_SET_GET_DEC(AttrValue::STR) - VALUE_SET_GET_DEC(AttrValue::INT) - VALUE_SET_GET_DEC(AttrValue::FLOAT) -#undef VALUE_SET_GET_DEC -}; -/*lint +e148*/ -} // namespace ge -#endif // INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ diff --git a/inc/external/graph/ge_error_codes.h b/inc/external/graph/ge_error_codes.h deleted file mode 100644 index d815a22d..00000000 --- a/inc/external/graph/ge_error_codes.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ -#define INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ - -namespace ge { -#ifdef HOST_VISIBILITY -#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_HOST_VISIBILITY -#endif -#ifdef DEV_VISIBILITY -#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_DEV_VISIBILITY -#endif - -using graphStatus = uint32_t; -const graphStatus GRAPH_FAILED = 0xFFFFFFFF; -const graphStatus GRAPH_SUCCESS = 0; -const graphStatus GRAPH_PARAM_INVALID = 50331649; -} // namespace ge - -#endif // INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ diff --git a/inc/external/graph/graph.h b/inc/external/graph/graph.h deleted file mode 100644 index 30886733..00000000 --- a/inc/external/graph/graph.h +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_GRAPH_H_ -#define INC_EXTERNAL_GRAPH_GRAPH_H_ - -#include -#include -#include -#include - -#include "./operator.h" - -namespace ge { -class GraphImpl; - -using GraphImplPtr = std::shared_ptr; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { - friend class GraphUtils; - - public: - explicit Graph(const std::string &name); - - Graph() = default; - - ~Graph() = default; - - Graph &SetInputs(const std::vector &inputs); - - Graph &SetOutputs(const std::vector &outputs); - - Graph &SetOutputs(const std::vector>> &output_indexs); - - Graph &SetOutputs(const std::vector> &outputs); - - Graph &SetTargets(const std::vector &targets); - - bool IsValid() const; - - graphStatus AddOp(const ge::Operator &op); - - graphStatus FindOpByName(const string &name, ge::Operator &op) const; - - graphStatus FindOpByType(const string &type, std::vector &ops) const; - - graphStatus GetAllOpName(std::vector &op_name) const; - - graphStatus SaveToFile(const string &file_name) const; - - graphStatus LoadFromFile(const string &file_name); - - const std::string &GetName() const; - - /// - /// Set is need train iteration. - /// If set true, it means this graph need to be run iteration some - /// times(according variant "npu_runconfig/iterations_per_loop"). - /// @param need_iteration need_iteration:whether to set iteration or not - /// - void SetNeedIteration(bool need_iteration); - - private: - GraphImplPtr impl_{nullptr}; -}; -} // namespace ge - -#endif // INC_EXTERNAL_GRAPH_GRAPH_H_ diff --git a/inc/external/graph/inference_context.h b/inc/external/graph/inference_context.h deleted file mode 100644 index 69079142..00000000 --- a/inc/external/graph/inference_context.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ -#define INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ - -#include -#include -#include - -#include "./tensor.h" -#include "./types.h" - -namespace ge { -class InferenceContext; -using InferenceContextPtr = std::shared_ptr; - -class ShapeAndTypeImpl; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ShapeAndType { - public: - ShapeAndType(); - ~ShapeAndType() = default; - - ShapeAndType(const Shape &shape, DataType dataType); - - void SetShape(const Shape &shape); - - void SetType(DataType dataType); - - Shape GetShape() const; - - DataType GetDataType() const; - - private: - std::shared_ptr shape_and_type_impl_; -}; - -class InferenceContextImpl; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { - public: - ~InferenceContext() = default; - InferenceContext(const InferenceContext &context) = delete; - InferenceContext(const InferenceContext &&context) = delete; - InferenceContext &operator=(const InferenceContext &context) = delete; - InferenceContext &operator=(const InferenceContext &&context) = delete; - - void SetInputHandleShapesAndTypes(std::vector> &&shapes_and_types); - const std::vector> &GetInputHandleShapesAndTypes() const; - const std::vector> &GetOutputHandleShapesAndTypes() const; - void SetOutputHandleShapesAndTypes(const std::vector> &shapes_and_types); - void SetOutputHandleShapesAndTypes(std::vector> &&shapes_and_types); - - void SetMarks(const std::vector &marks); - const std::vector &GetMarks() const; - - static std::unique_ptr Create(); - - private: - explicit InferenceContext(std::unique_ptr &impl); - std::shared_ptr inference_context_impl_; -}; -} // namespace ge -#endif // INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ diff --git a/inc/external/graph/operator.h b/inc/external/graph/operator.h deleted file mode 100644 index 81d726eb..00000000 --- a/inc/external/graph/operator.h +++ /dev/null @@ -1,289 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_OPERATOR_H_ -#define INC_EXTERNAL_GRAPH_OPERATOR_H_ - -#include -#include -#include -#include -#include - -#include "./ge_error_codes.h" -#include "./inference_context.h" -#include "./tensor.h" - -#ifndef USER_GE_LOGI -#define USER_GE_LOGI(...) -#endif // USER_GE_LOGI - -#ifndef USER_GE_LOGW -#define USER_GE_LOGW(...) -#endif // USER_GE_LOGW - -#ifndef USER_GE_LOGE -#define USER_GE_LOGE(...) -#endif // USER_GE_LOGE - -#define DYNAMIC_OUTPUT_TD_NUM(name) ("__dynamic_output_" + name + "_cnt") -#define DYNAMIC_INPUT_TD_NUM(name) ("__dynamic_input_" + name + "_cnt") - -namespace ge { -class Operator; -class OperatorImpl; -class NodeUtils; -class NamedAttrs; -class Graph; -class AttrValue; -class Node; - -using SubgraphBuilder = std::function; -using OperatorImplPtr = std::shared_ptr; -using OperatorPtr = std::shared_ptr; - -class OpIO; -using OutHandler = std::shared_ptr; -using InHandler = std::shared_ptr; - -using std::function; -using std::shared_ptr; -using std::string; - -/*lint -e148*/ -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { - public: - friend class OperatorImpl; - friend class GraphBuilderImpl; - friend class NodeUtils; - - using OpInt = int64_t; - using OpFloat = float; - using OpString = string; - using OpBool = bool; - using OpTensor = Tensor; - using OpType = ge::DataType; - using OpNamedAttrs = ge::NamedAttrs; - using OpListInt = std::vector; - using OpListFloat = std::vector; - using OpListString = std::vector; - using OpListBool = std::vector; - using OpListTensor = std::vector; - using OpBytes = std::vector; - using OpListListInt = std::vector>; - using OpListType = std::vector; - using OpListNamedAttrs = std::vector; - - Operator() {} - - explicit Operator(const string &type); - - Operator(const string &name, const string &type); // lint !e148 - - virtual ~Operator() = default; - - bool IsEmpty() const; - - string GetName() const; - - string GetOpType() const; - - // Only has one output index = 0 - Operator &SetInput(const string &dst_name, const Operator &src_oprt); - - Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); // lint !e148 - - Operator &SetInput(const string &dst_name, const Operator &src_oprt, uint32_t index); - - Operator &AddControlInput(const Operator &src_oprt); - - graphStatus GetInputConstData(const string &dst_name, Tensor &data) const; - - TensorDesc GetInputDesc(const string &name) const; - - TensorDesc GetInputDesc(uint32_t index) const; - - int GetDynamicOutputNum(const string &name) const; - - int GetDynamicInputNum(const string &name) const; - - graphStatus TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const; - - graphStatus UpdateInputDesc(const string &name, const TensorDesc &tensor_desc); - - TensorDesc GetOutputDesc(const string &name) const; - - TensorDesc GetOutputDesc(uint32_t index) const; - - graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); // lint !e148 - - TensorDesc GetDynamicInputDesc(const string &name, uint32_t index) const; - - graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 - - TensorDesc GetDynamicOutputDesc(const string &name, uint32_t index) const; - - graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 - - graphStatus InferShapeAndType(); // lint !e148 - - void SetInferenceContext(const InferenceContextPtr &inference_context); - InferenceContextPtr GetInferenceContext() const; - - graphStatus VerifyAllAttr(bool disable_common_verifier = false); // lint !e148 - - size_t GetInputsSize() const; - - size_t GetOutputsSize() const; - - const std::map GetAllAttrNamesAndTypes() const; - - Operator &SetAttr(const string &name, int64_t attr_value); - Operator &SetAttr(const string &name, int32_t attr_value); - Operator &SetAttr(const string &name, uint32_t attr_value); - graphStatus GetAttr(const string &name, int64_t &attr_value) const; - graphStatus GetAttr(const string &name, int32_t &attr_value) const; - graphStatus GetAttr(const string &name, uint32_t &attr_value) const; - Operator &SetAttr(const string &name, const std::vector &attr_value); - Operator &SetAttr(const string &name, const std::vector &attr_value); - Operator &SetAttr(const string &name, const std::vector &attr_value); - Operator &SetAttr(const string &name, std::initializer_list &&attr_value); - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - - Operator &SetAttr(const string &name, float attr_value); - graphStatus GetAttr(const string &name, float &attr_value) const; - Operator &SetAttr(const string &name, const std::vector &attr_value); - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - Operator &SetAttr(const string &name, AttrValue &&attr_value); - graphStatus GetAttr(const string &name, AttrValue &attr_value) const; - - Operator &SetAttr(const string &name, const string &attr_value); - graphStatus GetAttr(const string &name, string &attr_value) const; - Operator &SetAttr(const string &name, const std::vector &attr_value); - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - - Operator &SetAttr(const string &name, bool attr_value); - graphStatus GetAttr(const string &name, bool &attr_value) const; - Operator &SetAttr(const string &name, const std::vector &attr_value); - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - - Operator &SetAttr(const string &name, const Tensor &attr_value); - graphStatus GetAttr(const string &name, Tensor &attr_value) const; - Operator &SetAttr(const string &name, const std::vector &attr_value); - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - - // Bytes type - Operator &SetAttr(const string &name, const OpBytes &attr_value); - // Bytes type - graphStatus GetAttr(const string &name, OpBytes &attr_value) const; - - Operator &SetAttr(const string &name, const std::vector> &attr_value); - graphStatus GetAttr(const string &name, std::vector> &attr_value) const; - - Operator &SetAttr(const string &name, const std::vector &attr_value); - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - - Operator &SetAttr(const string &name, const ge::DataType &attr_value); - graphStatus GetAttr(const string &name, ge::DataType &attr_value) const; - - // func type - Operator &SetAttr(const string &name, const ge::NamedAttrs &attr_value); - graphStatus GetAttr(const string &name, ge::NamedAttrs &attr_value) const; - Operator &SetAttr(const string &name, const std::vector &attr_value); - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - - void BreakConnect() const; - - size_t GetSubgraphNamesCount() const; - std::vector GetSubgraphNames() const; - SubgraphBuilder GetSubgraphBuilder(const string &name) const; - Graph GetSubgraph(const string &name) const; - SubgraphBuilder GetDynamicSubgraphBuilder(const string &name, uint32_t index) const; - Graph GetDynamicSubgraph(const string &name, uint32_t index) const; - - protected: - void AttrRegister(const string &name, float attr_value); - void AttrRegister(const string &name, const std::vector &attr_value); - void AttrRegister(const string &name, int64_t attr_value); - void AttrRegister(const string &name, const std::vector &attr_value); - void AttrRegister(const string &name, const string &attr_value); - void AttrRegister(const string &name, const std::vector &attr_value); - void AttrRegister(const string &name, bool attr_value); - void AttrRegister(const string &name, const std::vector &attr_value); - void AttrRegister(const string &name, const Tensor &attr_value); - void AttrRegister(const string &name, const std::vector &attr_value); - void AttrRegister(const string &name, const OpBytes &attr_value); - void AttrRegister(const string &name, const std::vector> &attr_value); - void AttrRegister(const string &name, const std::vector &attr_value); - void AttrRegister(const string &name, const ge::DataType &attr_value); - void AttrRegister(const string &name, const ge::NamedAttrs &attr_value); - void AttrRegister(const string &name, const std::vector &attr_value); - - explicit Operator(OperatorImplPtr &&op_impl); - - void InputRegister(const string &name); - - void OptionalInputRegister(const string &name); - - void InferFuncRegister(const std::function &func); - - void VerifierFuncRegister(const std::function &func); - - void InferFormatFuncRegister(const std::function &func); - - void OutputRegister(const string &name); - - void DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back = true); - - void DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index); - - void DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back = true); - - void RequiredAttrRegister(const string &name); - - graphStatus VerifyAll(); // lint !e148 - - // Only has one output index = 0 - Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt); - - Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, - const string &name); // lint !e148 - - void SubgraphRegister(const string &ir_name, bool dynamic); - void SubgraphCountRegister(const string &ir_name, uint32_t count); - void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder); - - private: - Operator &SetInput(const string &dst_name, const OutHandler &out_handler); // lint !e148 - - OutHandler GetOutput(const string &name) const; - - OutHandler GetOutput(uint32_t index) const; - - OperatorImplPtr GetOperatorImplPtr() const; - - OperatorImplPtr operator_impl_{nullptr}; - - graphStatus GetInputConstDataOut(const string &dst_name, Tensor &data) const; - - std::shared_ptr GetNode() const; -}; -/*lint +e148*/ -} // namespace ge - -#endif // INC_EXTERNAL_GRAPH_OPERATOR_H_ diff --git a/inc/external/graph/operator_factory.h b/inc/external/graph/operator_factory.h deleted file mode 100644 index f9ec7669..00000000 --- a/inc/external/graph/operator_factory.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ -#define INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ - -#include -#include -#include -#include - -#include "./operator.h" -#include "./ge_error_codes.h" - -namespace ge { -using OpCreator = std::function; -using InferShapeFunc = std::function; -using InferFormatFunc = std::function; -using VerifyFunc = std::function; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactory { - public: - static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type); - - static graphStatus GetOpsTypeList(std::vector &all_ops); - - static bool IsExistOp(const string &operator_type); -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorCreatorRegister { - public: - OperatorCreatorRegister(const string &operator_type, OpCreator const &op_creator); - ~OperatorCreatorRegister() = default; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferShapeFuncRegister { - public: - InferShapeFuncRegister(const std::string &operator_type, const InferShapeFunc &infer_shape_func); - ~InferShapeFuncRegister() = default; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferFormatFuncRegister { - public: - InferFormatFuncRegister(const std::string &operator_type, const InferFormatFunc &infer_format_func); - ~InferFormatFuncRegister() = default; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY VerifyFuncRegister { - public: - VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func); - ~VerifyFuncRegister() = default; -}; -} // namespace ge - -#endif // INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ diff --git a/inc/external/graph/operator_reg.h b/inc/external/graph/operator_reg.h deleted file mode 100644 index 759c70f2..00000000 --- a/inc/external/graph/operator_reg.h +++ /dev/null @@ -1,376 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ -#define INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ - -#include -#include -#include -#include - -#include "graph/operator.h" -#include "graph/operator_factory.h" -#include "graph/tensor.h" -#include "graph/types.h" -#include "graph/graph.h" - -namespace ge { -using std::function; -using std::string; -using std::vector; - -class OpReg { - public: - OpReg &N() { return *this; } - - OpReg &ATTR() { return *this; } - - OpReg &REQUIRED_ATTR() { return *this; } - - OpReg &INPUT() { return *this; } - - OpReg &OPTIONAL_INPUT() { return *this; } - - OpReg &OUTPUT() { return *this; } - - OpReg &GRAPH() { return *this; } - - OpReg &DYNAMIC_GRAPH() { return *this; } - - OpReg &INFER_SHAPE_AND_TYPE() { return *this; } -}; - -#define REG_OP(x) \ - namespace op { \ - class x : public Operator { \ - typedef x _THIS_TYPE; \ - \ - public: \ - explicit x(const string &name) : Operator(name, #x) { __##x(); } \ - x() : Operator(#x) { __##x(); } \ - \ - private: \ - void __##x() { \ - OpReg() - -#define ATTR(x, Type, ...) \ - N(); \ - __attr_##x(); \ - } \ - \ - public: \ - static const string name_attr_##x() { return #x; } \ - Op##Type get_attr_##x() const { \ - Op##Type ret = __VA_ARGS__; \ - if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ - return ret; \ - } \ - return ret; \ - } \ - _THIS_TYPE &set_attr_##x(const Op##Type &v) { \ - Operator::SetAttr(#x, v); \ - return *this; \ - } \ - _THIS_TYPE &set_attr_##x(const function &v) { return *this; } \ - \ - private: \ - void __attr_##x() { \ - Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \ - string attr_name(#x); \ - (void)OpReg() - -#define REQUIRED_ATTR(x, Type) \ - N(); \ - __required_attr_##x(); \ - } \ - \ - public: \ - static const string name_attr_##x() { return #x; } \ - Op##Type get_attr_##x() const { \ - Op##Type ret; \ - if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ - return ret; \ - } \ - return ret; \ - } \ - _THIS_TYPE &set_attr_##x(const Op##Type &v) { \ - Operator::SetAttr(#x, v); \ - return *this; \ - } \ - _THIS_TYPE &set_attr_##x(const function &v) { return *this; } \ - \ - private: \ - void __required_attr_##x() { \ - Operator::RequiredAttrRegister(#x); \ - string attr_name(#x); \ - (void)OpReg() - -#define INPUT(x, t) \ - N(); \ - __input_##x(); \ - } \ - \ - public: \ - static const string name_in_##x() { return #x; } \ - _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \ - Operator::SetInput(#x, v, srcName); \ - return *this; \ - } \ - _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ - Operator::SetInput(#x, v, index); \ - return *this; \ - } \ - _THIS_TYPE &set_input_##x(Operator &v) { \ - Operator::SetInput(#x, v); \ - return *this; \ - } \ - TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \ - graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ - return Operator::UpdateInputDesc(#x, tensorDesc); \ - } \ - \ - private: \ - void __input_##x() { \ - Operator::InputRegister(#x); \ - (void)OpReg() - -#define OPTIONAL_INPUT(x, t) \ - N(); \ - __optional_input_##x(); \ - } \ - \ - public: \ - static const string name_in_##x() { return #x; } \ - _THIS_TYPE &set_input_##x(Operator &v) { \ - Operator::SetInput(#x, v); \ - return *this; \ - } \ - _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \ - Operator::SetInput(#x, v, srcName); \ - return *this; \ - } \ - _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ - Operator::SetInput(#x, v, index); \ - return *this; \ - } \ - TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \ - graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ - return Operator::UpdateInputDesc(#x, tensorDesc); \ - } \ - \ - private: \ - void __optional_input_##x() { \ - Operator::OptionalInputRegister(#x); \ - (void)OpReg() - -#define OUTPUT(x, t) \ - N(); \ - __out_##x(); \ - } \ - \ - public: \ - static const string name_out_##x() { return #x; } \ - TensorDesc get_output_desc_##x() const { return Operator::GetOutputDesc(#x); } \ - graphStatus update_output_desc_##x(const TensorDesc &tensorDesc) { \ - return Operator::UpdateOutputDesc(#x, tensorDesc); \ - } \ - \ - private: \ - void __out_##x() { \ - Operator::OutputRegister(#x); \ - (void)OpReg() - -#define DYNAMIC_INPUT(x, t) \ - N(); \ - __dy_input_##x(); \ - } \ - \ - public: \ - _THIS_TYPE &create_dynamic_input_##x(uint32_t num, bool isPushBack = true) { \ - Operator::DynamicInputRegister(#x, num, isPushBack); \ - return *this; \ - } \ - _THIS_TYPE &create_dynamic_input_byindex_##x(uint32_t num, size_t index) { \ - Operator::DynamicInputRegisterByIndex(#x, num, index); \ - return *this; \ - } \ - TensorDesc get_dynamic_input_desc_##x(uint32_t index) const { return Operator::GetDynamicInputDesc(#x, index); } \ - graphStatus update_dynamic_input_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ - return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ - } \ - _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v) { \ - Operator::SetInput(#x, dstIndex, v); \ - return *this; \ - } \ - _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const string &srcName) { \ - Operator::SetInput(#x, dstIndex, v, srcName); \ - return *this; \ - } \ - \ - private: \ - void __dy_input_##x() { \ - Operator::DynamicInputRegister(#x, 0, true); \ - (void)OpReg() - -#define DYNAMIC_OUTPUT(x, t) \ - N(); \ - __dy_output_##x(); \ - } \ - \ - public: \ - _THIS_TYPE &create_dynamic_output_##x(uint32_t num, bool isPushBack = true) { \ - Operator::DynamicOutputRegister(#x, num, isPushBack); \ - return *this; \ - } \ - TensorDesc get_dynamic_output_desc_##x(uint32_t index) const { return Operator::GetDynamicOutputDesc(#x, index); } \ - graphStatus update_dynamic_output_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ - return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \ - } \ - \ - private: \ - void __dy_output_##x() { \ - Operator::DynamicOutputRegister(#x, 0, true); \ - (void)OpReg() - -#define GRAPH(x) \ - N(); \ - __graph_##x(); \ - } \ - \ - public: \ - static const string name_graph_##x() { return #x; } \ - SubgraphBuilder get_subgraph_builder_##x() const { return Operator::GetSubgraphBuilder(#x); } \ - _THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \ - Operator::SetSubgraphBuilder(#x, 0, v); \ - return *this; \ - } \ - Graph get_subgraph_##x() const { return Operator::GetSubgraph(#x); } \ - \ - private: \ - void __graph_##x() { \ - Operator::SubgraphRegister(#x, false); \ - Operator::SubgraphCountRegister(#x, 1); \ - (void)OpReg() - -#define DYNAMIC_GRAPH(x) \ - N(); \ - __graph_##x(); \ - } \ - \ - public: \ - static const string name_graph_##x() { return #x; } \ - _THIS_TYPE &create_dynamic_subgraph_##x(uint32_t num) { \ - Operator::SubgraphCountRegister(#x, num); \ - return *this; \ - } \ - SubgraphBuilder get_dynamic_subgraph_builder_##x(uint32_t index) const { \ - return Operator::GetDynamicSubgraphBuilder(#x, index); \ - } \ - Graph get_dynamic_subgraph_##x(uint32_t index) const { return Operator::GetDynamicSubgraph(#x, index); } \ - _THIS_TYPE &set_dynamic_subgraph_builder_##x(uint32_t index, const SubgraphBuilder &v) { \ - Operator::SetSubgraphBuilder(#x, index, v); \ - return *this; \ - } \ - \ - private: \ - void __graph_##x() { \ - Operator::SubgraphRegister(#x, true); \ - (void)OpReg() - -#define PASTE(g_register, y) g_register##y -#define __OP_END_IMPL__(x, y) \ - N(); \ - } \ - static_assert( \ - std::is_same::value, \ - "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ - } \ - ; \ - static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \ - } -#define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__) - -// Specialized shape inferencer macro - -#define IMPLEMT_INFERFUNC(op_name, func_name) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) - -#define IMPLEMT_COMMON_INFERFUNC(func_name) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(Operator &op) - -#define IMPLEMT_INFERFORMAT_FUNC(op_name, func_name) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) - -// Specialized verifier macro - -#define IMPLEMT_VERIFIER(op_name, func_name) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name op) - -#define INFER_VERIFY_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); } - -#define COMMON_INFER_VERIFY_FUNC(x) [&](Operator &v) { return x(v); } - -#define INFER_FORMAT_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); } - -#define __INFER_FUNC_REG_IMPL__(op_name, x, n) static const InferShapeFuncRegister PASTE(if_register, n)(#op_name, x) - -#define __VERIFY_FUNC_REG_IMPL__(op_name, x, n) static const VerifyFuncRegister PASTE(vf_register, n)(#op_name, x) -// Infer format func register -#define __INFER_FORMAT_FUNC_REG_IMPL__(op_name, x, n) \ - static const InferFormatFuncRegister PASTE(ff_register, n)(#op_name, x) - -// Shape inferencer & verifier register macro - -#define INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__) - -#define COMMON_INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, COMMON_INFER_VERIFY_FUNC(x), __COUNTER__) - -#define VERIFY_FUNC_REG(op_name, x) __VERIFY_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__) - -// Infer format func reg -#define INFER_FORMAT_FUNC_REG(op_name, x) \ - __INFER_FORMAT_FUNC_REG_IMPL__(op_name, INFER_FORMAT_FUNC(op_name, x), __COUNTER__) - -// Common shape inferencer - -#define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \ - [](Operator op) -> graphStatus { \ - auto x_shape = op.GetInputDesc(in_name).GetShape().GetDims(); \ - auto x_type = op.GetInputDesc(in_name).GetDataType(); \ - TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ - op_output_desc.SetShape(ge::Shape(x_shape)); \ - op_output_desc.SetOriginShape(ge::Shape(x_shape)); \ - op_output_desc.SetDataType(x_type); \ - return op.UpdateOutputDesc(out_name, op_output_desc); \ - } - -graphStatus BroadCastInfer(const function()> &get_in1_shape, - const function()> &get_in2_shape, - const function &y_shape)> &set_out_shape); - -#define BROADCAST_INFER(in1_name, in2_name, out_name) \ - [](Operator op) -> graphStatus { \ - return BroadCastInfer([&]() { return op.GetInputDesc(in1_name).GetShape().GetDims(); }, \ - [&]() { return op.GetInputDesc(in2_name).GetShape().GetDims(); }, \ - [&](const vector &y_shape) { \ - TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ - op_output_desc.SetShape(ge::Shape(y_shape)); \ - (void)op.UpdateOutputDesc(out_name, op_output_desc); \ - }); \ - } -} // namespace ge -#endif // INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ diff --git a/inc/external/graph/tensor.h b/inc/external/graph/tensor.h deleted file mode 100644 index 800e1037..00000000 --- a/inc/external/graph/tensor.h +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_TENSOR_H_ -#define INC_EXTERNAL_GRAPH_TENSOR_H_ - -#include -#include -#include -#include -#include - -#include "./ge_error_codes.h" -#include "./types.h" - -namespace ge { -class ShapeImpl; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Shape { - public: - Shape(); - ~Shape() = default; - explicit Shape(const std::vector &dims); - - size_t GetDimNum() const; - // If the idx is invalid, return 0 - int64_t GetDim(size_t idx) const; - graphStatus SetDim(size_t idx, int64_t value); - std::vector GetDims() const; - int64_t GetShapeSize() const; - - private: - std::shared_ptr impl_; -}; - -class TensorDescImpl; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc { - public: - TensorDesc(); - ~TensorDesc() = default; - explicit TensorDesc(Shape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); - // Copy - TensorDesc(const TensorDesc &desc); - // Move - TensorDesc(TensorDesc &&desc); - // Copy - TensorDesc &operator=(const TensorDesc &desc); - // Move - TensorDesc &operator=(TensorDesc &&desc); - - void Update(const Shape &shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); - Shape GetShape() const; - void SetShape(const Shape &shape); - // set shape with -2, it stand for unknown shape - graphStatus SetUnknownDimNumShape(); - // for unknown shape - graphStatus SetShapeRange(const std::vector> &range); - graphStatus GetShapeRange(std::vector> &range) const; - - Format GetFormat() const; - void SetFormat(Format format); - - Shape GetOriginShape() const; - void SetOriginShape(const Shape &originShape); - - Format GetOriginFormat() const; - void SetOriginFormat(Format originFormat); - - DataType GetDataType() const; - void SetDataType(DataType dt); - - std::string GetName() const; - void SetName(const std::string &name); - - // Attr acess - void SetSize(int64_t size); - int64_t GetSize() const; - - int64_t GetRealDimCnt() const; - void SetRealDimCnt(const int64_t realDimCnt); - - private: - std::shared_ptr impl; -}; - -class TensorImpl; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Tensor { - public: - Tensor(); - ~Tensor() = default; - explicit Tensor(const TensorDesc &tensorDesc); - Tensor(const TensorDesc &tensorDesc, const std::vector &data); - Tensor(const TensorDesc &tensorDesc, const uint8_t *data, size_t size); - Tensor(TensorDesc &&tensorDesc, std::vector &&data); - - TensorDesc GetTensorDesc() const; - graphStatus SetTensorDesc(const TensorDesc &tensorDesc); - - const uint8_t *GetData() const; - uint8_t *GetData(); - size_t GetSize() const; - - graphStatus SetData(std::vector &&data); - graphStatus SetData(const std::vector &data); - graphStatus SetData(const uint8_t *data, size_t size); - graphStatus SetData(const std::string &data); - graphStatus SetData(const std::vector &data); - graphStatus IsValid(); - - Tensor Clone() const; - - private: - std::shared_ptr impl; - friend class TensorAdapter; -}; -} // namespace ge -/*lint +e148*/ - -#endif // INC_EXTERNAL_GRAPH_TENSOR_H_ diff --git a/inc/external/graph/types.h b/inc/external/graph/types.h deleted file mode 100644 index a1245c9d..00000000 --- a/inc/external/graph/types.h +++ /dev/null @@ -1,240 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_TYPES_H_ -#define INC_EXTERNAL_GRAPH_TYPES_H_ - -#include -#include -#include - -namespace ge { -static const int64_t UNKNOWN_DIM = -1; -static const int64_t UNKNOWN_DIM_NUM = -2; -static const std::vector UNKNOWN_SHAPE = {-1}; -static const std::vector UNKNOWN_RANK = {-2}; - -#ifdef HOST_VISIBILITY -#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_HOST_VISIBILITY -#endif -#ifdef DEV_VISIBILITY -#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_DEV_VISIBILITY -#endif - -enum DataType { - DT_FLOAT = 0, // float type - DT_FLOAT16 = 1, // fp16 type - DT_INT8 = 2, // int8 type - DT_INT16 = 6, // int16 type - DT_UINT16 = 7, // uint16 type - DT_UINT8 = 4, // uint8 type - DT_INT32 = 3, // - DT_INT64 = 9, // int64 type - DT_UINT32 = 8, // unsigned int32 - DT_UINT64 = 10, // unsigned int64 - DT_BOOL = 12, // bool type - DT_DOUBLE = 11, // double type - DT_STRING = 13, // string type - DT_DUAL_SUB_INT8 = 14, // dual output int8 type - DT_DUAL_SUB_UINT8 = 15, // dual output uint8 type - DT_COMPLEX64 = 16, // complex64 type - DT_COMPLEX128 = 17, // complex128 type - DT_QINT8 = 18, // qint8 type - DT_QINT16 = 19, // qint16 type - DT_QINT32 = 20, // qint32 type - DT_QUINT8 = 21, // quint8 type - DT_QUINT16 = 22, // quint16 type - DT_RESOURCE = 23, // resource type - DT_STRING_REF = 24, // string ref type - DT_DUAL = 25, // dual output type - DT_UNDEFINED // Used to indicate a DataType field has not been set. -}; - -inline int GetSizeByDataType(DataType data_type) { - static int data_type_size[DT_UNDEFINED] = { - 4, // DT_FLOAT = 0, float type - 2, // DT_FLOAT16 = 1, fp16 type - 1, // DT_INT8 = 2, int8 type - 4, // DT_INT32 = 3, - 1, // DT_UINT8 = 4, uint8 type - -1, - 2, // DT_INT16 = 6, int16 type - 2, // DT_UINT16 = 7, uint16 type - 4, // DT_UINT32 = 8, unsigned int32 - 8, // DT_INT64 = 9, int64 type - 8, // DT_UINT64 = 10, unsigned int64 - 8, // DT_DOUBLE = 11, double type - 1, // DT_BOOL = 12, bool type - -1, // DT_STRING = 13, string type - 1, // DT_DUAL_SUB_INT8 = 14, dual output int8 type - 1, // DT_DUAL_SUB_UINT8 = 15, dual output uint8 type - 8, // DT_COMPLEX64 = 16, complex64 type - 16, // DT_COMPLEX128 = 17, complex128 type - 1, // DT_QINT8 = 18, qint8 type - 2, // DT_QINT16 = 19, qint16 type - 4, // DT_QINT32 = 20, qint32 type - 1, // DT_QUINT8 = 21, quint8 type - 2, // DT_QUINT16 = 22, quint16 type - -1, // DT_RESOURCE = 23, resource type - -1, // DT_STRING_REF = 24, string ref type - 5, // DT_DUAL = 25, dual output type (float + int8) - // DT_UNDEFINED Used to indicate a DataType field has not been set. - }; - if (data_type >= DT_UNDEFINED) { - return -1; - } - return data_type_size[data_type]; -} - -enum Format { - FORMAT_NCHW = 0, // NCHW - FORMAT_NHWC, // NHWC - FORMAT_ND, // Nd Tensor - FORMAT_NC1HWC0, // NC1HWC0 - FORMAT_FRACTAL_Z, // FRACTAL_Z - FORMAT_NC1C0HWPAD, - FORMAT_NHWC1C0, - FORMAT_FSR_NCHW, - FORMAT_FRACTAL_DECONV, - FORMAT_C1HWNC0, - FORMAT_FRACTAL_DECONV_TRANSPOSE, - FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, - FORMAT_NC1HWC0_C04, // NC1HWC0, C0 =4 - FORMAT_FRACTAL_Z_C04, // FRACZ, C0 =4 - FORMAT_CHWN, - FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, - FORMAT_HWCN, - FORMAT_NC1KHKWHWC0, // KH,KW kernel h& kernel w maxpooling max output format - FORMAT_BN_WEIGHT, - FORMAT_FILTER_HWCK, // filter input tensor format - FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20, - FORMAT_HASHTABLE_LOOKUP_KEYS, - FORMAT_HASHTABLE_LOOKUP_VALUE, - FORMAT_HASHTABLE_LOOKUP_OUTPUT, - FORMAT_HASHTABLE_LOOKUP_HITS = 24, - FORMAT_C1HWNCoC0, - FORMAT_MD, - FORMAT_NDHWC, - FORMAT_FRACTAL_ZZ, - FORMAT_FRACTAL_NZ, - FORMAT_NCDHW, - FORMAT_DHWCN, // 3D filter input tensor format - FORMAT_NDC1HWC0, - FORMAT_FRACTAL_Z_3D, - FORMAT_CN, - FORMAT_NC, - FORMAT_DHWNC, - FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format - FORMAT_FRACTAL_ZN_LSTM, - FORMAT_FRACTAL_Z_G, - FORMAT_RESERVED, - FORMAT_ALL, - FORMAT_NULL -}; - -// for unknown shape op type -enum UnknowShapeOpType { - DEPEND_IN_SHAPE = 1, // op out shape get by input shape - DEPEND_CONST_VALUE = 2, // op out shape get by const op value - DEPEND_SHAPE_RANGE = 3, // op out shape get by range - DEPEND_COMPUTE = 4 // op out shape get by totally computing -}; - -struct TensorDescInfo { - Format format_ = FORMAT_RESERVED; // tbe op register support format - DataType dataType_ = DT_UNDEFINED; // tbe op register support datatype -}; - -enum DeviceType { - NPU = 0, - CPU = 1, -}; - -class TensorTypeImpl; -struct TensorType { - explicit TensorType(DataType dt); - - TensorType(const std::initializer_list &types); - - static TensorType ALL() { - return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, - DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, - DT_QUINT8, DT_RESOURCE, DT_STRING, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; - } - - static TensorType QuantifiedType() { return TensorType{DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, DT_QUINT8}; } - - static TensorType OrdinaryType() { - return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, - DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; - } - - static TensorType BasicType() { - return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, - DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, - DT_QUINT16, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; - } - - static TensorType NumberType() { - return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, - DT_INT8, DT_QINT32, DT_QINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; - } - - static TensorType RealNumberType() { - return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, - DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; - } - - static TensorType ComplexDataType() { return TensorType{DT_COMPLEX128, DT_COMPLEX64}; } - - static TensorType IntegerDataType() { - return TensorType{DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; - } - - static TensorType SignedDataType() { return TensorType{DT_INT16, DT_INT32, DT_INT64, DT_INT8}; } - - static TensorType UnsignedDataType() { return TensorType{DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; } - - static TensorType FloatingDataType() { return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16}; } - - static TensorType IndexNumberType() { return TensorType{DT_INT32, DT_INT64}; } - - static TensorType UnaryDataType() { return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16}; } - - static TensorType FLOAT() { return TensorType{DT_FLOAT, DT_FLOAT16}; } - - std::shared_ptr tensor_type_impl_; -}; -} // namespace ge - -namespace domi { -enum class ImplyType : unsigned int { - BUILDIN = 0, // Built in operator, normally executed by OME - TVM, // Compile to TVM bin file for execution - CUSTOM, // User defined calculation logic, executed by CPU - AI_CPU, // AICPU - CCE, // Cce - GELOCAL, // GE local, do node need execute by device - HCCL, // Hccl - INVALID = 0xFFFFFFFF, -}; -} // namespace domi - -#endif // INC_EXTERNAL_GRAPH_TYPES_H_ diff --git a/inc/external/register/register.h b/inc/external/register/register.h deleted file mode 100644 index f3091fae..00000000 --- a/inc/external/register/register.h +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_REGISTER_REGISTER_H_ -#define INC_EXTERNAL_REGISTER_REGISTER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "graph/operator.h" -#include "register/register_error_codes.h" -#include "register/register_fmk_types.h" -#include "register/register_types.h" - -using std::make_shared; -using std::map; -using std::pair; -using std::string; -using std::to_string; -using std::unique_ptr; -using std::vector; - -/*lint -e148*/ -namespace ge { -class Operator; -class TensorDesc; -class Tensor; -class TBEPluginManager; -} // namespace ge - -namespace google { -namespace protobuf { -class Message; -} -} // namespace google - -namespace domi { -const int64_t kMaxNameLength = 1048576; // 1M - -enum DynamicType { kInvalid = 0, kInput = 1, kOutput = 2 }; -struct DynamicInputOutputInfo { - DynamicType type; // input/output - const char *port_name; - int64_t port_name_len; - const char *attr_name; - int64_t attr_name_len; - DynamicInputOutputInfo() - : type(kInvalid), port_name(nullptr), port_name_len(0), attr_name(nullptr), attr_name_len(0) {} - DynamicInputOutputInfo(DynamicType type, const char *port_name, int64_t port_name_len, const char *attr_name, - int64_t attr_name_len) - : type(type), - port_name(port_name), - port_name_len(port_name_len), - attr_name(attr_name), - attr_name_len(attr_name_len) {} -}; -Status AutoMappingByOpFn(const ge::Operator &op_src, ge::Operator &op); -Status AutoMappingByOpFnDynamic(const ge::Operator &op_src, ge::Operator &op, - const vector &dynamic_name_attr_value); -Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); -Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, - std::map> dynamic_name_attr_value, - int in_pos = -1, int out_pos = -1); -Status AutoMappingSubgraphIndex(const ge::Graph &graph, const std::function &input, - const std::function &output); -Status AutoMappingSubgraphIndex(const ge::Graph &graph, - const std::function &input, - const std::function &output); -using google::protobuf::Message; -class OpRegistrationDataImpl; - -using ParseParamFunc = std::function; -using ParseParamByOpFunc = std::function; -using FusionParseParamFunc = - std::function, ge::Operator &)>; -using FusionParseParamByOpFunc = std::function &, ge::Operator &)>; -using ParseSubgraphFunc = std::function; - -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { - public: - OpRegistrationData(const std::string &om_optype); - - ~OpRegistrationData(); - - OpRegistrationData &FrameworkType(const domi::FrameworkType &fmk_type); - - OpRegistrationData &OriginOpType(const std::initializer_list &ori_optype_list); - - OpRegistrationData &OriginOpType(const std::string &ori_optype); - - OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn); - - OpRegistrationData &ParseParamsByOperatorFn(const ParseParamByOpFunc &parse_param_by_op_fn); - - OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); - - OpRegistrationData &FusionParseParamsFn(const FusionParseParamByOpFunc &fusion_parse_param_fn); - - OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn); - - OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); - - OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); - - OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type); - - OpRegistrationData &InputReorderVector(const vector &input_order); - - domi::ImplyType GetImplyType() const; - std::string GetOmOptype() const; - std::set GetOriginOpTypeSet() const; - domi::FrameworkType GetFrameworkType() const; - ParseParamFunc GetParseParamFn() const; - ParseParamByOpFunc GetParseParamByOperatorFn() const; - FusionParseParamFunc GetFusionParseParamFn() const; - FusionParseParamByOpFunc GetFusionParseParamByOpFn() const; - ParseSubgraphFunc GetParseSubgraphPostFn() const; - - private: - std::shared_ptr impl_; - friend class OpRegistry; - friend class OpRegistrationTbe; - friend class ge::TBEPluginManager; -}; - -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { - public: - OpReceiver(OpRegistrationData ®_data); - ~OpReceiver() {} -}; - -#define REGISTER_CUSTOM_OP(name) REGISTER_CUSTOM_OP_UNIQ_HELPER(__COUNTER__, name) -#define REGISTER_CUSTOM_OP_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_OP_UNIQ(ctr, name) -#define REGISTER_CUSTOM_OP_UNIQ(ctr, name) \ - static OpReceiver register_op##ctr __attribute__((unused)) = OpRegistrationData(name) -} // namespace domi - -namespace ge { -using OpRegistrationData = domi::OpRegistrationData; -using OpReceiver = domi::OpReceiver; -} // namespace ge -/*lint +e148*/ -#endif // INC_EXTERNAL_REGISTER_REGISTER_H_ diff --git a/inc/external/register/register_error_codes.h b/inc/external/register/register_error_codes.h deleted file mode 100644 index 5e0ed79f..00000000 --- a/inc/external/register/register_error_codes.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ -#define INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ - -#define SYSID_FWK 3 // Subsystem ID -#define MODID_COMMON 0 // Common module ID - -#define DECLARE_ERRORNO(sysid, modid, name, value) \ - const domi::Status name = \ - ((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value)); - -#define DECLARE_ERRORNO_COMMON(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_COMMON, name, value) - -namespace domi { -using Status = uint32_t; - -// General error code -DECLARE_ERRORNO(0, 0, SUCCESS, 0); -DECLARE_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFFFFFF); -DECLARE_ERRORNO_COMMON(PARAM_INVALID, 1); // 50331649 -DECLARE_ERRORNO(SYSID_FWK, 1, SCOPE_NOT_CHANGED, 201); -} // namespace domi - -#endif // INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ diff --git a/inc/external/register/register_fmk_types.h b/inc/external/register/register_fmk_types.h deleted file mode 100644 index 97616060..00000000 --- a/inc/external/register/register_fmk_types.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ -#define INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ - -#include - -namespace domi { -/// -/// @ingroup domi_omg -/// @brief AI framework types -/// -enum FrameworkType { - CAFFE = 0, - MINDSPORE = 1, - TENSORFLOW = 3, - ANDROID_NN, - ONNX, - FRAMEWORK_RESERVED, -}; -} // namespace domi - -#endif // INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ diff --git a/inc/external/register/register_types.h b/inc/external/register/register_types.h deleted file mode 100644 index 08d72713..00000000 --- a/inc/external/register/register_types.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ -#define INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ - -namespace domi { -#ifdef HOST_VISIBILITY -#define FMK_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) -#else -#define FMK_FUNC_HOST_VISIBILITY -#endif -#ifdef DEV_VISIBILITY -#define FMK_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) -#else -#define FMK_FUNC_DEV_VISIBILITY -#endif - -/// CCE defined constant - -/// -/// @ingroup domi -/// @brief original tensor type -/// -typedef enum tagDomiTensorFormat { - DOMI_TENSOR_NCHW = 0, // < NCHW - DOMI_TENSOR_NHWC, // < NHWC - DOMI_TENSOR_ND, // < Nd Tensor - DOMI_TENSOR_NC1HWC0, // < NC1HWC0 - DOMI_TENSOR_FRACTAL_Z, // < FRACTAL_Z - DOMI_TENSOR_NC1C0HWPAD, - DOMI_TENSOR_NHWC1C0, - DOMI_TENSOR_FSR_NCHW, - DOMI_TENSOR_FRACTAL_DECONV, - DOMI_TENSOR_BN_WEIGHT, - DOMI_TENSOR_CHWN, // Android NN Depth CONV - DOMI_TENSOR_FILTER_HWCK, // filter input tensor format - DOMI_TENSOR_NDHWC, - DOMI_TENSOR_NCDHW, - DOMI_TENSOR_DHWCN, // 3D filter input tensor format - DOMI_TENSOR_DHWNC, - DOMI_TENSOR_RESERVED -} domiTensorFormat_t; -} // namespace domi - -#endif // INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ diff --git a/inc/external/register/scope/scope_fusion_pass_register.h b/inc/external/register/scope/scope_fusion_pass_register.h deleted file mode 100644 index 8e5605a7..00000000 --- a/inc/external/register/scope/scope_fusion_pass_register.h +++ /dev/null @@ -1,334 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ -#define EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ - -#include -#include -#include -#include -#include -#include "ge/ge_api_error_codes.h" -#include "register/register_error_codes.h" -#include "register/register_types.h" -#include "graph/operator.h" - -#define CHECK_INNER_NODE_CONDITION(cond, fusion_rlt) \ - do { \ - if (!(cond)) { \ - if ((fusion_rlt) != nullptr) { \ - (fusion_rlt)->SetType(ge::kScopeInvalidType); \ - } \ - return; \ - } \ - } while (0) - -namespace domi { -class TensorFlowModelParser; -} // namespace domi -namespace ge { -const int32_t kFusionDisableIndex = 99999; -const char *const kScopeToMultiNodes = "ScopeToMultiNodes"; -const char *const kScopeInvalidType = "ScopeInvalidType"; -const char *const kInputFromFusionScope = "InputFromFusionScope"; -const char *const kOutputToFusionScope = "OutputToFusionScope"; -class ScopePattern; -using ScopeFusionPatterns = std::vector>; - -class ScopePassManager; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope { - public: - Scope(); - Status Init(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr); - ~Scope(); - - const std::string &Name() const; - const std::string &SubType() const; - const std::unordered_map &AllNodesMap() const; - Scope *GetSubScope(const std::string &scope_name) const; - const std::string LastName() const; - const std::vector &GetAllSubScopes() const; - const Scope *GetFatherScope() const; - - private: - class ScopeImpl; - std::unique_ptr impl_; - friend class ScopeBasePass; - friend class ScopeTree; - friend class NodeOpTypeFeature; - friend class NodeAttrFeature; - friend class ScopeFeature; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY FusionScopesResult { - public: - FusionScopesResult(); - Status Init(); - ~FusionScopesResult(); - void SetName(const std::string &name); - void SetType(const std::string &type); - void SetDescription(const std::string &description); - const std::string &Name() const; - const std::vector &Nodes() const; - void InsertInputs(const std::string &inner_op_name, const std::vector &index_map); - void InsertOutputs(const std::string &inner_op_name, const std::vector &index_map); - - class InnerNodeInfo { - public: - explicit InnerNodeInfo(const std::string &fusion_node_name); - InnerNodeInfo(const std::string &fusion_node_name, const std::string &name, const std::string &type); - InnerNodeInfo(InnerNodeInfo &&other) noexcept; - InnerNodeInfo &operator=(InnerNodeInfo &&other) noexcept; - InnerNodeInfo(const InnerNodeInfo &) = delete; - InnerNodeInfo &operator=(const InnerNodeInfo &) = delete; - ~InnerNodeInfo(); - InnerNodeInfo &SetName(const std::string &name); - InnerNodeInfo &SetType(const std::string &type); - InnerNodeInfo &InsertInput(const std::string &input_node, int32_t peer_out_idx); - InnerNodeInfo &InsertOutput(const std::string &output_node, int32_t peer_in_idx); - ge::graphStatus BuildInnerNode(); - ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format); - ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format); - ge::graphStatus SetDynamicInputFormat(const std::string &input_name, uint32_t index, const std::string &format); - ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, uint32_t index, const std::string &format); - ge::Operator *MutableOperator(); - - std::string GetName() const; - std::string GetType() const; - std::vector> GetInputs() const; - std::vector> GetOutputs() const; - - private: - class InnerNodeInfoImpl; - std::unique_ptr impl_; - }; - - InnerNodeInfo *AddInnerNode(const std::string &name, const std::string &type); - InnerNodeInfo *MutableRecentInnerNode(); - InnerNodeInfo *MutableInnerNode(uint32_t index); - ge::graphStatus CheckInnerNodesInfo(); - - private: - class FusionScopesResultImpl; - std::unique_ptr impl_; - friend class ScopeGraph; - friend class ScopeBasePass; - friend class TensorFlowModelParser; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeTree { - public: - ScopeTree(); - Status Init(); - ScopeTree(const ScopeTree &scopetree) = delete; - ScopeTree &operator=(const ScopeTree &scopetree) = delete; - ~ScopeTree(); - - const std::vector &GetAllScopes() const; - - private: - class ScopeTreeImpl; - std::unique_ptr impl_; - friend class ScopeGraph; - friend class ScopeBasePass; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeGraph { - public: - ScopeGraph(); - Status Init(); - ScopeGraph(const ScopeGraph &scope_graph) = delete; - ScopeGraph &operator=(const ScopeGraph &scope_graph) = delete; - ~ScopeGraph(); - - const ScopeTree *GetScopeTree() const; - const std::unordered_map &GetNodesMap() const; - - private: - class ScopeGraphImpl; - std::unique_ptr impl_; - friend class ScopePassManager; - friend class ScopeBasePass; - friend class TensorFlowModelParser; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeAttrValue { - public: - ScopeAttrValue(); - ScopeAttrValue(ScopeAttrValue const &attr_value); - ScopeAttrValue &operator=(ScopeAttrValue const &attr_value); - ~ScopeAttrValue(); - - void SetIntValue(int64_t value); - void SetFloatValue(float value); - void SetStringValue(std::string value); - void SetBoolValue(bool value); - - private: - class ScopeAttrValueImpl; - std::unique_ptr impl_; - friend class NodeAttrFeature; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBaseFeature { - public: - virtual bool Match(const Scope *scope) = 0; - virtual ~ScopeBaseFeature(){}; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeOpTypeFeature : ScopeBaseFeature { - public: - NodeOpTypeFeature(std::string nodeType, int num, int step = 0); - NodeOpTypeFeature(NodeOpTypeFeature const &feature); - NodeOpTypeFeature &operator=(NodeOpTypeFeature const &feature); - ~NodeOpTypeFeature(); - bool Match(const Scope *scope) override; - - private: - class NodeOpTypeFeatureImpl; - std::unique_ptr impl_; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeAttrFeature : ScopeBaseFeature { - public: - NodeAttrFeature(std::string nodeType, std::string attr_name, ge::DataType datatype, ScopeAttrValue &attr_value); - NodeAttrFeature(NodeAttrFeature const &feature); - NodeAttrFeature &operator=(NodeAttrFeature const &feature); - ~NodeAttrFeature(); - bool Match(const Scope *scope) override; - - private: - class NodeAttrFeatureImpl; - std::unique_ptr impl_; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFeature : ScopeBaseFeature { - public: - ScopeFeature(std::string sub_type, int32_t num, std::string suffix = "", std::string sub_scope_mask = "", - int step = 0); - ScopeFeature(ScopeFeature const &feature); - ScopeFeature &operator=(ScopeFeature const &feature); - ~ScopeFeature(); - bool Match(const Scope *scope) override; - - private: - class ScopeFeatureImpl; - std::unique_ptr impl_; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopePattern { - public: - ScopePattern(); - ~ScopePattern(); - - ScopePattern &SetSubType(const std::string &sub_type); - ScopePattern &AddNodeOpTypeFeature(NodeOpTypeFeature feature); - ScopePattern &AddNodeAttrFeature(NodeAttrFeature feature); - ScopePattern &AddScopeFeature(ScopeFeature feature); - - private: - class ScopePatternImpl; - std::unique_ptr impl_; - friend class ScopeBasePass; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopesResult { - public: - ScopesResult(); - ScopesResult(ScopesResult const &result); - ScopesResult &operator=(ScopesResult const &result); - ~ScopesResult(); - - void SetScopes(std::vector &scopes); - void SetNodes(std::vector &nodes); - - private: - class ScopesResultImpl; - std::unique_ptr impl_; - friend class ScopeBasePass; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBasePass { - public: - ScopeBasePass(); - virtual ~ScopeBasePass(); - - protected: - // Subclasses implement respective fusion strategies and build the Patterns - virtual std::vector DefinePatterns() = 0; - // Define the name of the scope pass - virtual std::string PassName() = 0; - // Subclasses implement respective multi-scope or operator fusion methods across scopes - virtual Status LastMatchScopesAndOPs(std::shared_ptr &scope_graph, - std::vector &results) = 0; - // Subclasses implement their own results and set the input and output of the final fusion operator - virtual void GenerateFusionResult(const std::vector &scopes, FusionScopesResult *fusion_rlt) = 0; - - private: - class ScopeBasePassImpl; - std::unique_ptr impl_; - friend class ge::ScopePassManager; - friend class ScopeBasePassImpl; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistry { - public: - using CreateFn = ScopeBasePass *(*)(); - ~ScopeFusionPassRegistry(); - - static ScopeFusionPassRegistry &GetInstance() { - static ScopeFusionPassRegistry instance; - return instance; - } - - void RegisterScopeFusionPass(const std::string &pass_name, CreateFn create_fn, bool is_general); - - private: - ScopeFusionPassRegistry(); - class ScopeFusionPassRegistryImpl; - /*lint -e148*/ - std::unique_ptr impl_; - friend class TensorFlowModelParser; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeUtil { - public: - static std::string StringReplaceAll(std::string str, const std::string &old_value, const std::string &new_value); - static void FreeScopePatterns(ScopeFusionPatterns &patterns); - static void FreeOneBatchPattern(std::vector &one_batch_pattern); -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistrar { - public: - ScopeFusionPassRegistrar(const char *pass_name, ScopeBasePass *(*create_fn)(), bool is_general); - ~ScopeFusionPassRegistrar() {} -}; - -#define REGISTER_SCOPE_FUSION_PASS(pass_name, scope_pass, is_general) \ - REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(__COUNTER__, pass_name, scope_pass, is_general) - -#define REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, scope_pass, is_general) \ - REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) - -#define REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) \ - static ::ge::ScopeFusionPassRegistrar register_scope_fusion_pass##ctr __attribute__((unused)) = \ - ::ge::ScopeFusionPassRegistrar( \ - pass_name, []() -> ::ge::ScopeBasePass * { return new (std::nothrow) scope_pass(); }, is_general) -} // namespace ge - -#endif // EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ diff --git a/inc/framework/omg/parser/model_parser.h b/inc/framework/omg/parser/model_parser.h new file mode 100644 index 00000000..3a8aa6ce --- /dev/null +++ b/inc/framework/omg/parser/model_parser.h @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_PARSER_MODEL_PARSER_H_ +#define INC_FRAMEWORK_OMG_PARSER_MODEL_PARSER_H_ + +#include +#include "framework/common/types.h" +#include "framework/omg/omg_inner_types.h" +#include "graph/attr_value.h" +#include "graph/compute_graph.h" +#include "graph/ge_tensor.h" +#include "graph/graph.h" +#include "graph/op_desc.h" +#include "graph/operator.h" +#include "graph/range_vistor.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" + +using Status = domi::Status; + +namespace domi { +using GetGraphCallback = std::function( + const google::protobuf::Message *root_proto, const std::string &graph)>; +class ModelParser { + public: + ModelParser() {} + + virtual ~ModelParser() {} + + /** + * @ingroup domi_omg + * @brief Analyze network model data + * @param [in] file Network model file path + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status Parse(const char *file, ge::Graph &graph) = 0; + + /** + * @ingroup domi_omg + * @brief Parse relevant data from memory and save it to graph + * @param [in] input Model file memory data + * @param [in|out] graph A graph for saving the model information after analysis + * @return SUCCESS + * @return FAILED + * @author + */ + virtual Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) = 0; + + /** + * @ingroup domi_omg + * @brief Analyze network model data + * @param [in] proto network model + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) = 0; + + /** + * @ingroup domi_omg + * @brief Analyze callback model data in subgraph + * @param [in] proto network model + * @param [in] callback callback of subgraph + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status ParseProtoWithSubgraph(const google::protobuf::Message *proto, + GetGraphCallback callback, + ge::ComputeGraphPtr &graph) = 0; + /** + * @ingroup domi_omg + * @brief Convert model files to JSON format + * @param [in] model_file Model file path to be converted + * @param [out] json_file Converted JSON file path + * @return SUCCESS + * @return Others failed + */ + virtual Status ToJson(const char *model_file, const char *json_file) { return domi::SUCCESS; } + + /* + * @ingroup domi_omg + * @brief Convert network data type + * @param [in] type Data type to be converted + * @return ge::DataType + */ + virtual ge::DataType ConvertToGeDataType(const uint32_t type) = 0; + + virtual Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) = 0; +}; +} // namespace domi + +#endif // INC_FRAMEWORK_OMG_PARSER_MODEL_PARSER_H_ diff --git a/inc/framework/omg/parser/op_parser.h b/inc/framework/omg/parser/op_parser.h new file mode 100644 index 00000000..251c0447 --- /dev/null +++ b/inc/framework/omg/parser/op_parser.h @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_PARSER_OP_PARSER_H_ +#define INC_FRAMEWORK_OMG_PARSER_OP_PARSER_H_ + +#include +#include "common/types.h" +#include "omg/omg_inner_types.h" +#include "proto/om.pb.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "graph/utils/op_desc_utils.h" + +using google::protobuf::Message; +using Status = domi::Status; + +namespace ge { +/** + * @ingroup domi_omg + * @brief Used to analyze operator information + * + */ +class OpParser { + public: + /** + * @ingroup domi_omg + * @brief Deconstructor + */ + virtual ~OpParser() {} + + /** + * @ingroup domi_omg + * @brief Analytic operator parameters + * @param [in] op_src Parameter data to be resolved + * @param [out] graph Parsed parameter data + * @return SUCCESS + * @return FAILED + */ + virtual Status ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) = 0; + + /** + * @ingroup domi_omg + * @brief Analytic operator parameters + * @param [in] op_src Parameter data to be resolved + * @param [out] Operator parameter data + * @return SUCCESS + * @return FAILED + */ + virtual Status ParseParams(const Message *op_src, ge::Operator &op_dest) = 0; + + /** + * @ingroup domi_omg + * @brief Analytic operator weight information + * @param [in] op_src Weight data to be resolved + * @param [out] op_dest Weight data after analysis + * @return SUCCESS + * @return FAILED + */ + virtual Status ParseWeights(const Message *op_src, ge::NodePtr &node) = 0; + + /** + * @ingroup domi_omg + * @brief Get the format information according to the parameters in the operator + * @param [in] op_src Parameter data to be resolved + * @param [out] format Output the parsed format + * @return SUCCESS + * @return FAILED + */ + virtual Status GetFormat(const Message *op_src, domi::domiTensorFormat_t &format) { + (void)op_src; + // Indicates that the op does not provide a value for format + format = domi::DOMI_TENSOR_RESERVED; + return domi::SUCCESS; + } +}; +} // namespace ge + +#endif // INC_FRAMEWORK_OMG_PARSER_OP_PARSER_H_ diff --git a/inc/common/optimizer/graph_optimizer_types.h b/inc/framework/omg/parser/parser_api.h similarity index 59% rename from inc/common/optimizer/graph_optimizer_types.h rename to inc/framework/omg/parser/parser_api.h index 9e1ec96b..382bdfde 100644 --- a/inc/common/optimizer/graph_optimizer_types.h +++ b/inc/framework/omg/parser/parser_api.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,21 +14,18 @@ * limitations under the License. */ -#ifndef INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ -#define INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ +#ifndef INC_FRAMEWORK_OMG_PARSER_PARSER_API_H_ +#define INC_FRAMEWORK_OMG_PARSER_PARSER_API_H_ -#include +#include +#include #include -namespace ge { -enum OPTIMIZER_SCOPE { - UNIT = 0, - ENGINE, -}; +#include "ge/ge_api_error_codes.h" -struct GraphOptimizerAttribute { - std::string engineName; - OPTIMIZER_SCOPE scope; -}; +namespace ge { +// Initialize parser +Status ParserInitialize(const std::map& options); +// Finalize parser, release all resources +Status ParserFinalize(); } // namespace ge - -#endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ +#endif // INC_FRAMEWORK_OMG_PARSER_PARSER_API_H_ diff --git a/inc/framework/omg/parser/parser_factory.h b/inc/framework/omg/parser/parser_factory.h new file mode 100644 index 00000000..90d441d7 --- /dev/null +++ b/inc/framework/omg/parser/parser_factory.h @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_PARSER_PARSER_FACTORY_H_ +#define INC_FRAMEWORK_OMG_PARSER_PARSER_FACTORY_H_ + +#include +#include +#include +#include +#include "framework/common/types.h" +#include "framework/omg/omg_inner_types.h" + +using Status = domi::Status; + +namespace domi { +class WeightsParser; +class ModelParser; + +typedef std::shared_ptr (*MODEL_PARSER_CREATOR_FUN)(void); + +// Create modelparser for different frameworks +class ModelParserFactory { + public: + static ModelParserFactory *Instance(); + + /** + * @ingroup domi_omg + * @brief Create a modelparser based on the type entered + * @param [in] type Framework type + * @return Created modelparser + */ + std::shared_ptr CreateModelParser(const domi::FrameworkType type); + + /** + * @ingroup domi_omg + * @brief Register create function + * @param [in] type Framework type + * @param [in] fun ModelParser's create function + */ + void RegisterCreator(const domi::FrameworkType type, MODEL_PARSER_CREATOR_FUN fun); + + protected: + ModelParserFactory() {} + ~ModelParserFactory(); + + private: + std::map creator_map_; +}; // end class ModelParserFactory + +class ModelParserRegisterar { + public: + ModelParserRegisterar(const domi::FrameworkType type, MODEL_PARSER_CREATOR_FUN fun) { + ModelParserFactory::Instance()->RegisterCreator(type, fun); + } + ~ModelParserRegisterar() {} +}; + +// Registration macros for model parsers +#define REGISTER_MODEL_PARSER_CREATOR(type, clazz) \ + std::shared_ptr Creator_##type##_Model_Parser() { \ + std::shared_ptr ptr = nullptr; \ + try { \ + ptr = make_shared(); \ + } catch (...) { \ + ptr = nullptr; \ + } \ + return std::shared_ptr(ptr); \ + } \ + ModelParserRegisterar g_##type##_Model_Parser_Creator(type, Creator_##type##_Model_Parser) + +typedef std::shared_ptr (*WEIGHTS_PARSER_CREATOR_FUN)(void); + +// Create weightsparser for different frameworks +class WeightsParserFactory { + public: + static WeightsParserFactory *Instance(); + + /** + * @ingroup domi_omg + * @brief Create weightsparser based on the type entered + * @param [in] type Framework type + * @return Created weightsparser + */ + std::shared_ptr CreateWeightsParser(const domi::FrameworkType type); + + /** + * @ingroup domi_omg + * @brief Register create function + * @param [in] type Framework type + * @param [in] fun WeightsParser's create function + */ + void RegisterCreator(const domi::FrameworkType type, WEIGHTS_PARSER_CREATOR_FUN fun); + + protected: + WeightsParserFactory() {} + ~WeightsParserFactory(); + + private: + std::map creator_map_; +}; // end class WeightsParserFactory + +class WeightsParserRegisterar { + public: + WeightsParserRegisterar(const domi::FrameworkType type, WEIGHTS_PARSER_CREATOR_FUN fun) { + WeightsParserFactory::Instance()->RegisterCreator(type, fun); + } + ~WeightsParserRegisterar() {} +}; + +// Register macro of weight resolver +#define REGISTER_WEIGHTS_PARSER_CREATOR(type, clazz) \ + std::shared_ptr Creator_##type##_Weights_Parser() { \ + std::shared_ptr ptr = nullptr; \ + try { \ + ptr = make_shared(); \ + } catch (...) { \ + ptr = nullptr; \ + } \ + return std::shared_ptr(ptr); \ + } \ + WeightsParserRegisterar g_##type##_Weights_Parser_Creator(type, Creator_##type##_Weights_Parser) +}; // namespace domi + +#endif // INC_FRAMEWORK_OMG_PARSER_PARSER_FACTORY_H_ diff --git a/inc/framework/omg/parser/parser_inner_ctx.h b/inc/framework/omg/parser/parser_inner_ctx.h new file mode 100644 index 00000000..53f79895 --- /dev/null +++ b/inc/framework/omg/parser/parser_inner_ctx.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_PARSER_PARSER_INNER_CONTEXT_H_ +#define INC_FRAMEWORK_OMG_PARSER_PARSER_INNER_CONTEXT_H_ + +#include +#include +#include +#include +#include +#include +#include "external/register/register_fmk_types.h" +#include "external/register/register_types.h" +#include "framework/omg/omg_inner_types.h" + +namespace ge { +struct ParserContext { + std::unordered_map> input_dims; + domi::domiTensorFormat_t format = domi::DOMI_TENSOR_ND; + RunMode run_mode = ONLY_PRE_CHECK; + std::string custom_proto_path; // save caffe custom proto path, used by caffe parse + std::string caffe_proto_path; // save caffe proto path, used by caffe parse + std::string enable_scope_fusion_passes; // name of the pass that needs to take effect +}; + +ParserContext &GetParserContext(); +} // namespace ge + +#endif // INC_FRAMEWORK_OMG_PARSER_PARSER_INNER_CONTEXT_H_ diff --git a/inc/framework/omg/parser/weights_parser.h b/inc/framework/omg/parser/weights_parser.h new file mode 100644 index 00000000..1b5216b3 --- /dev/null +++ b/inc/framework/omg/parser/weights_parser.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_PARSER_WEIGHTS_PARSER_H_ +#define INC_FRAMEWORK_OMG_PARSER_WEIGHTS_PARSER_H_ + +#include "graph/graph.h" +#include "graph/attr_value.h" +#include "graph/compute_graph.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "graph/operator.h" +#include "graph/range_vistor.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" + +namespace domi { +/** + * @ingroup domi_omg + * @brief Weight information resolver + * + */ +class WeightsParser { + public: + /** + * @ingroup domi_omg + * @brief Constructor + */ + WeightsParser() {} + + /** + * @ingroup domi_omg + * @brief Deconstructor + */ + virtual ~WeightsParser() {} + + /** + * @ingroup domi_omg + * @brief Analyze weight data + * @param [in] file Path of weight file after training + * @param [in|out] graph Graph for saving weight information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status Parse(const char *file, ge::Graph &graph) = 0; + + /** + * @ingroup domi_omg + * @brief Parse relevant data from memory and save it to graph + * @param [in] input Model file memory data + * @param [in|out] graph A graph for saving the model information after analysis + * @return SUCCESS + * @return FAILED + * @author + */ + virtual Status ParseFromMemory(const char *input, uint32_t lengt, ge::ComputeGraphPtr &graph) = 0; +}; +} // namespace domi + +#endif // INC_FRAMEWORK_OMG_PARSER_WEIGHTS_PARSER_H_ diff --git a/inc/graph/anchor.h b/inc/graph/anchor.h deleted file mode 100644 index 565f0843..00000000 --- a/inc/graph/anchor.h +++ /dev/null @@ -1,284 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_ANCHOR_H_ -#define INC_GRAPH_ANCHOR_H_ - -#include -#include -#include -#include "graph/ge_error_codes.h" -#include "graph/range_vistor.h" -#include "graph/types.h" - -namespace ge { -enum AnchorStatus { - ANCHOR_SUSPEND = 0, // dat null - ANCHOR_CONST = 1, - ANCHOR_DATA = 2, // Effective - ANCHOR_RESERVED = 3 -}; -using std::string; -using std::vector; - -class Node; - -using NodePtr = std::shared_ptr; - -class Edge; - -using EdgePtr = std::shared_ptr; - -class Anchor; - -using AnchorPtr = std::shared_ptr; - -class DataAnchor; - -using DataAnchorPtr = std::shared_ptr; - -class InDataAnchor; - -using InDataAnchorPtr = std::shared_ptr; - -class OutDataAnchor; - -using OutDataAnchorPtr = std::shared_ptr; - -class ControlAnchor; - -using ControlAnchorPtr = std::shared_ptr; - -class InControlAnchor; - -using InControlAnchorPtr = std::shared_ptr; - -class OutControlAnchor; - -using OutControlAnchorPtr = std::shared_ptr; - -using ConstAnchor = const Anchor; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable_shared_from_this { - friend class AnchorUtils; - - public: - using TYPE = const char *; - template - using Vistor = RangeVistor>; - - Anchor(const NodePtr &ownerNode, int idx); - - virtual ~Anchor() = default; - - protected: - // Whether the two anchor is equal - virtual bool Equal(AnchorPtr anchor) const = 0; - virtual bool IsTypeOf(TYPE type) const; - - public: - // Get all peer anchors connected to current anchor - Vistor GetPeerAnchors() const; - // Get peer anchor size - size_t GetPeerAnchorsSize() const; - // Get first peer anchor - AnchorPtr GetFirstPeerAnchor() const; - - // Get the anchor belong to which node - NodePtr GetOwnerNode() const; - - // Remove all links with the anchor - void UnlinkAll() noexcept; - - // Remove link with the given anchor - graphStatus Unlink(const AnchorPtr &peer); - - // Replace peer with new peers - graphStatus ReplacePeer(const AnchorPtr &oldPeer, const AnchorPtr &firstPeer, const AnchorPtr &secondPeer); - - // Judge if the anchor is linked with the given anchor - bool IsLinkedWith(const AnchorPtr &peer); - - // Get anchor index of the node - int GetIdx() const; - - // set anchor index of the node - void SetIdx(int index); - - protected: - // All peer anchors connected to current anchor - vector> peer_anchors_; - // The owner node of anchor - std::weak_ptr owner_node_; - // The index of current anchor - int idx_; - template - static Anchor::TYPE TypeOf() { - static_assert(std::is_base_of::value, "T must be a Anchor!"); - return __PRETTY_FUNCTION__; - } - - public: - template - static std::shared_ptr DynamicAnchorCast(AnchorPtr anchorPtr) { - static_assert(std::is_base_of::value, "T must be a Anchor!"); - if (anchorPtr == nullptr || !anchorPtr->IsTypeOf()) { - return nullptr; - } - return std::static_pointer_cast(anchorPtr); - } - - template - bool IsTypeOf() { - return IsTypeOf(TypeOf()); - } -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY DataAnchor : public Anchor { - friend class AnchorUtils; - - public: - explicit DataAnchor(const NodePtr &ownerNode, int idx); - - virtual ~DataAnchor() = default; - - protected: - bool IsTypeOf(TYPE type) const override; - - private: - Format format_{FORMAT_ND}; - AnchorStatus status_{ANCHOR_SUSPEND}; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchor : public DataAnchor { - friend class OutDataAnchor; - - friend class OutControlAnchor; - - public: - explicit InDataAnchor(const NodePtr &ownerNode, int idx); - - virtual ~InDataAnchor() = default; - - // Get source out data anchor - OutDataAnchorPtr GetPeerOutAnchor() const; - - // Build connection from OutDataAnchor to InDataAnchor - graphStatus LinkFrom(const OutDataAnchorPtr &src); - - protected: - bool Equal(AnchorPtr anchor) const override; - bool IsTypeOf(TYPE type) const override; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchor : public DataAnchor { - friend class InDataAnchor; - - friend class AnchorUtils; - - public: - template - using Vistor = RangeVistor>; - - explicit OutDataAnchor(const NodePtr &ownerNode, int idx); - - virtual ~OutDataAnchor() = default; - // Get dst in data anchor(one or more) - Vistor GetPeerInDataAnchors() const; - uint32_t GetPeerInDataNodesSize() const; - - // Get dst in control anchor(one or more) - Vistor GetPeerInControlAnchors() const; - - // Build connection from OutDataAnchor to InDataAnchor - graphStatus LinkTo(const InDataAnchorPtr &dest); - - // Build connection from OutDataAnchor to InControlAnchor - graphStatus LinkTo(const InControlAnchorPtr &dest); - - protected: - bool Equal(AnchorPtr anchor) const override; - bool IsTypeOf(TYPE type) const override; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ControlAnchor : public Anchor { - public: - explicit ControlAnchor(const NodePtr &ownerNode); - - explicit ControlAnchor(const NodePtr &ownerNode, int idx); - - virtual ~ControlAnchor() = default; - - protected: - bool IsTypeOf(TYPE type) const override; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InControlAnchor : public ControlAnchor { - friend class OutControlAnchor; - - friend class OutDataAnchor; - - public: - explicit InControlAnchor(const NodePtr &ownerNode); - - explicit InControlAnchor(const NodePtr &ownerNode, int idx); - - virtual ~InControlAnchor() = default; - - // Get source out control anchors - Vistor GetPeerOutControlAnchors() const; - bool IsPeerOutAnchorsEmpty() const { return peer_anchors_.empty(); } - - // Get source out data anchors - Vistor GetPeerOutDataAnchors() const; - - // Build connection from OutControlAnchor to InControlAnchor - graphStatus LinkFrom(const OutControlAnchorPtr &src); - - protected: - bool Equal(AnchorPtr anchor) const override; - bool IsTypeOf(TYPE type) const override; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutControlAnchor : public ControlAnchor { - friend class InControlAnchor; - - public: - template - using Vistor = RangeVistor>; - - explicit OutControlAnchor(const NodePtr &ownerNode); - - explicit OutControlAnchor(const NodePtr &ownerNode, int idx); - - virtual ~OutControlAnchor() = default; - - // Get dst in control anchor(one or more) - Vistor GetPeerInControlAnchors() const; - // Get dst data anchor in control anchor(one or more) - Vistor GetPeerInDataAnchors() const; - - // Build connection from OutControlAnchor to InControlAnchor - graphStatus LinkTo(const InControlAnchorPtr &dest); - // Build connection from OutDataAnchor to InDataAnchor - graphStatus LinkTo(const InDataAnchorPtr &dest); - - protected: - bool Equal(AnchorPtr anchor) const override; - bool IsTypeOf(TYPE type) const override; -}; -} // namespace ge -#endif // INC_GRAPH_ANCHOR_H_ diff --git a/inc/graph/attr_value_serializable.h b/inc/graph/attr_value_serializable.h deleted file mode 100644 index a69beb96..00000000 --- a/inc/graph/attr_value_serializable.h +++ /dev/null @@ -1,191 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ -#define INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ - -#include -#include -#include "graph/ge_attr_value.h" - -namespace ge { - -class GeAttrValue; -class _GeSerializable { - public: - template - struct ge_serializable_int64_t_support_type { - using DT = typename std::remove_cv::type; - static const bool value = std::is_same::value // by cast - || std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value; - }; - - template - static GeAttrValue SaveItemAsAttrValue(const T &t) { - return GeAttrValue::CreateFrom(t); - } - - template - static GeAttrValue SaveItemAsAttrValue(const vector &t) { - return GeAttrValue::CreateFrom(t); - } - - template = 0, typename DT = typename std::remove_cv::type> - static GeAttrValue SaveItemAsAttrValue(const T &t) { - return GeAttrValue::CreateFrom
(t); - } - // int64_t support type - template ::value, int>::type = 0> - static GeAttrValue SaveItemAsAttrValue(const T &t) { - return GeAttrValue::CreateFrom(t); - } - // vector int64_t support type - template ::value, int>::type = 0> - static GeAttrValue SaveItemAsAttrValue(const vector &t) { - return GeAttrValue::CreateFrom(t); - } - - template - static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { - return attrVal.GetValue(t); - } - - template - static graphStatus LoadItemFromAttrValue(vector &t, GeAttrValue &attrVal) { - return attrVal.GetValue(t); - } - - template = 0, typename DT = typename std::remove_cv::type> - static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { - return attrVal.GetValue
(t); - } - - template ::value, int>::type = 0> - static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { - return attrVal.GetValue(t); - } - - template ::value, int>::type = 0> - static graphStatus LoadItemFromAttrValue(vector &t, GeAttrValue &attrVal) { - return attrVal.GetValue(t); - } - - template - static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { - GeAttrValue itemVal = SaveItemAsAttrValue(item); - (void)namedAttrs.SetAttr(itemName, itemVal); - SaveItem(namedAttrs, args...); - } - - static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) {} - - template - static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { - auto itemVal = namedAttrs.GetItem(itemName); - auto status = LoadItemFromAttrValue(item, itemVal); - if (status != GRAPH_SUCCESS) { - return status; - } - return LoadItem(namedAttrs, args...); - } - - static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) { - return GRAPH_SUCCESS; - } -}; - -#define _GE_FI(a) #a, a -#define _GE_MAP_FIELDS1(a1) _GE_FI(a1) -#define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2) -#define _GE_MAP_FIELDS3(a1, a2, a3) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3) -#define _GE_MAP_FIELDS4(a1, a2, a3, a4) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4) -#define _GE_MAP_FIELDS5(a1, a2, a3, a4, a5) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5) -#define _GE_MAP_FIELDS6(a1, a2, a3, a4, a5, a6) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6) -#define _GE_MAP_FIELDS7(a1, a2, a3, a4, a5, a6, a7) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7) -#define _GE_MAP_FIELDS8(a1, a2, a3, a4, a5, a6, a7, a8) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8) -#define _GE_MAP_FIELDS9(a1, a2, a3, a4, a5, a6, a7, a8, a9) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9) -#define _GE_MAP_FIELDS10(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10) -#define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ - _GE_FI(a11) -#define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ - _GE_FI(a11), _GE_FI(a12) -#define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ - _GE_FI(a11), _GE_FI(a12), _GE_FI(a13) -#define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ - _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14) -#define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ - _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15) - -#define _GE_PRIVATE_ARGS_GLUE(x, y) x y - -#define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, \ - ...) \ - N -#define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL(args) _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT args -#define _GE_COUNT_MACRO_VAR_ARGS(...) \ - _GE_PRIVATE_MACRO_VAR_ARGS_IMPL((__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)) - -#define _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count) M##count -#define _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count) -#define _GE_PRIVATE_MACRO_CHOOSE_HELPER(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count) - -#define _GE_INVOKE_VAR_MACRO(...) \ - _GE_PRIVATE_ARGS_GLUE(_GE_PRIVATE_MACRO_CHOOSE_HELPER(_GE_MAP_FIELDS, _GE_COUNT_MACRO_VAR_ARGS(__VA_ARGS__)), \ - (__VA_ARGS__)) - -#define GE_SERIALIZABLE(...) \ - public: \ - friend class ge::GeAttrValue; \ - using __ge_serializable = int; \ - \ - private: \ - ge::graphStatus Save(GeAttrValue &ar) const { \ - GeAttrValue::NAMED_ATTRS named_attrs; \ - _GeSerializable::SaveItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ - return ar.SetValue(named_attrs); \ - } \ - ge::graphStatus Load(const GeAttrValue &ar) { \ - GeAttrValue::NAMED_ATTRS named_attrs; \ - ge::graphStatus status = ar.GetValue(named_attrs); \ - if (status != GRAPH_SUCCESS) { \ - return status; \ - } \ - return _GeSerializable::LoadItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ - } - -// end NamedAttrs Helper: GE_SERIALIZABLE -} // namespace ge -#endif // INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ diff --git a/inc/graph/buffer.h b/inc/graph/buffer.h deleted file mode 100644 index ca4355a7..00000000 --- a/inc/graph/buffer.h +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_BUFFER_H_ -#define INC_GRAPH_BUFFER_H_ - -#include -#include -#include -#include -#include "detail/attributes_holder.h" - -namespace ge { -#ifdef HOST_VISIBILITY -#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_HOST_VISIBILITY -#endif -#ifdef DEV_VISIBILITY -#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_DEV_VISIBILITY -#endif - -using std::shared_ptr; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer { - public: - Buffer(); - Buffer(const Buffer &other); - - explicit Buffer(std::size_t bufferSize, std::uint8_t defualtVal = 0); - - ~Buffer() = default; - - Buffer &operator=(const Buffer &other); - - static Buffer CopyFrom(const std::uint8_t *data, std::size_t bufferSize); - - const std::uint8_t *GetData() const; - std::uint8_t *GetData(); - std::size_t GetSize() const; - void ClearBuffer(); - - // For compatibility - inline const std::uint8_t *data() const { return GetData(); } - inline std::uint8_t *data() { return GetData(); } // lint !e659 - inline std::size_t size() const { return GetSize(); } - inline void clear() { return ClearBuffer(); } - uint8_t operator[](size_t index) const { // lint !e1022 !e1042 - if (buffer_ != nullptr && index < buffer_->size()) { // lint !e574 - return (uint8_t)(*buffer_)[index]; - } - return 0xff; - } - - private: - GeIrProtoHelper data_; - std::string *buffer_ = nullptr; - - // Create from protobuf obj - Buffer(const ProtoMsgOwner &protoOnwer, proto::AttrDef *buffer); - Buffer(const ProtoMsgOwner &protoOnwer, std::string *buffer); - - friend class GeAttrValueImp; - friend class GeTensor; -}; -} // namespace ge -#endif // INC_GRAPH_BUFFER_H_ diff --git a/inc/graph/compute_graph.h b/inc/graph/compute_graph.h deleted file mode 100644 index 2ec6b663..00000000 --- a/inc/graph/compute_graph.h +++ /dev/null @@ -1,308 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_COMPUTE_GRAPH_H_ -#define INC_GRAPH_COMPUTE_GRAPH_H_ - -#include -#include -#include -#include -#include -#include -#include "detail/attributes_holder.h" -#include "graph/anchor.h" -#include "graph/node.h" -#include "graph/op_desc.h" -#include "graph/range_vistor.h" - -namespace ge { -class Node; -using NodePtr = std::shared_ptr; -class Edge; -using EdgePtr = std::shared_ptr; - -class InDataAnchor; -using InDataAnchorPtr = std::shared_ptr; - -class OutDataAnchor; -using OutDataAnchorPtr = std::shared_ptr; - -class ControlAnchor; -using ControlAnchorPtr = std::shared_ptr; -class InControlAnchor; -using InControlAnchorPtr = std::shared_ptr; -class OutControlAnchor; -using OutControlAnchorPtr = std::shared_ptr; -class GeAttrValue; -using AttrValuePtr = std::shared_ptr; -using ConstComputeGraph = const ComputeGraph; - -class OperatorImpl; -using OperatorImplPtr = std::shared_ptr; - -class ComputeGraph : public std::enable_shared_from_this, public AttrHolder { - friend class GraphUtils; - - public: - template - using Vistor = RangeVistor>; - - explicit ComputeGraph(const std::string &name); - ~ComputeGraph() override; - - std::string GetName() const; - void SetName(const std::string &name); - - using AttrHolder::DelAttr; - using AttrHolder::GetAttr; - using AttrHolder::HasAttr; - using AttrHolder::SetAttr; - - size_t GetAllNodesSize() const; - Vistor GetAllNodes() const; - // is_unknown_shape: false, same with GetAllNodes func - // is_unknown_shape: true, same with GetDirectNodes func - Vistor GetNodes(bool is_unknown_shape) const; - size_t GetDirectNodesSize() const; - Vistor GetDirectNode() const; - Vistor GetInputNodes() const; - Vistor GetOutputNodes() const; - - NodePtr FindNode(const std::string &name) const; - NodePtr FindFirstNodeMatchType(const std::string &name) const; - /*lint -e504*/ - // AddNode with NodePtr - NodePtr AddNode(NodePtr node); - NodePtr AddNode(OpDescPtr op); - NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize - NodePtr AddNodeFront(NodePtr node); - NodePtr AddNodeFront(const OpDescPtr &op); - NodePtr AddInputNode(NodePtr node); - NodePtr AddOutputNode(NodePtr node); - NodePtr AddOutputNodeByIndex(NodePtr node, int32_t index); - // insert node with specific pre_node - NodePtr AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node); - NodePtr AddNodeAfter(NodePtr node, const NodePtr &pre_node); - - graphStatus RemoveNode(const NodePtr &node); - graphStatus RemoveInputNode(const NodePtr &node); - graphStatus RemoveOutputNode(const NodePtr &node); - graphStatus RemoveConstInput(const NodePtr &node); - - /// Add a subgraph to this graph. The subgraph must has a parent graph and parent node, - /// which means the member functions `SetParentGraph` and `SetParentNode` of the subgraph - /// must be called before add it to the root graph. and subgraph->GetParentNode()->GetOwnerGraph() - /// must equal to subgraph->GetOwnerGraph(). - /// The subgraphs can only be added to a *root graph*. A root graph is a graph without any parent graph. - /// The subgraph's name SHOULD(not must) be the same as the parameter `name` - graphStatus AddSubgraph(const std::string &name, const std::shared_ptr &subgraph); - graphStatus AddSubgraph(const std::shared_ptr &subgraph); - - void RemoveSubgraph(const std::string &name); - void RemoveSubgraph(const std::shared_ptr &subgraph); - - std::shared_ptr GetSubgraph(const std::string &name) const; - std::vector> GetAllSubgraphs() const; - - // obsolete - std::shared_ptr AddSubGraph(std::shared_ptr sub_graph); - // obsolete - graphStatus RemoveSubGraph(const std::shared_ptr &sub_graph); - - /// - /// @brief Update input-mapping - /// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input - /// @return graphStatus - /// - graphStatus UpdateInputMapping(const std::map &input_mapping); - - /// - /// @brief Update output-mapping - /// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output - /// @return graphStatus - /// - graphStatus UpdateOutputMapping(const std::map &output_mapping); - - graphStatus TopologicalSorting(); - bool IsValid() const; - void InValid() { is_valid_flag_ = false; } - void Dump() const; - - void Swap(ComputeGraph &graph); - - graphStatus IsolateNode(const NodePtr &node); - graphStatus Verify(); - graphStatus InferShape(); - graphStatus InferOriginFormat(); - graphStatus InferShapeInNeed(); - graphStatus InsertEventNodes(); - bool operator==(const ComputeGraph &r_compute_graph) const; - - /*lint +e504*/ - const std::map, std::vector> &GetShareParamLayer() const { - return params_share_map_; - } - - void SetShareParamLayer(const std::map, std::vector> params_share_map) { - params_share_map_ = params_share_map; - } - - void SetInputsOrder(const std::vector &inputs_order) { inputs_order_ = inputs_order; } - - void SetGraphOutNodes(std::map> out_nodes_map) { out_nodes_map_ = out_nodes_map; } - - void AppendGraphOutNodes(std::map> out_nodes_map) { - for (auto &item : out_nodes_map) { - (void)out_nodes_map_.emplace(item.first, item.second); - } - } - - shared_ptr GetParentGraph(); - void SetParentGraph(const shared_ptr &parent); - shared_ptr GetParentNode(); - void SetParentNode(const shared_ptr &parent); - - const std::map> &GetGraphOutNodes() const { return out_nodes_map_; } - - void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; } - - ComputeGraphPtr GetOrigGraph(void) { return origGraph_; } - void SetOutputSize(uint32_t size) { output_size_ = size; } - uint32_t GetOutputSize() const { return output_size_; } - void SetInputSize(uint32_t size) { input_size_ = size; } - uint32_t GetInputSize() const { return input_size_; } - - // false: known shape true: unknow shape - bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; } - void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; } - - /// - /// Set is need train iteration. - /// If set true, it means this graph need to be run iteration some - /// times(according variant "npu_runconfig/iterations_per_loop"). - /// @param need_iteration is need iteration - /// - void SetNeedIteration(bool need_iteration) { need_iteration_ = need_iteration; } - - void SetUserDefOutput(const std::string &output_name); - - const std::string GetOutput(); - - /// - /// Get is need train iteration. - /// @return is need iteration - /// - bool GetNeedIteration() const { return need_iteration_; } - - void SetGraphOpName(const std::map &op_name_map) { op_name_map_ = op_name_map; } - const std::map &GetGraphOpName() const { return op_name_map_; } - - const std::map &GetAllNodesInfo() const; - - void SetAllNodesInfo(const std::map &nodes) { all_nodes_infos_ = nodes; } - - void SetGraphOutNodesInfo(std::vector> &out_nodes_info) { - output_nodes_info_ = out_nodes_info; - } - - void AppendGraphOutNodesInfo(std::vector> &out_nodes_info) { - output_nodes_info_.insert(output_nodes_info_.end(), out_nodes_info.begin(), out_nodes_info.end()); - } - - const std::vector> &GetGraphOutNodesInfo() const { return output_nodes_info_; } - - void SetGraphTargetNodesInfo(const std::vector &target_nodes_info) { - target_nodes_info_ = target_nodes_info; - } - const std::vector &GetGraphTargetNodesInfo() const { return target_nodes_info_; } - - void SetSessionID(uint64_t session_id) { session_id_ = session_id; } - uint64_t GetSessionID() const { return session_id_; } - - void SetGraphID(uint32_t graph_id) { graph_id_ = graph_id; } - uint32_t GetGraphID() const { return graph_id_; } - - void SaveDataFormat(ge::Format data_format) { data_format_ = data_format; } - ge::Format GetDataFormat() const { return data_format_; } - bool IsSummaryGraph() const { return is_summary_graph_; } - void SetSummaryFlag(bool is_summary_graph) { is_summary_graph_ = is_summary_graph; } - // Graph Before BFE - ComputeGraphPtr origGraph_; - - protected: - ProtoAttrMapHelper MutableAttrMap() override; - ConstProtoAttrMapHelper GetAttrMap() const override; - - private: - graphStatus DFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, - std::vector &stack); - graphStatus BFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, - std::deque &stack); - graphStatus CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, - std::map &breadth_node_map); - graphStatus TopologicalSortingGraph(); - graphStatus SortNodes(std::vector &stack, std::map &mapInEdgeNum); - Vistor AllGraphNodes(std::vector> &subgraphs) const; - size_t GetInEdgeSize(const NodePtr &node); - size_t GetOutEdgeSize(const NodePtr &node); - graphStatus RemoveExtraOutEdge(const NodePtr &node); - bool GraphMembersAreEqual(const ComputeGraph &r_graph) const; - bool GraphAttrsAreEqual(const ComputeGraph &r_graph) const; - bool VectorInputNodePtrIsEqual(const std::vector &r_node_ptr_vector, - const std::vector &l_node_ptr_vector) const; - - void SetNodesOwner(); - - friend class ModelSerializeImp; - friend class GraphDebugImp; - friend class OnnxUtils; - friend class TuningUtils; - - std::string name_; - uint32_t graph_id_ = 0; - ProtoAttrMapHelper attrs_; - std::vector nodes_; - std::map all_nodes_infos_; - std::vector target_nodes_info_; - - std::vector input_nodes_; - std::vector inputs_order_; - uint32_t input_size_ = 1; - std::map> out_nodes_map_; - uint32_t output_size_ = 1; - std::vector> output_nodes_info_; - - std::vector> sub_graph_; - std::map> names_to_subgraph_; - std::weak_ptr parent_graph_; - std::weak_ptr parent_node_; - - // the members followed should not in the ComputeGraph class - bool is_valid_flag_; - bool is_summary_graph_ = false; - // Indicates whether it is need iteration - bool need_iteration_ = false; - std::map, std::vector> params_share_map_; - // TaskIdx -> op_name Map - std::map op_name_map_; - uint64_t session_id_ = 0; - ge::Format data_format_ = ge::FORMAT_ND; - // unknown graph indicator, default is false, mean known shape - bool is_unknown_shape_graph_ = false; -}; -} // namespace ge -#endif // INC_GRAPH_COMPUTE_GRAPH_H_ diff --git a/inc/graph/debug/ge_attr_define.h b/inc/graph/debug/ge_attr_define.h deleted file mode 100644 index 47b11ba8..00000000 --- a/inc/graph/debug/ge_attr_define.h +++ /dev/null @@ -1,1130 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*lint -e618*/ -#ifndef INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ -#define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ - -#include -#include "graph/types.h" - -namespace ge { -#ifdef HOST_VISIBILITY -#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_HOST_VISIBILITY -#endif -#ifdef DEV_VISIBILITY -#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_DEV_VISIBILITY -#endif -// Public attribute -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_UNKNOWN_SHAPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAME; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WORKSPACE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHT_NAME; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_QUANTIZE_FACTOR; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ALPHA; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BETA; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADMODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADMODES; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILTER; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD_MODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SCALE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WINDOWS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GLOBAL_POOLING; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELU_FLAG; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ALGO; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORMAT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_FORMAT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_SHAPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILTER_FORMAT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_K; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_NORM_REGION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_LOCAL_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_ALPHA; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_BETA; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROADCAST; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TIDX; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TPADDINGS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_IMG_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_IMG_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NET_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NET_W; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TMULTIPLES; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTIPLES; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_T; - -extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string ATTR_NAME_N; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TSHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_INPUTS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_OUTPUTS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DIMS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_AIPP_INPUT_DIMS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_RELATED_AIPP_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_AIPP_DATA_NAME_MAP; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRAPH_HAS_BEEN_ADDED; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_GRAPH_NAME; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_NODE_DEF; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_OP_DEF; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_TENSOR_DESC; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_TENSOR_DESC; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INFERRED_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_ORIGIN_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_INPUT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_OUTPUT; - -// to be deleted -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_LOC_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_CONF_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_OCR_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; - -// _Arg -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INDEX; -// _RetVal -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETVAL_ATTR_NAME_INDEX; -// Data -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DATA_ATTR_NAME_DATA_TYPE; - -// Send -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SEND_ATTR_EVENT_ID; - -// Recv -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RECV_ATTR_EVENT_ID; - -// Convolution -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COEF; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDES; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATIONS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_MODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_ALGO; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_GROUP; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_PAD_MODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_PAD; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_STRIDE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_DILATION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_NUM_OUTPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_KERNEL; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_FILTER; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_BIAS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_RELU_FLAG; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_ADJ; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_TARGET_SHAPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_BEFORE_PAD; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_HAS_BIAS; - -// Pooling -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_NAN_OPT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_PAD_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_GLOBAL_POOLING; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_WINDOW; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_PAD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_STRIDE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_CEIL_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_DATA_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_BEFORE_PAD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_NAME_ALGO; - -// Eltwise -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_COEFF; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_WEIGHT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_RELU_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_ALPHA; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_BETA; - -// BatchNorm -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_EPSILON; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_MEAN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; - -// Huberloss -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; - -// SSDRealDivTileMul -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; - -// SSDSumMulRealDivMean -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; -/// ConcatFive2Four -/// ConcatFour2Five -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; -// Scale -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; - -// FullConnection -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_FILTER; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_BIAS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_RELU_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_ATTR_NAME_ALGO; - -// SoftmaxOpParams -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_ALGO; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_MODE; - -// SparseSoftmaxCrossEntropy -extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE; -extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD; -// Attr labelSmoothing -extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING; - -// ApplyMomentum -extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string APPLYMENTUM_ATTR_IS_GRAPH_FUSION; - -// Activation -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ACTIVATION_ATTR_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ACTIVATION_ATTR_COEF; - -// Concat -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_ATTR_NAME_AXIS; - -// Const -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_DATA_TRANSTYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; - -// Roipooling -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLED_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLED_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO; - -// DetectionOutput -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IMG_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IMG_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE; -// Ssd DetectionOutput -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_ETA; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K; - -// Refinedet DetectionOutput -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE; - -// Yolo DetectionOutput -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_ClASSES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BIASES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_RELATIVE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION; - -// DetectionPostprocess -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_CLS_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_POST_NMS_TOPN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT; - -// Spatialtransfrom -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_OUTPUT_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_OUTPUT_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM; - -// Proposal -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_BASE_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_MIN_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_RATIO; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_NMS_THRESH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_TOP_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_IMG_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_IMG_W; -// Softmax -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_AXIS; - -// Permute -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; - -// SSD Normalize -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_EPS; - -// Flatten -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_ATTR_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_ATTR_END_AXIS; - -// SsdPRIORBOX -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_FLIP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_CLIP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_IMG_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_IMG_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_STEP_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_STEP_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_VARIANCE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM; - -// PRelu -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PRELU_ATTR_CHANNEL_SHARED; - -// Psroi pooling -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_OUTPUT_DIM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_GROUP_SIZE; - -// Power -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_POWER; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; - -// Log -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; -// Pack -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; - -// Dynamic stitch -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; -// Unpack -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; -// Gathernd -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND_ATTR_NAME_TINDICES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND_ATTR_NAME_TPARAMS; - -// Argmax -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; - -// Upsample -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; -// Relu -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; - -// FreeSpaceExtract -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT; - -// Split -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_SLICE_POINT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_SIZE_SPLIT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_NUM_SPLIT; - -// Tvm -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_MAGIC; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_BLOCKDIM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_METADATA; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_WORKSPACE_TYPE; - -// Squeeze -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_DIMS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_OP_NAME; - -// Stride slice -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_BEGIN_MASK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_END_MASK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK; - -// Slice -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SLICE_ATTR_NAME_BEGINS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SLICE_ATTR_NAME_SIZES; - -// Roialign -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SPATIAL_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W; - -// Generate_rpn_proposal -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string - GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string - GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH; -// Decode_bbox -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DECODE_BBOX_ATTR_DECODECLIP; - -// Cast -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_DSTT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_SRCT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_DST_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_TRUNCATE; - -// Fastrcnnn predications -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES; - -// REORG -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_ATTR_STRIDE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_ATTR_REVERSE; - -// MERGE -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; -static const std::string NOT_NET_OUTPUT = "not_net_output"; - -// ENTER -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_CONSTANT_FLAG; - -// Concatv2 -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_V2_ATTR_TIDX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_V2_ATTR_N; -// SUM -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_TIDX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_KEEP_DIMS; - -// ResizeBilinear -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_HEIGHT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_WIDTH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_PAD_END; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; - -// RetinaNet -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; -// MatMul -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_HAS_BIAS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_ATTR_IS_TRAINING; - -// Flatten -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_START_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_END_AXIS; - -// Reshape -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NUM_AXES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_SHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_ALPHA; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_BETA; - -// Frameoworkop -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string T_IN_DATATYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string T_OUT_DATATYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_N; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_C; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_PAD_DEPTH_CONV; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_PAD_CONV; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BEFORE_PAD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ANN_MEAN_KEEPDIMS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_ATTR_PADDINGDS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_ATTR_CONSTANT_VALUE; - -// ConvGradFilter -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE; -// ConvGradInput -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE; - -// Rnn -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_MODE_STATIC; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MUTI_RNN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CELL_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CNN_RNN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; - -// Upsample -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; - -// PadV2 -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; - -// MirrorPad -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; -// Filler -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; - -// Shufflechannel -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHUFFLE_CHANNEL_GROUP; - -// TopKV2 -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TOPKV2_ATTR_K; - -// Calibaration -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_H_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_W_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_TOP_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_BOTTOM_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_RIGHT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATION_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_EPSILON; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_POOLING_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CLASS_NUM; -// Model -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TARGET_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_STREAM_NUM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_HUGE_STREAM_LIST; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OUT_NODES_NAME; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; - -// Public attribute -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BYTE_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_INFERENCE_ID; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_OPDEF; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IO_OP; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_SCOPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OPATTR; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUFLAG; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SEQLEN_INDEX; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_X_INDEX; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONT_INDEX; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_XSTATIC_INDEX; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_MINI; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_TINY; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_LITE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_OUTPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; - -// Used for operators that do not generate task -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOTASK; - -// Used for operators that output reuse input -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_REUSE_INPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATC_VERSION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OPP_VERSION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; - -// L2_normalize -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; -// HCOM -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; - -// Log time stamp -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; -// SpaceToDepth/DepthToSpace -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; - -// SparseSoftmaxCrossEntropyWithLogits -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; - -// MaxPoolGradWithArgmax -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; - -// AvgPoolGrad -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; - -// Varible -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; - -// Assign -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VAR_NAME; - -// ShapeN -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; - -// Space2bacth batch2space -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; -// Depth_to_space space_to_depth -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; -// FakeQuantWithMinMaxVars -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; -// Mobilenet_ssd_conv_fusion -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; - -// Lsh project -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; - -// Control flow -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; - -// GatherV2 attr def -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; - -// Reshape attr def -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; - -// Axis attr def -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; -// The node link with SparseSoftmaxCrossEntropyWithLogits -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; -// For constant folding -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; - -// Used for mark the active label list to find stream of activated node -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE; - -// Multi batch -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_LABEL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_BATCH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_USER_DESIGNEATE_SHAPE_ORDER; - -// Control flow -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_DATA_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ORIG_NODE_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEXT_ITERATION; - -// Function Op -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; - -// Used for mark the active node is for loop, type:bool -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_INPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_OUTPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_RANGE; - -// Atomic addr clean attrs -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_INPUT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_OUTPUT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_IS_FUSION_NODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_IS_ATOMIC_NODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string EXT_ATTR_ATOMIC_WORKSPACE_INFO; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string EXT_ATTR_ATOMIC_WORKSPACE_OFFSET; -// Used for find variable session_id -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MODEL_ATTR_SESSION_ID; - -// Source/dst format for Op FormatTransfer -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FORMAT_TRANSFER_SRC_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FORMAT_TRANSFER_DST_FORMAT; - -// For compile op by ge call -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NEED_COMPILE; - -// For mutil-batch -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERT_BY_MBATCH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_TYPE; - -// For inserted op -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; - -// For compress weight -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMPRESS_WEIGHT; - -// For data dump -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_SUB_SPLITER_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_GROUP_OP_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; - -// used for lX fusion -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_DUMP_REF; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_ADDR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ENGINE_NAME_FOR_LX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEED_LX_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OPTIMIZE_GROUP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_COMPILE_STRATEGY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_BUFFER; - -// for unregistered op -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_OPPATH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_ATTRLIST; - -// op overflow dump -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_MODE; - -// op dynamic input -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_INPUT_START; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_INPUT_END; - -// functional ops attr -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_ELSE_BRANCH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; - -// used for label switch -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_END_NODE; - -// Variable -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; - -// HCOM -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DATATYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_DATATYPE; -// used for LX tiling -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_L1_SPACE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_TYPE_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST; - -// Dynamic stitch -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; - -// Used for support Horovod -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INTER_EVENT_IDENTIFY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE; -// for gradient group -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG; - -// dynamic shape attrs -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX; - -// atc user def dtype&format -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_DATATYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_FORMAT; - -// for fusion op plugin -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; - -// graph partition for aicpu -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME; - -// input and output memory type -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_VARIABLE_PLACEMENT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INPUT_MEMORY_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OUTPUT_MEMORY_TYPE; - -// input_output_offset -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_BASIC_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET; -} // namespace ge - -#endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ -/*lint +e618*/ diff --git a/inc/graph/def_types.h b/inc/graph/def_types.h deleted file mode 100644 index 6d70fb18..00000000 --- a/inc/graph/def_types.h +++ /dev/null @@ -1,195 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_DEF_TYPES_H_ -#define INC_GRAPH_DEF_TYPES_H_ - -#include -#include -#include -#include "graph/attr_value_serializable.h" -#include "graph/buffer.h" -namespace ge { -#define DEF_TYPE_DEC(type, name) \ - inline void set_##name(const type &value) { name = value; } \ - type *mutable_##name() { return &name; } - -#define DEF_TYPE_HAS_DEC(type, name) \ - inline void set_##name(const type &value) { name = value; } \ - \ - private: \ - bool has_mutable_##name{false}; \ - \ - public: \ - bool has_##name() const { return (has_mutable_##name) || QuantizeFactorHasData(name); } \ - type *mutable_##name() { \ - has_mutable_##name = true; \ - return &name; \ - } - -#define DEF_TYPE_VEC_DEC(type, name) \ - inline int name##_size() const { return name.size(); } \ - inline void clear_##name() { name.clear(); } \ - inline void set_##name(int index, type value) { name[index] = value; } \ - inline void add_##name(type value) { name.push_back(value); } \ - inline std::vector *mutable_##name() { return &name; } - -#define DEF_TYPE_BYTES_DEC(name) \ - inline void clear_##name() { name.ClearBuffer(); } \ - inline void set_##name(const void *value, size_t size) { name = Buffer::CopyFrom((const uint8_t *)(value), size); } \ - inline Buffer *mutable_##name() { return &name; } - -struct CompressInfo { - public: - CompressInfo() {} - CompressInfo(int32_t blockRow, int32_t blockCol, int32_t fractalK, int32_t fractalN, int32_t lastFractalK, - int32_t lastFractalN, int32_t cubeSize, int32_t loadDir) { - blockrow = blockRow; - blockcol = blockCol; - fractalk = fractalK; - fractaln = fractalN; - lastfractalk = lastFractalK; - lastfractaln = lastFractalN; - cubesize = cubeSize; - loaddir = loadDir; - } - - int32_t blockrow{0}; // Block row - int32_t blockcol{0}; // Block col - int32_t fractalk{0}; // Fractal K - int32_t fractaln{0}; // Fractal N - int32_t lastfractalk{0}; // K of last fractal - int32_t lastfractaln{0}; // N of last fractal - int32_t cubesize{0}; // Cube's length - int32_t loaddir{0}; // Data load directtiono 0:col load 1:row load - DEF_TYPE_DEC(int32_t, blockrow); - DEF_TYPE_DEC(int32_t, blockcol); - DEF_TYPE_DEC(int32_t, fractalk); - DEF_TYPE_DEC(int32_t, fractaln); - DEF_TYPE_DEC(int32_t, lastfractalk); - DEF_TYPE_DEC(int32_t, lastfractaln); - DEF_TYPE_DEC(int32_t, cubesize); - DEF_TYPE_DEC(int32_t, loaddir); - - GE_SERIALIZABLE(blockrow, blockcol, fractalk, fractaln, lastfractalk, lastfractaln, cubesize, loaddir); -}; - -enum QuantizeScaleType { VECTOR_SCALE = 0, SCALAR_SCALE = 1 }; -enum QuantizeScaleMode { NORMAL_MODE = 0, SQRT_MODE = 1 }; -enum QuantizeAlgorithm { - NON_OFFSET_ALGO = 0, - HALF_OFFSET_ALGO = 1, - ALL_OFFSET_ALGO = 2, -}; -struct QuantizeFactor { - public: - // QuantizeScaleMode scale_mode; - uint32_t scale_mode{0}; - Buffer scale_value; - int64_t scale_offset{0}; - Buffer offset_data_value; - int64_t offset_data_offset{0}; - Buffer offset_weight_value; - int64_t offset_weight_offset{0}; - Buffer offset_pad_value; - int64_t offset_pad_offset{0}; - - DEF_TYPE_DEC(uint32_t, scale_mode); - DEF_TYPE_BYTES_DEC(scale_value); - - DEF_TYPE_DEC(int64_t, scale_offset); - DEF_TYPE_BYTES_DEC(offset_data_value); - DEF_TYPE_DEC(int64_t, offset_data_offset); - - DEF_TYPE_BYTES_DEC(offset_weight_value); - DEF_TYPE_DEC(int64_t, offset_weight_offset); - DEF_TYPE_BYTES_DEC(offset_pad_value); - DEF_TYPE_DEC(int64_t, offset_pad_offset); - - GE_SERIALIZABLE(scale_mode, scale_value, scale_offset, offset_data_value, offset_data_offset, offset_weight_value, - offset_weight_offset, offset_pad_value, offset_pad_offset) -}; - -static inline bool QuantizeFactorHasData(const QuantizeFactor &factor) { - return factor.scale_value.GetSize() > 0 || factor.offset_data_value.GetSize() > 0 || - factor.offset_weight_value.GetSize() > 0 || factor.offset_pad_value.GetSize() > 0; -} - -struct AllOffsetQuantizeInfo { - public: - AllOffsetQuantizeInfo() {} - AllOffsetQuantizeInfo(float s, int32_t o) : scale(s), offset(o) {} - float scale{0}; - int32_t offset{0}; - - DEF_TYPE_DEC(float, scale); - DEF_TYPE_DEC(int32_t, offset); - - GE_SERIALIZABLE(scale, offset) -}; - -struct QuantizeCalcFactor { - public: - Buffer offsetw; - int64_t offsetw_offset{0}; - Buffer offsetd; - int64_t offsetd_offset{0}; - Buffer scalereq; - int64_t scaledreq_offset{0}; - Buffer offsetdnext; - int64_t offsetdnext_offset{0}; - - DEF_TYPE_BYTES_DEC(offsetw); - DEF_TYPE_DEC(int64_t, offsetw_offset); - DEF_TYPE_BYTES_DEC(offsetd); - DEF_TYPE_DEC(int64_t, offsetd_offset); - DEF_TYPE_BYTES_DEC(scalereq); - DEF_TYPE_DEC(int64_t, scaledreq_offset); - DEF_TYPE_BYTES_DEC(offsetdnext); - DEF_TYPE_DEC(int64_t, offsetdnext_offset); - - GE_SERIALIZABLE(offsetw, offsetw_offset, offsetd, offsetd_offset, scalereq, scaledreq_offset, offsetdnext, - offsetdnext_offset); -}; - -static inline bool QuantizeFactorHasData(const QuantizeCalcFactor &factor) { - return factor.offsetw.GetSize() > 0 || factor.offsetd.GetSize() > 0 || factor.scalereq.GetSize() > 0 || - factor.offsetdnext.GetSize() > 0; -} - -struct QuantizeFactorParams { - uint32_t quantize_algo{0}; - uint32_t scale_type{0}; - QuantizeFactor quantize_param; - QuantizeFactor dequantize_param; - QuantizeFactor requantize_param; - QuantizeCalcFactor quantizecalc_param; - DEF_TYPE_DEC(uint32_t, quantize_algo); - DEF_TYPE_DEC(uint32_t, scale_type); - DEF_TYPE_HAS_DEC(QuantizeFactor, quantize_param); - DEF_TYPE_HAS_DEC(QuantizeFactor, dequantize_param); - DEF_TYPE_HAS_DEC(QuantizeFactor, requantize_param); - DEF_TYPE_HAS_DEC(QuantizeCalcFactor, quantizecalc_param); - - GE_SERIALIZABLE(quantize_algo, scale_type, quantize_param, dequantize_param, requantize_param, quantizecalc_param, - has_mutable_quantize_param, has_mutable_dequantize_param, has_mutable_requantize_param, - has_mutable_quantizecalc_param); -}; - -#undef DEF_TYPE_DEC -} // namespace ge - -#endif // INC_GRAPH_DEF_TYPES_H_ diff --git a/inc/graph/detail/any_map.h b/inc/graph/detail/any_map.h deleted file mode 100644 index 70533ea1..00000000 --- a/inc/graph/detail/any_map.h +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_DETAIL_ANY_MAP_H_ -#define INC_GRAPH_DETAIL_ANY_MAP_H_ - -#include -#include -#include -#include - -namespace ge { -using std::shared_ptr; -using std::string; - -class TypeID { - public: - template - static TypeID Of() { - return TypeID(__PRETTY_FUNCTION__); - } - - ~TypeID() = default; - - bool operator==(const TypeID &__arg) const { return type_ == __arg.type_; } - - private: - explicit TypeID(string type) : type_(std::move(type)) {} // lint !e30 !e32 - - string type_; -}; - -class AnyMap { - public: - template - bool Set(const string &name, const DT &val); - - template - bool Get(const string &name, T &retValue) const; - - bool Has(const string &name) const { return anyValues_.find(name) != anyValues_.end(); } - - void Swap(AnyMap &other) { anyValues_.swap(other.anyValues_); } - - private: - class Placeholder { - public: - virtual ~Placeholder() = default; - - virtual const TypeID &GetTypeInfo() const = 0; - }; - - template - class Holder : public Placeholder { - public: - explicit Holder(const VT &value) : value_(value) {} - - ~Holder() override = default; - - const TypeID &GetTypeInfo() const override { - static const TypeID typeId = TypeID::Of(); - return typeId; - } - - const VT value_; - }; - - std::map> anyValues_; -}; - -template -bool AnyMap::Set(const string &name, const DT &val) { - auto it = anyValues_.find(name); - - std::shared_ptr> tmp; - try { - tmp = std::make_shared>(val); - } catch (std::bad_alloc &e) { - tmp = nullptr; - } catch (...) { - tmp = nullptr; - } - - if (it == anyValues_.end()) { - (void)anyValues_.emplace(name, tmp); - } else { - if (it->second && it->second->GetTypeInfo() == TypeID::Of
()) { - it->second = tmp; - } else { - return false; - } - } - return true; -} - -template -bool AnyMap::Get(const string &name, T &retValue) const { - auto it = anyValues_.find(name); - if (it != anyValues_.end() && it->second && it->second->GetTypeInfo() == TypeID::Of()) { - auto retPtr = std::static_pointer_cast>(it->second); - retValue = retPtr->value_; - return true; - } - return false; -} -} // namespace ge -#endif // INC_GRAPH_DETAIL_ANY_MAP_H_ diff --git a/inc/graph/detail/attributes_holder.h b/inc/graph/detail/attributes_holder.h deleted file mode 100644 index 49741143..00000000 --- a/inc/graph/detail/attributes_holder.h +++ /dev/null @@ -1,165 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ -#define INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ - -#include -#include -#include -#include -#include -#include -#include "graph/detail/any_map.h" -#include "graph/ge_error_codes.h" -#include "graph/types.h" - -namespace google { -namespace protobuf { -class Message; -template -class Map; -} // namespace protobuf -} // namespace google - -namespace ge { -using std::string; -class GeAttrValue; - -namespace proto { -class AttrDef; -class TensorDef; -class TensorDescriptor; -class ShapeDef; -class NamedAttrs; -class ModelDef; -class OpDef; -class GraphDef; -} // namespace proto - -using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>; // lint !e1073 -using ProtoMsgOwner = std::shared_ptr<::google::protobuf::Message>; - -template -class GeIrProtoHelper { - public: - GeIrProtoHelper(const ProtoMsgOwner &protoOwner, ProtoType *protoMsg) - : protoOwner_(protoOwner), protoMsg_(protoMsg) {} - - GeIrProtoHelper() { - protoOwner_ = std::shared_ptr<::google::protobuf::Message>(nullptr); - protoMsg_ = nullptr; - } - virtual ~GeIrProtoHelper() = default; - - template - GeIrProtoHelper(const GeIrProtoHelper &other) { - protoOwner_ = other.protoOwner_; - protoMsg_ = other.protoMsg_; - } - template - GeIrProtoHelper &operator=(const GeIrProtoHelper &other) { - protoOwner_ = other.protoOnwer_; - protoMsg_ = other.protoMsg_; - return *this; - } - void InitDefault(); - template - bool operator==(const GeIrProtoHelper &other) const { - return protoOwner_ == other.protoOwner_ && protoMsg_ == other.protoMsg_; - } - - inline const ProtoMsgOwner &GetProtoOwner() const { return protoOwner_; } - inline ProtoType *GetProtoMsg() const { return protoMsg_; } - void CopyValueFrom(const GeIrProtoHelper &other) { - if (other.protoMsg_ != nullptr && protoMsg_ != nullptr) { - *protoMsg_ = *other.protoMsg_; - } - } - void MoveValueFrom(GeIrProtoHelper &&other) { - if (other.protoMsg_ != nullptr && protoMsg_ != nullptr) { - *protoMsg_ = std::move(*other.protoMsg_); - } - } - - void Swap(GeIrProtoHelper &other) { - protoOwner_.swap(other.protoOwner_); - - ProtoType *temp = protoMsg_; - protoMsg_ = other.protoMsg_; - other.protoMsg_ = temp; - } - - // protoMsg_ is part of protoOwner_, they have the same runtime - ProtoMsgOwner protoOwner_ = nullptr; - ProtoType *protoMsg_ = nullptr; - friend class GeIrProtoHelper::value, typename std::remove_const::type, const ProtoType>::type>; -}; - -using ProtoAttrMapHelper = GeIrProtoHelper; -using ConstProtoAttrMapHelper = GeIrProtoHelper; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { - public: - AttrHolder() = default; - virtual ~AttrHolder() = default; - - graphStatus SetAttr(const string &name, const GeAttrValue &value); - - graphStatus GetAttr(const string &name, GeAttrValue &value) const; - - bool HasAttr(const string &name) const; - - graphStatus DelAttr(const string &name); - - void CopyAttrsFrom(const AttrHolder &holder); - - void Swap(AttrHolder &holder) { - requiredAttrs_.swap(holder.requiredAttrs_); - extAttrs_.Swap(holder.extAttrs_); - } - - template - bool SetExtAttr(const string &name, const T &value) { - return extAttrs_.Set(name, value); - } - template - T TryGetExtAttr(const string &name, T defaultValue) const { - T ret(defaultValue); - (void)extAttrs_.Get(name, ret); - return ret; - } - - protected: - graphStatus AddRequiredAttr(const std::string &name); - const std::unordered_set GetAllAttrNames() const; - const std::map GetAllAttrs() const; // lint !e1073 - - virtual ProtoAttrMapHelper MutableAttrMap() = 0; - virtual ConstProtoAttrMapHelper GetAttrMap() const = 0; - - friend class ModelSerializeImp; - friend class AttrUtils; - friend class AttrUtilsHelper; - - std::vector requiredAttrs_; - - private: - AnyMap extAttrs_; -}; -} // namespace ge -#endif // INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ diff --git a/inc/graph/detail/model_serialize_imp.h b/inc/graph/detail/model_serialize_imp.h deleted file mode 100644 index ff27335a..00000000 --- a/inc/graph/detail/model_serialize_imp.h +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ -#define INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ - -#include -#include -#include -#include -#include "graph/anchor.h" -#include "graph/detail/attributes_holder.h" -#include "graph/ge_tensor.h" -#include "graph/graph.h" -#include "graph/node.h" - -namespace ge { -using ComputeGraphPtr = std::shared_ptr; - -struct NodeNameGraphReq { - string node_name; - int32_t index; - ComputeGraphPtr graph; -}; - -struct NodeNameNodeReq { - string src_node_name; - int32_t src_out_index; - NodePtr dst_node; - int32_t dst_in_index; - string dst_node_name; -}; - -class ModelSerializeImp { - public: - bool SerializeModel(const Model &model, proto::ModelDef *modeProto, bool is_dump = false); - - bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto, bool is_dump = false); - - bool SerializeEdge(const NodePtr &node, proto::OpDef *opDefProto); - - bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto, bool is_dump = false); - - bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto, bool is_dump = false); - - bool SerializeTensor(const ConstGeTensorPtr &tensor, proto::TensorDef *tensorProto); - - bool UnserializeModel(Model &model, proto::ModelDef &modeProto); - - bool UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graphProto); - - bool UnserializeGraph(ComputeGraphPtr &graph, proto::GraphDef &graphProto); - - bool HandleNodeNameRef(); - - bool UnserializeOpDesc(OpDescPtr &opDesc, proto::OpDef &opDefProto); - void AttrDefToOpDesc(OpDescPtr &op_desc, std::vector &key_in, std::vector &key_out, - std::vector &value_in, std::vector &value_out, std::vector &opt); - void OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto); - - bool UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &opDefProto); - - bool UnserializeTensor(GeTensorPtr &tensor, proto::TensorDef &tensorProto); - - bool ParseNodeIndex(const string &node_index, string &nodeName, int32_t &index); - - void SetProtobufOwner(const ProtoMsgOwner &bufferProtobufOnwer) { protobuf_owner_ = bufferProtobufOnwer; } - - private: - bool RebuildOwnership(ComputeGraphPtr &compute_graph, std::map &subgraphs); - - std::vector graph_input_node_names_; - std::vector graph_output_node_names_; - std::vector node_input_node_names_; - std::map node_map_; - ProtoMsgOwner protobuf_owner_; -}; -} // namespace ge - -#endif // INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ diff --git a/inc/graph/ge_attr_value.h b/inc/graph/ge_attr_value.h deleted file mode 100644 index 0c265c20..00000000 --- a/inc/graph/ge_attr_value.h +++ /dev/null @@ -1,343 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_GE_ATTR_VALUE_H_ -#define INC_GRAPH_GE_ATTR_VALUE_H_ - -#include -#include -#include -#include -#include -#include -#include "graph/buffer.h" -#include "detail/attributes_holder.h" -#include "graph/ge_error_codes.h" -#include "graph/ge_tensor.h" - -using std::map; -using std::string; -using std::vector; - -namespace ge { -class GeTensor; - -using GeTensorPtr = std::shared_ptr; -using ConstGeTensorPtr = std::shared_ptr; - -class ComputeGraph; -using ComputeGraphPtr = std::shared_ptr; -using ConstComputeGraphPtr = std::shared_ptr; - -class GeTensorDesc; -class GeAttrValue; -class GeAttrValueImp; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NamedAttrs : public AttrHolder { - public: - NamedAttrs(); - virtual ~NamedAttrs() = default; - void SetName(const std::string &name); - string GetName() const; - GeAttrValue GetItem(const string &key) const; - - protected: - ProtoAttrMapHelper MutableAttrMap() override; - ConstProtoAttrMapHelper GetAttrMap() const override; - - private: - // Create namedAttrs from protobuf obj - NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg); - GeIrProtoHelper named_attrs_; - friend class GeAttrValueImp; - friend class GeAttrValue; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { - public: - using INT = int64_t; - using FLOAT = float; - using BOOL = bool; - using STR = std::string; - using TENSOR = GeTensorPtr; - using TENSOR_DESC = GeTensorDesc; - using GRAPH = ComputeGraphPtr; - using BYTES = Buffer; - using NAMED_ATTRS = ge::NamedAttrs; - using DATA_TYPE = ge::DataType; - - using LIST_INT = vector; - using LIST_FLOAT = vector; - using LIST_BOOL = vector; - using LIST_STR = vector; - using LIST_TENSOR = vector; - using LIST_TENSOR_DESC = vector; - using LIST_GRAPH = vector; - using LIST_BYTES = vector; - using LIST_NAMED_ATTRS = vector; - using LIST_LIST_INT = vector>; - using LIST_DATA_TYPE = vector; - - using NamedAttrs = ge::NamedAttrs; // for cce use (ge::GeAttrValue::NamedAttrs). - - enum ValueType { - VT_NONE = 0, - VT_STRING, - VT_FLOAT, - VT_BOOL, - VT_INT, - VT_TENSOR_DESC, - VT_TENSOR, - VT_BYTES, - VT_GRAPH, - VT_NAMED_ATTRS, - VT_LIST_LIST_INT, - VT_DATA_TYPE, - - VT_LIST_BASE = 1000, - VT_LIST_STRING = VT_LIST_BASE + VT_STRING, - VT_LIST_FLOAT = VT_LIST_BASE + VT_FLOAT, - VT_LIST_BOOL = VT_LIST_BASE + VT_BOOL, - VT_LIST_INT = VT_LIST_BASE + VT_INT, - VT_LIST_TENSOR_DESC = VT_LIST_BASE + VT_TENSOR_DESC, - VT_LIST_TENSOR = VT_LIST_BASE + VT_TENSOR, - VT_LIST_BYTES = VT_LIST_BASE + VT_BYTES, - VT_LIST_GRAPH = VT_LIST_BASE + VT_GRAPH, - VT_LIST_NAMED_ATTRS = VT_LIST_BASE + VT_NAMED_ATTRS, - VT_LIST_DATA_TYPE = VT_LIST_BASE + VT_DATA_TYPE, - }; - - template - struct IsAttrTypeEnable { - using DT = typename std::remove_cv::type; - - static bool const VALUE = std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value; - - // Not has list type of NamedAttrs - static bool const LIST_VALUE = std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || - std::is_same::value || std::is_same::value; - }; - - template - // To cols - using enable_if_vector_type_valid_t = typename std::enable_if::LIST_VALUE, int>::type; - - template - using enable_if_one_type_valid_t = typename std::enable_if::VALUE, int>::type; - - template - using enable_if_type_valid_t = - typename std::enable_if::VALUE || IsAttrTypeEnable::LIST_VALUE, int>::type; - - template - using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable; - - GeAttrValue(); - ~GeAttrValue() = default; - // SetValue, Set initializer_list - template = 0> - graphStatus SetValue(std::initializer_list
&&val) { - T vectorVal; - for (auto &item : val) { - vectorVal.push_back(item); - } - return SetValue(vectorVal); - } - - // SetValue, Set vector - template = 0> - graphStatus SetValue(const std::vector
&val) { - T vectorVal; - for (auto item : val) { - vectorVal.push_back(item); - } - return SetValue(vectorVal); - } - - // SetValue, not list type - template = 0> - graphStatus SetValue(DT &&val) { - return SetValue(T(std::forward
(val))); - } - - // GE_SERIALIZABLE - template = 0> - graphStatus SetValue(const T &t) { - return t.Save(*this); - } - - template = 0> - graphStatus SetValue(const vector &t) { - vector attrs; - for (auto &item : t) { - GeAttrValue val; - item.Save(val); - NamedAttrs attrsItem; - (void)val.GetValue(attrsItem); - attrs.push_back(attrsItem); - } - return SetValue(attrs); - } - - // GetValue, list value - template = 0, - typename std::enable_if::value, int>::type = 0> - graphStatus GetValue(std::vector
&val) const { - T valGet; - val.clear(); - auto status = GetValue(valGet); - if (status != GRAPH_SUCCESS) { - return status; - } - for (auto item : valGet) { - val.push_back(item); - } - return GRAPH_SUCCESS; - } - - // GetValue, not list type - template = 0, - typename std::enable_if::value, int>::type = 0> - graphStatus GetValue(DT &val) const { - T valGet; - auto status = GetValue(valGet); - if (status != GRAPH_SUCCESS) { - return status; - } - val = DT(valGet); - return GRAPH_SUCCESS; - } - - // GE_SERIALIZABLE - template = 0> - graphStatus GetValue(T &t) { - return t.Load(*this); - } - - template = 0> - graphStatus GetValue(vector &t) { - graphStatus status; - t.clear(); - vector attrs; - status = this->GetValue(attrs); - if (status != GRAPH_SUCCESS) { - return status; - } - for (auto &attr : attrs) { - T item; - GeAttrValue val; - (void)val.SetValue(attr); - status = item.Load(val); - if (status != GRAPH_SUCCESS) { - return status; - } - t.push_back(item); - } - return GRAPH_SUCCESS; - } - - template = 0> - static GeAttrValue CreateFrom(DT &&val) { - GeAttrValue valRet; - (void)valRet.SetValue(std::forward
(val)); - return valRet; - } - - template = 0> - static GeAttrValue CreateFrom(std::initializer_list
&&val) { - GeAttrValue valRet; - (void)valRet.SetValue(std::move(val)); - return valRet; - } - - template = 0> - static GeAttrValue CreateFrom(const T &val) { - GeAttrValue valRet; - (void)valRet.SetValue(val); - return valRet; - } - - template = 0> - static GeAttrValue CreateFrom(const vector &val) { - GeAttrValue valRet; - (void)valRet.SetValue(val); - return valRet; - } - - ValueType GetValueType() const; - - bool IsEmpty() const; - - GeAttrValue Copy() const; - - // For map key - bool operator==(const GeAttrValue &other) const { return value_ == other.value_; } - - graphStatus MutableTensor(GeTensorPtr &tensor); - graphStatus MutableListTensor(vector &list_tensor); - - private: -#define VALUE_SET_GET_DEC(DT) \ - graphStatus SetValue(const DT &val); \ - graphStatus GetValue(DT &val) const; - VALUE_SET_GET_DEC(GeAttrValue::STR) - VALUE_SET_GET_DEC(GeAttrValue::INT) - VALUE_SET_GET_DEC(GeAttrValue::FLOAT) - VALUE_SET_GET_DEC(GeAttrValue::BOOL) - VALUE_SET_GET_DEC(GeTensorDesc) - VALUE_SET_GET_DEC(GeAttrValue::TENSOR) - VALUE_SET_GET_DEC(GeAttrValue::GRAPH) - VALUE_SET_GET_DEC(BYTES) - VALUE_SET_GET_DEC(NamedAttrs) - VALUE_SET_GET_DEC(ge::DataType) // lint !e665 - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector>) // lint !e665 - VALUE_SET_GET_DEC(vector) // lint !e665 -#undef VALUE_SET_GET_DEC - - GeIrProtoHelper value_; - GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val); - - friend class AttrHolder; - friend class ModelSerializeImp; - friend class OnnxUtils; -}; - -class AttrValueImpl { - public: - AttrValueImpl() = default; - ~AttrValueImpl() = default; - - GeAttrValue geAttrValue_; -}; -} // namespace ge -#endif // INC_GRAPH_GE_ATTR_VALUE_H_ diff --git a/inc/graph/ge_context.h b/inc/graph/ge_context.h deleted file mode 100644 index 53985e9c..00000000 --- a/inc/graph/ge_context.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_GE_CONTEXT_H_ -#define INC_GRAPH_GE_CONTEXT_H_ - -#include -#include "graph/ge_error_codes.h" - -namespace ge { -class GEContext { - public: - graphStatus GetOption(const std::string &key, std::string &option); - bool GetHostExecFlag(); - uint64_t SessionId(); - uint32_t DeviceId(); - uint64_t TraceId(); - void Init(); - void SetSessionId(uint64_t session_id); - void SetCtxDeviceId(uint32_t device_id); - - private: - uint64_t session_id_ = 0; - uint32_t device_id_ = 0; - uint64_t trace_id_ = 0; -}; // class GEContext - -/// Get context -/// @return -GEContext &GetContext(); -} // namespace ge - -#endif // INC_GRAPH_GE_CONTEXT_H_ diff --git a/inc/graph/ge_global_options.h b/inc/graph/ge_global_options.h deleted file mode 100644 index b55192e2..00000000 --- a/inc/graph/ge_global_options.h +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_GE_GLOBAL_OPTIONS_H_ -#define INC_GRAPH_GE_GLOBAL_OPTIONS_H_ - -#include -#include - -namespace ge { -std::map &GetMutableGlobalOptions(); -} -#endif // INC_GRAPH_GE_GLOBAL_OPTIONS_H_ diff --git a/inc/graph/ge_local_context.h b/inc/graph/ge_local_context.h deleted file mode 100644 index b47098fb..00000000 --- a/inc/graph/ge_local_context.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_GE_LOCAL_CONTEXT_H_ -#define INC_GRAPH_GE_LOCAL_CONTEXT_H_ - -#include -#include -#include -#include "graph/ge_error_codes.h" - -using std::map; -using std::string; - -namespace ge { -class GEThreadLocalContext { - public: - graphStatus GetOption(const string &key, string &option); - void SetGraphOption(map options_map); - void SetSessionOption(map options_map); - void SetGlobalOption(map options_map); - - private: - map graph_options_; - map session_options_; - map global_options_; -}; // class GEThreadLocalContext - -GEThreadLocalContext &GetThreadLocalContext(); -} // namespace ge -#endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_ diff --git a/inc/graph/ge_tensor.h b/inc/graph/ge_tensor.h deleted file mode 100644 index 834dca0b..00000000 --- a/inc/graph/ge_tensor.h +++ /dev/null @@ -1,193 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_GE_TENSOR_H_ -#define INC_GRAPH_GE_TENSOR_H_ - -#include -#include -#include -#include -#include "detail/attributes_holder.h" -#include "graph/buffer.h" -#include "graph/ge_error_codes.h" -#include "graph/types.h" - -namespace ge { -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { - public: - GeShape(); - ~GeShape() = default; - explicit GeShape(std::vector s); - - size_t GetDimNum() const; - // If the idx is invalid, return 0 - int64_t GetDim(size_t idx) const; - graphStatus SetDim(size_t idx, int64_t value); - std::vector GetDims() const; - - int64_t GetShapeSize() const; - std::string ToString() const; - - /// - /// @brief Check is unknown shape - /// @return bool - /// - bool IsUnknownShape() const; - - /// - /// @brief Check is a scalar - /// @return bool - /// - bool IsScalar() const; - - GeShape(const GeShape &other); - GeShape(GeShape &&other); - GeShape &operator=(const GeShape &other); - GeShape &operator=(GeShape &&other); - - private: - GeIrProtoHelper shape_def_; - friend class GeTensorDesc; - // Create from proto obj - GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg); - - void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; } -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrHolder { - friend class TensorUtils; - friend class GeAttrValue; - friend class ModelSerialize; - - public: - GeTensorDesc(); - explicit GeTensorDesc(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); - GeTensorDesc(const GeTensorDesc &desc); - GeTensorDesc(GeTensorDesc &&desc); - - ~GeTensorDesc() = default; - bool operator==(const GeTensorDesc &r_ge_tensor_desc) const; - - void Update(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); - - GeShape GetShape() const; - GeShape &MutableShape(); - void SetShape(GeShape shape); - - // set shape with -2, it stand for unknown shape - void SetUnknownDimNumShape(); - // for unknown shape - graphStatus SetShapeRange(const std::vector> &range); - graphStatus GetShapeRange(std::vector> &range) const; - - GeShape GetOriginShape() const; - void SetOriginShape(const GeShape &originShape); - - Format GetFormat() const; - void SetFormat(Format format); - - Format GetOriginFormat() const; - void SetOriginFormat(Format originFormat); - - void SetName(const std::string &name); - const std::string GetName() const; - - DataType GetDataType() const; - void SetDataType(DataType dt); - - DataType GetOriginDataType() const; - void SetOriginDataType(DataType originDataType); - - std::vector GetRefPortIndex() const; - void SetRefPortByIndex(const std::vector &index); - - GeTensorDesc Clone() const; - GeTensorDesc &operator=(const GeTensorDesc &desc); - GeTensorDesc &operator=(GeTensorDesc &&desc); - - graphStatus IsValid() const; - - protected: - ProtoAttrMapHelper MutableAttrMap() override; - ConstProtoAttrMapHelper GetAttrMap() const override; - - private: - bool GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const; - using AttrHolder::DelAttr; - using AttrHolder::GetAllAttrs; - using AttrHolder::GetAttr; - using AttrHolder::HasAttr; - using AttrHolder::SetAttr; - - void Init(); - - // Create from proto obj - GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg); - friend class GeTensor; - friend class GeAttrValueImp; - friend class ModelSerializeImp; - friend class OnnxUtils; - - GeIrProtoHelper tensor_descriptor_; - // Reference from tensorDescriptor_, do not direct use - mutable GeShape __shape_; - - void RefTo(const GeTensorDesc &tensorDesc) { tensor_descriptor_ = tensorDesc.tensor_descriptor_; } - GeShape &ShapeReference() const; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { - public: - GeTensor(); - explicit GeTensor(const GeTensorDesc &tensorDesc); - explicit GeTensor(const GeTensorDesc &tensorDesc, const std::vector &data); - explicit GeTensor(const GeTensorDesc &tensorDesc, const Buffer &data); - explicit GeTensor(const GeTensorDesc &tensorDesc, const uint8_t *data, size_t size); - explicit GeTensor(GeTensorDesc &&tensorDesc, std::vector &&data); - ~GeTensor() = default; - - GeTensorDesc GetTensorDesc() const; - GeTensorDesc &MutableTensorDesc(); - void SetTensorDesc(const GeTensorDesc &tensorDesc); - - const Buffer GetData() const; - Buffer MutableData(); - graphStatus SetData(std::vector &&data); - graphStatus SetData(const std::vector &data); - graphStatus SetData(const Buffer &data); - graphStatus SetData(const uint8_t *data, size_t size); - - GeTensor Clone() const; - - // Share value - GeTensor(const GeTensor &other); - // Share value - GeTensor &operator=(const GeTensor &other); - - private: - friend class GeAttrValueImp; - friend class ModelSerializeImp; - friend class OnnxUtils; - // Create from proto obj - GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg); - GeIrProtoHelper tensor_def_; - // Reference from tensorDef_, do not direct use - mutable GeTensorDesc __desc_; - GeTensorDesc &DescReference() const; -}; -} // namespace ge -#endif // INC_GRAPH_GE_TENSOR_H_ diff --git a/inc/graph/graph_util.h b/inc/graph/graph_util.h deleted file mode 100644 index c39ecbc1..00000000 --- a/inc/graph/graph_util.h +++ /dev/null @@ -1,134 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_GRAPH_UTIL_H_ -#define INC_GRAPH_GRAPH_UTIL_H_ - -#include - -#include "proto/om.pb.h" - -namespace ge { -using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; -bool HasOpAttr(const OpDef *opdef, std::string attr_name); -bool GetOpAttr(const std::string &key, int32_t *value, const OpDef *opdef); - -static const char OP_TYPE_DATA[] = "Data"; -static const char OP_TYPE_INPUT[] = "Input"; -static const char ATTR_KEY_INPUT_FORMAT[] = "input_format"; -static const char ATTR_KEY_OUTPUT_FORMAT[] = "output_format"; -static const char OP_TYPE_ANN_DATA[] = "AnnData"; -} // namespace ge - -#if !defined(__ANDROID__) && !defined(ANDROID) -#include "toolchain/slog.h" -const char levelStr[4][8] = {"ERROR", "WARN", "INFO", "DEBUG"}; -#else -#include -#include -const char levelStr[8][8] = {"EMERG", "ALERT", "CRIT", "ERROR", "WARNING", "NOTICE", "INFO", "DEBUG"}; -#endif - -#ifdef _MSC_VER -#define FUNC_NAME __FUNCTION__ -#else -#define FUNC_NAME __PRETTY_FUNCTION__ -#endif - -#if !defined(__ANDROID__) && !defined(ANDROID) -#define D_GRAPH_LOGI(MOD_NAME, fmt, ...) \ - dlog_info(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) -#define D_GRAPH_LOGW(MOD_NAME, fmt, ...) \ - dlog_warn(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) -#define D_GRAPH_LOGE(MOD_NAME, fmt, ...) \ - dlog_error(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) -#else -#define D_GRAPH_LOG(level, format, ...) \ - do { \ - { \ - fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \ - __FILE__, __LINE__, ##__VA_ARGS__); \ - syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \ - ##__VA_ARGS__); \ - } \ - } while (0) -#define D_GRAPH_LOGI(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#define D_GRAPH_LOGW(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#define D_GRAPH_LOGE(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#endif - -#if !defined(__ANDROID__) && !defined(ANDROID) -#define GRAPH_LOGI(...) D_GRAPH_LOGI(GRAPH_MOD_NAME, __VA_ARGS__) -#define GRAPH_LOGW(...) D_GRAPH_LOGW(GRAPH_MOD_NAME, __VA_ARGS__) -#define GRAPH_LOGE(...) D_GRAPH_LOGE(GRAPH_MOD_NAME, __VA_ARGS__) -#else - -#define GRAPH_LOG(level, format, ...) \ - do { \ - { \ - fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \ - __FILE__, __LINE__, ##__VA_ARGS__); \ - syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \ - ##__VA_ARGS__); \ - } \ - } while (0) -#define GRAPH_LOGI(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#define GRAPH_LOGW(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#define GRAPH_LOGE(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#endif - -#define GRAPH_CHK_STATUS_RET_NOLOG(expr) \ - do { \ - const domi::graphStatus _status = (expr); \ - if (_status != domi::GRAPH_SUCCESS) { \ - return _status; \ - } \ - } while (0) - -#define GRAPH_CHK_BOOL_RET_STATUS(expr, _status, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - GRAPH_LOGE(__VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -#define GRAPH_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ - { \ - bool b = (expr); \ - if (!b) { \ - exec_expr; \ - } \ - }; - -#define GRAPH_IF_BOOL_EXEC(expr, exec_expr) \ - { \ - if (expr) { \ - exec_expr; \ - } \ - } - -#define GRAPH_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ - do { \ - const ::domi::graphStatus _status = (expr); \ - if (_status) { \ - GRAPH_LOGE(__VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -#endif // INC_GRAPH_GRAPH_UTIL_H_ diff --git a/inc/graph/model.h b/inc/graph/model.h deleted file mode 100644 index 38ea501b..00000000 --- a/inc/graph/model.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_MODEL_H_ -#define INC_GRAPH_MODEL_H_ - -#include -#include -#include -#include -#include "detail/attributes_holder.h" -#include "graph/ge_attr_value.h" -#include "graph/graph.h" - -namespace ge { -using std::map; -using std::string; -using std::vector; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { - public: - Model(); - - ~Model() = default; - - Model(const string &name, const string &custom_version); - - string GetName() const; - void SetName(const string &name); - - uint32_t GetVersion() const; - - void SetVersion(uint32_t version) { version_ = version; } - - std::string GetPlatformVersion() const; - - void SetPlatformVersion(string version) { platform_version_ = version; } - - Graph GetGraph() const; - - void SetGraph(const Graph &graph); - - void SetAttr(const ProtoAttrMapHelper &attrs); - - using AttrHolder::GetAllAttrNames; - using AttrHolder::GetAllAttrs; - using AttrHolder::GetAttr; - using AttrHolder::HasAttr; - using AttrHolder::SetAttr; - - graphStatus Save(Buffer &buffer, bool is_dump = false) const; - - graphStatus SaveToFile(const string &file_name) const; - // Model will be rewrite - static graphStatus Load(const uint8_t *data, size_t len, Model &model); - graphStatus Load(ge::proto::ModelDef &model_def); - graphStatus LoadFromFile(const string &file_name); - - bool IsValid() const; - - protected: - ConstProtoAttrMapHelper GetAttrMap() const override; - ProtoAttrMapHelper MutableAttrMap() override; - - private: - void Init(); - ProtoAttrMapHelper attrs_; - friend class ModelSerializeImp; - friend class GraphDebugImp; - friend class OnnxUtils; - friend class ModelHelper; - friend class ModelBuilder; - string name_; - uint32_t version_; - std::string platform_version_{""}; - Graph graph_; -}; -} // namespace ge -using ModelPtr = std::shared_ptr; - -#endif // INC_GRAPH_MODEL_H_ diff --git a/inc/graph/model_serialize.h b/inc/graph/model_serialize.h deleted file mode 100644 index 16529512..00000000 --- a/inc/graph/model_serialize.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_MODEL_SERIALIZE_H_ -#define INC_GRAPH_MODEL_SERIALIZE_H_ - -#include -#include -#include "graph/buffer.h" -#include "graph/compute_graph.h" -#include "graph/model.h" - -namespace ge { -class ModelSerialize { - public: - Buffer SerializeModel(const Model &model, bool is_dump = false); - - Model UnserializeModel(const uint8_t *data, size_t len); - Model UnserializeModel(ge::proto::ModelDef &model_def); - - Buffer SerializeGraph(const ComputeGraphPtr &graph); - - ComputeGraphPtr UnserializeGraph(const uint8_t *data, size_t len); - - Buffer SerializeOpDesc(const ConstOpDescPtr &opDesc); - OpDescPtr UnserializeOpDesc(const uint8_t *data, size_t len); - - size_t GetSerializeModelSize(const Model &model); - - private: - static std::map &MutableTensorDescAttrMap(GeTensorDesc &tensorDesc); - - static const std::map &GetTensorDescAttrMap(const GeTensorDesc &tensorDesc); - - friend class ModelSerializeImp; - friend class GraphDebugImp; -}; -} // namespace ge -#endif // INC_GRAPH_MODEL_SERIALIZE_H_ diff --git a/inc/graph/node.h b/inc/graph/node.h deleted file mode 100644 index f4a1c6a8..00000000 --- a/inc/graph/node.h +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_NODE_H_ -#define INC_GRAPH_NODE_H_ - -#include -#include -#include -#include -#include -#include -#include "graph/ge_attr_value.h" -#include "utils/attr_utils.h" - -#include "graph/op_desc.h" -#include "graph/range_vistor.h" - -namespace ge { -class ComputeGraph; - -using ComputeGraphPtr = std::shared_ptr; - -class Node; - -using NodePtr = std::shared_ptr; -using ConstNodePtr = std::shared_ptr; -using NodeRef = std::weak_ptr; - -class Anchor; - -using AnchorPtr = std::shared_ptr; - -class InDataAnchor; - -using InDataAnchorPtr = std::shared_ptr; - -class OutDataAnchor; - -using OutDataAnchorPtr = std::shared_ptr; - -class ControlAnchor; - -using ControlAnchorPtr = std::shared_ptr; - -class InControlAnchor; - -using InControlAnchorPtr = std::shared_ptr; - -class OutControlAnchor; - -using OutControlAnchorPtr = std::shared_ptr; - -using OpDescPtr = std::shared_ptr; - -using ConstNode = const Node; - -typedef std::vector> kFusionDataFlowVec_t; - -// Node is a component of ComputeGraph -class Node : public std::enable_shared_from_this { - friend class ComputeGraph; - friend class ModelSerializeImp; - - public: - template - using Vistor = RangeVistor>; - ~Node(); - Node(const Node &) = delete; - Node &operator=(const Node &) = delete; - bool operator==(const Node &r_node) const; - - protected: - Node() = default; - Node(const OpDescPtr &op, const ComputeGraphPtr &ownerGraph); - - public: - graphStatus Init(); - - std::string GetName() const; - std::string GetType() const; - - ComputeGraphPtr GetOwnerComputeGraph() const; - graphStatus SetOwnerComputeGraph(const ComputeGraphPtr &graph); - - Vistor GetAllInDataAnchors() const; - Vistor GetAllOutDataAnchors() const; - uint32_t GetAllInDataAnchorsSize() const; - uint32_t GetAllOutDataAnchorsSize() const; - Vistor GetAllOutAnchors() const; - Vistor GetAllInAnchors() const; - InDataAnchorPtr GetInDataAnchor(int idx) const; - OutDataAnchorPtr GetOutDataAnchor(int idx) const; - InControlAnchorPtr GetInControlAnchor() const; - OutControlAnchorPtr GetOutControlAnchor() const; - Vistor GetInNodes() const; - Vistor GetOutNodes() const; - AnchorPtr GetInAnchor(int idx) const; - AnchorPtr GetOutAnchor(int idx) const; - - bool IsAllInNodesSeen(std::unordered_set &nodes_seen) const; - - // All in Data nodes - Vistor GetInDataNodes() const; - // All in Control nodes - Vistor GetInControlNodes() const; - // GetInAllNodes = InDataNodes + InControlNodes - Vistor GetInAllNodes() const; - - // All out Data nodes - Vistor GetOutDataNodes() const; - uint32_t GetOutDataNodesSize() const; - // All out Control nodes - Vistor GetOutControlNodes() const; - // GetOutAllNodes = OutDataNodes + InControlNodes - Vistor GetOutAllNodes() const; - - // Get all in data nodes and its out-anchor - Vistor> GetInDataNodesAndAnchors() const; - - // Get all out data nodes and its in-anchor - Vistor> GetOutDataNodesAndAnchors() const; - - graphStatus InferShapeAndType() const; - graphStatus Verify() const; - - graphStatus InferOriginFormat() const; - - OpDescPtr GetOpDesc() const; - - graphStatus UpdateOpDesc(const OpDescPtr &op); - - graphStatus AddLinkFrom(const NodePtr &input_node); - - graphStatus AddLinkFrom(const uint32_t &index, NodePtr input_node); - - graphStatus AddLinkFrom(const string &name, NodePtr input_node); - - graphStatus AddLinkFromForParse(const NodePtr &input_node); - - void AddSendEventId(uint32_t event_id) { send_event_id_list_.push_back(event_id); } - - void AddRecvEventId(uint32_t event_id) { recv_event_id_list_.push_back(event_id); } - - const std::vector &GetSendEventIdList() const { return send_event_id_list_; } - - const std::vector &GetRecvEventIdList() const { return recv_event_id_list_; } - void GetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) { - fusion_input_list = fusion_input_dataflow_list_; - } - - void GetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) { - fusion_output_list = fusion_output_dataflow_list_; - } - - void SetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) { - fusion_input_dataflow_list_ = fusion_input_list; - } - - void SetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) { - fusion_output_dataflow_list_ = fusion_output_list; - } - - bool GetHostNode() const { return host_node_; } - void SetHostNode(bool is_host) { host_node_ = is_host; } - - void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } - - NodePtr GetOrigNode() { return orig_node_; } - - private: - bool NodeMembersAreEqual(const Node &r_node) const; - bool NodeAttrsAreEqual(const Node &r_node) const; - bool NodeInConnectsAreEqual(const Node &r_node) const; - bool NodeOutConnectsAreEqual(const Node &r_node) const; - bool NodeAnchorIsEqual(const AnchorPtr &l_anchor, const AnchorPtr &r_anchor, size_t i) const; - OpDescPtr op_; - std::weak_ptr owner_graph_; - vector in_data_anchors_; - vector out_data_anchors_; - InControlAnchorPtr in_control_anchor_; - OutControlAnchorPtr out_control_anchor_; - map attrs_; // lint !e1073 - bool has_init_{false}; - bool host_node_{false}; - bool anchor_status_updated_{false}; - std::vector send_event_id_list_; - std::vector recv_event_id_list_; - - kFusionDataFlowVec_t fusion_input_dataflow_list_; - kFusionDataFlowVec_t fusion_output_dataflow_list_; - - NodePtr orig_node_; - friend class NodeUtils; - friend class OnnxUtils; - friend class TuningUtils; -}; -} // namespace ge - -#endif // INC_GRAPH_NODE_H_ diff --git a/inc/graph/op_desc.h b/inc/graph/op_desc.h deleted file mode 100644 index 4d724c42..00000000 --- a/inc/graph/op_desc.h +++ /dev/null @@ -1,329 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_OP_DESC_H_ -#define INC_GRAPH_OP_DESC_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "detail/attributes_holder.h" -#include "graph/range_vistor.h" - -#define DYNAMIN_INPUT_NAME(name, index) (((name)) + std::to_string((index))) -#define DYNAMIN_OUTPUT_NAME(name, index) (((name)) + std::to_string((index))) -namespace ge { -using std::map; -using std::pair; -using std::shared_ptr; -using std::string; -using std::vector; - -class Operator; -class GeTensorDesc; - -using GeTensorDescPtr = shared_ptr; -using ConstGeTensorDescPtr = shared_ptr; - -class OpDesc; - -using OpDescPtr = shared_ptr; -using ConstOpDescPtr = shared_ptr; - -class GeAttrValue; - -using ConstOpDesc = const OpDesc; - -enum SubgraphType { kStatic, kDynamic, kSubgraphTypeEnd }; - -class OpDesc : public std::enable_shared_from_this, public AttrHolder { - public: - template - using Vistor = RangeVistor>; - - friend class GraphBuilderImpl; - - friend class OperatorImpl; - - OpDesc(const string &name, const string &type); - - OpDesc(); - - ~OpDesc(); - - bool operator==(const OpDesc &r_op_desc) const; - - string GetName() const; - - void SetName(const string &name); - - string GetType() const; - - void SetType(const string &type); - - graphStatus AddInputDesc(const GeTensorDesc &input_desc); - - graphStatus AddInputDesc(const string &name, const GeTensorDesc &input_desc); - - graphStatus AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_desc); - - graphStatus AddInputDescForward(const string &name, const unsigned int num); - - graphStatus AddInputDescMiddle(const string &name, const unsigned int num, size_t index); - - graphStatus AddOutputDescMiddle(const string &name, const unsigned int num, size_t index); - - graphStatus AddOutputDescForward(const string &name, const unsigned int num); - - graphStatus AddOptionalInputDesc(const string &name, const GeTensorDesc &input_desc); - - graphStatus UpdateInputDesc(uint32_t index, const GeTensorDesc &tensor_desc); - - graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc); - - bool InputIsSet(const string &name) const; - - GeTensorDesc GetInputDesc(uint32_t index) const; - - GeTensorDesc GetInputDesc(const string &name) const; - - Vistor GetAllInputNames() const; - - GeTensorDescPtr MutableInputDesc(uint32_t index) const; - - GeTensorDescPtr MutableInputDesc(const string &name) const; - - Vistor GetAllInputsDesc() const; - - Vistor GetAllInputsDescPtr() const; - - size_t GetInputsSize() const; - - size_t GetAllInputsSize() const; - - graphStatus AddOutputDesc(const GeTensorDesc &output_desc); - - graphStatus AddOutputDesc(const string &name, const GeTensorDesc &output_desc); - - graphStatus UpdateOutputDesc(uint32_t index, const GeTensorDesc &tensor_desc); - - graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc); - - GeTensorDesc GetOutputDesc(uint32_t index) const; - - GeTensorDesc GetOutputDesc(const string &name) const; - - GeTensorDescPtr MutableOutputDesc(uint32_t index) const; - - GeTensorDescPtr MutableOutputDesc(const string &name) const; - - uint32_t GetAllOutputsDescSize() const; - - Vistor GetAllOutputsDesc() const; - - Vistor GetAllOutputsDescPtr() const; - - size_t GetOutputsSize() const; - - ConstGeTensorDescPtr GetOutputDescPtr(uint32_t index) const; - - ConstGeTensorDescPtr GetInputDescPtr(uint32_t index) const; - - ConstGeTensorDescPtr GetInputDescPtrDfault(uint32_t index) const; - - ConstGeTensorDescPtr GetInputDescPtr(const string &name) const; - - graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); - - graphStatus AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index); - - graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); - - bool IsOptionalInput(const string &name) const; - - bool IsOptionalInput(uint32_t index) const; - - std::map GetAllInputName() const; - - std::map GetAllOutputName(); - - bool UpdateInputName(std::map inputNameIdx); - - bool UpdateOutputName(std::map outputNameIdx); - - void AddInferFunc(const std::function &func); - - std::function GetInferFunc() const; - - graphStatus InferShapeAndType(); - - void AddInferFormatFunc(const std::function &func); - - std::function GetInferFormatFunc() const; - - graphStatus DefaultInferFormat(); - - std::function GetVerifyFunc() const; - - void AddVerifierFunc(const std::function &func); - - graphStatus CallInferFormatFunc(Operator &op); - - graphStatus OpVerify(); - - graphStatus CommonVerify() const; - - graphStatus AddRegisterInputName(const string &name); - - graphStatus AddRegisterOutputName(const string &name); - - vector GetRegisterInputName() const; - - vector GetRegisterOutputName() const; - - using AttrHolder::AddRequiredAttr; - using AttrHolder::DelAttr; - using AttrHolder::GetAllAttrNames; - using AttrHolder::GetAllAttrs; - using AttrHolder::GetAttr; - using AttrHolder::HasAttr; - using AttrHolder::SetAttr; - - void SetId(int64_t id); - int64_t GetId() const; - void SetStreamId(int64_t stream_id); - int64_t GetStreamId() const; - void SetInputName(const vector &input_name); - vector GetInputName() const; - void SetSrcName(const vector &src_name); - vector GetSrcName() const; - void SetSrcIndex(const vector &src_index); - vector GetSrcIndex() const; - void SetInputOffset(const vector &input); - vector GetInputOffset() const; - void SetOutputOffset(const vector &input); - vector GetOutputOffset() const; - void SetDstName(const vector &dst_name); - vector GetDstName() const; - void SetDstIndex(const vector &dst_index); - vector GetDstIndex() const; - void SetWorkspace(const vector &workspace); - vector GetWorkspace() const; - void SetWorkspaceBytes(const vector &workspace_bytes); - vector GetWorkspaceBytes() const; - void SetIsInputConst(const vector &is_input_const); - vector GetIsInputConst() const; - - void SetOpInferDepends(const vector &depend_names); - vector GetOpInferDepends() const; - - string GetInputNameByIndex(uint32_t index) const; - string GetValidInputNameByIndex(uint32_t index) const; - int GetValidInputIndexByName(const string &name) const; - int GetInputIndexByName(const string &name) const; - - string GetOutputNameByIndex(uint32_t index) const; - - int GetOutputIndexByName(const string &name) const; - - graphStatus RestoreInputNameIdx(const string &name, const int &index); - - graphStatus RestoreOutputNameIdx(const string &name, const int &index); - - graphStatus CallInferFunc(Operator &op); - - void SetOpKernelLibName(const std::string &name); - - std::string GetOpKernelLibName() const; - - void SetOpEngineName(const std::string &name); - - std::string GetOpEngineName() const; - - void RegisterSubgraphIrName(const std::string &name, SubgraphType type); - const std::map &GetSubgraphIrNames() const; - SubgraphType GetSubgraphTypeByIrName(const std::string &name) const; - - graphStatus AddSubgraphName(const std::string &name); - const std::map &GetSubgraphNameIndexes() const; - - std::string GetSubgraphInstanceName(uint32_t index) const; - const std::vector &GetSubgraphInstanceNames() const; - /// Does not provide functions `AddSubgraphInstance` or `AppendSubgraphInstance`, - /// because this kind of functions will only append a new subgraph instance name - /// at the tail of `subgraph_instance_names_` and ignore the synchronous change of `subgraph_names_to_index_`. - /// If we want to append a new subgraph instance name, the function `AddSubgraphName` should be called first. - /// \param index - /// \param name - /// \return - graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name); - void RemoveSubgraphInstanceName(const std::string &name); - - graphStatus GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const; - - protected: - ProtoAttrMapHelper MutableAttrMap() override; - ConstProtoAttrMapHelper GetAttrMap() const override; - - private: - OpDesc(const ProtoMsgOwner &proto_msg_owner, ge::proto::OpDef *op_def); - bool OpDescMembersAreEqual(const OpDesc &r_op_desc) const; - bool OpDescAttrsAreEqual(const OpDesc &r_op_desc) const; - bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const; - - GeIrProtoHelper op_def_; - std::vector subgraph_instance_names_; - - // subgraph names to index, for a `if` operator: - // then_branch: 0 - // else_branch: 1 - // or for a `case` node: - // branches0: 0 - // branches1: 1 - // branches2: 2 - std::map subgraph_names_to_index_; - - // subgraph ir names to type, for a `if` operator: - // then_branch: static - // else_branch: static - // or for a `case` op: - // branches: dynamic - std::map subgraph_ir_names_to_type_; - - vector inputs_desc_{}; - map input_name_idx_{}; - vector register_input_name_{}; - std::unordered_set optional_input_names_{}; - vector outputs_desc_{}; - map output_name_idx_{}; - vector register_output_name_{}; - std::function infer_func_ = nullptr; - std::function infer_format_func_ = nullptr; - std::function verifier_func_ = nullptr; - string op_kernel_lib_name_; - string engine_name_; - friend class OpDescUtils; - friend class ModelSerializeImp; - friend class AttrUtils; - friend class GeAttrValueImp; - friend class OnnxUtils; -}; -} // namespace ge -#endif // INC_GRAPH_OP_DESC_H_ diff --git a/inc/graph/op_kernel_bin.h b/inc/graph/op_kernel_bin.h deleted file mode 100644 index 3970460a..00000000 --- a/inc/graph/op_kernel_bin.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_OP_KERNEL_BIN_H_ -#define INC_GRAPH_OP_KERNEL_BIN_H_ - -#include -#include -#include -#include - -namespace ge { -class OpKernelBin { - public: - OpKernelBin(std::string name, std::vector &&data) : name_(std::move(name)), data_(std::move(data)) {} - - ~OpKernelBin() = default; - - const std::string &GetName() const { return name_; } - const uint8_t *GetBinData() const { return (const uint8_t *)data_.data(); } - size_t GetBinDataSize() const { return data_.size(); } - OpKernelBin(const OpKernelBin &) = delete; - const OpKernelBin &operator=(const OpKernelBin &) = delete; - - private: - std::string name_; - std::vector data_; -}; - -using OpKernelBinPtr = std::shared_ptr; -const char *const OP_EXTATTR_NAME_TBE_KERNEL = "tbeKernel"; -const char *const OP_EXTATTR_CUSTAICPU_KERNEL = "cust_aicpu_kernel"; -} // namespace ge - -#endif // INC_GRAPH_OP_KERNEL_BIN_H_ diff --git a/inc/graph/operator_factory_impl.h b/inc/graph/operator_factory_impl.h deleted file mode 100644 index ea343ebc..00000000 --- a/inc/graph/operator_factory_impl.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ -#define INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ - -#include -#include -#include -#include -#include "graph/operator_factory.h" - -namespace ge { -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactoryImpl { - public: - static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type); - - static graphStatus GetOpsTypeList(std::vector &all_ops); - - static bool IsExistOp(const string &operator_type); - - static InferShapeFunc GetInferShapeFunc(const std::string &operator_type); - - static InferFormatFunc GetInferFormatFunc(const std::string &operator_type); - - static VerifyFunc GetVerifyFunc(const std::string &operator_type); - - static graphStatus RegisterOperatorCreator(const std::string &operator_type, OpCreator const &op_creator); - - static graphStatus RegisterInferShapeFunc(const std::string &operator_type, InferShapeFunc const infer_shape_func); - - static graphStatus RegisterInferFormatFunc(const std::string &operator_type, InferFormatFunc const infer_format_func); - - static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func); - - static shared_ptr> operator_creators_; - static shared_ptr> operator_infershape_funcs_; - static shared_ptr> operator_inferformat_funcs_; - static shared_ptr> operator_verify_funcs_; -}; -} // namespace ge - -#endif // INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ diff --git a/inc/graph/opsproto_manager.h b/inc/graph/opsproto_manager.h deleted file mode 100644 index 06846573..00000000 --- a/inc/graph/opsproto_manager.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_OPSPROTO_MANAGER_H_ -#define INC_GRAPH_OPSPROTO_MANAGER_H_ - -#include -#include -#include -#include -#include -#include -#include - -namespace ge { -class OpsProtoManager { - public: - static OpsProtoManager *Instance(); - - bool Initialize(const std::map &options); - void Finalize(); - - private: - void LoadOpsProtoPluginSo(std::string &path); - - std::string pluginPath_; - std::vector handles_; - bool is_init_ = false; - std::mutex mutex_; -}; -} // namespace ge - -#endif // INC_GRAPH_OPSPROTO_MANAGER_H_ diff --git a/inc/graph/range_vistor.h b/inc/graph/range_vistor.h deleted file mode 100644 index 8635d413..00000000 --- a/inc/graph/range_vistor.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_RANGE_VISTOR_H_ -#define INC_GRAPH_RANGE_VISTOR_H_ - -#include - -template -class RangeVistor { - public: - /*lint -e151*/ - using Iterator = typename std::vector::iterator; - using ConstIterator = typename std::vector::const_iterator; - /*lint +e151*/ - - RangeVistor(O owner, const std::vector &vs) : owner_(owner), elements_(vs) {} - - ~RangeVistor() {} - - Iterator begin() { return elements_.begin(); } - - Iterator end() { return elements_.end(); } - - ConstIterator begin() const { return elements_.begin(); } - - ConstIterator end() const { return elements_.end(); } - - std::size_t size() const { return elements_.size(); } - - bool empty() const { return elements_.empty(); } - - /*lint -e659*/ - E &at(std::size_t index) { return elements_.at(index); } - /*lint +e659*/ - - const E &at(std::size_t index) const { return elements_.at(index); } - - private: - O owner_; - std::vector elements_; -}; - -#endif // INC_GRAPH_RANGE_VISTOR_H_ diff --git a/inc/graph/ref_relation.h b/inc/graph/ref_relation.h deleted file mode 100644 index 71457916..00000000 --- a/inc/graph/ref_relation.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_REF_RELATION_H_ -#define COMMON_GRAPH_REF_RELATION_H_ - -#include -#include -#include -#include - -#include "graph/compute_graph.h" -#include "graph/types.h" -#include "graph/ge_error_codes.h" -#include "node.h" - -namespace ge { -enum InOutFlag { - NODE_IN = 0, // input flag - NODE_OUT = 1, // output flag -}; - -struct RefCell { - std::string node_name; - ge::NodePtr node = nullptr; - InOutFlag in_out = NODE_IN; - int in_out_idx = 0; - - bool operator==(const RefCell &c) const { - return node_name == c.node_name && node == c.node && in_out == c.in_out && in_out_idx == c.in_out_idx; - } - - RefCell() = default; - RefCell(std::string name, ge::NodePtr node_ptr, InOutFlag in_out_flag, int idx) { - node_name = name; - node = node_ptr; - in_out = in_out_flag; - in_out_idx = idx; - }; - ~RefCell() = default; -}; - -struct RefCellHash { - size_t operator()(const RefCell &c) const { - unsigned long number = reinterpret_cast(reinterpret_cast(c.node.get())); - string tmp = c.node_name + std::to_string(c.in_out) + std::to_string(c.in_out_idx) + std::to_string(number); - return std::hash()(tmp); - } -}; - -class RefRelations { - public: - graphStatus LookUpRefRelations(const RefCell &key, std::unordered_set &result); - graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); - graphStatus Clear(); - - RefRelations(); - ~RefRelations() = default; - - public: - class Impl; - std::shared_ptr impl_ = nullptr; -}; - -} // namespace ge -#endif // COMMON_GRAPH_REF_RELATION_H_ diff --git a/inc/graph/runtime_inference_context.h b/inc/graph/runtime_inference_context.h deleted file mode 100644 index f0b38546..00000000 --- a/inc/graph/runtime_inference_context.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ -#define INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ - -#include -#include -#include -#include -#include "external/graph/ge_error_codes.h" -#include "external/graph/tensor.h" -#include "ge_attr_value.h" - -namespace ge { -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY RuntimeInferenceContext { - public: - static graphStatus GetContext(const std::string &context_id, RuntimeInferenceContext **ctx); - static graphStatus CreateContext(const std::string &context_id); - static void DestroyContext(const std::string &context_id); - - graphStatus SetTensor(int64_t node_id, int output_id, Tensor &&tensor); - graphStatus GetTensor(int64_t node_id, int output_id, GeTensorPtr &tensor); - graphStatus GetTensor(int64_t node_id, int output_id, Tensor &tensor); - - private: - std::map> tensors_; - std::map> ge_tensors_; - std::mutex mu_; - - static std::map> contexts_; - static std::mutex ctx_mu_; -}; -} // namespace ge - -#endif // INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ diff --git a/inc/graph/shape_refiner.h b/inc/graph/shape_refiner.h deleted file mode 100644 index 4f8783a3..00000000 --- a/inc/graph/shape_refiner.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_SHAPE_REFINER_H_ -#define INC_GRAPH_SHAPE_REFINER_H_ - -#include -#include "external/graph/inference_context.h" - -#include "external/graph/ge_error_codes.h" -#include "graph/node.h" - -namespace ge { -// ShapeRefiner performs shape inference for compute graphs -class ShapeRefiner { - public: - static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph); - static graphStatus InferShapeAndType(const NodePtr &node, bool before_subgraph); - static graphStatus InferShapeAndType(const NodePtr &node); - static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); - static void ClearContextMap(); - - private: - static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); -}; -} // namespace ge -#endif // INC_GRAPH_SHAPE_REFINER_H_ diff --git a/inc/graph/tuning_utils.h b/inc/graph/tuning_utils.h deleted file mode 100644 index 98262a23..00000000 --- a/inc/graph/tuning_utils.h +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MAIN_TUNING_UTILS_H -#define MAIN_TUNING_UTILS_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "framework/common/debug/ge_log.h" -#include "utils/attr_utils.h" -#include "utils/node_utils.h" -#include "external/ge/ge_api_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -namespace ge { -// Configure build mode, default value is "normal" -const char *const BUILD_MODE = "ge.buildMode"; -const char *const BUILD_STEP = "ge.buildStep"; -// Configure tuning path -const char *const TUNING_PATH = "ge.tuningPath"; -// for interface: aclgrphBuildModel -const std::set ir_builder_supported_options_for_lx_fusion = {BUILD_MODE, BUILD_STEP, TUNING_PATH}; - -// Build model -const char *const BUILD_MODE_NORMAL = "normal"; -const char *const BUILD_MODE_TUNING = "tuning"; -const char *const BUILD_MODE_BASELINE = "baseline"; -const std::set build_mode_options = {BUILD_MODE_NORMAL, BUILD_MODE_TUNING, BUILD_MODE_BASELINE}; - -// Build step -const char *const BUILD_STEP_BEFORE_UB_MATCH = "before_ub_match"; -const char *const BUILD_STEP_AFTER_UB_MATCH = "after_ub_match"; -const char *const BUILD_STEP_AFTER_BUILDER = "after_builder"; -const char *const BUILD_STEP_AFTER_BUILDER_SUB = "after_builder_sub"; -const char *const BUILD_STEP_AFTER_MERGE = "after_merge"; -const std::set build_step_options = {BUILD_STEP_BEFORE_UB_MATCH, BUILD_STEP_AFTER_UB_MATCH, - BUILD_STEP_AFTER_BUILDER, BUILD_STEP_AFTER_BUILDER_SUB, - BUILD_STEP_AFTER_MERGE}; - -using SubgraphCreateOutNode = std::unordered_map; -using NodetoNodeMap = std::unordered_map; -using NodeSet = std::set; -using NodeNametoNodeNameMap = std::unordered_map; -using NodetoNodeNameMap = std::unordered_map; -class TuningUtils { - public: - TuningUtils() = default; - ~TuningUtils() = default; - // Dump all the subgraphs and modify - // the subgraphs in them to be executable subgraphs if exe_flag is true - // `tuning_path` means path to save the graphs - static graphStatus ConvertGraphToFile(std::vector tuning_subgraphs, - std::vector non_tuning_subgraphs = {}, bool exe_flag = false, - const std::string &path = "", const std::string &user_path = ""); - // Recovery `graph` from graph dump files configured in options - static graphStatus ConvertFileToGraph(const map &options, ge::Graph &graph); - - private: - // part 1 - struct HelpInfo { - int64_t index; - bool exe_flag; - bool is_tuning_graph; - const std::string &path; - const std::string &user_path; - }; - static graphStatus MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info); - static graphStatus HandlePld(NodePtr &node); - static graphStatus HandleEnd(NodePtr &node); - static graphStatus ChangePld2Data(NodePtr &node, NodePtr &data_node); - static graphStatus ChangeEnd2NetOutput(NodePtr &node, NodePtr &out_node); - static graphStatus LinkEnd2NetOutput(NodePtr &node, NodePtr &out_node); - static graphStatus CreateDataNode(NodePtr &node, NodePtr &data_node); - static graphStatus CreateNetOutput(NodePtr &node, NodePtr &out_node); - static graphStatus AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node); - static graphStatus AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node); - static void DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path); - - static SubgraphCreateOutNode create_output_; - // part 2 - static graphStatus MergeAllSubGraph(std::vector &graphs, ComputeGraphPtr &graph); - static graphStatus MergeSubGraph(ComputeGraphPtr &graph); - // Deletes new data and output nodes added by call `MakeExeGraph()` func in part 1 - static graphStatus RemoveDataNetoutputEdge(ComputeGraphPtr &graph); - static graphStatus GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor, - AnchorPtr &src_out_anchor); - static NodeNametoNodeNameMap data_2_netoutput_; - static NodetoNodeNameMap data_node_2_netoutput_; - static NodetoNodeMap data_node_2_netoutput_node_; - static NodeSet netoutput_nodes_; - static NodeSet merged_graph_nodes_; - static std::mutex mutex_; - // for debug - static std::string PrintCheckLog(); - static std::string GetNodeNameByAnchor(const Anchor *anchor); -}; -} // namespace ge -#endif // MAIN_TUNING_UTILS_H diff --git a/inc/graph/usr_types.h b/inc/graph/usr_types.h deleted file mode 100644 index 90e02001..00000000 --- a/inc/graph/usr_types.h +++ /dev/null @@ -1,133 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_USR_TYPES_H_ -#define INC_GRAPH_USR_TYPES_H_ - -#include -#include -#include -namespace ge { -#define USR_TYPE_DEC(type, name) \ - inline void set_##name(const type &value) { name = value; } \ - type *mutable_##name() { return &name; } - -#define USR_TYPE_HAS_DEC(type, name) \ - inline void set_##name(const type &value) { name = value; } \ - \ - private: \ - bool has_mutable_##name{false}; \ - \ - public: \ - bool has_##name() const { return (has_mutable_##name) || QuantizeFactorHasData(name); } \ - type *mutable_##name() { \ - has_mutable_##name = true; \ - return &name; \ - } - -#define USR_TYPE_BYTES_DEC(name) \ - inline void clear_##name() { name.clear(); } \ - inline void set_##name(const void *value, size_t size) { \ - name.assign(reinterpret_cast(const_cast(value)), \ - reinterpret_cast(const_cast(value)) + size); \ - } - -enum UsrQuantizeScaleType { USR_VECTOR_SCALE = 0, USR_SCALAR_SCALE = 1 }; -enum UsrQuantizeScaleMode { USR_NORMAL_MODE = 0, USR_SQRT_MODE = 1 }; -enum UsrQuantizeAlgorithm { - USR_NON_OFFSET_ALGO = 0, - USR_HALF_OFFSET_ALGO = 1, - USR_ALL_OFFSET_ALGO = 2, -}; - -struct UsrQuantizeFactor { - public: - // QuantizeScaleMode scale_mode; - UsrQuantizeScaleMode scale_mode{USR_NORMAL_MODE}; - std::vector scale_value; - int64_t scale_offset{0}; - std::vector offset_data_value; - int64_t offset_data_offset{0}; - std::vector offset_weight_value; - int64_t offset_weight_offset{0}; - std::vector offset_pad_value; - int64_t offset_pad_offset{0}; - - USR_TYPE_DEC(UsrQuantizeScaleMode, scale_mode); - USR_TYPE_BYTES_DEC(scale_value); - - USR_TYPE_DEC(int64_t, scale_offset); - USR_TYPE_BYTES_DEC(offset_data_value); - USR_TYPE_DEC(int64_t, offset_data_offset); - - USR_TYPE_BYTES_DEC(offset_weight_value); - USR_TYPE_DEC(int64_t, offset_weight_offset); - USR_TYPE_BYTES_DEC(offset_pad_value); - USR_TYPE_DEC(int64_t, offset_pad_offset); -}; - -static inline bool QuantizeFactorHasData(const UsrQuantizeFactor &factor) { - return factor.scale_value.size() > 0 || factor.offset_data_value.size() > 0 || - factor.offset_weight_value.size() > 0 || factor.offset_pad_value.size() > 0; -} - -struct UsrQuantizeCalcFactor { - public: - std::vector offsetw; - int64_t offsetw_offset{0}; - std::vector offsetd; - int64_t offsetd_offset{0}; - std::vector scalereq; - int64_t scaledreq_offset{0}; - std::vector offsetdnext; - int64_t offsetdnext_offset{0}; - - USR_TYPE_BYTES_DEC(offsetw); - USR_TYPE_DEC(int64_t, offsetw_offset); - USR_TYPE_BYTES_DEC(offsetd); - USR_TYPE_DEC(int64_t, offsetd_offset); - USR_TYPE_BYTES_DEC(scalereq); - USR_TYPE_DEC(int64_t, scaledreq_offset); - USR_TYPE_BYTES_DEC(offsetdnext); - USR_TYPE_DEC(int64_t, offsetdnext_offset); -}; - -static inline bool QuantizeFactorHasData(const UsrQuantizeCalcFactor &factor) { - return factor.offsetw.size() > 0 || factor.offsetd.size() > 0 || factor.scalereq.size() > 0 || - factor.offsetdnext.size() > 0; -} - -struct UsrQuantizeFactorParams { - UsrQuantizeAlgorithm quantize_algo{USR_NON_OFFSET_ALGO}; - UsrQuantizeScaleType scale_type{USR_VECTOR_SCALE}; - UsrQuantizeFactor quantize_param; - UsrQuantizeFactor dequantize_param; - UsrQuantizeFactor requantize_param; - UsrQuantizeCalcFactor quantizecalc_param; - USR_TYPE_DEC(UsrQuantizeAlgorithm, quantize_algo); - USR_TYPE_DEC(UsrQuantizeScaleType, scale_type); - USR_TYPE_HAS_DEC(UsrQuantizeFactor, quantize_param); - USR_TYPE_HAS_DEC(UsrQuantizeFactor, dequantize_param); - USR_TYPE_HAS_DEC(UsrQuantizeFactor, requantize_param); - USR_TYPE_HAS_DEC(UsrQuantizeCalcFactor, quantizecalc_param); -}; - -#undef USR_TYPE_DEC -#undef USR_TYPE_HAS_DEC -#undef USR_TYPE_BYTES_DEC -} // namespace ge - -#endif // INC_GRAPH_USR_TYPES_H_ diff --git a/inc/graph/utils/anchor_utils.h b/inc/graph/utils/anchor_utils.h deleted file mode 100644 index 35b3b035..00000000 --- a/inc/graph/utils/anchor_utils.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_ANCHOR_UTILS_H_ -#define INC_GRAPH_UTILS_ANCHOR_UTILS_H_ - -#include "graph/anchor.h" -#include "graph/node.h" - -namespace ge { -class AnchorUtils { - public: - // Get anchor format - static Format GetFormat(const DataAnchorPtr &dataAnchor); - - // Set anchor format - static graphStatus SetFormat(const DataAnchorPtr &dataAnchor, Format dataFormat); - - // Get anchor status - static AnchorStatus GetStatus(const DataAnchorPtr &dataAnchor); - - // Set anchor status - static graphStatus SetStatus(const DataAnchorPtr &dataAnchor, AnchorStatus anchorStatus); - - static bool HasControlEdge(const AnchorPtr &anchor); - - static bool IsControlEdge(const AnchorPtr &src, const AnchorPtr &dst); - - static int GetIdx(const AnchorPtr &anchor); -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_ANCHOR_UTILS_H_ diff --git a/inc/graph/utils/attr_utils.h b/inc/graph/utils/attr_utils.h deleted file mode 100644 index 15a815d4..00000000 --- a/inc/graph/utils/attr_utils.h +++ /dev/null @@ -1,150 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_ATTR_UTILS_H_ -#define INC_GRAPH_UTILS_ATTR_UTILS_H_ - -#include -#include -#include -#include "graph/detail/attributes_holder.h" -#include "graph/ge_attr_value.h" -#include "graph/types.h" - -namespace ge { -class OpDesc; -using OpDescPtr = std::shared_ptr; -using ConstOpDescPtr = std::shared_ptr; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { - public: - class ConstAttrHolderAdapter; - class AttrHolderAdapter; - // Set - static bool HasAttr(ConstAttrHolderAdapter &&obj, const string &name); - - static bool SetInt(AttrHolderAdapter &&obj, const string &name, const int64_t &value); - static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list &&value); - - static bool SetFloat(AttrHolderAdapter &&obj, const string &name, const float &value); - static bool SetListFloat(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetBool(AttrHolderAdapter &&obj, const string &name, const bool &value); - static bool SetListBool(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetStr(AttrHolderAdapter &&obj, const string &name, const string &value); - static bool SetListStr(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetTensorDesc(AttrHolderAdapter &&obj, const string &name, const GeTensorDesc &value); - static bool SetListTensorDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensorPtr &value); - static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const ConstGeTensorPtr &value); - static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensor &value); - static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, - std::initializer_list &&value); - static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetGraph(AttrHolderAdapter &&obj, const string &name, const ComputeGraphPtr &value); - static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value); - static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NAMED_ATTRS &value); - static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name, - const vector &value); - static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); - - // Get - static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int64_t &value); - static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int32_t &value); - static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, uint32_t &value); - static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetFloat(ConstAttrHolderAdapter &&obj, const string &name, float &value); - static bool GetListFloat(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetBool(ConstAttrHolderAdapter &&obj, const string &name, bool &value); - static bool GetListBool(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetStr(ConstAttrHolderAdapter &&obj, const string &name, string &value); - static bool GetListStr(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, GeTensorDesc &value); - static bool GetListTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value); - static bool MutableTensor(AttrHolderAdapter &&obj, const string &name, GeTensorPtr &value); - static bool GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetGraph(ConstAttrHolderAdapter &&obj, const string &name, ComputeGraphPtr &value); - static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value); - static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NAMED_ATTRS &value); - static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, - vector &value); - static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - // Value will be moved - static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); - static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer); - // Value will be moved - static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector &listBuffer); - static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &listBuffer); - - static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector> &value); - static bool GetListListInt(ConstAttrHolderAdapter &&obj, const string &name, vector> &value); - - static bool SetListDataType(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool GetListDataType(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - - static bool SetDataType(AttrHolderAdapter &&obj, const string &name, const ge::DataType &value); - static bool GetDataType(ConstAttrHolderAdapter &&obj, const string &name, ge::DataType &value); - - static OpDescPtr CloneOpDesc(const ConstOpDescPtr &orgOpDesc); - - static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc); - - static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj); - - class AttrHolderAdapter { - public: - AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {} - ~AttrHolderAdapter() {} - template - AttrHolderAdapter(const std::shared_ptr &obj) : obj_(obj.get()) {} - AttrHolderAdapter(AttrHolder &obj) : obj_(&obj) {} - operator bool() const { return obj_ != nullptr; } - AttrHolder *operator->() { return obj_; } - AttrHolder *get() { return obj_; } - - AttrHolder *obj_; - }; - - class ConstAttrHolderAdapter { - public: - ConstAttrHolderAdapter(const AttrHolder *obj) : obj_(obj) {} - ~ConstAttrHolderAdapter() {} - template - ConstAttrHolderAdapter(const std::shared_ptr obj) : obj_(obj.get()) {} - ConstAttrHolderAdapter(const AttrHolder &obj) : obj_(&obj) {} - operator bool() const { return obj_ != nullptr; } - const AttrHolder *operator->() const { return obj_; } - const AttrHolder *get() const { return obj_; } - - private: - const AttrHolder *obj_; - }; -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_ATTR_UTILS_H_ diff --git a/inc/graph/utils/graph_utils.h b/inc/graph/utils/graph_utils.h deleted file mode 100644 index fdcbe1a9..00000000 --- a/inc/graph/utils/graph_utils.h +++ /dev/null @@ -1,771 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_GRAPH_UTILS_H_ -#define INC_GRAPH_UTILS_GRAPH_UTILS_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "graph/anchor.h" -#include "graph/node.h" -#include "graph/compute_graph.h" -#include "graph/utils/anchor_utils.h" -#include "graph/graph.h" -#include "graph/model.h" - -#define GE_DUMP(compute_graph, name) \ - do { \ - GraphUtils::DumpGEGraph(compute_graph, name); \ - GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); \ - uint64_t i = 0; \ - for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { \ - auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); \ - GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); \ - GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); \ - } \ - } while (0) - -#define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ - do { \ - DataType ret; \ - attr.GetValue(ret); \ - } while (0) - -#define PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \ - do { \ - if (value_type == VT_ENUM) { \ - REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ - stream << ret; \ - } \ - } while (0) - -#define PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \ - do { \ - if (value_type == VT_ENUM) { \ - REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ - stream << "["; \ - for (int i = 0; i < ret.size(); i++) { \ - stream << ret[i]; \ - if (i + 1 != ret.size()) stream << ", "; \ - } \ - stream << "]"; \ - } \ - } while (0) - -#define PRINT_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \ - else PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) - -#define PRINT_LIST_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \ - else PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) - -#define PRINT_SHAPE(i_o, n, idx, stream) \ - do { \ - auto op = n->GetOpDesc(); \ - GeTensorDesc td = i_o == "input" ? op->GetInputDesc(idx) : op->GetOutputDesc(idx); \ - auto shape = td.GetShape().GetDims(); \ - stream << "["; \ - for (int i = 0; i < shape.size(); i++) { \ - stream << shape[i]; \ - if (i + 1 < shape.size()) stream << ", "; \ - } \ - stream << "]"; \ - } while (0) - -#define PRINT_ATTR_FUNC(stream) \ - [&](GeAttrValue attr) { \ - auto type = attr.GetValueType(); \ - PRINT_ATTR_VALUE_IF(type, GeAttrValue::ValueType::VT_STRING, GeAttrValue::STR, attr, stream) \ - PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_FLOAT, GeAttrValue::FLOAT, attr, stream) \ - PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_BOOL, GeAttrValue::BOOL, attr, stream) \ - PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_INT, GeAttrValue::INT, attr, stream) \ - PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_STRING, GeAttrValue::LIST_STR, attr, stream) \ - PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_FLOAT, GeAttrValue::LIST_FLOAT, attr, stream) \ - PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_BOOL, GeAttrValue::LIST_BOOL, attr, stream) \ - PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_INT, GeAttrValue::LIST_INT, attr, stream) \ - else if (type == GeAttrValue::ValueType::VT_TENSOR_DESC) stream << "TENSOR_DESC"; \ - else if (type == GeAttrValue::ValueType::VT_TENSOR) stream << "TENSOR"; \ - else if (type == GeAttrValue::ValueType::VT_BYTES) stream << "BYTES"; \ - else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR_DESC) stream << "LIST_TENSOR_DESC"; \ - else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR) stream << "LIST_TENSOR"; \ - else if (type == GeAttrValue::ValueType::VT_LIST_BYTES) stream << "LIST_BYTES"; \ - }; - -namespace ge { -enum IOType { kIn, kOut }; - -struct NodeIndexIO { - NodeIndexIO(ge::NodePtr node, uint32_t index, IOType io_type) - : node_(std::move(node)), index_(index), io_type_(io_type) { - if (node_ != nullptr) { - value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_); - } - } - NodeIndexIO(ge::NodePtr node, int index, IOType io_type) - : node_(std::move(node)), index_(static_cast(index)), io_type_(io_type) { - if (node_ != nullptr) { - value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_); - } - } - ~NodeIndexIO() {} - - NodePtr node_ = nullptr; - uint32_t index_ = 0; - IOType io_type_ = kOut; - std::string value_; - - const std::string &ToString() const { return value_; } -}; - -class GraphUtils { - public: - static ComputeGraphPtr GetComputeGraph(const Graph &graph); - - static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); - - static graphStatus RecoverGraphOperators(const Graph &graph); - - static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector &inputs); - - static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); - - static graphStatus AddEdge(const OutDataAnchorPtr &src, const Format &src_format, const InDataAnchorPtr &dst, - const Format &dst_format); - - static graphStatus AddEdge(const AnchorPtr &src, const AnchorPtr &dst); - - static graphStatus AddEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst); - - static graphStatus AddEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst); - - // check whether src is link to dst and then remove - static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); - - static graphStatus RemoveEdge(const AnchorPtr &src, const AnchorPtr &dst); - - static graphStatus RemoveEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst); - - static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst); - - static graphStatus ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, - const InDataAnchorPtr &new_dst); - - static graphStatus ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst, - const InControlAnchorPtr &new_dst); - - static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, - const NodePtr &new_node); - - static graphStatus RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node); - - static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node); - - static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, - const std::vector &vec_op_desc); - - /// - /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst - /// @param [in] src - /// @param [in] dsts - /// @param [in] insert_node - /// @param [in] input_index - /// @param [in] output_index - /// @return graphStatus - /// - static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector &dsts, - const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); - - static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); - - static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node); - - static void RecordOriginalNames(std::vector original_nodes, const ge::NodePtr &node); - - static void RecordOriginalNames(std::vector names_tmp, const ge::NodePtr &node); - - static bool MatchDumpStr(const std::string &suffix); - - static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false, - const std::string &user_graph_name = ""); - - static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); - - static bool LoadGEGraph(const char *file, ge::ComputeGraphPtr &compute_graph); - - static void BreakConnect(const std::map &all_nodes_infos); - - static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); - - static bool LoadGEGraphFromOnnx(const char *file, ge::ComputeGraph &compute_graph); - - static bool ReadProtoFromTextFile(const char *file, google::protobuf::Message *message); - - static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *real_path); - - static graphStatus AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node); - - /// - /// Isolating `node`, relinking data links from the in-anchor peer nodes to - /// the out-anchor peer nodes according to `io_map`, relinking control links - /// to ensure that input nodes of `node` are before out nodes - /// - /// Link the `io_map[i]` input anchor peer node to `i` output anchor peer - /// nodes, then unlink all links connecting with `node`. If `io_map[i]` < 0, - /// unlink all links from `i` output anchor without any relinking. - /// - /// @param node - /// @param io_map - /// @return - /// - static graphStatus IsolateNode(const NodePtr &node, const std::initializer_list &io_map); - static graphStatus IsolateNode(const NodePtr &node, const std::vector &io_map); - - /// - /// Isolate `node` which must be one input one output, equivalent to - /// `IsolateNode(node, {0})` - /// @param node - /// @return - /// - static graphStatus IsolateNodeOneIO(const NodePtr &node); - - /// - /// The data anchors replacing behavior is the same with - /// `ReplaceNodeDataAnchors`. In addition, replace all `old_node` control - /// anchors with `new_node`'s. - /// @param new_node - /// @param old_node - /// @param inputs_map - /// @param outputs_map - /// @return - /// - static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, - std::initializer_list inputs_map, std::initializer_list outputs_map); - - static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, - const std::vector &inputs_map, const std::vector &outputs_map); - - /// - /// Replace `old_node` data anchors with `new_node`'s according to `inputs_map` and `outputs_map`. - /// Replace the `i` in/out data anchor on `old_node` with - /// `inputs_map[i]`/`outputs_map[i]` data anchor on `new_node`. - /// If `inputs_map[i]`/`outputs_map[i]` < 0 or the index not contained in - /// `inputs_map[i]`/`outputs_map[i]`, the `i` data anchor will remain - /// on `old_node`. - /// @param new_node - /// @param old_node - /// @param inputs_map - /// @param outputs_map - /// @return - /// - static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, - std::initializer_list inputs_map, - std::initializer_list outputs_map); - - static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, - const std::vector &inputs_map, const std::vector &outputs_map); - - /// - /// Copy all in-control edges from `src_node` to `dst_node` - /// @param src_node - /// @param dst_node - /// @return - /// - static graphStatus CopyInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); - - static graphStatus MoveInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); - - /// - /// Copy all out-control edges from `src_node` to `dst_node` - /// @param src_node - /// @param dst_node - /// @return success: GRAPH_SUCESS - /// - static graphStatus CopyOutCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); - - /// - /// Move all out-control edges from `src_node` to `dst_node` - /// @param src_node - /// @param dst_node - /// @return success: GRAPH_SUCESS - /// - static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); - - /// - /// Copy all in-data edges from `src_node` to `dst_node` - /// @param src_node - /// @param dst_node - /// @return - /// - static graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node); - - static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); - - /// - /// Make a copy of ComputeGraph. - /// @param graph: original graph. - /// @param prefix: node name prefix of new graph. - /// @return ComputeGraphPtr - /// - static ComputeGraphPtr CloneGraph(const ComputeGraphPtr &graph, const string &prefix, - std::vector &input_nodes, std::vector &output_nodes); - - /// - /// Copy tensor attribute to new node. - /// @param [in] dst_desc: cloned node. - /// @param [in] src_node: original node. - /// @return success: GRAPH_SUCESS - /// - static graphStatus CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node); - - static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector &node_vec); - - /// - /// Get reference-mapping of all data_anchors in graph - /// @param [in] graph - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus GetRefMapping(const ComputeGraphPtr &graph, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol); - - /// - /// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs - /// of the graph have UNKNOWN_SHAPE operators or not. - /// Note: This function will only look 'down' from the graph, not 'up'. For example, the following - /// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE - /// ROOT graph: A -----> B -----> C - /// K subgraph U - /// | - /// V - /// SUB graph: D --> E --> F - /// K K K - /// @param [in] graph - /// @return bool - /// - static bool IsUnknownShapeGraph(const ComputeGraphPtr &graph); - - static NodePtr FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::string &name); - - private: - /// - /// Get reference-mapping for in_data_anchors of node - /// @param [in] node - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus HandleInAnchorMapping(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol); - - /// - /// Get reference-mapping for out_data_anchors of node - /// @param [in] node - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus HandleOutAnchorMapping(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol); - - /// - /// Handle input of subgraph - /// @param [in] node - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus HandleSubgraphInput(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol); - - /// - /// Handle input of Merge op - /// @param [in] node - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus HandleMergeInput(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol); - - /// - /// Handle output of subgraph - /// @param [in] node - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus HandleSubgraphOutput(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol); - - /// - /// Relink all edges for cloned ComputeGraph. - /// @param [in] node: original node. - /// @param [in] prefix: node name prefix of new node. - /// @param [in] all_nodes: all nodes in new graph. - /// @return success: GRAPH_SUCESS - /// - static graphStatus RelinkGraphEdges(const NodePtr &node, const string &prefix, - const std::unordered_map &all_nodes); - - /// - /// Union ref-mapping - /// @param [in] exist_node_info1 - /// @param [in] exist_node_info2 - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @param [out] symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol, std::string &symbol); - - /// - /// Update symbol mapping with a new reference pair - /// @param [in] cur_node_info - /// @param [in] exist_node_info - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol); - - /// - /// Check if out_data_anchor is reference of input - /// @param [in] out_data_anchor - /// @param [out] reuse_in_index - /// @return bool - /// - static bool IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index); -}; - -class ComputeGraphBuilder { - public: - ComputeGraphBuilder() : owner_graph_(nullptr) {} - ComputeGraphBuilder(const ComputeGraphBuilder &) = delete; - ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete; - ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete; - ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete; - ~ComputeGraphBuilder() = default; - - /// - /// @brief Add node to graph - /// @param [in] op_desc - /// @return ComputeGraphBuilder - /// - virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc); - - /// - /// @brief Add data-link among nodes in graph - /// @param [in] src_name - /// @param [in] out_anchor_ind - /// @param [in] dst_name - /// @param [in] in_anchor_ind - /// @return ComputeGraphBuilder - /// - virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, - const std::string &dst_name, uint32_t in_anchor_ind); - - /// - /// @brief Add ctrl-link among nodes in graph - /// @param [in] src_name - /// @param [in] dst_name - /// @return ComputeGraphBuilder - /// - virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name); - - /// - /// @brief Build graph - /// @param [out] error_code - /// @param [out] error_msg - /// @return ComputeGraphPtr - /// - virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0; - - /// @brief Get node with name - /// @param [in] name - /// @return NodePtr - /// - NodePtr GetNode(const std::string &name); - - /// @brief Get all nodes - /// @return std::vector - /// - std::vector GetAllNodes(); - - protected: - /// - /// @brief Build nodes - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - void BuildNodes(graphStatus &error_code, std::string &error_msg); - - /// - /// @brief Build data-links - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - void BuildDataLinks(graphStatus &error_code, std::string &error_msg); - - /// - /// @brief Build ctrl-links - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg); - - ComputeGraphPtr owner_graph_; - - // node_name -> node - std::map node_names_; - std::vector nodes_; - - // -> - std::vector, std::pair>> data_links_; - // src_node_name -> dst_node_name - std::vector> ctrl_links_; -}; - -class CompleteGraphBuilder : public ComputeGraphBuilder { - public: - explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {} - CompleteGraphBuilder(const CompleteGraphBuilder &) = delete; - CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete; - CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete; - CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete; - ~CompleteGraphBuilder() = default; - - /// - /// @brief Add node to graph - /// @param [in] op_desc - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override; - - /// - /// @brief Add data-link among nodes in graph - /// @param [in] src_name - /// @param [in] out_anchor_ind - /// @param [in] dst_name - /// @param [in] in_anchor_ind - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, - uint32_t in_anchor_ind) override; - - /// - /// @brief Add ctrl-link among nodes in graph - /// @param [in] src_name - /// @param [in] dst_name - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; - - /// - /// @brief Set index_th input anchor for graph - /// @param [in] index - /// @param [in] node_names - /// @param [in] anchor_inds - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &SetInput(uint32_t index, const std::vector &node_names, - const std::vector &anchor_inds); - - /// - /// @brief Set index_th input of graph as useless - /// @param [in] index - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &SetUselessInput(uint32_t index); - - /// - /// @brief Add output anchor for graph - /// @param [in] owner_node_name - /// @param [in] anchor_ind - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind); - - /// - /// @brief Add target for graph - /// @param [in] target_name - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &AddTarget(const std::string &target_name); - - /// - /// @brief Set parent-node of graph - /// @param [in] parent_node - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node); - - /// - /// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node - /// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &SetInputMapping(const std::map &input_mapping); - - /// - /// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind - /// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &SetOutputMapping(const std::map &output_mapping); - - /// - /// @brief Build graph - /// @param [out] error_code - /// @param [out] error_msg - /// @return ComputeGraphPtr - /// - ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; - - private: - /// - /// @brief Add data nodes - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - void AddDataNodes(graphStatus &error_code, std::string &error_msg); - - /// - /// @brief Add data node - /// @param [in] index - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - NodePtr AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg); - - /// - /// @brief Add RetVal nodes - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - void AddRetValNodes(graphStatus &error_code, std::string &error_msg); - - /// - /// @brief Build target-nodes for graph - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - void BuildGraphTargets(graphStatus &error_code, std::string &error_msg); - - std::string name_; - NodePtr parent_node_; - std::map, std::vector>> graph_inputs_; - std::vector> graph_outputs_; - std::vector graph_targets_; - - // index_of_graph_input -> in_anchor_index_of_parent_node - std::map input_mapping_; - // index_of_graph_output -> out_anchor_index_of_parent_node - std::map output_mapping_; -}; - -class PartialGraphBuilder : public ComputeGraphBuilder { - public: - PartialGraphBuilder() = default; - PartialGraphBuilder(const PartialGraphBuilder &) = delete; - PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete; - PartialGraphBuilder(const PartialGraphBuilder &&) = delete; - PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete; - ~PartialGraphBuilder() = default; - - /// - /// @brief Add node to graph - /// @param [in] op_desc - /// @return PartialGraphBuilder - /// - PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override; - - /// - /// @brief Add data-link among nodes in graph - /// @param [in] src_name - /// @param [in] out_anchor_ind - /// @param [in] dst_name - /// @param [in] in_anchor_ind - /// @return PartialGraphBuilder - /// - PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, - uint32_t in_anchor_ind) override; - - /// - /// @brief Add ctrl-link among nodes in graph - /// @param [in] src_name - /// @param [in] dst_name - /// @return PartialGraphBuilder - /// - PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; - - /// - /// @brief Set owner graph - /// @param [in] graph - /// @return PartialGraphBuilder - /// - PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph); - - /// - /// @brief Add exist node - /// @param [in] node - /// @return PartialGraphBuilder - /// - PartialGraphBuilder &AddExistNode(const NodePtr &node); - - /// - /// @brief Build multi nodes with links - /// @param [out] error_code - /// @param [out] error_msg - /// @return ComputeGraphPtr - /// - ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; - - private: - /// - /// @brief Build exist nodes - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - void BuildExistNodes(graphStatus &error_code, std::string &error_msg); - - std::vector exist_nodes_; -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_ diff --git a/inc/graph/utils/node_utils.h b/inc/graph/utils/node_utils.h deleted file mode 100644 index bf57148d..00000000 --- a/inc/graph/utils/node_utils.h +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_NODE_UTILS_H_ -#define INC_GRAPH_UTILS_NODE_UTILS_H_ - -#include -#include -#include -#include "external/graph/operator.h" -#include "graph/node.h" - -namespace ge { -// Op types of Const like Opps. -extern const std::set kConstOpTypes; -// Op types of If like Opps. -extern const std::set kIfOpTypes; -// Op types of While like Opps. -extern const std::set kWhileOpTypes; -// Op types of Case like Opps. -extern const std::set kCaseOpTypes; -// Op types of For like Opps. -extern const std::set kForOpTypes; - -class NodeUtils { - public: - static graphStatus AddSendEventId(const NodePtr &node, const uint32_t &event_id); - static graphStatus AddRecvEventId(const NodePtr &node, const uint32_t &event_id); - static graphStatus GetSendEventIdList(const NodePtr &node, std::vector &vec_send); - static graphStatus GetRecvEventIdList(const NodePtr &node, std::vector &vec_recv); - - static graphStatus ClearSendInfo(); - static graphStatus ClearRecvInfo(); - - static graphStatus GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst); - - static graphStatus GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data, - InControlAnchorPtr &in_control); - - static graphStatus ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor); - static graphStatus SetAllAnchorStatus(const NodePtr &nodePtr); - static graphStatus SetAllAnchorStatus(Node &node); - static bool IsAnchorStatusSet(const NodePtr &nodePtr); - static bool IsAnchorStatusSet(const Node &node); - - static graphStatus MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node); - - static void UpdateIsInputConst(const NodePtr &nodePtr); - static void UpdateIsInputConst(Node &node); - static bool IsConst(const Node &node); - static void UnlinkAll(const Node &node); - static graphStatus UpdatePeerNodeInputDesc(const NodePtr &node_ptr); - - static graphStatus AppendInputAnchor(const NodePtr &node, uint32_t num); - static graphStatus RemoveInputAnchor(const NodePtr &node, uint32_t num); - - static graphStatus AppendOutputAnchor(const NodePtr &node, uint32_t num); - static graphStatus RemoveOutputAnchor(const NodePtr &node, uint32_t num); - - static bool IsInNodesEmpty(const Node &node); - static GeTensorDesc GetOutputDesc(const Node &node, uint32_t index); - static GeTensorDesc GetInputDesc(const Node &node, uint32_t index); - static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); - static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); - // check node whether unknown shape.If node shape contain -1 or -2,out param "is_unknow" will be true; - // for func op, it will check subgraph yet, if some node shape of subgraph contain -1 or -2, - // the out param "is_unknow" will be true too - static graphStatus GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow); - - static std::string GetNodeType(const Node &node); - static std::string GetNodeType(const NodePtr &node); - - static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); - static graphStatus SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph); - - /// - /// Check if node is input of subgraph - /// @param [in] node - /// @return bool - /// - static bool IsSubgraphInput(const NodePtr &node); - - /// - /// Check if node is output of subgraph - /// @param [in] node - /// @return bool - /// - static bool IsSubgraphOutput(const NodePtr &node); - - /// - /// @brief Get subgraph original input node. - /// @param [in] node - /// @return Node - /// - static NodePtr GetParentInput(const Node &node); - static NodePtr GetParentInput(const NodePtr &node); - - /// - /// @brief Get is dynamic shape graph from node. - /// @param [in] node - /// @return bool - /// - static bool IsDynamicShape(const Node &node); - static bool IsDynamicShape(const NodePtr &node); - - /// - /// @brief Check is varying_input for while node - /// @param [in] node: Data node for subgraph - /// @return bool - /// - static bool IsWhileVaryingInput(const ge::NodePtr &node); - - /// - /// @brief Get subgraph input is constant. - /// @param [in] node - /// @param [out] string - /// @return bool - /// - static bool GetConstOpType(const NodePtr &node, std::string &type); - - /// - /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. - /// @param [in] node - /// @return return GRAPH_SUCCESS if remove successfully, other for failed. - /// - static graphStatus RemoveSubgraphsOnNode(const NodePtr &node); - - /// - /// @brief Get subgraph input data node by index. - /// @param [in] node - /// @return Node - /// - static vector GetSubgraphDataNodesByIndex(const Node &node, int index); - - /// - /// @brief Get subgraph input data node by index. - /// @param [in] node - /// @return Node - /// - static vector GetSubgraphOutputNodes(const Node &node); - - static NodePtr GetInDataNodeByIndex(const Node &node, const int index); - - static vector> GetOutDataNodesWithAnchorByIndex(const Node &node, const int index); - - static ge::ConstNodePtr GetNodeFromOperator(const Operator &oprt); - - static graphStatus GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor); - - static graphStatus GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor); - - private: - static std::map> map_send_info_; - static std::map> map_recv_info_; -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_NODE_UTILS_H_ diff --git a/inc/graph/utils/op_desc_utils.h b/inc/graph/utils/op_desc_utils.h deleted file mode 100644 index daa95ebe..00000000 --- a/inc/graph/utils/op_desc_utils.h +++ /dev/null @@ -1,182 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_OP_DESC_UTILS_H_ -#define INC_GRAPH_UTILS_OP_DESC_UTILS_H_ - -#include -#include -#include -#include "graph/def_types.h" -#include "graph/node.h" -#include "graph/op_desc.h" -#include "graph/operator.h" -#include "graph/range_vistor.h" - -namespace ge { -class OpDesc; -using OpDescPtr = std::shared_ptr; - -class OpDescUtils { - public: - template - using Vistor = RangeVistor>; - - OpDescUtils() = default; - ~OpDescUtils() = default; - static bool HasQuantizeFactorParams(const OpDescPtr& op_desc); - static bool HasQuantizeFactorParams(const OpDesc& op_desc); - static graphStatus GetQuantizeFactorParams(const OpDescPtr& op_desc, QuantizeFactorParams& quant); - static graphStatus GetQuantizeFactorParams(const OpDesc& op_desc, QuantizeFactorParams& quant); - static graphStatus SetQuantizeFactorParams(const OpDescPtr& op_desc, const QuantizeFactorParams& quant); - static graphStatus SetQuantizeFactorParams(OpDesc& op_desc, const QuantizeFactorParams& quant); - - static vector GetConstInputNode(const ge::Node& node); - static vector GetInputData(const vector& input_nodes); - - static vector GetWeights(const ge::Node& node); - static vector GetWeights(const ge::ConstNodePtr& node); - static vector MutableWeights(const ge::Node& node); - static vector MutableWeights(const ge::NodePtr node); - static graphStatus SetWeights(ge::Node& node, const vector& weights); - static graphStatus SetWeights(ge::NodePtr node, const vector& weights); - static graphStatus SetWeights(ge::Node& node, const map& weights_map); - static graphStatus ClearWeights(ge::NodePtr node); - - static bool ClearInputDesc(ge::OpDescPtr op_desc, uint32_t index); - static bool ClearInputDesc(const ge::NodePtr& node); - static bool ClearOutputDesc(const ge::OpDescPtr& op_desc, uint32_t index); - static bool ClearOutputDesc(const ge::NodePtr& node); - static vector GetConstInputs(const ge::Node& node); - static vector GetConstInputs(const ge::ConstNodePtr& node); - static size_t GetNonConstInputsSize(const ge::Node& node); - static size_t GetNonConstInputsSize(ge::ConstNodePtr node); - // Index: Indicates the index of all non const inputs - static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node& node, size_t index_non_const = 0); - static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr& node, size_t index_non_const = 0); - static bool GetNonConstInputIndex(const ge::Node& node, size_t index_non_const, size_t& index); - static bool GetNonConstInputIndex(const ge::ConstNodePtr& node, size_t index_non_const, size_t& index); - // Index: Indicates the index of all inputs - static bool IsNonConstInput(const ge::Node& node, size_t index = 0); - static bool IsNonConstInput(const ge::ConstNodePtr& node, size_t index = 0); - - static vector GetNonConstTensorDesc(const ge::ConstNodePtr& node); - static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr& tensor_ptr); - - static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc); - static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr); - static OpDescPtr GetOpDescFromOperator(const Operator& oprt); - - static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); - - static graphStatus SetSubgraphInstanceName(const std::string& subgraph_name, - const std::string& subgraph_instance_name, OpDescPtr& op_desc); - - private: - static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); - static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); - static graphStatus SetWeights(ge::OpDesc& op_desc, const GeTensorPtr weight); - static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight); -}; - -class OpDescBuilder { - public: - OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {} - OpDescBuilder(const OpDescBuilder&) = delete; - OpDescBuilder& operator=(const OpDescBuilder&) = delete; - OpDescBuilder(const OpDescBuilder&&) = delete; - OpDescBuilder& operator=(const OpDescBuilder&&) = delete; - ~OpDescBuilder() = default; - - /// - /// @brief Add input - /// @param [in] name - /// @return OpDescBuilder - /// - OpDescBuilder& AddInput(const std::string& name); - - /// - /// @brief Add input - /// @param [in] name - /// @param [in] tensor - /// @return OpDescBuilder - /// - OpDescBuilder& AddInput(const std::string& name, const GeTensorDesc& tensor); - - /// - /// @brief Add dynamic input - /// @param [in] name - /// @param [in] num - /// @return OpDescBuilder - /// - OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num); - - /// - /// @brief Add dynamic input - /// @param [in] name - /// @param [in] num - /// @param [in] tensor - /// @return OpDescBuilder - /// - OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); - - /// - /// @brief Add output - /// @param [in] name - /// @return OpDescBuilder - /// - OpDescBuilder& AddOutput(const std::string& name); - - /// - /// @brief Add output - /// @param [in] name - /// @param [in] tensor - /// @return OpDescBuilder - /// - OpDescBuilder& AddOutput(const std::string& name, const GeTensorDesc& tensor); - - /// - /// @brief Add dynamic output - /// @param [in] name - /// @param [in] num - /// @return OpDescBuilder - /// - OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num); - - /// - /// @brief Add dynamic output - /// @param [in] name - /// @param [in] num - /// @param [in] tensor - /// @return OpDescBuilder - /// - OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); - - /// - /// @brief Build op_desc - /// @return OpDescPtr - /// - OpDescPtr Build(); - - private: - std::string name_; - std::string type_; - std::vector> inputs_; - std::vector> outputs_; -}; -} // namespace ge - -#endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_ diff --git a/inc/graph/utils/tensor_adapter.h b/inc/graph/utils/tensor_adapter.h deleted file mode 100644 index a7355553..00000000 --- a/inc/graph/utils/tensor_adapter.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ -#define INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ - -#include -#include "graph/ge_tensor.h" -#include "graph/tensor.h" - -namespace ge { -using GeTensorPtr = std::shared_ptr; -using ConstGeTensorPtr = std::shared_ptr; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorAdapter { - public: - static GeTensorDesc TensorDesc2GeTensorDesc(const TensorDesc &tensorDesc); - static TensorDesc GeTensorDesc2TensorDesc(const GeTensorDesc &geTensorDesc); - static GeTensorPtr Tensor2GeTensor(const Tensor &tensor); - static Tensor GeTensor2Tensor(const ConstGeTensorPtr &geTensor); - - static ConstGeTensorPtr AsGeTensorPtr(const Tensor &tensor); // Share value - static GeTensorPtr AsGeTensorPtr(Tensor &tensor); // Share value - static const GeTensor AsGeTensor(const Tensor &tensor); // Share value - static GeTensor AsGeTensor(Tensor &tensor); // Share value - static const Tensor AsTensor(const GeTensor &tensor); // Share value - static Tensor AsTensor(GeTensor &tensor); // Share value -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ diff --git a/inc/graph/utils/tensor_utils.h b/inc/graph/utils/tensor_utils.h deleted file mode 100644 index caa80dcf..00000000 --- a/inc/graph/utils/tensor_utils.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_TENSOR_UTILS_H_ -#define INC_GRAPH_UTILS_TENSOR_UTILS_H_ - -#include -#include "graph/def_types.h" -#include "graph/ge_error_codes.h" -#include "graph/ge_tensor.h" - -namespace ge { -class TensorUtils { - public: - static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, int64_t &size); - static void SetSize(GeTensorDesc &tensorDesc, int64_t size); - static uint32_t GetWeightSize(const ConstGeTensorPtr &tensorPtr); - static uint32_t GetWeightSize(const GeTensor &tensor); - static uint32_t GetWeightSize(const GeTensorDesc &tensorDesc); - static uint8_t *GetWeightAddr(const ConstGeTensorPtr &tensorPtr, uint8_t *base); - static uint8_t *GetWeightAddr(const GeTensor &tensor, uint8_t *base); - static void SetWeightSize(GeTensorDesc &tensorDesc, uint32_t size); - static ge::graphStatus GetReuseInput(const GeTensorDesc &tensorDesc, bool &flag); - static void SetReuseInput(GeTensorDesc &tensorDesc, bool flag); - static ge::graphStatus GetOutputTensor(const GeTensorDesc &tensorDesc, bool &flag); - static void SetOutputTensor(GeTensorDesc &tensorDesc, bool flag); - static graphStatus GetDeviceType(const GeTensorDesc &tensorDesc, DeviceType &type); - static void SetDeviceType(GeTensorDesc &tensorDesc, DeviceType type); - static ge::graphStatus GetInputTensor(const GeTensorDesc &tensorDesc, bool &flag); - static void SetInputTensor(GeTensorDesc &tensorDesc, bool flag); - static ge::graphStatus GetRealDimCnt(const GeTensorDesc &tensorDesc, uint32_t &cnt); - static void SetRealDimCnt(GeTensorDesc &tensorDesc, uint32_t cnt); - static ge::graphStatus GetReuseInputIndex(const GeTensorDesc &tensorDesc, uint32_t &idx); - static void SetReuseInputIndex(GeTensorDesc &tensorDesc, uint32_t idx); - static ge::graphStatus GetDataOffset(const GeTensorDesc &tensorDesc, int64_t &offset); - static void SetDataOffset(GeTensorDesc &tensorDesc, int64_t offset); - static ge::graphStatus GetCmpsSize(const GeTensorDesc &tensorDesc, uint32_t &cmp_size); - static void SetCmpsSize(GeTensorDesc &tensorDesc, uint32_t cmp_size); - static ge::graphStatus GetCmpsTab(const GeTensorDesc &tensorDesc, vector &vec); - static void SetCmpsTab(GeTensorDesc &tensorDesc, const uint8_t *data, size_t size); - static ge::graphStatus GetCmpsTabOffset(const GeTensorDesc &tensorDesc, int64_t &tab_offset); - static void SetCmpsTabOffset(GeTensorDesc &tensorDesc, int64_t tab_offset); - static ge::graphStatus GetCmpsInfo(const GeTensorDesc &tensorDesc, CompressInfo &info); - static void SetCmpsInfo(GeTensorDesc &tensorDesc, const CompressInfo &info); - static bool HasAlloffsetQuantizeInfo(const GeTensorDesc &tensorDesc); - static ge::graphStatus GetAlloffsetQuantizeInfo(const GeTensorDesc &tensorDesc, AllOffsetQuantizeInfo &info); - static void SetAlloffsetQuantizeInfo(GeTensorDesc &tensorDesc, const AllOffsetQuantizeInfo &info); - static ge::graphStatus GetRC(const GeTensorDesc &tensorDesc, uint32_t &rc); - static void SetRC(GeTensorDesc &tensorDesc, uint32_t rc); - - /// - /// calculate tensor mem size. - /// @param shape tensor shape - /// @param format tensor format - /// @param data_type tensor data type - /// @param mem_size -1 means unknown shape,other means mem size - /// @return GRAPH_SUCCESS:success, other:failed - /// - static ge::graphStatus CalcTensorMemSize(const GeShape &shape, Format format, DataType data_type, int64_t &mem_size); - static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); - static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_ diff --git a/inc/graph/utils/type_utils.h b/inc/graph/utils/type_utils.h deleted file mode 100644 index 38509b9a..00000000 --- a/inc/graph/utils/type_utils.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_TYPE_UTILS_H_ -#define INC_GRAPH_UTILS_TYPE_UTILS_H_ - -#include -#include -#include -#include "graph/def_types.h" -#include "graph/ge_error_codes.h" -#include "graph/types.h" -#include "graph/usr_types.h" -#include "register/register_types.h" -#include "external/register/register_fmk_types.h" - -namespace ge { -class TypeUtils { - public: - static bool IsDataTypeValid(DataType dt); - static bool IsFormatValid(Format format); - static bool IsInternalFormat(Format format); - - static std::string ImplyTypeToSerialString(domi::ImplyType imply_type); - static std::string DataTypeToSerialString(DataType data_type); - static DataType SerialStringToDataType(const std::string &str); - static std::string FormatToSerialString(Format format); - static Format SerialStringToFormat(const std::string &str); - static Format DataFormatToFormat(const std::string &str); - static Format DomiFormatToFormat(domi::domiTensorFormat_t domi_format); - static std::string FmkTypeToSerialString(domi::FrameworkType fmk_type); - - static graphStatus Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def); - static graphStatus Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr); - - static bool GetDataTypeLength(ge::DataType data_type, uint32_t &length); - static bool CheckUint64MulOverflow(uint64_t a, uint32_t b); -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_TYPE_UTILS_H_ diff --git a/src/common/graph/CMakeLists.txt b/src/common/graph/CMakeLists.txt deleted file mode 100755 index 7608e2b3..00000000 --- a/src/common/graph/CMakeLists.txt +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -# libgraph.so -# compiling proto files generates some warnings, use no-unused-variable to suppress them -set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") -# add all proto files, generate corresponding .h and .cc files -file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "../../proto/om.proto" - "../../proto/ge_ir.proto" - "../../proto/insert_op.proto" - "../../proto/task.proto" - "../../proto/fwk_adapter.proto" - "../../proto/op_mapping_info.proto" - "../../proto/dump_task.proto" - "../../proto/onnx.proto" -) - -file(GLOB_RECURSE ONNX_PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "${onnx_INC}/onnx/onnx.proto" -) - - -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -#protobuf_generate(ge PROTO_ONNX_SRCS PROTO_ONNX_HDRS ${ONNX_PROTO_LIST}) -# need to remove dependencies on pb files later -file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "*.cc" - "utils/*.cc" - "opsproto/*.cc" - "detail/*.cc" - "debug/*.cc" - "option/*.cc" - ) - -# include directories -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${GE_SOURCE_DIR}) -include_directories(${GE_SOURCE_DIR}/src) -include_directories(${GE_SOURCE_DIR}/src/ge) -include_directories(${GE_SOURCE_DIR}/src/common) -include_directories(${GE_SOURCE_DIR}/src/common/graph) -include_directories(${GE_SOURCE_DIR}/inc) -include_directories(${GE_SOURCE_DIR}/inc/framework) -include_directories(${GE_SOURCE_DIR}/inc/external) -include_directories(${GE_SOURCE_DIR}/inc/external/graph) -include_directories(${GE_SOURCE_DIR}/inc/graph) -include_directories(${GE_SOURCE_DIR}/inc/common) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) -include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_BINARY_DIR}/proto/ge) -include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) -include_directories(${GE_SOURCE_DIR}/build) - -######### libgraph.so ############# -add_library(graph SHARED ${SRC_LIST} ${PROTO_SRCS}) -target_compile_definitions(graph PRIVATE - DAVINCI_CLOUD - FMK_SUPPORT_DUMP - Werror) -target_link_libraries(graph PRIVATE - #${PROTOBUF_LIBRARY} - protobuf - ${c_sec} - ${slog} - ${error_manager} - rt - dl) diff --git a/src/common/graph/anchor.cc b/src/common/graph/anchor.cc deleted file mode 100644 index f02037e5..00000000 --- a/src/common/graph/anchor.cc +++ /dev/null @@ -1,371 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/anchor.h" -#include -#include -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/node.h" - -namespace ge { -Anchor::Anchor(const NodePtr &owner_node, int idx) : owner_node_(owner_node), idx_(idx) {} - -bool Anchor::IsTypeOf(TYPE type) const { return strcmp(Anchor::TypeOf(), type) == 0; } - -size_t Anchor::GetPeerAnchorsSize() const { return peer_anchors_.size(); } - -Anchor::Vistor Anchor::GetPeerAnchors() const { - vector ret; - for (const auto &anchor : peer_anchors_) { - ret.push_back(anchor.lock()); - } - return Anchor::Vistor(shared_from_this(), ret); -} - -AnchorPtr Anchor::GetFirstPeerAnchor() const { - if (peer_anchors_.empty()) { - return nullptr; - } else { - return Anchor::DynamicAnchorCast(peer_anchors_.begin()->lock()); - } -} - -NodePtr Anchor::GetOwnerNode() const { return owner_node_.lock(); } - -void Anchor::UnlinkAll() noexcept { - if (!peer_anchors_.empty()) { - do { - auto peer_anchor_ptr = peer_anchors_.begin()->lock(); - if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { - GELOGW("unlink peer_anchor_ptr failed."); - } - } while (!peer_anchors_.empty()); - } -} - -graphStatus Anchor::Unlink(const AnchorPtr &peer) { - if (peer == nullptr) { - GELOGE(GRAPH_FAILED, "peer anchor is invalid."); - return GRAPH_FAILED; - } - auto it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [peer](const std::weak_ptr &an) { - auto anchor = an.lock(); - return peer->Equal(anchor); - }); - - GE_IF_BOOL_EXEC(it == peer_anchors_.end(), GELOGW("this anchor is not connected to peer"); return GRAPH_FAILED); - - auto it_peer = - std::find_if(peer->peer_anchors_.begin(), peer->peer_anchors_.end(), [this](const std::weak_ptr &an) { - auto anchor = an.lock(); - return Equal(anchor); - }); - - GE_CHK_BOOL_RET_STATUS(it_peer != peer->peer_anchors_.end(), GRAPH_FAILED, "peer is not connected to this anchor"); - - (void)peer_anchors_.erase(it); - (void)peer->peer_anchors_.erase(it_peer); - return GRAPH_SUCCESS; -} - -graphStatus Anchor::ReplacePeer(const AnchorPtr &old_peer, const AnchorPtr &first_peer, const AnchorPtr &second_peer) { - GE_CHK_BOOL_RET_STATUS(old_peer != nullptr, GRAPH_FAILED, "this old peer anchor is nullptr"); - GE_CHK_BOOL_RET_STATUS(first_peer != nullptr, GRAPH_FAILED, "this first peer anchor is nullptr"); - GE_CHK_BOOL_RET_STATUS(second_peer != nullptr, GRAPH_FAILED, "this second peer anchor is nullptr"); - auto this_it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [old_peer](const std::weak_ptr &an) { - auto anchor = an.lock(); - return old_peer->Equal(anchor); - }); - - GE_CHK_BOOL_RET_STATUS(this_it != peer_anchors_.end(), GRAPH_FAILED, "this anchor is not connected to old_peer"); - - auto old_it = std::find_if(old_peer->peer_anchors_.begin(), old_peer->peer_anchors_.end(), - [this](const std::weak_ptr &an) { - auto anchor = an.lock(); - return Equal(anchor); - }); - - GE_CHK_BOOL_RET_STATUS(old_it != old_peer->peer_anchors_.end(), GRAPH_FAILED, - "old_peer is not connected to this anchor"); - *this_it = first_peer; - first_peer->peer_anchors_.push_back(shared_from_this()); - *old_it = second_peer; - second_peer->peer_anchors_.push_back(old_peer); - return GRAPH_SUCCESS; -} - -bool Anchor::IsLinkedWith(const AnchorPtr &peer) { - auto it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [peer](const std::weak_ptr &an) { - auto anchor = an.lock(); - GE_CHK_BOOL_RET_STATUS(peer != nullptr, false, "this old peer anchor is nullptr"); - return peer->Equal(anchor); - }); - return (it != peer_anchors_.end()); -} - -int Anchor::GetIdx() const { return idx_; } - -void Anchor::SetIdx(int index) { idx_ = index; } - -DataAnchor::DataAnchor(const NodePtr &owner_node, int idx) : Anchor(owner_node, idx) {} - -bool DataAnchor::IsTypeOf(TYPE type) const { - if (strcmp(Anchor::TypeOf(), type) == 0) { - return true; - } - return Anchor::IsTypeOf(type); -} - -InDataAnchor::InDataAnchor(const NodePtr &owner_node, int idx) : DataAnchor(owner_node, idx) {} - -OutDataAnchorPtr InDataAnchor::GetPeerOutAnchor() const { - if (peer_anchors_.empty()) { - return nullptr; - } else { - return Anchor::DynamicAnchorCast(peer_anchors_.begin()->lock()); - } -} - -graphStatus InDataAnchor::LinkFrom(const OutDataAnchorPtr &src) { - // InDataAnchor must be only linkfrom once - if (src == nullptr || !peer_anchors_.empty()) { - GELOGE(GRAPH_FAILED, "src anchor is invalid or the peerAnchors is not empty."); - return GRAPH_FAILED; - } - peer_anchors_.push_back(src); - src->peer_anchors_.push_back(shared_from_this()); - return GRAPH_SUCCESS; -} - -bool InDataAnchor::Equal(AnchorPtr anchor) const { - auto in_data_anchor = Anchor::DynamicAnchorCast(anchor); - if (in_data_anchor != nullptr) { - if (GetOwnerNode() == in_data_anchor->GetOwnerNode() && GetIdx() == in_data_anchor->GetIdx()) { - return true; - } - } - return false; -} - -bool InDataAnchor::IsTypeOf(TYPE type) const { - if (strcmp(Anchor::TypeOf(), type) == 0) { - return true; - } - return DataAnchor::IsTypeOf(type); -} - -OutDataAnchor::OutDataAnchor(const NodePtr &owner_node, int idx) : DataAnchor(owner_node, idx) {} - -OutDataAnchor::Vistor OutDataAnchor::GetPeerInDataAnchors() const { - vector ret; - for (const auto &anchor : peer_anchors_) { - auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (in_data_anchor != nullptr) { - ret.push_back(in_data_anchor); - } - } - return OutDataAnchor::Vistor(shared_from_this(), ret); -} - -uint32_t OutDataAnchor::GetPeerInDataNodesSize() const { - uint32_t out_nums = 0; - for (const auto &anchor : peer_anchors_) { - auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (in_data_anchor != nullptr && in_data_anchor->GetOwnerNode() != nullptr) { - out_nums++; - } - } - return out_nums; -} - -OutDataAnchor::Vistor OutDataAnchor::GetPeerInControlAnchors() const { - vector ret; - for (const auto &anchor : peer_anchors_) { - auto in_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (in_control_anchor != nullptr) { - ret.push_back(in_control_anchor); - } - } - return OutDataAnchor::Vistor(shared_from_this(), ret); -} - -graphStatus OutDataAnchor::LinkTo(const InDataAnchorPtr &dest) { - if (dest == nullptr || !dest->peer_anchors_.empty()) { - GELOGE(GRAPH_FAILED, "dest anchor is invalid or the peerAnchors is not empty."); - return GRAPH_FAILED; - } - peer_anchors_.push_back(dest); - dest->peer_anchors_.push_back(shared_from_this()); - return GRAPH_SUCCESS; -} - -graphStatus OutDataAnchor::LinkTo(const InControlAnchorPtr &dest) { - if (dest == nullptr) { - GELOGE(GRAPH_FAILED, "dest anchor is invalid."); - return GRAPH_FAILED; - } - peer_anchors_.push_back(dest); - dest->peer_anchors_.push_back(shared_from_this()); - return GRAPH_SUCCESS; -} - -graphStatus OutControlAnchor::LinkTo(const InDataAnchorPtr &dest) { - if (dest == nullptr) { - GELOGE(GRAPH_FAILED, "dest anchor is invalid."); - return GRAPH_FAILED; - } - peer_anchors_.push_back(dest); - dest->peer_anchors_.push_back(shared_from_this()); - return GRAPH_SUCCESS; -} - -bool OutDataAnchor::Equal(AnchorPtr anchor) const { - CHECK_FALSE_EXEC(anchor != nullptr, return false); - auto out_data_anchor = Anchor::DynamicAnchorCast(anchor); - if (out_data_anchor != nullptr) { - if (GetOwnerNode() == out_data_anchor->GetOwnerNode() && GetIdx() == out_data_anchor->GetIdx()) { - return true; - } - } - return false; -} - -bool OutDataAnchor::IsTypeOf(TYPE type) const { - if (strcmp(Anchor::TypeOf(), type) == 0) { - return true; - } - return DataAnchor::IsTypeOf(type); -} - -ControlAnchor::ControlAnchor(const NodePtr &owner_node) : Anchor(owner_node, -1) {} - -ControlAnchor::ControlAnchor(const NodePtr &owner_node, int idx) : Anchor(owner_node, idx) {} - -bool ControlAnchor::IsTypeOf(TYPE type) const { - if (strcmp(Anchor::TypeOf(), type) == 0) { - return true; - } - return Anchor::IsTypeOf(type); -} - -InControlAnchor::InControlAnchor(const NodePtr &owner_node) : ControlAnchor(owner_node) {} - -InControlAnchor::InControlAnchor(const NodePtr &owner_node, int idx) : ControlAnchor(owner_node, idx) {} - -InControlAnchor::Vistor InControlAnchor::GetPeerOutControlAnchors() const { - vector ret; - for (const auto &anchor : peer_anchors_) { - auto out_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (out_control_anchor != nullptr) { - ret.push_back(out_control_anchor); - } - } - return InControlAnchor::Vistor(shared_from_this(), ret); -} - -InControlAnchor::Vistor InControlAnchor::GetPeerOutDataAnchors() const { - vector ret; - for (const auto &anchor : peer_anchors_) { - auto out_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (out_data_anchor != nullptr) { - ret.push_back(out_data_anchor); - } - } - return InControlAnchor::Vistor(shared_from_this(), ret); -} - -graphStatus InControlAnchor::LinkFrom(const OutControlAnchorPtr &src) { - if (src == nullptr) { - GELOGE(GRAPH_FAILED, "src anchor is invalid."); - return GRAPH_FAILED; - } - peer_anchors_.push_back(src); - src->peer_anchors_.push_back(shared_from_this()); - return GRAPH_SUCCESS; -} - -bool InControlAnchor::Equal(AnchorPtr anchor) const { - CHECK_FALSE_EXEC(anchor != nullptr, return false); - auto in_control_anchor = Anchor::DynamicAnchorCast(anchor); - if (in_control_anchor != nullptr) { - if (GetOwnerNode() == in_control_anchor->GetOwnerNode()) { - return true; - } - } - return false; -} - -bool InControlAnchor::IsTypeOf(TYPE type) const { - if (strcmp(Anchor::TypeOf(), type) == 0) { - return true; - } - return ControlAnchor::IsTypeOf(type); -} - -OutControlAnchor::OutControlAnchor(const NodePtr &owner_node) : ControlAnchor(owner_node) {} - -OutControlAnchor::OutControlAnchor(const NodePtr &owner_node, int idx) : ControlAnchor(owner_node, idx) {} - -OutControlAnchor::Vistor OutControlAnchor::GetPeerInControlAnchors() const { - vector ret; - for (const auto &anchor : peer_anchors_) { - auto in_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (in_control_anchor != nullptr) { - ret.push_back(in_control_anchor); - } - } - return OutControlAnchor::Vistor(shared_from_this(), ret); -} - -OutControlAnchor::Vistor OutControlAnchor::GetPeerInDataAnchors() const { - vector ret; - for (const auto &anchor : peer_anchors_) { - auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (in_data_anchor != nullptr) { - ret.push_back(in_data_anchor); - } - } - return OutControlAnchor::Vistor(shared_from_this(), ret); -} - -graphStatus OutControlAnchor::LinkTo(const InControlAnchorPtr &dest) { - if (dest == nullptr) { - GELOGE(GRAPH_FAILED, "dest anchor is invalid."); - return GRAPH_FAILED; - } - peer_anchors_.push_back(dest); - dest->peer_anchors_.push_back(shared_from_this()); - return GRAPH_SUCCESS; -} - -bool OutControlAnchor::Equal(AnchorPtr anchor) const { - auto out_control_anchor = Anchor::DynamicAnchorCast(anchor); - if (out_control_anchor != nullptr) { - if (GetOwnerNode() == out_control_anchor->GetOwnerNode()) { - return true; - } - } - return false; -} - -bool OutControlAnchor::IsTypeOf(TYPE type) const { - if (strcmp(Anchor::TypeOf(), type) == 0) { - return true; - } - return ControlAnchor::IsTypeOf(type); -} -} // namespace ge diff --git a/src/common/graph/attr_value.cc b/src/common/graph/attr_value.cc deleted file mode 100644 index 066767c2..00000000 --- a/src/common/graph/attr_value.cc +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "external/graph/attr_value.h" -#include "debug/ge_log.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/ge_attr_value.h" - -namespace ge { -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue::AttrValue() { impl = ComGraphMakeShared(); } - -#define ATTR_VALUE_SET_GET_IMP(type) \ - graphStatus AttrValue::GetValue(type &val) const { \ - if (impl != nullptr) { \ - GELOGW("GetValue failed."); \ - return impl->geAttrValue_.GetValue(val); \ - } \ - return GRAPH_FAILED; \ - } - -ATTR_VALUE_SET_GET_IMP(AttrValue::STR) -ATTR_VALUE_SET_GET_IMP(AttrValue::INT) -ATTR_VALUE_SET_GET_IMP(AttrValue::FLOAT) -} // namespace ge diff --git a/src/common/graph/buffer.cc b/src/common/graph/buffer.cc deleted file mode 100644 index 48cdd397..00000000 --- a/src/common/graph/buffer.cc +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/buffer.h" -#include "proto/ge_ir.pb.h" -#include "framework/common/debug/ge_log.h" - -namespace ge { -Buffer::Buffer() { - data_.InitDefault(); - if (data_.GetProtoMsg()) { - buffer_ = data_.GetProtoMsg()->mutable_bt(); - } -} - -Buffer::Buffer(const Buffer &other) { - // Share data - data_ = other.data_; - buffer_ = other.buffer_; -} - -Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { // default - auto proto_msg = data_.GetProtoMsg(); - if (proto_msg != nullptr) { - try { - proto_msg->set_bt(std::string(buffer_size, default_val)); - buffer_ = proto_msg->mutable_bt(); - } catch (std::bad_alloc &e) { - GELOGE(MEMALLOC_FAILED, "Failed to alloc buffer memory, buffer size %zu", buffer_size); - buffer_ = nullptr; - } - } -} - -Buffer Buffer::CopyFrom(const std::uint8_t *data, std::size_t buffer_size) { - Buffer buffer; - auto proto_msg = buffer.data_.GetProtoMsg(); - if (proto_msg != nullptr && data != nullptr) { - try { - proto_msg->set_bt(data, buffer_size); - buffer.buffer_ = proto_msg->mutable_bt(); - } catch (std::bad_alloc &e) { - GELOGE(MEMALLOC_FAILED, "Failed to alloc buffer memory, buffer size %zu", buffer_size); - buffer.buffer_ = nullptr; - } - } - return buffer; -} - -Buffer::Buffer(const std::shared_ptr &proto_owner, proto::AttrDef *buffer) - : data_(proto_owner, buffer) { - if (data_.GetProtoMsg() != nullptr) { - buffer_ = data_.GetProtoMsg()->mutable_bt(); - } -} - -Buffer::Buffer(const std::shared_ptr &proto_owner, std::string *buffer) - : data_(proto_owner, nullptr) { - buffer_ = buffer; -} - -Buffer &Buffer::operator=(const Buffer &other) { - if (&other != this) { - // Share data - data_ = other.data_; - buffer_ = other.buffer_; - } - return *this; -} - -const std::uint8_t *Buffer::GetData() const { - if (buffer_ != nullptr) { - return (const std::uint8_t *)buffer_->data(); - } - return nullptr; -} - -std::uint8_t *Buffer::GetData() { - if (buffer_ != nullptr && !buffer_->empty()) { - // Avoid copy on write - (void)(*buffer_)[0]; - return reinterpret_cast(const_cast(buffer_->data())); - } - return nullptr; -} - -std::size_t Buffer::GetSize() const { - if (buffer_ != nullptr) { - return buffer_->size(); - } - return 0; -} - -void Buffer::ClearBuffer() { - if (buffer_ != nullptr) { - buffer_->clear(); - } -} -} // namespace ge diff --git a/src/common/graph/compute_graph.cc b/src/common/graph/compute_graph.cc deleted file mode 100644 index bae4d362..00000000 --- a/src/common/graph/compute_graph.cc +++ /dev/null @@ -1,1314 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/compute_graph.h" -#include -#include "./format_refiner.h" -#include "./ge_context.h" -#include "debug/ge_attr_define.h" -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "ge/ge_api_types.h" -#include "graph/shape_refiner.h" -#include "proto/ge_ir.pb.h" -#include "utils/ge_ir_utils.h" -#include "utils/graph_utils.h" -#include "utils/node_utils.h" -#include "utils/op_desc_utils.h" -#include "utils/string_utils.h" -#include "utils/tensor_utils.h" - -namespace ge { -namespace { -const size_t OUTPUT_PARAM_SIZE = 2; -const std::string alias_name_attr = "_aliasName"; -bool IsUseBFS() { - string run_mode; - const int base = 10; - if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) { - if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) >= TRAIN) { - return true; - } - } else { - GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); - } - return false; -} -} // namespace - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name) - : name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) { - attrs_.InitDefault(); -} - -ComputeGraph::~ComputeGraph() {} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() const { return name_; } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const { - return GetAllNodes().size(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetAllNodes() const { - std::vector> subgraphs; - return AllGraphNodes(subgraphs); -} - -ComputeGraph::Vistor ComputeGraph::AllGraphNodes(std::vector> &subgraphs) const { - std::vector all_nodes; - std::deque candidates; - - candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end()); - while (!candidates.empty()) { - NodePtr node = candidates.front(); - all_nodes.emplace_back(node); - candidates.pop_front(); - - OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - continue; - } - - const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); - for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { - auto subgraph = GetSubgraph(*name_iter); - if (subgraph != nullptr) { - subgraphs.emplace_back(subgraph); - candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); - } - } - } - - return Vistor(shared_from_this(), all_nodes); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetNodes( - bool is_unknown_shape) const { - if (is_unknown_shape) { - return GetDirectNode(); - } else { - return GetAllNodes(); - } -} - -size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetDirectNode() const { - return Vistor(shared_from_this(), nodes_); -} - -ComputeGraph::Vistor ComputeGraph::GetInputNodes() const { - return Vistor(shared_from_this(), input_nodes_); -} - -ComputeGraph::Vistor ComputeGraph::GetOutputNodes() const { - std::vector result; - for (auto iter = output_nodes_info_.begin(); iter != output_nodes_info_.end(); ++iter) { - result.push_back(iter->first); - } - return Vistor(shared_from_this(), result); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(const std::string &name) const { - for (const auto &node : nodes_) { - if (node == nullptr) { - continue; - } - if (node->GetName() == name) { - return node; - } - std::vector out_alias_name; - if (AttrUtils::GetListStr(node->GetOpDesc(), alias_name_attr, out_alias_name)) { - for (const auto &alias_name : out_alias_name) { - if (alias_name == name) { - return node; - } - } - } - } - return nullptr; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr -ComputeGraph::FindFirstNodeMatchType(const std::string &name) const { - for (const auto &node : nodes_) { - if (node == nullptr) { - continue; - } - if (node->GetType() == name) { - return node; - } - } - return nullptr; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphAttrsAreEqual( - const ComputeGraph &r_graph) const { - // ProtoMsgOwner <::google::protobuf::Message> is temporarily ignored - if ((this->attrs_.protoMsg_ != nullptr) && (r_graph.attrs_.protoMsg_ != nullptr)) { - const auto &proto_attr_map = *(this->attrs_.protoMsg_); - const auto &r_proto_attr_map = *(r_graph.attrs_.protoMsg_); - // 1.Verify graph's ProtoAttrMap size - if (proto_attr_map.size() != r_proto_attr_map.size()) { - GELOGE(GRAPH_FAILED, "Size of compute graph's ProtoAttrMap verify failed, graph name: %s.", - this->GetName().c_str()); - return false; - } - // 2.Verify graph's ProtoAttrMap key, verify values is temporarily not implemented - for (const auto &it : proto_attr_map) { - if (r_proto_attr_map.count(it.first) == 0) { - GELOGE(GRAPH_FAILED, "Key of compute graph's ProtoAttrMap verify failed, graph name: %s key name: %s.", - this->GetName().c_str(), it.first.c_str()); - return false; - } - } - return true; - } - return ((this->attrs_.protoMsg_ == nullptr) && (r_graph.attrs_.protoMsg_ == nullptr)); -} - -/// Since there may be different input nodes -/// chosen by user in the same graph, special judgment is needed -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNodePtrIsEqual( - const std::vector &left_nodes, const std::vector &right_nodes) const { - const auto left_nodes_size = left_nodes.size(); - const auto right_nodes_size = right_nodes.size(); - if (left_nodes_size != right_nodes_size) { - GELOGE(GRAPH_FAILED, - "Check failed with graph input_nodes_: " - "left inputNodes size %zu is different with right inputNodes size %zu .", - left_nodes_size, right_nodes_size); - return false; - } - for (size_t j = 0; j < left_nodes_size; j++) { - if (left_nodes.at(j) == nullptr || right_nodes.at(j) == nullptr) { - GELOGE(GRAPH_FAILED, "left_nodes.at(%zu) or right_nodes.at(%zu) is nullptr", j, j); - return false; - } - const auto &left_input_name = left_nodes.at(j)->GetName(); - const auto &right_input_name = right_nodes.at(j)->GetName(); - if (left_input_name != right_input_name) { - GELOGE(GRAPH_FAILED, - "Check failed with graph input_nodes_: " - "left inputNode name %s is different with right inputNode name %s at inputNodes index %zu.", - left_input_name.c_str(), right_input_name.c_str(), j); - return false; - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual( - const ComputeGraph &r_graph) const { - return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.subgraphs_.size()") && - IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") && - VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) && - IsEqual(this->name_, r_graph.name_, "graph.name_") && - IsEqual(this->is_valid_flag_, r_graph.is_valid_flag_, "graph.is_valid_flag_") && - IsEqual(this->need_iteration_, r_graph.need_iteration_, "graph.need_iteration_") && - IsEqual(this->params_share_map_, r_graph.params_share_map_, "graph.params_share_map_") && - IsEqual(this->out_nodes_map_, r_graph.out_nodes_map_, "graph.out_nodes_map_") && - IsEqual(this->inputs_order_, r_graph.inputs_order_, "graph.inputs_order_") && - IsEqual(this->output_size_, r_graph.output_size_, "graph.output_size_") && - IsEqual(this->input_size_, r_graph.input_size_, "graph.input_size_") && - IsEqual(this->output_nodes_info_, r_graph.output_nodes_info_, "graph.output_nodes_info_")); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::operator==(const ComputeGraph &r_graph) const { - // Firstly: Graph's members equal - if ((!GraphMembersAreEqual(r_graph)) || (!GraphAttrsAreEqual(r_graph))) { - return false; - } - - // Secondly: Node equal means the link relationship between node and node itself equal - for (const auto &left_node : nodes_) { - if (left_node == nullptr) { - GELOGE(GRAPH_FAILED, "left_node is nullptr"); - return false; - } - const auto &node_name = left_node->GetName(); - // After TopologicalSorting, node order can change, so find node by name - const auto &right_node = r_graph.FindNode(node_name); - GE_IF_BOOL_EXEC(right_node == nullptr, GELOGE(GRAPH_FAILED, "right_node is NULL!!!"); return false); - if (!(*right_node == *left_node)) { - GELOGE(GRAPH_FAILED, "Compare graph failed, node name: %s.", node_name.c_str()); - return false; - } - } - - // Thirdly: Recursively determine whether the sub graphs are equal - for (size_t i = 0; i < this->sub_graph_.size(); i++) { - if (!(*((this->sub_graph_)[i]) == *((r_graph.sub_graph_)[i]))) { - return false; - } - } - return true; -} - -NodePtr ComputeGraph::AddNodeFront(NodePtr node) { - if (node == nullptr || node->GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null."); - return nullptr; - } - node->SetHostNode(is_valid_flag_); - node->GetOpDesc()->SetId(nodes_.size()); - if (nodes_.size() > 0 && nodes_[0]->GetType() == DATA) { - (void)nodes_.insert(nodes_.begin() + 1, node); - } else { - (void)nodes_.insert(nodes_.begin(), node); - } - return node; -} - -NodePtr ComputeGraph::AddNodeFront(const OpDescPtr &op) { - if (op == nullptr) { - GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); - return nullptr; - } - op->SetId(nodes_.size()); - NodePtr node_ptr = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); - GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); - GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); - return AddNodeFront(node_ptr); -} - -NodePtr ComputeGraph::AddNodeAfter(NodePtr node, const NodePtr &pre_node) { - if (node == nullptr || node->GetOpDesc() == nullptr || pre_node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null."); - return nullptr; - } - node->SetHostNode(is_valid_flag_); - node->GetOpDesc()->SetId(nodes_.size()); - auto node_iter = std::find(nodes_.begin(), nodes_.end(), pre_node); - if (node_iter != nodes_.end()) { - nodes_.insert(node_iter + 1, node); - } else { - GELOGE(GRAPH_FAILED, "Cannot find pre_node in nodes_."); - return nullptr; - } - - return node; -} - -NodePtr ComputeGraph::AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node) { - if (op == nullptr) { - GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); - return nullptr; - } - op->SetId(nodes_.size()); - NodePtr node_ptr = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); - GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); - GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init failed."); return nullptr); - return AddNodeAfter(node_ptr, pre_node); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(NodePtr node) { - if (node == nullptr || node->GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should not be null."); - return nullptr; - } - node->SetHostNode(is_valid_flag_); - node->GetOpDesc()->SetId((int64_t)GetDirectNodesSize()); - nodes_.push_back(node); - return node; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpDescPtr op) { - if (op == nullptr) { - GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); - return nullptr; - } - op->SetId(GetDirectNodesSize()); - NodePtr node_ptr = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); - GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); - GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); - return AddNode(node_ptr); -} - -NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize. - if (op == nullptr) { - GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); - return nullptr; - } - op->SetId(id); - NodePtr node = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); - GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); - GE_IF_BOOL_EXEC(node->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); - node->SetHostNode(is_valid_flag_); - nodes_.push_back(node); - return node; -} - -NodePtr ComputeGraph::AddInputNode(NodePtr node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should not be null."); - return nullptr; - } - input_nodes_.push_back(node); - if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) { - GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "add node failed"); - } - return node; -} - -NodePtr ComputeGraph::AddOutputNode(NodePtr node) { return AddOutputNodeByIndex(node, 0); } - -NodePtr ComputeGraph::AddOutputNodeByIndex(NodePtr node, int32_t index) { - if (node == nullptr || node->GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr or opdesc should not be null."); - return nullptr; - } - - bool already_have = false; - NodePtr result = node; - // [output_nodes_info_ : should not be null] - for (const auto &item : output_nodes_info_) { - if (item.first->GetName() == node->GetName() && item.second == index) { - already_have = true; - result = item.first; - break; - } - } - - if (!already_have) { - output_nodes_info_.emplace_back(std::make_pair(node, index)); - GELOGI("Push back node name:%s, index:%ld, into output_nodes_info_.", node->GetName().c_str(), index); - } - - if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) { - GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "add node failed"); - } - return result; -} - -graphStatus ComputeGraph::RemoveConstInput(const NodePtr &node) { - GE_CHECK_NOTNULL(node); - - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) { - continue; - } - if (out_anchor->GetOwnerNode()->GetType() == CONSTANT || out_anchor->GetOwnerNode()->GetType() == CONSTANTOP) { - GE_CHK_BOOL_RET_STATUS(GraphUtils::RemoveEdge(out_anchor, in_anchor) == GRAPH_SUCCESS, GRAPH_FAILED, - "Remove edge from const op failed."); - if (out_anchor->GetOwnerNode()->GetOutNodes().size() == 0) { - GELOGI("Remove const op %s.", out_anchor->GetOwnerNode()->GetName().c_str()); - auto iter = find(nodes_.begin(), nodes_.end(), out_anchor->GetOwnerNode()); - if (iter != nodes_.end()) { - (void)nodes_.erase(iter); - } - } - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveNode(const NodePtr &node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should not be null."); - return GRAPH_FAILED; - } - - // delete const op for this node - (void)RemoveConstInput(node); - - // if the node save as input node, delete it - (void)RemoveInputNode(node); - - // if the node save as input node, delete it - (void)RemoveOutputNode(node); - - if (GRAPH_SUCCESS != IsolateNode(node)) { - GELOGE(GRAPH_FAILED, "Isolate node failed, node name: %s.", node->GetName().c_str()); - return GRAPH_FAILED; - } - - auto iter = find(nodes_.begin(), nodes_.end(), node); - if (iter != nodes_.end()) { - (void)nodes_.erase(iter); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -// Used in sub_graph scenes -graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should not be null."); - return GRAPH_FAILED; - } - - auto iter = find(input_nodes_.begin(), input_nodes_.end(), node); - if (iter != input_nodes_.end()) { - (void)input_nodes_.erase(iter); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -// Used in sub_graph scenes -graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should not be null."); - return GRAPH_FAILED; - } - - auto iter = output_nodes_info_.begin(); - bool find_node = false; - // [output_nodes_info_ : should not be null] - while (iter != output_nodes_info_.end()) { - if (node->GetName() == iter->first->GetName()) { - iter = output_nodes_info_.erase(iter); - find_node = true; - } else { - ++iter; - } - } - GE_IF_BOOL_EXEC(find_node == false, return GRAPH_FAILED); - return GRAPH_SUCCESS; -} - -std::shared_ptr ComputeGraph::AddSubGraph(std::shared_ptr sub_graph) { - if (sub_graph == nullptr) { - GELOGE(GRAPH_FAILED, "The graph ptr should not be null."); - return nullptr; - } - sub_graph_.push_back(sub_graph); - names_to_subgraph_[sub_graph->GetName()] = sub_graph; - return sub_graph; -} - -graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr &sub_graph) { - if (sub_graph == nullptr) { - GELOGE(GRAPH_FAILED, "The graph ptr should not be null."); - return GRAPH_FAILED; - } - - names_to_subgraph_.erase(sub_graph->GetName()); - auto iter = find(sub_graph_.begin(), sub_graph_.end(), sub_graph); - if (iter != sub_graph_.end()) { - (void)sub_graph_.erase(iter); - return GRAPH_SUCCESS; - } else { - GELOGW("find sub_graph failed"); - return GRAPH_SUCCESS; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr &subgraph) { - if (subgraph == nullptr) { - GE_LOGE("Try to add a null subgraph, name %s", name.c_str()); - return GRAPH_PARAM_INVALID; - } - auto parent_graph = subgraph->GetParentGraph(); - if (parent_graph == nullptr) { - GE_LOGE("Try to add subgraph without parent graph, name %s", name.c_str()); - return GRAPH_PARAM_INVALID; - } - auto parent_node = subgraph->GetParentNode(); - if (parent_node == nullptr) { - GE_LOGE("Try to add a subgraph without parent node, name %s", name.c_str()); - return GRAPH_PARAM_INVALID; - } - if (parent_node->GetOwnerComputeGraph() != parent_graph) { - GE_LOGE( - "Try to add a subgraph which parent node's parent graph is not equal to " - "the subgraph's parent graph, subgraph name %s, parent node name %s", - subgraph->GetName().c_str(), parent_graph->GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - if (!this->parent_graph_.expired()) { - GELOGW("The subgraphs should only be added to the root graph"); - } - if (name != subgraph->GetName()) { - GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str()); - } - if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) { - GE_LOGE("The subgraph %s existed", name.c_str()); - return GRAPH_PARAM_INVALID; - } - sub_graph_.push_back(subgraph); - names_to_subgraph_[name] = subgraph; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ComputeGraph::AddSubgraph(const std::shared_ptr &subgraph) { - if (subgraph == nullptr) { - return GRAPH_PARAM_INVALID; - } - return AddSubgraph(subgraph->GetName(), subgraph); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(const std::string &name) { - auto iter = names_to_subgraph_.find(name); - if (iter == names_to_subgraph_.end()) { - return; - } - for (auto vec_iter = sub_graph_.begin(); vec_iter != sub_graph_.end(); ++vec_iter) { - if (*vec_iter == iter->second) { - sub_graph_.erase(vec_iter); - break; - } - } - names_to_subgraph_.erase(iter); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph( - const std::shared_ptr &subgraph) { - if (subgraph != nullptr) { - RemoveSubgraph(subgraph->GetName()); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr ComputeGraph::GetSubgraph( - const std::string &name) const { - std::shared_ptr parent = parent_graph_.lock(); - if (parent == nullptr) { - auto iter = names_to_subgraph_.find(name); - return iter == names_to_subgraph_.end() ? nullptr : iter->second; - } else { - return parent->GetSubgraph(name); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector> -ComputeGraph::GetAllSubgraphs() const { - return sub_graph_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr ComputeGraph::GetParentGraph() { - return parent_graph_.lock(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentGraph( - const shared_ptr &parent) { - parent_graph_ = parent; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr ComputeGraph::GetParentNode() { - return parent_node_.lock(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode(const shared_ptr &parent) { - parent_node_ = parent; -} - -/// -/// @brief Update input-mapping -/// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input -/// @return graphStatus -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ComputeGraph::UpdateInputMapping(const std::map &input_mapping) { - for (auto &input : nodes_) { - if (input->GetType() == DATA) { - uint32_t cur_index = 0; - if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { - continue; - } - auto iter = input_mapping.find(cur_index); - if (iter == input_mapping.end()) { - continue; - } - if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { - GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); - return GRAPH_FAILED; - } - } - } - - return GRAPH_SUCCESS; -} - -/// -/// @brief Update output-mapping -/// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output -/// @return graphStatus -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ComputeGraph::UpdateOutputMapping(const std::map &output_mapping) { - NodePtr net_output = FindFirstNodeMatchType(NETOUTPUT); - if (net_output == nullptr) { - GE_LOGE("UpdateOutputMapping failed: node type %s not exist in graph.", NETOUTPUT); - return GRAPH_FAILED; - } - OpDescPtr op_desc = net_output->GetOpDesc(); - if (op_desc == nullptr) { - GE_LOGE("UpdateOutputMapping failed: op_desc is NULL."); - return GRAPH_FAILED; - } - - size_t num = op_desc->GetAllInputsSize(); - for (size_t i = 0; i < num; i++) { - GeTensorDesc tensor = op_desc->GetInputDesc(i); - uint32_t cur_index = 0; - if (!ge::AttrUtils::GetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { - continue; - } - auto iter = output_mapping.find(cur_index); - if (iter == output_mapping.end()) { - continue; - } - if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { - GE_LOGE("UpdateOutputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); - return GRAPH_FAILED; - } - if (op_desc->UpdateInputDesc(i, tensor) != GRAPH_SUCCESS) { - GE_LOGE("UpdateOutputMapping failed: update %u input_tensor failed.", i); - return GRAPH_FAILED; - } - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { - std::vector node_vec = nodes_; - for (const auto &node : GetDirectNode()) { - if (node == nullptr || node->GetOpDesc() == nullptr) { - GELOGW("node or OpDescPtr is nullptr."); - continue; - } - GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should not be null."); return GRAPH_FAILED); - if (node->GetOpDesc()->GetType() == RECV) { - auto iter = find(node_vec.begin(), node_vec.end(), node); - if (iter == node_vec.end()) { - GELOGW("no node found."); - } else { - (void)node_vec.erase(iter); - } - - auto dst_iter = find(node_vec.begin(), node_vec.end(), node->GetOutControlNodes().at(0)); - (void)node_vec.insert(dst_iter, node); - } - if (node->GetOpDesc()->GetType() == SEND) { - auto iter = find(node_vec.begin(), node_vec.end(), node); - if (iter == node_vec.end()) { - GELOGW("no node found."); - } else { - (void)node_vec.erase(iter); - } - - auto src_iter = find(node_vec.begin(), node_vec.end(), node->GetInControlNodes().at(0)); - (void)node_vec.insert(src_iter + 1, node); - } - } - nodes_.clear(); - for (size_t i = 0; i < node_vec.size(); ++i) { - NodePtr node = node_vec[i]; - if (node == nullptr || node->GetOpDesc() == nullptr) { - GELOGW("node or OpDescPtr is nullptr."); - } else { - node->GetOpDesc()->SetId((int64_t)i); - nodes_.push_back(node); - } - } - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraph::DFSTopologicalSorting(std::vector &node_vec, - std::map &map_in_edge_num, - std::vector &stack) { - GELOGI("Runing_Dfs_Sort: %s", name_.c_str()); - // Record the number of non data nodes but no input nodes - GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); - - // Only data nodes here - while (!stack.empty()) { - NodePtr node = stack.back(); - stack.pop_back(); - node_vec.push_back(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); - for (const auto &anchor : node->GetAllOutDataAnchors()) { - GE_CHECK_NOTNULL(anchor); - for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { - GE_CHECK_NOTNULL(peer_in_anchor); - auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end() && --iter->second == 0) { - stack.push_back(peer_in_anchor->GetOwnerNode()); - } - } - for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { - GE_CHECK_NOTNULL(peer_in_anchor); - auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end() && --iter->second == 0) { - stack.push_back(peer_in_anchor->GetOwnerNode()); - } - } - } - GE_IF_BOOL_EXEC( - node->GetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor - : node->GetOutControlAnchor()->GetPeerAnchors()) { - GE_CHECK_NOTNULL(peer_in_anchor); - auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end() && --iter->second == 0) { - stack.push_back(peer_in_anchor->GetOwnerNode()); - } - }) - } - - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraph::BFSTopologicalSorting(std::vector &node_vec, - std::map &map_in_edge_num, - std::deque &stack) { - GELOGI("Runing_Bfs_Sort: %s", name_.c_str()); - std::vector stack_input; - std::map breadth_node_map; - // Record the number of non data nodes but no input nodes - GE_CHK_BOOL_EXEC(SortNodes(stack_input, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); - - // Only data nodes here - while (!stack_input.empty() || !stack.empty()) { - NodePtr node = nullptr; - if (!stack.empty()) { - node = stack.back(); - stack.pop_back(); - } else { - node = stack_input.back(); - stack_input.pop_back(); - } - - node_vec.push_back(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); - CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map); - - for (const auto &name_node : breadth_node_map) { - (void)stack.push_front(name_node.second); - } - breadth_node_map.clear(); - } - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, - std::map &breadth_node_map) { - for (const auto &anchor : node->GetAllOutDataAnchors()) { - for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { - auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end() && 0 == --iter->second) { - (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); - } - } - - for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { - auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end() && 0 == --iter->second) { - (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); - } - } - } - if (node->GetOutControlAnchor() != nullptr) { - for (AnchorPtr peer_in_anchor : node->GetOutControlAnchor()->GetPeerAnchors()) { - auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end() && 0 == --iter->second) { - (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); - } - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { - auto ret = TopologicalSortingGraph(); - if (ret != SUCCESS) { - GraphUtils::DumpGEGraphToOnnx(*this, "black_box"); - GELOGE(ret, "Graph [%s] topological sort failed, saved to file black_box", name_.c_str()); - return ret; - } - - if (sub_graph_.empty()) { - return SUCCESS; - } - - // partition sub graph - for (const auto &sub_graph : sub_graph_) { - ret = sub_graph->TopologicalSortingGraph(); - if (ret != SUCCESS) { - GELOGE(ret, "Sub graph topological sort Failed"); - return ret; - } - } - - std::vector> subgraphs; - auto nodes = AllGraphNodes(subgraphs); - for (size_t i = 0; i < nodes.size(); i++) { - NodePtr node = nodes.at(i); // [node: should not be null] - node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null] - } - if (sub_graph_.size() != subgraphs.size()) { // Graph Partition use subgraph, Keep original - GELOGW("Keep original subgraph for graph size %zu not equal %zu.", sub_graph_.size(), subgraphs.size()); - return SUCCESS; - } - sub_graph_.swap(subgraphs); - return SUCCESS; -} - -graphStatus ComputeGraph::TopologicalSortingGraph() { - std::vector node_vec; - std::map map_in_edge_num; - bool use_BFS = IsUseBFS(); - if (use_BFS) { - std::deque stack; - if (BFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } else { - std::vector stack; - if (DFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - - // If they are not equal, there is a closed loop - if (node_vec.size() != nodes_.size()) { - std::set itered_nodes_set; - for (auto &node : node_vec) { - itered_nodes_set.insert(node.get()); - } - GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", nodes_.size(), - node_vec.size()); - for (auto &node : nodes_) { - if (itered_nodes_set.count(node.get()) == 0) { - GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); - } - } - return GRAPH_FAILED; - } - - nodes_.clear(); - for (size_t i = 0; i < node_vec.size(); i++) { - NodePtr node = node_vec[i]; // [node: should not be null] - node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null] - nodes_.push_back(node); - } - - is_valid_flag_ = true; - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraph::SortNodes(std::vector &stack, std::map &map_in_edge_num) { - // Record the number of non data nodes but no input nodes - uint32_t spec_node_size = 0; - bool verify_isolated = false; - string run_mode; - const int base = 10; - // Need verify isolated point in PREDICTION mode. - if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) { - if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) < TRAIN) { - verify_isolated = true; - } - } - for (const auto &node : GetDirectNode()) { - GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); - map_in_edge_num[node] = static_cast(GetInEdgeSize(node)); - if (map_in_edge_num[node] == 0) { - if ((node->GetOpDesc()->GetType() != DATA) && (node->GetOpDesc()->GetType() != AIPPDATA) && - (node->GetOpDesc()->GetType() != INPUT_TYPE) && (node->GetOpDesc()->GetType() != ANN_DATA)) { - // At present, can only judge the isolated point without input and output. - // It is impossible to judge the situation with multiple output nodes. - if (verify_isolated && GetOutEdgeSize(node) == 0) { - GELOGE(GRAPH_FAILED, "May has isolated nodes in graph, node name: %s.", node->GetName().c_str()); - return GRAPH_FAILED; - } - (void)stack.insert(stack.begin(), node); - spec_node_size++; - continue; - } - // Need to insert the data nodes in reverse order - (void)stack.insert(stack.begin() + spec_node_size, node); - } - } - - /// Make sure the inputs order matches with user-designated - /// 1. Get the index of two input nodes in the user-inputs-order(inputs_order_) - /// 2. Compare two indices, if not match, swap the positions of two inputs - /// *: Remind: stack is reverse-order - for (size_t i = 0; i < stack.size(); ++i) { - // If not found in 'inputs_order_', skip it - auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); - GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue); - auto inx_i = it_i - inputs_order_.begin(); - for (size_t j = i + 1; j < stack.size(); ++j) { - // If not found in 'inputs_order_', skip it - auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName()); - GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue); - - // Compare index, swap them if it should be - auto inx_j = it_j - inputs_order_.begin(); - GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j])); - } - } - - return GRAPH_SUCCESS; -} - -size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { - size_t in_edge_size = 0; - if (node == nullptr) { - return in_edge_size; - } - for (const auto &anchor : node->GetAllInDataAnchors()) { - in_edge_size = in_edge_size + anchor->GetPeerAnchorsSize(); - // Break flow control data loop. - OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor(); - if ((out_anchor != nullptr) && (out_anchor->GetOwnerNode() != nullptr)) { - NodePtr out_node = out_anchor->GetOwnerNode(); - if (out_node == nullptr) { - GELOGW("out node is nullptr"); - continue; - } - if ((out_node->GetType() == NEXTITERATION) || (out_node->GetType() == REFNEXTITERATION)) { - GE_IF_BOOL_EXEC(in_edge_size == 0, GELOGE(GRAPH_FAILED, "If [in_edge_size = 0], the result will be reversed"); - return in_edge_size); - in_edge_size -= 1; - } - } - } - if (node->GetInControlAnchor() != nullptr) { - in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchorsSize(); - } - return in_edge_size; -} - -size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { - size_t out_edge_size = 0; - if (node == nullptr) { - return out_edge_size; - } - - // Break flow control data loop. - if ((node->GetType() != NEXTITERATION) && (node->GetType() != REFNEXTITERATION)) { - for (const auto &anchor : node->GetAllOutDataAnchors()) { - if (anchor != nullptr) { - out_edge_size = out_edge_size + anchor->GetPeerAnchors().size(); - } - } - } - if (node->GetOutControlAnchor() != nullptr) { - if (out_edge_size > (UINT64_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) { - return 0; - } - out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchors().size(); - } - return out_edge_size; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::IsValid() const { return is_valid_flag_; } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { - GELOGI("graph name = %s.", GetName().c_str()); - for (const auto &node : GetAllNodes()) { - GELOGI("node name = %s.", node->GetName().c_str()); - for (const auto &anchor : node->GetAllOutDataAnchors()) { - for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { - GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, - GELOGI("node name = %s, out data node name = %s.", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str())); - } - for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { - GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, - GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str())); - } - } - auto out_control_anchor = node->GetOutControlAnchor(); - if (out_control_anchor != nullptr) { - for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) { - GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, - GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str())); - } - for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) { - GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, - GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str())); - } - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Swap(ComputeGraph &graph) { - this->AttrHolder::Swap(graph); - - origGraph_.swap(graph.origGraph_); - - name_.swap(graph.name_); - std::swap(graph_id_, graph.graph_id_); - attrs_.Swap(graph.attrs_); - nodes_.swap(graph.nodes_); - all_nodes_infos_.swap(graph.all_nodes_infos_); - target_nodes_info_.swap(graph.target_nodes_info_); - - input_nodes_.swap(graph.input_nodes_); - inputs_order_.swap(graph.inputs_order_); - std::swap(input_size_, graph.input_size_); - out_nodes_map_.swap(graph.out_nodes_map_); - std::swap(output_size_, graph.output_size_); - output_nodes_info_.swap(graph.output_nodes_info_); - - sub_graph_.swap(graph.sub_graph_); - names_to_subgraph_.swap(graph.names_to_subgraph_); - parent_graph_.swap(graph.parent_graph_); - parent_node_.swap(graph.parent_node_); - - // the members followed should not in the ComputeGraph class - std::swap(is_valid_flag_, graph.is_valid_flag_); - std::swap(is_summary_graph_, graph.is_summary_graph_); - std::swap(need_iteration_, graph.need_iteration_); - params_share_map_.swap(graph.params_share_map_); - op_name_map_.swap(graph.op_name_map_); - std::swap(session_id_, graph.session_id_); - std::swap(data_format_, graph.data_format_); - std::swap(is_unknown_shape_graph_, graph.is_unknown_shape_graph_); - - // Update Node owner. - SetNodesOwner(); - graph.SetNodesOwner(); -} - -void ComputeGraph::SetNodesOwner() { - for (const auto &node : nodes_) { - if (node == nullptr) { - continue; - } - node->SetOwnerComputeGraph(shared_from_this()); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) { - GE_CHECK_NOTNULL(node); - auto next_nodes = node->GetOutAllNodes(); - // If there is input data side - for (size_t i = 0; i < node->GetAllInDataAnchors().size(); i++) { - auto in_data_anchor = node->GetInDataAnchor(static_cast(i)); - auto pre_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - if (pre_out_data_anchor != nullptr) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_data_anchor, in_data_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - GE_IF_BOOL_EXEC(pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANT || - pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANTOP, - continue); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "add edge failed"); - } - for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "add edge failed"); - } - } - auto out_ctrl_anchor = node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(out_ctrl_anchor); - auto pre_out_ctrl_anchor = pre_out_data_anchor->GetOwnerNode()->GetOutControlAnchor(); - GE_CHECK_NOTNULL(pre_out_ctrl_anchor); - for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "add edge failed"); - } - } - } - - // If there is an input control side - auto in_ctrl_anchor = node->GetInControlAnchor(); - GE_CHECK_NOTNULL(in_ctrl_anchor); - for (const auto &pre_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_ctrl_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, return GRAPH_FAILED, - "remove edge failed"); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "add edge failed"); - } - } - auto out_ctrl_anchor = node->GetOutControlAnchor(); - if (out_ctrl_anchor != nullptr) { - for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "add edge failed"); - } - } - } - - for (const auto &out_peer_data_anchor : in_ctrl_anchor->GetPeerOutDataAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_peer_data_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, return GRAPH_FAILED, - "remove edge failed"); - for (const auto &next_node : next_nodes) { - auto next_in_control_anchor = next_node->GetInControlAnchor(); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(out_peer_data_anchor, next_in_control_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "add edge failed"); - } - } - - return RemoveExtraOutEdge(node); -} - -graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) { - GE_CHECK_NOTNULL(node); - // Remove redundant output edges - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - } - - for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - } - } - auto out_ctrl_anchor = node->GetOutControlAnchor(); - if (out_ctrl_anchor != nullptr) { - for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - } - } - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraph::Verify() { - bool is_unknown_graph = GetGraphUnknownFlag(); - for (const auto &node_ptr : GetAllNodes()) { - GE_CHECK_NOTNULL(node_ptr); - GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); - GE_IF_BOOL_EXEC(is_unknown_graph, continue); - GE_CHK_BOOL_EXEC(node_ptr->GetOpDesc()->CommonVerify() == GRAPH_SUCCESS, return GRAPH_FAILED, - "Verifying %s failed.", node_ptr->GetName().c_str()); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferOriginFormat() { - return ge::FormatRefiner::InferOrigineFormat(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferShapeInNeed() { - GE_CHK_BOOL_ONLY_LOG(TopologicalSorting() == GRAPH_SUCCESS, "Verifying failed."); - for (const auto &node_ptr : GetAllNodes()) { - GE_CHECK_NOTNULL(node_ptr); - auto op_desc = node_ptr->GetOpDesc(); - bool is_need_infer = false; - (void)ge::AttrUtils::GetBool(op_desc, NEED_INFER, is_need_infer); - if (is_need_infer) { - GE_CHK_BOOL_EXEC(node_ptr->Verify() == GRAPH_SUCCESS, return GRAPH_FAILED, "Verifying %s failed.", - node_ptr->GetName().c_str()); - - graphStatus status = node_ptr->InferShapeAndType(); - GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == DATA || GRAPH_PARAM_INVALID != status, break, - "Op %s does not have the IMPLEMT_INFERFUNC definition," - " and subsequent operators no longer perform shape inference.", - node_ptr->GetName().c_str()); - GE_CHK_BOOL_EXEC(status == GRAPH_SUCCESS, return GRAPH_FAILED, "Inferring %s failed.", - node_ptr->GetName().c_str()); - - for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { - GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc()); - auto output_tensor = out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()); - ge::TensorUtils::SetRealDimCnt(output_tensor, output_tensor.GetShape().GetDims().size()); - (void)out_anchor->GetOwnerNode()->GetOpDesc()->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor); - for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { - (void)peer_anchor->GetOwnerNode()->GetOpDesc()->UpdateInputDesc(peer_anchor->GetIdx(), output_tensor); - } - } - } - } - return GRAPH_SUCCESS; -} - -ProtoAttrMapHelper ComputeGraph::MutableAttrMap() { return attrs_; } - -ConstProtoAttrMapHelper ComputeGraph::GetAttrMap() const { - return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); -} - -const std::map &ComputeGraph::GetAllNodesInfo() const { return all_nodes_infos_; } - -void ComputeGraph::SetUserDefOutput(const std::string &output_name) { - if (output_name.empty()) { - return; - } - - vector nodes = StringUtils::Split(output_name, ';'); - for (string node : nodes) { - vector item = StringUtils::Split(node, ':'); - if (item.size() != OUTPUT_PARAM_SIZE) { - GELOGW("invalid output param!input:%s", output_name.c_str()); - continue; - } - - int32_t index; - try { - index = stoi(StringUtils::Trim(item[1])); - } catch (const std::out_of_range &) { - GELOGW("outputname cause out of range execption!output_name:%s", output_name.c_str()); - continue; - } catch (const std::invalid_argument &) { - GELOGW("outputname cause invalid argument!output_name:%s", output_name.c_str()); - continue; - } catch (...) { - GELOGW("stoi fail! output_name:%s", output_name.c_str()); - continue; - } - auto iter = out_nodes_map_.find(item[0]); - if (iter == out_nodes_map_.end()) { - out_nodes_map_[item[0]] = std::vector(1, index); - } else { - auto idx_iter = std::find(iter->second.begin(), iter->second.end(), index); - if (idx_iter == iter->second.end()) { - iter->second.push_back(index); - } - } - } -} - -const std::string ComputeGraph::GetOutput() { - static const int resultDefaultSize = 2048; - string result; - result.reserve(resultDefaultSize); - auto iter = out_nodes_map_.begin(); - while (iter != out_nodes_map_.end()) { - auto idxes = iter->second; - for (auto idx : idxes) { - (void)result.append(iter->first).append(":").append(std::to_string(idx)).append(";"); - } - ++iter; - } - - return result.substr(0, result.length() - 1); -} -} // namespace ge diff --git a/src/common/graph/debug/ge_log.h b/src/common/graph/debug/ge_log.h deleted file mode 100644 index 14a66709..00000000 --- a/src/common/graph/debug/ge_log.h +++ /dev/null @@ -1,147 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_DEBUG_GE_LOG_H_ -#define COMMON_GRAPH_DEBUG_GE_LOG_H_ - -#include "graph/ge_error_codes.h" -#include "framework/common/debug/ge_log.h" - -#define GE_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) - -#define GE_LOGI_IF(condition, ...) \ - if ((condition)) { \ - GELOGI(__VA_ARGS__); \ - } - -#define GE_LOGW_IF(condition, ...) \ - if ((condition)) { \ - GELOGW(__VA_ARGS__); \ - } - -#define GE_LOGE_IF(condition, ...) \ - if ((condition)) { \ - GELOGE(ge::FAILED, __VA_ARGS__); \ - } - -#define GE_CHK_STATUS_RET_NOLOG(expr) \ - do { \ - const ge::graphStatus _status = (expr); \ - if (ge::SUCCESS != _status) { \ - return _status; \ - } \ - } while (0) - -#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - GELOGE(ge::FAILED, __VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -#define GE_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ - { \ - bool b = (expr); \ - if (!b) { \ - exec_expr; \ - } \ - } - -#define GE_IF_BOOL_EXEC(expr, exec_expr) \ - { \ - if (expr) { \ - exec_expr; \ - } \ - } - -#define GE_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ - do { \ - const ge::graphStatus _status = (expr); \ - if (_status) { \ - GELOGE(ge::FAILED, __VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -// If expr is true, the log is printed and a custom statement is executed -#define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \ - { \ - bool b = (expr); \ - if (b) { \ - GELOGE(ge::FAILED, __VA_ARGS__); \ - exec_expr; \ - } \ - } - -// Only check error log -#define GE_CHK_BOOL_ONLY_LOG(expr, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - GELOGI(__VA_ARGS__); \ - } \ - } while (0) - -// If expr is not true, do not print the log and return the specified status -#define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - return _status; \ - } \ - } while (0) - -// If expr is not true, the log is printed and a custom statement is executed -#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ - { \ - bool b = (expr); \ - if (!b) { \ - GELOGE(ge::FAILED, __VA_ARGS__); \ - exec_expr; \ - } \ - } - -// If expr is not true, the log is printed and a custom statement is executed -#define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ - { \ - bool b = (expr); \ - if (!b) { \ - GELOGI(__VA_ARGS__); \ - exec_expr; \ - } \ - } - -// If expr is not GRAPH_SUCCESS, print the log and return the same value -#define GE_CHK_STATUS_RET(expr, ...) \ - do { \ - const ge::graphStatus _status = (expr); \ - if (ge::SUCCESS != _status) { \ - GELOGE(ge::FAILED, __VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ - try { \ - exec_expr0; \ - } catch (...) { \ - GELOGE(ge::FAILED, "Make shared failed"); \ - exec_expr1; \ - } - -#endif // COMMON_GRAPH_DEBUG_GE_LOG_H_ diff --git a/src/common/graph/debug/ge_op_types.h b/src/common/graph/debug/ge_op_types.h deleted file mode 100644 index dff87331..00000000 --- a/src/common/graph/debug/ge_op_types.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ -#define COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ - -namespace ge { -#define GE_REGISTER_OPTYPE(var_name, str_name) static const char *var_name __attribute__((unused)) = str_name - -GE_REGISTER_OPTYPE(DATA, "Data"); -GE_REGISTER_OPTYPE(AIPPDATA, "AippData"); -GE_REGISTER_OPTYPE(MATMUL, "MatMul"); -GE_REGISTER_OPTYPE(RESHAPE, "Reshape"); -GE_REGISTER_OPTYPE(PERMUTE, "Permute"); -GE_REGISTER_OPTYPE(NETOUTPUT, "NetOutput"); -GE_REGISTER_OPTYPE(_WHILE, "_While"); -GE_REGISTER_OPTYPE(WHILE, "While"); -GE_REGISTER_OPTYPE(STATELESSWHILE, "StatelessWhile"); -GE_REGISTER_OPTYPE(SQUEEZE, "Squeeze"); -GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); -GE_REGISTER_OPTYPE(SWITCH, "Switch"); -GE_REGISTER_OPTYPE(REFSWITCH, "RefSwitch"); -GE_REGISTER_OPTYPE(SWITCHN, "SwitchN"); -GE_REGISTER_OPTYPE(MERGE, "Merge"); -GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); -GE_REGISTER_OPTYPE(ENTER, "Enter"); -GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); -GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); -GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); -GE_REGISTER_OPTYPE(CONSTANT, "Const"); -GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); -GE_REGISTER_OPTYPE(END, "End"); -GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); -GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); -GE_REGISTER_OPTYPE(INITDATA, "InitData"); -GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity"); -GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); - -GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); -GE_REGISTER_OPTYPE(VARIABLE, "Variable"); -GE_REGISTER_OPTYPE(VARIABLEV2, "VariableV2"); - -GE_REGISTER_OPTYPE(INPUT_TYPE, "Input"); - -// Horovod operator -GE_REGISTER_OPTYPE(HVDCALLBACKALLREDUCE, "hvdCallbackAllreduce"); -GE_REGISTER_OPTYPE(HVDCALLBACKALLGATHER, "hvdCallbackAllgather"); -GE_REGISTER_OPTYPE(HVDCALLBACKBROADCAST, "hvdCallbackBroadcast"); -GE_REGISTER_OPTYPE(HVDWAIT, "hvdWait"); - -GE_REGISTER_OPTYPE(NODE_NAME_NET_OUTPUT, "Node_Output"); - -GE_REGISTER_OPTYPE(RECV, "Recv"); -GE_REGISTER_OPTYPE(SEND, "Send"); -}; // namespace ge -#endif // COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ diff --git a/src/common/graph/debug/ge_util.h b/src/common/graph/debug/ge_util.h deleted file mode 100644 index 4c6ae051..00000000 --- a/src/common/graph/debug/ge_util.h +++ /dev/null @@ -1,274 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_DEBUG_GE_UTIL_H_ -#define COMMON_GRAPH_DEBUG_GE_UTIL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "framework/common/debug/ge_log.h" -#include "graph/debug/ge_log.h" -#include "graph/ge_error_codes.h" - -#if !defined(__ANDROID__) && !defined(ANDROID) -#define GE_DYNAMIC_CAST dynamic_cast -#define GE_DYNAMIC_POINTER_CAST std::dynamic_pointer_cast -#else -#define GE_DYNAMIC_CAST static_cast -#define GE_DYNAMIC_POINTER_CAST std::static_pointer_cast -#endif - -#define GE_RETURN_IF_ERROR(expr) \ - do { \ - const ::ge::optStatus _status = (expr); \ - if (_status) return _status; \ - } while (0) - -#define GE_RETURN_WITH_LOG_IF_INFO(expr, ...) \ - do { \ - const ::ge::optStatus _status = (expr); \ - if (_status) { \ - GELOGI(__VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -// Verify whether the parameter is true. If yes, return graph failed and record the error log -#define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ - do { \ - if (condition) { \ - GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ - return ge::GRAPH_FAILED; \ - } \ - } while (0) - -// Verify whether the parameter is false. If yes, return graph failed and record the error log -#define GE_RETURN_WITH_LOG_IF_FALSE(condition, ...) \ - do { \ - bool _condition = (condition); \ - if (!_condition) { \ - GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ - return ge::GRAPH_FAILED; \ - } \ - } while (0) - -// Verify whether the parameter is true. If yes, return GRAPH_PARAM_INVALID and record the error log -#define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ - do { \ - if (condition) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// Verify whether the parameter is false. If yes, return GRAPH_PARAM_INVALID and record the error log -#define GE_RT_PARAM_INVALID_WITH_LOG_IF_FALSE(condition, ...) \ - do { \ - bool _condition = (condition); \ - if (!_condition) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log -#define GE_CHECK_NOTNULL(val) \ - do { \ - if (val == nullptr) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log -#define GE_CHECK_NOTNULL_EXEC(val, expr) \ - do { \ - if (val == nullptr) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \ - expr; \ - } \ - } while (0) - -// Verify whether the parameter is null. If yes, return false and record the error log -#define GE_RT_FALSE_CHECK_NOTNULL(val) \ - do { \ - if (val == nullptr) { \ - GELOGE(ge::GRAPH_FAILED, "param[%s] must not be null.", #val); \ - return false; \ - } \ - } while (0) - -// Check whether the parameter is out of range -#define GE_CHECK_SIZE(size) \ - do { \ - if (size == 0) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -/// -/// @ingroup GE_common -/// eg:GE_DEFINE_BYTE_SIZE(filter_byte, filter.data().size(), sizeof(float)); -/// -#define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \ - uint32_t _var_name; \ - do { \ - uint32_t _expr_size = (_expr); \ - uint32_t _sizeof_size = (_sizeof); \ - if (_expr_size > (0xffffffff) / _sizeof_size) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "byte size : %s is out of range", #_var_name); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - _var_name = _sizeof_size * _expr_size; \ - } while (0); - -// Check whether the container is empty -#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ - do { \ - if (vector.empty()) { \ - GELOGE(ge::GRAPH_FAILED, "param[#vector] is empty", #vector); \ - return ge::GRAPH_FAILED; \ - } \ - } while (0) - -// Check whether the container is empty and return the specified status code -#define GE_CHECK_VECTOR_NOT_EMPTY_RET_STATUS(vector, _status) \ - do { \ - if (vector.empty()) { \ - GELOGE(_status, "param[%s] is empty", #vector); \ - return _status; \ - } \ - } while (0) - -/// -/// @ingroup GE_common -/// @brief This macro provides the ability to disable copying constructors and assignment operators. -/// It is usually placed under private -/// -#define GE_DISALLOW_COPY_AND_ASSIGN(TypeName) \ - TypeName(const TypeName &) = delete; \ - void operator=(const TypeName &) = delete - -/// Check whether the size is 0 or out of range -/// @param:size:Size to be verified -#define GE_CHECK_SIZE_RANGE(size) \ - do { \ - if (size == 0 || size >= UINT_MAX / 4) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -#define GE_CHECK_SHORT_SIZE_RANGE(size) \ - do { \ - if (size == 0 || size >= UINT_MAX / 2) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ - do { \ - if (size <= 0) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not a positive number", #size); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -#define GE_CHECK_POSITIVE_SHORT_SIZE_RANGE(size) \ - do { \ - if (size <= 0 || size == 0 || size >= UINT_MAX / 4) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// Verify that the value on the left is greater than or equal to the value on the right -#define GE_CHECK_GE(lhs, rhs) \ - do { \ - if (lhs < rhs) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is less than[%s]", #lhs, #rhs); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// Check whether the parameters are equal -#define GE_CHECK_EQ(val1, val2) \ - do { \ - if (val1 != val2) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not equals to[%s]", #val1, #val2); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// Verify that the value on the left is less than or equal to the value on the right -#define GE_CHECK_LE(lhs, rhs) \ - do { \ - if (lhs > rhs) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is greater than[%s]", #lhs, #rhs); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// Check whether the parameters are equal -#define GE_CHECK_EQ_WITH_LOG(val1, val2, ...) \ - do { \ - if (val1 != val2) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// If expr is false, the custom statement is executed -#define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - exec_expr; \ - } \ - } while (0) - -#define GE_DELETE_NEW_SINGLE(var) \ - do { \ - if (var != nullptr) { \ - delete var; \ - var = nullptr; \ - } \ - } while (0) - -#define GE_DELETE_NEW_ARRAY(var) \ - do { \ - if (var != nullptr) { \ - delete[] var; \ - var = nullptr; \ - } \ - } while (0) - -template -static inline std::shared_ptr ComGraphMakeShared(Args &&... args) { - using T_nc = typename std::remove_const::type; - std::shared_ptr ret(new (std::nothrow) T_nc(std::forward(args)...)); - return ret; -} - -#endif // COMMON_GRAPH_DEBUG_GE_UTIL_H_ diff --git a/src/common/graph/debug/graph_debug.cc b/src/common/graph/debug/graph_debug.cc deleted file mode 100644 index 7ce9db37..00000000 --- a/src/common/graph/debug/graph_debug.cc +++ /dev/null @@ -1,246 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/debug/graph_debug.h" -#include -#include -#include -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" - -#define TAB " " -#define STR_FMT(str) (" \"" + std::string(str) + "\" ") -#define INPUT_ANCHOR_PORT(name) ("__input__" + (name)) -#define OUTPUT_ANCHOR_PORT(name) ("__output__" + (name)) - -namespace ge { -std::unordered_set control_anchor; -std::vector types = { - "DT_FLOAT", "DT_FLOAT16", "DT_INT8", "DT_INT32", "DT_UINT8", "", - "DT_INT16", "DT_UINT16", "DT_UINT32", "DT_INT64", "DT_UINT64", "DT_DOUBLE", - "DT_BOOL", "DT_DUAL", "DT_DUAL_SUB_INT8", "DT_DUAL_SUB_UINT8", "DT_UNDEFINED"}; - -std::vector formats = {"FORMAT_NCHW", - "FORMAT_NHWC", - "FORMAT_ND", - "FORMAT_NC1HWC0", - "FORMAT_FRACTAL_Z", - "FORMAT_NC1C0HWPAD", - "FORMAT_NHWC1C0", - "FORMAT_FSR_NCHW", - "FORMAT_FRACTAL_DECONV", - "FORMAT_C1HWNC0", - "FORMAT_FRACTAL_DECONV_TRANSPOSE", - "FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS", - "FORMAT_NC1HWC0_C04", - "FORMAT_FRACTAL_Z_C04", - "FORMAT_CHWN", - "FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS", - "FORMAT_HWCN", - "FORMAT_NC1KHKWHWC0", - "FORMAT_BN_WEIGHT", - "FORMAT_FILTER_HWCK", - "FORMAT_HASHTABLE_LOOKUP_LOOKUPS", - "FORMAT_HASHTABLE_LOOKUP_KEYS", - "FORMAT_HASHTABLE_LOOKUP_VALUE", - "FORMAT_HASHTABLE_LOOKUP_OUTPUT", - "FORMAT_HASHTABLE_LOOKUP_HITS", - "FORMAT_RESERVED"}; - -std::vector data_nodes = {"Const", "Data"}; - -void GraphDebugPrinter::DumpNodeToDot(const NodePtr node, std::ostringstream &out_) { - if (node == nullptr) { - GELOGI("Some nodes are null."); - return; - } - - bool in_control = false; - auto name = node->GetName(); - out_ << TAB << STR_FMT(name); - auto input_cnt = std::max(static_cast(1), node->GetAllInDataAnchors().size()); - auto output_cnt = std::max(static_cast(1), node->GetAllOutDataAnchors().size()); - if (control_anchor.find(node->GetName()) != control_anchor.end()) { - input_cnt++; - in_control = true; - } - auto max_col = input_cnt * output_cnt; - out_ << "[\n"; - if (find(data_nodes.begin(), data_nodes.end(), node->GetType()) != data_nodes.end()) { - out_ << TAB << TAB << "shape=plaintext, color=goldenrod\n"; - } else { - out_ << TAB << TAB << "shape=plaintext, color=deepskyblue\n"; - } - out_ << TAB << TAB << "label=<\n"; - out_ << TAB << TAB << R"(" << std::endl; - - auto input_anchors = node->GetAllInDataAnchors(); - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(op_desc, return ); - if (!input_anchors.empty()) { - out_ << TAB << TAB << ""; - } - for (const auto &anchor : input_anchors) { - string anchor_text = op_desc->GetInputNameByIndex(anchor->GetIdx()); - - out_ << ""; - } - if (in_control) { - string anchor_text = "ctrl"; - out_ << ""; - } - if (!input_anchors.empty()) { - out_ << "\n"; - } - // Node type - out_ << TAB << TAB << "\n"; - // Output - auto output_anchors = node->GetAllOutDataAnchors(); - if (!output_anchors.empty()) { - out_ << TAB << TAB << ""; - } - for (const auto &anchor : output_anchors) { - string anchor_text = op_desc->GetOutputNameByIndex(anchor->GetIdx()); - - out_ << ""; - } - - if (!output_anchors.empty()) { - out_ << "\n"; - } - out_ << TAB << TAB << "
" - << anchor_text << "" - << anchor_text << "
" - << "" << node->GetType() << "
" - << anchor_text << "
\n" << TAB << ">];\n"; -} - -void GraphDebugPrinter::DumpEdgeToDot(const NodePtr node, std::ostringstream &out_, uint32_t flag) { - if (node == nullptr) { - GELOGI("Some nodes are null."); - return; - } - auto all_out_anchor = node->GetAllOutDataAnchors(); - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(op_desc, return ); - for (const auto &anchor : all_out_anchor) { - auto src_anchor = anchor; - auto src_node_name = node->GetName(); - auto src_anchor_index = op_desc->GetOutputNameByIndex(static_cast(src_anchor->GetIdx())); - auto des_anchors = anchor->GetPeerAnchors(); - for (const auto &peer_in_anchor : des_anchors) { - auto in_data_anchor = Anchor::DynamicAnchorCast(peer_in_anchor); - std::string dst_node_name; - out_ << TAB << STR_FMT(src_node_name); - out_ << ":" << OUTPUT_ANCHOR_PORT(src_anchor_index); - auto op = peer_in_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(op, continue); - if (in_data_anchor != nullptr) { - dst_node_name = in_data_anchor->GetOwnerNode()->GetName(); - string des_anchor_index = op->GetInputNameByIndex(static_cast(in_data_anchor->GetIdx())); - out_ << " -> " << STR_FMT(dst_node_name); - out_ << ":" << INPUT_ANCHOR_PORT(des_anchor_index); - out_ << "["; - } - auto in_control_anchor = Anchor::DynamicAnchorCast(peer_in_anchor); - if (in_control_anchor != nullptr) { - dst_node_name = in_control_anchor->GetOwnerNode()->GetName(); - string des_anchor_index = "ctrl"; - out_ << " -> " << STR_FMT(dst_node_name); - out_ << ":" << INPUT_ANCHOR_PORT(des_anchor_index); - out_ << "["; - out_ << " style=dashed "; - } - if (flag != DOT_NOT_SHOW_EDGE_LABEL && in_data_anchor) { - string label; - auto src_ops = src_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(src_ops, return ); - auto src_shape = src_ops->GetOutputDesc(src_anchor->GetIdx()).GetShape(); - auto dim = src_shape.GetDims(); - std::ostringstream tensor_info; - if (dim.size() > 0) { - for (size_t i = 0; i < dim.size(); i++) { - if (i != dim.size() - 1) { - tensor_info << dim[i] << "x"; - } else { - tensor_info << dim[i]; - } - } - } else { - tensor_info << "?"; - } - auto src_tensor_desc = src_ops->GetOutputDescPtr(src_anchor->GetIdx()); - GE_CHECK_NOTNULL_EXEC(src_tensor_desc, return ); - auto format = src_tensor_desc->GetFormat(); - auto datatype = src_tensor_desc->GetDataType(); - tensor_info << " : " << formats[format] << " : " << types[datatype]; - label = tensor_info.str(); - out_ << "label=" << STR_FMT(label); - } - out_ << "]" << std::endl; - } - } -} - -graphStatus GraphDebugPrinter::DumpGraphDotFile(const Graph &graph, const std::string &output_dot_file_name, - uint32_t flag) { - auto compute_graph = GraphUtils::GetComputeGraph(graph); - if (compute_graph == nullptr) { - GELOGI("Compute graph is NULL ."); - return GRAPH_SUCCESS; - } - return DumpGraphDotFile(compute_graph, output_dot_file_name, flag); -} - -graphStatus GraphDebugPrinter::DumpGraphDotFile(const ComputeGraphPtr graph, const std::string &output_dot_file_name, - uint32_t flag) { - if (graph == nullptr) { - GELOGI("graph is null."); - return GRAPH_SUCCESS; - } - std::ostringstream out_; - out_ << "digraph G{\n"; - out_ << TAB << R"(ratio=compress;size="8, 100")" << std::endl; - out_ << TAB << R"(node[fontname="Consolas"])" << std::endl; - out_ << TAB << R"(edge[fontsize = "8" fontname = "Consolas" color="dimgray" ])" << std::endl; - auto all_nodes = graph->GetAllNodes(); - for (const auto &node : all_nodes) { - for (const auto &temp : node->GetAllOutDataAnchors()) { - for (const auto &peer : temp->GetPeerAnchors()) { - auto temp_control_anchor = Anchor::DynamicAnchorCast(peer); - if (temp_control_anchor) { - (void)control_anchor.insert(peer->GetOwnerNode()->GetName()); - } - } - } - } - for (const auto &node : all_nodes) { - DumpNodeToDot(node, out_); - } - for (const auto &node : all_nodes) { - DumpEdgeToDot(node, out_, flag); - } - out_ << "}"; - std::ofstream output_file(output_dot_file_name); - if (output_file.is_open()) { - output_file << out_.str(); - } else { - GELOGW("%s open error.", output_dot_file_name.c_str()); - } - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/src/common/graph/debug/graph_debug.h b/src/common/graph/debug/graph_debug.h deleted file mode 100644 index 29de632a..00000000 --- a/src/common/graph/debug/graph_debug.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ -#define COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ -#include -#include -#include -#include -#include -#include "external/graph/graph.h" -#include "./ge_error_codes.h" -#include "graph/compute_graph.h" -#include "graph/debug/ge_log.h" -#include "graph/node.h" -#include "utils/graph_utils.h" - -namespace ge { -enum DotFileFlag { - // Show nodes, edges, size, type and format - DOT_FLAG_DEFAULT = 0, - DOT_NOT_SHOW_EDGE_LABEL = 1, -}; -class GraphDebugPrinter { - public: - static graphStatus DumpGraphDotFile(const Graph &graph, const std::string &output_dot_file_name, - uint32_t flag = DOT_FLAG_DEFAULT); - static graphStatus DumpGraphDotFile(const ComputeGraphPtr graph, const std::string &output_dot_file_name, - uint32_t flag = DOT_FLAG_DEFAULT); - static void DumpNodeToDot(const NodePtr node, std::ostringstream &out_); - static void DumpEdgeToDot(const NodePtr node, std::ostringstream &out_, uint32_t flag = DOT_FLAG_DEFAULT); -}; -} // namespace ge - -#endif // COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ diff --git a/src/common/graph/detail/attributes_holder.cc b/src/common/graph/detail/attributes_holder.cc deleted file mode 100644 index 7e3b6de9..00000000 --- a/src/common/graph/detail/attributes_holder.cc +++ /dev/null @@ -1,241 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "detail/attributes_holder.h" -#include -#include "debug/ge_log.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/ge_attr_value.h" -#include "proto/ge_ir.pb.h" - -namespace ge { -using std::map; -using std::unordered_set; -void AttrHolder::CopyAttrsFrom(const AttrHolder &holder) { MutableAttrMap().CopyValueFrom(holder.GetAttrMap()); } -graphStatus AttrHolder::SetAttr(const std::string &name, const GeAttrValue &value) { - if (value.IsEmpty()) { - GELOGE(GRAPH_FAILED, "value is empty, key of the attr is %s", name.c_str()); - return GRAPH_FAILED; - } - auto proto_map = MutableAttrMap().GetProtoMsg(); - auto proto_val = value.value_.GetProtoMsg(); - if (proto_map == nullptr || proto_val == nullptr) { - return GRAPH_FAILED; - } - auto it = proto_map->find(name); - if (it != proto_map->end()) { - if (it->second.value_case() != proto::AttrDef::VALUE_NOT_SET && - it->second.value_case() != proto_val->value_case()) { - return GRAPH_FAILED; - } - } - (*proto_map)[name] = *proto_val; - return GRAPH_SUCCESS; -} - -graphStatus AttrHolder::AddRequiredAttr(const std::string &name) { - if (HasAttr(name)) { - return GRAPH_FAILED; - } - requiredAttrs_.push_back(name); - return GRAPH_SUCCESS; -} - -graphStatus AttrHolder::GetAttr(const std::string &name, GeAttrValue &value) const { - auto proto_map = GetAttrMap().GetProtoMsg(); - auto proto_val = value.value_.GetProtoMsg(); - if (proto_map == nullptr || proto_val == nullptr) { - return GRAPH_FAILED; - } - auto it = proto_map->find(name); - if (it != proto_map->end()) { - *proto_val = it->second; - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -bool AttrHolder::HasAttr(const std::string &name) const { - auto proto_map = GetAttrMap().GetProtoMsg(); - if (proto_map != nullptr) { - if (proto_map->find(name) != proto_map->end()) { - return true; - } - } - return std::find(requiredAttrs_.begin(), requiredAttrs_.end(), name) != requiredAttrs_.end(); -} - -graphStatus AttrHolder::DelAttr(const std::string &name) { - auto proto_map = MutableAttrMap().GetProtoMsg(); - if (proto_map == nullptr) { - return GRAPH_FAILED; - } - auto it = proto_map->find(name); - if (it != proto_map->end()) { - (void)proto_map->erase(it); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -const std::map AttrHolder::GetAllAttrs() const { - std::map attr_value_map; - auto proto_map = GetAttrMap().GetProtoMsg(); - if (proto_map != nullptr) { - auto proto_owner = GetAttrMap().GetProtoOwner(); - GE_CHK_BOOL_EXEC(proto_owner != nullptr, return attr_value_map, "proto_owner is nullptr"); - for (const auto &it : *proto_map) { - attr_value_map[it.first] = GeAttrValue(proto_owner, const_cast(&it.second)); - } - } - return attr_value_map; -} - -const std::unordered_set AttrHolder::GetAllAttrNames() const { - std::unordered_set names; - auto proto_map = GetAttrMap().GetProtoMsg(); - if (proto_map != nullptr) { - for (const auto &it : *proto_map) { - (void)names.insert(it.first); - } - } - for (const string &it : requiredAttrs_) { - (void)names.insert(it); - } - return names; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::AttrDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::TensorDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::ShapeDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::NamedAttrs make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); - return; - } - protoMsg_ = proto_owner->mutable_attr(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); - return; - } - protoMsg_ = &proto_owner->attr(); - protoOwner_ = proto_owner; -} -} // namespace ge diff --git a/src/common/graph/format_refiner.cc b/src/common/graph/format_refiner.cc deleted file mode 100644 index c716825a..00000000 --- a/src/common/graph/format_refiner.cc +++ /dev/null @@ -1,508 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "format_refiner.h" - -#include -#include -#include -#include -#include - -#include "graph/ref_relation.h" -#include "./compute_graph.h" -#include "./ge_error_codes.h" -#include "./graph/ge_tensor.h" -#include "./operator.h" -#include "./operator_factory.h" -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "utils/node_utils.h" -#include "utils/op_desc_utils.h" -#include "utils/tensor_utils.h" -#include "utils/type_utils.h" - -using namespace ge; -using namespace std; -namespace ge { -namespace { -const std::unordered_set kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; -const string kIsGraphInferred = "_is_graph_inferred"; -thread_local RefRelations reflection_builder; -} // namespace - -graphStatus ReflectionProcess(const std::unordered_set &reflection, - std::deque &nodes, ge::Format to_be_set_format) { - for (const auto &cell : reflection) { - auto node = cell.node; - auto in_out_idx = cell.in_out_idx; - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - if (cell.in_out == ge::NODE_IN) { - auto desc = node->GetOpDesc()->GetInputDesc(static_cast(in_out_idx)); - desc.SetOriginFormat(to_be_set_format); - desc.SetFormat(to_be_set_format); - (void)node->GetOpDesc()->UpdateInputDesc(static_cast(in_out_idx), desc); - } else { - auto desc = node->GetOpDesc()->GetOutputDesc(static_cast(in_out_idx)); - desc.SetOriginFormat(to_be_set_format); - desc.SetFormat(to_be_set_format); - (void)node->GetOpDesc()->UpdateOutputDesc(static_cast(in_out_idx), desc); - } - nodes.push_back(cell.node); - } - - return GRAPH_SUCCESS; -} - -graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) { - // 5 meas dim num - if (node_ptr->GetType() != "BiasAdd") { - return GRAPH_SUCCESS; - } - std::unordered_map kTfFormatFix = {{"NHWC", FORMAT_NDHWC}, {"NCHW", FORMAT_NCDHW}}; - for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) { - auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i); - GE_CHECK_NOTNULL(in_desc); - if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num - continue; - } - auto format = in_desc->GetOriginFormat(); - auto key = TypeUtils::FormatToSerialString(format); - auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; - in_desc->SetOriginFormat(fixed_format); - in_desc->SetFormat(fixed_format); - GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", i, - node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), - TypeUtils::FormatToSerialString(fixed_format).c_str()); - } - for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) { - auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i); - GE_CHECK_NOTNULL(out_desc); - if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num - continue; - } - auto format = out_desc->GetOriginFormat(); - auto key = TypeUtils::FormatToSerialString(format); - auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; - out_desc->SetOriginFormat(fixed_format); - out_desc->SetFormat(fixed_format); - GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", i, - node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), - TypeUtils::FormatToSerialString(fixed_format).c_str()); - } - return GRAPH_SUCCESS; -} - -graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { - GE_CHECK_NOTNULL(graph); - GE_CHECK_NOTNULL(op_desc); - if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) { - ConstGeTensorPtr tensor_value; - if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) { - GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str()); - return GRAPH_FAILED; - } - GE_CHECK_NOTNULL(tensor_value); - (void)op_desc->UpdateOutputDesc(0, tensor_value->GetTensorDesc()); - } - return GRAPH_SUCCESS; -} - -graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector &anchor_points, - std::vector &data_nodes, - std::unordered_map &node_status) { - if (graph == nullptr) { - GELOGE(GRAPH_FAILED, "input graph is null"); - return GRAPH_FAILED; - } - anchor_points.clear(); - // Get all anchor point nodes and switch nodes - for (auto &node_ptr : graph->GetAllNodes()) { - if (node_ptr == nullptr) { - return GRAPH_FAILED; - } - auto op_desc = node_ptr->GetOpDesc(); - if (op_desc == nullptr) { - return GRAPH_FAILED; - } - graphStatus status = RefreshConstantOutProcess(graph, op_desc); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "refresh constant out process failed!"); - return GRAPH_FAILED; - } - // consider special node save process - // get all input desc format - bool node_is_all_nd = false; - auto input_size = static_cast(op_desc->GetAllInputsSize()); - for (uint32_t i = 0; i < input_size; i++) { - // Operator pre-set format but not origin format - GE_IF_BOOL_EXEC(op_desc->MutableInputDesc(i) == nullptr, continue); - auto input_format = op_desc->MutableInputDesc(i)->GetFormat(); - // Pre-save data node (only main graph data) and default infer fail - if (node_ptr->GetType() == DATA) { - data_nodes.push_back(node_ptr); - } - if (input_format != FORMAT_ND && input_format != FORMAT_RESERVED) { - node_is_all_nd = true; - } - } - // Get all output desc format - auto output_size = static_cast(op_desc->GetOutputsSize()); - for (uint32_t i = 0; i < output_size; i++) { - GE_IF_BOOL_EXEC(op_desc->MutableOutputDesc(i) == nullptr, continue); - auto output_format = op_desc->MutableOutputDesc(i)->GetFormat(); - if (output_format != FORMAT_ND && output_format != FORMAT_RESERVED) { - node_is_all_nd = true; - } - } - // check anchor point valid - if (!node_is_all_nd) { - continue; - } - // special process for biasAdd op - // In tensorflow, biasAdd's format is alwayse NHWC even though set the arg - // "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism - // so here do special process - status = BiasAddFormatFixProcess(node_ptr); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "fix biasAdd process failed!"); - return GRAPH_FAILED; - } - - GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str()); - anchor_points.push_back(node_ptr); - } - GELOGI("anchor_points number is %zu", anchor_points.size()); - return GRAPH_SUCCESS; -} -graphStatus FormatRefiner::AnchorProcess(const ge::NodePtr &anchor_node, - std::unordered_map &node_status) { - if (anchor_node == nullptr) { - GELOGE(GRAPH_FAILED, "anchor node is null!"); - return GRAPH_FAILED; - } - std::deque nodes; - nodes.push_back(anchor_node); - while (!nodes.empty()) { - ge::NodePtr node = nodes.front(); - nodes.pop_front(); - graphStatus status = BackInferProcess(nodes, node, node_status); - if (status != GRAPH_SUCCESS && node != nullptr) { - GELOGE(status, "BackInferProcess failed!node name [%s]", node->GetName().c_str()); - return status; - } - status = ForwardInferProcess(nodes, node, node_status); - if (status != GRAPH_SUCCESS && node != nullptr) { - GELOGE(status, "ForwardInferProcess failed!node name [%s]", node->GetName().c_str()); - return status; - } - } - return GRAPH_SUCCESS; -} -graphStatus FormatRefiner::BackInferProcess(std::deque &nodes, ge::NodePtr &node, - std::unordered_map &node_status) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - - GELOGD("Enter back infer process!Node is [%s]", (node->GetName()).c_str()); - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - GELOGD("Node is [%s] [B]", (node->GetName()).c_str()); - auto in_data_anchor_idx = in_anchor->GetIdx(); - auto input_desc = node->GetOpDesc()->MutableInputDesc(static_cast(in_data_anchor_idx)); - GE_IF_BOOL_EXEC(input_desc == nullptr, continue); - auto to_be_set_format = input_desc->GetOriginFormat(); - if (to_be_set_format == FORMAT_ND) { - GELOGD("Node [%s] [B], format is ND", (node->GetName()).c_str()); - continue; - } - auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_data_anchor == nullptr) { - GELOGW("Node[%s] %dth in data anchor's peer_out_anchor is null", (node->GetName()).c_str(), in_data_anchor_idx); - continue; - } - auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); - if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { - GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (node->GetName()).c_str()); - continue; - } - // Check format whether have been set - int idx = peer_out_data_anchor->GetIdx(); - // do peer_out_node name and index as key to lookup reflections - ge::RefCell key(peer_out_data_node->GetName(), peer_out_data_node, ge::NODE_OUT, idx); - std::unordered_set reflection; - auto status = reflection_builder.LookUpRefRelations(key, reflection); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d out edge", - (peer_out_data_node->GetName()).c_str(), idx); - return GRAPH_FAILED; - } - - auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast(idx)); - if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { - auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); - if (dim_num == 0) { - GELOGD("node name:%s idx:%d out is scalar. stop back infer!", peer_out_data_node->GetName().c_str(), idx); - continue; - } - /// Check whether node to change dims () - /// Because some node will calculate with 5D, C dim maybe multi meaning - auto peer_out_data_node_type = peer_out_data_node->GetType(); - auto iter1 = kChangeDimNodes.find(peer_out_data_node_type); - // 4 means dims num - if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) { - GELOGD("Node[%s] is change dim node and shape is smaller than 4. do not modify format", - (peer_out_data_node->GetName()).c_str()); - continue; - } - - if (reflection.empty()) { - ge_tensor_desc.SetOriginFormat(to_be_set_format); - ge_tensor_desc.SetFormat(to_be_set_format); - (void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast(idx), ge_tensor_desc); - - // Call operator infer format api (forward) to get out format - GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); - status = peer_out_data_node->InferOriginFormat(); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str()); - return GRAPH_FAILED; - } - nodes.push_back(peer_out_data_node); - } else { - auto status = ReflectionProcess(reflection, nodes, to_be_set_format); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "reflection process failed!"); - return GRAPH_FAILED; - } - } - } - } - return GRAPH_SUCCESS; -} -graphStatus FormatRefiner::ForwardInferProcess(std::deque &nodes, ge::NodePtr &node, - std::unordered_map &node_status) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - GELOGD("Enter forward infer process!Node is [%s]", (node->GetName()).c_str()); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - GELOGD("Node is [%s] [F]", (node->GetName()).c_str()); - GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); - auto out_data_anchor_idx = out_data_anchor->GetIdx(); - auto to_be_set_format = - node->GetOpDesc()->MutableOutputDesc(static_cast(out_data_anchor_idx))->GetOriginFormat(); - if (to_be_set_format == FORMAT_ND) { - GELOGD("Node [%s] format is ND.[F]", (node->GetName()).c_str()); - continue; - } - for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_IF_BOOL_EXEC(peer_in_data_anchor == nullptr, continue); - - auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); - GE_IF_BOOL_EXEC(peer_in_data_node == nullptr, continue); - GE_IF_BOOL_EXEC(peer_in_data_node->GetOpDesc() == nullptr, continue); - - // Check format whether have been set - int idx = peer_in_data_anchor->GetIdx(); - // do peer_out_node name and index as key to lookup reflections - ge::RefCell key(peer_in_data_node->GetName(), peer_in_data_node, ge::NODE_IN, idx); - std::unordered_set reflection; - auto status = reflection_builder.LookUpRefRelations(key, reflection); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d input edge", - (peer_in_data_node->GetName()).c_str(), idx); - return GRAPH_FAILED; - } - auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast(idx)); - if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { - auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); - if (dim_num == 0) { - GELOGI("node name:%s idx:%d in is scalar. stop forward infer!", peer_in_data_node->GetName().c_str(), idx); - continue; - } - /// Check whether node to change dims () - /// Because some node will calculate with 5D, C dim maybe multi meaning - auto peer_in_data_node_type = peer_in_data_node->GetType(); - auto iter1 = kChangeDimNodes.find(peer_in_data_node_type); - // 4 means dims num - if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) { - GELOGD("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str()); - continue; - } - - if (reflection.empty()) { - ge_tensor_desc.SetOriginFormat(to_be_set_format); - ge_tensor_desc.SetFormat(to_be_set_format); - (void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(static_cast(idx), ge_tensor_desc); - - /// Because netoutput node added before infer format ,so netoutput is end condition - /// must set netoutput format , because saved result depend on format - if (peer_in_data_node_type == NETOUTPUT) { - continue; - } - - // Call operator infer format api (forward) to get out format - GELOGD("call infer format func[Back]!Node is [%s] ", (peer_in_data_node->GetName()).c_str()); - status = peer_in_data_node->InferOriginFormat(); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str()); - return GRAPH_FAILED; - } - nodes.push_back(peer_in_data_node); - } else { - auto status = ReflectionProcess(reflection, nodes, to_be_set_format); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "reflection process failed!"); - return GRAPH_FAILED; - } - } - } - } - } - return GRAPH_SUCCESS; -} - -void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector &anchor_points) { - for (const auto &node : anchor_points) { - if (node == nullptr || node->GetOpDesc() == nullptr) { - continue; - } - for (const auto &input_desc : node->GetOpDesc()->GetAllInputsDescPtr()) { - if (input_desc != nullptr) { - input_desc->SetOriginFormat(input_desc->GetFormat()); - } - } - for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDescPtr()) { - if (output_desc != nullptr) { - output_desc->SetOriginFormat(output_desc->GetFormat()); - } - } - } -} - -graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector &data_nodes, - ge::Format data_format, - std::unordered_map &node_status) { - if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) { - GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph), - TypeUtils::FormatToSerialString(data_format).c_str()); - return GRAPH_SUCCESS; - } - GELOGD("Enter DataNodeFormatProcess"); - std::vector uninfered_data_nodes; - // Check and renew data nodes format - for (const auto &data_node : data_nodes) { - GE_CHECK_NOTNULL(data_node); - auto op_desc = data_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(0)); - auto curr_format = op_desc->GetOutputDescPtr(0)->GetOriginFormat(); - if (curr_format != FORMAT_ND) { - // Data format has been infered , continue - continue; - } - // Set format for un-infered data node - auto input_descs = op_desc->GetAllInputsDescPtr(); - auto output_descs = op_desc->GetAllOutputsDescPtr(); - - for (const auto &input_desc : input_descs) { - if (input_desc != nullptr) { - input_desc->SetOriginFormat(data_format); - input_desc->SetFormat(data_format); - } - } - for (const auto &output_desc : output_descs) { - if (output_desc != nullptr) { - output_desc->SetOriginFormat(data_format); - output_desc->SetFormat(data_format); - } - } - uninfered_data_nodes.push_back(data_node); - } - // Reinfer format from uninfered data nodes - for (const auto &node : uninfered_data_nodes) { - if (node == nullptr) { - continue; - } - GELOGD("data node [%s] start infer format process", node->GetName().c_str()); - auto status = AnchorProcess(node, node_status); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "data node [%s] infer format process failed!", node->GetName().c_str()); - return GRAPH_FAILED; - } - } - GELOGD("DataNodeFormatProcess success"); - return GRAPH_SUCCESS; -} - -graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) { - GELOGI("Enter InferOrigineFormat process!"); - - // True: infered false:no-infered - std::unordered_map node_status; - std::vector anchor_points; - std::vector data_nodes; - // global net format - - if (graph == nullptr) { - GELOGE(GRAPH_FAILED, "input graph is null"); - return GRAPH_FAILED; - } - // build reflection relations of boundary - (void)reflection_builder.Clear(); - auto status = reflection_builder.BuildRefRelations(*graph); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "build reflection relations failed for main and subgraph!"); - return GRAPH_FAILED; - } - // User set global net format - status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "GetAnchorPoints Process Faild!"); - return GRAPH_FAILED; - } - // Refresh origin format of anchor point - RefreshOriginFormatOfAnchor(anchor_points); - // Infer format process - for (const auto &anchor_node : anchor_points) { - if (anchor_node == nullptr) { - continue; - } - status = AnchorProcess(anchor_node, node_status); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Anchor node [%s] process failed!", anchor_node->GetName().c_str()); - return GRAPH_FAILED; - } - } - /// According to discuss with sys-enginer, data node default format is ND.Its format - /// should be set by infered.But if some data-node can not be got by infer, set context's - /// format for these data nodes. - /// Notice: ignore 5D formats - auto data_format = graph->GetDataFormat(); - status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status); - - (void)AttrUtils::SetBool(graph, kIsGraphInferred, true); - - return status; -} - -bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) { - bool is_graph_inferred = false; - return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred); -} -} // namespace ge diff --git a/src/common/graph/format_refiner.h b/src/common/graph/format_refiner.h deleted file mode 100644 index eca93bae..00000000 --- a/src/common/graph/format_refiner.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_FORMAT_REFINER_H_ -#define COMMON_GRAPH_FORMAT_REFINER_H_ - -#include -#include -#include -#include -#include "./compute_graph.h" -#include "./external/graph/types.h" -#include "./ge_error_codes.h" - -namespace ge { -// ShapeRefiner performs shape inference for compute graphs -class FormatRefiner { - public: - static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph); - - private: - static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); - static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector &anchor_points, - std::vector &data_nodes, - std::unordered_map &node_status); - static graphStatus AnchorProcess(const ge::NodePtr &anchor_node, std::unordered_map &node_status); - static void RefreshOriginFormatOfAnchor(std::vector &anchor_points); - static graphStatus BackInferProcess(std::deque &nodes, ge::NodePtr &node, - std::unordered_map &node_status); - static graphStatus ForwardInferProcess(std::deque &nodes, ge::NodePtr &node, - std::unordered_map &node_status); - static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector &data_nodes, - ge::Format data_format, std::unordered_map &node_status); - static bool IsGraphInferred(const ComputeGraphPtr &graph); -}; -} // namespace ge -#endif // COMMON_GRAPH_FORMAT_REFINER_H_ diff --git a/src/common/graph/ge_attr_define.cc b/src/common/graph/ge_attr_define.cc deleted file mode 100644 index 9b723bb3..00000000 --- a/src/common/graph/ge_attr_define.cc +++ /dev/null @@ -1,1086 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace ge { -// Public attribute -const std::string ATTR_NAME_IS_UNKNOWN_SHAPE = "_is_unknown_shape"; - -const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED = "_dynamic_shape_partitioned"; - -const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE = "_unknown_shape_type"; - -const std::string ATTR_NAME_NAME = "name"; - -const std::string ATTR_NAME_TYPE = "type"; - -const std::string ATTR_NAME_WEIGHT_NAME = "weight_name"; - -const std::string ATTR_NAME_IS_QUANTIZE_FACTOR = "quantize_factor"; - -const std::string ATTR_NAME_ALPHA = "alpha"; - -const std::string ATTR_NAME_BETA = "beta"; - -const std::string ATTR_NAME_PADMODE = "pad_mode"; - -const std::string ATTR_NAME_PADMODES = "padding"; - -const std::string ATTR_NAME_MODE = "mode"; - -const std::string ATTR_NAME_FILTER = "filter"; - -const std::string ATTR_NAME_BIAS = "bias"; - -const std::string ATTR_NAME_BIAS_TERM = "bias_term"; - -const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; - -const std::string ATTR_NAME_PAD = "pad"; - -const std::string ATTR_NAME_PADS = "pad"; - -const std::string ATTR_NAME_PAD_SIZE = "pad size"; - -const std::string ATTR_NAME_PAD_MODE = "pad mode"; - -const std::string ATTR_NAME_SCALE = "scale"; - -const std::string ATTR_NAME_WINDOWS = "windows"; - -const std::string ATTR_NAME_GLOBAL_POOLING = "global_pooling"; - -const std::string ATTR_NAME_CEIL_MODE = "ceil_mode"; - -const std::string ATTR_NAME_RELUMODE = "relu_mode"; - -const std::string ATTR_NAME_STRIDE_SIZE = "stride size"; - -const std::string ATTR_NAME_RELU_FLAG = "relu_flag"; - -const std::string ATTR_NAME_ALGO = "algo"; - -const std::string ATTR_NAME_FORMAT = "format"; - -const std::string ATTR_NAME_STORAGE_FORMAT = "storage_format"; - -const std::string ATTR_NAME_STORAGE_SHAPE = "storage_shape"; - -const std::string ATTR_NAME_FILTER_FORMAT = "filter_format"; - -const std::string ATTR_NAME_LRN_K = "lrn_k"; - -const std::string ATTR_NAME_LRN_NORM_REGION = "lrn_normregion"; - -const std::string ATTR_NAME_LRN_LOCAL_SIZE = "lrn_localsize"; - -const std::string ATTR_NAME_LRN_ALPHA = "lrn_alpha"; - -const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; - -const std::string ATTR_NAME_AXIS = "axis"; -const std::string ATTR_NAME_BROADCAST = "broadcast"; - -const std::string ATTR_NAME_OUTPUT = "output"; -const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; -const std::string ATTR_NAME_TIDX = "t_idx"; - -const std::string ATTR_NAME_TPADDINGS = "t_paddings"; -const std::string ATTR_IMG_H = "img_h"; -const std::string ATTR_IMG_W = "img_w"; -const std::string ATTR_NET_H = "net_h"; -const std::string ATTR_NET_W = "net_w"; - -const std::string ATTR_NAME_TMULTIPLES = "t_multiples"; - -const std::string ATTR_NAME_MULTIPLES = "multiples"; - -const std::string ATTR_NAME_T = "T"; -const std::string ATTR_NAME_N = "N"; - -const std::string ATTR_NAME_TSHAPE = "Tshape"; -const std::string ATTR_NAME_NAN_OPT = "nan_opt"; - -const std::string ATTR_NAME_AIPP = "aipp"; -const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; - -const std::string ATTR_NAME_AIPP_INPUTS = "_aipp_inputs"; -const std::string ATTR_NAME_AIPP_OUTPUTS = "_aipp_outputs"; - -const std::string ATTR_NAME_INPUT_DIMS = "input_dims"; -const std::string ATTR_DYNAMIC_AIPP_INPUT_DIMS = "_dynamic_aipp_input_dims"; -const std::string ATTR_DATA_RELATED_AIPP_MODE = "_data_related_aipp_mode"; -const std::string ATTR_DATA_AIPP_DATA_NAME_MAP = "_data_aipp_data_name_map"; - -const std::string ATTR_NAME_GRAPH_HAS_BEEN_ADDED = "_graph_has_been_added"; - -const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; -const std::string ATTR_NAME_PARENT_GRAPH_NAME = "_parent_graph_name"; - -const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; -const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; -const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; - -const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; -const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; - -const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; -const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; -const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; -const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; -const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; - -const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; -const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; - -const std::string ATTR_NAME_INFERRED_FORMAT = "inferred_format"; -const std::string ATTR_NAME_PRED_PERMUTE_DELETED = "pred_permute_deleted"; -const std::string ATTR_NAME_IGNORE_PRED_FORMAT = "ignore_pred_format"; -const std::string ATTR_NAME_WEIGHTS = "value"; -const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; -const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; -const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; -const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; -const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL = "_continuous_stream_label"; -const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; -const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; -const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; -const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; -const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS = "_dynamic_output_dims"; -const std::string ATTR_NAME_INPUT_ORIGIN_SIZE = "input_origin_size"; - -// Identify node connecting to input and output -const std::string ATTR_NAME_NODE_CONNECT_INPUT = "_is_connected_to_data"; -const std::string ATTR_NAME_NODE_CONNECT_OUTPUT = "_is_connected_to_netoutput"; - -// To be deleted -const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; -const std::string PERMUTE_RESHAPE_FUSION = "permute_reshape_fusion"; -const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL = "fusion_conv_proposal"; -const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX = "fusion_conv_decodebbox"; -const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM = "box_type_num"; -const std::string SSD_MBOX_LOC_FUSION = "permute_flatten_fusion"; -const std::string SSD_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; -const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; -const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; -const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; - -// Refinedet -const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; - -const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; -const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; -const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; -const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; - -// _Arg -const std::string ATTR_NAME_INDEX = "index"; -// _RetVal -const std::string RETVAL_ATTR_NAME_INDEX = "retval_index"; -// Data -const std::string DATA_ATTR_NAME_DATA_TYPE = "data_type"; - -// Send -const std::string SEND_ATTR_EVENT_ID = "event_id"; - -// Recv -const std::string RECV_ATTR_EVENT_ID = "event_id"; - -// convolution -const std::string ATTR_NAME_COEF = "coef"; - -const std::string ATTR_NAME_STRIDE = "stride"; - -const std::string ATTR_NAME_STRIDES = "stride"; - -const std::string ATTR_NAME_DILATION = "dilation"; - -const std::string ATTR_NAME_DILATIONS = "dilation"; - -const std::string CONV_ATTR_NAME_MODE = "mode"; - -const std::string CONV_ATTR_NAME_ALGO = "algo"; - -const std::string CONV_ATTR_NAME_GROUP = "group"; - -const std::string CONV_ATTR_NAME_PAD_MODE = "pad_mode"; - -const std::string CONV_ATTR_NAME_PAD = "pad"; - -const std::string CONV_ATTR_NAME_STRIDE = "stride"; - -const std::string CONV_ATTR_NAME_DILATION = "dilation"; - -const std::string CONV_ATTR_NAME_NUM_OUTPUT = "num_output"; - -const std::string CONV_ATTR_NAME_KERNEL = "kernel"; - -const std::string CONV_ATTR_NAME_FILTER = "filter"; - -const std::string CONV_ATTR_NAME_BIAS = "bias"; - -const std::string CONV_ATTR_NAME_RELU_FLAG = "relu_flag"; - -const std::string CONV_ATTR_NAME_ADJ = "adj"; - -const std::string CONV_ATTR_NAME_TARGET_SHAPE = "target_shape"; - -const std::string CONV_ATTR_NAME_BEFORE_PAD = "before_pad"; - -const std::string CONV_ATTR_NAME_HAS_BIAS = "has_bias"; - -const std::string NEED_INFER = "isNeedInfer"; - -// Pooling -const std::string POOLING_ATTR_MODE = "mode"; -const std::string POOLING_ATTR_NAN_OPT = "nan_opt"; -const std::string POOLING_ATTR_PAD_MODE = "pad_mode"; -const std::string POOLING_ATTR_GLOBAL_POOLING = "global_pooling"; -const std::string POOLING_ATTR_WINDOW = "window"; -const std::string POOLING_ATTR_PAD = "pad"; -const std::string POOLING_ATTR_STRIDE = "stride"; -const std::string POOLING_ATTR_CEIL_MODE = "ceil_mode"; -const std::string POOLING_ATTR_DATA_MODE = "data_mode"; -const std::string POOLING_ATTR_BEFORE_PAD = "before_pad"; -const std::string POOLING_ATTR_NAME_ALGO = "algo"; - -// Eltwise -const std::string ELTWISE_ATTR_MODE = "mode"; -const std::string ELTWISE_ATTR_COEFF = "coeff"; -const std::string ELTWISE_ATTR_WEIGHT = "weight"; -const std::string ELTWISE_ATTR_RELU_FLAG = "relu_flag"; -const std::string ELTWISE_ATTR_ALPHA = "alpha"; -const std::string ELTWISE_ATTR_BETA = "beta"; - -// BatchNorm -const std::string BATCHNORM_ATTR_MODE = "mode"; -const std::string BATCHNORM_ATTR_EPSILON = "epsilon"; -const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS = "use_global_stats"; -const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION = "moving_average_fraction"; -const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; -const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; -const std::string BATCHNORM_ATTR_SCALE = "scale"; -const std::string BATCHNORM_ATTR_BIAS = "bias"; -const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; -const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; -const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; - -// huberloss -const std::string HUBER_LOSS_ATTR_DELTA = "delta"; - -// SSDRealDivTileMul -const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; - -// SSDSumMulRealDivMean -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; - -// ConcatFive2Four -// ConcatFour2Five -const std::string SSD_BOX_TYPE_NUM = "box_type_num"; -const std::string SSD_CLASS_NUM = "class_num"; -const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; -const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; -const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; -const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; - -// Scale -const std::string SCALE_ATTR_SCALE = "scale"; -const std::string SCALE_ATTR_BIAS = "bias"; - -// FullConnection -const std::string FULL_CONNECTION_ATTR_FILTER = "filter"; -const std::string FULL_CONNECTION_ATTR_BIAS = "bias"; -const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT = "num_output"; -const std::string FULL_CONNECTION_ATTR_RELU_FLAG = "relu_flag"; -const std::string FULL_ATTR_NAME_ALGO = "algo"; - -// SoftmaxOpParams -const std::string SOFTMAX_ATTR_ALGO = "algo"; -const std::string SOFTMAX_ATTR_MODE = "mode"; - -// SparseSoftmaxCrossEntropy -const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE = "cross_entropy_mode"; -const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD = "cross_entropy_is_grad"; -// Attr labelSmoothing -const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING = "labelSmoothing"; - -// ApplyMomentum -const std::string APPLYMENTUM_ATTR_IS_GRAPH_FUSION = "applymomentum_is_graph_fusion"; - -// Activation -const std::string ACTIVATION_ATTR_MODE = "mode"; -const std::string ACTIVATION_ATTR_COEF = "coef"; - -// Concat -const std::string CONCAT_ATTR_NAME_AXIS = "axis"; - -// Const -const std::string CONST_ATTR_NAME_DATA_TRANSTYPE = "data_transtype"; -const std::string CONST_ATTR_NAME_OUTPUT_FORMAT = "output_format"; -const std::string CONST_ATTR_NAME_OUTPUT_TYPE = "output_type"; - -// Roipooling -const std::string ROIPOOLING_ATTR_NAME_POOLED_H = "pooled_h"; -const std::string ROIPOOLING_ATTR_NAME_POOLED_W = "pooled_w"; -const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE = "spatial_scale"; -const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE = "rio_pooling_mode"; -const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE = "pooling_mode"; -const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO = "sampling_ratio"; - -// DetectionOutput -const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES = "num_classes"; -const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES = "ocr_num_classes"; -const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD = "nms_threshold"; -const std::string DETECTIONOUTPUT_ATTR_TOP_K = "top_k"; -const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD = "confidence_threshold"; -const std::string DETECTIONOUTPUT_ATTR_IMG_H = "img_h"; -const std::string DETECTIONOUTPUT_ATTR_IMG_W = "img_w"; -const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE = "batch_size"; -// Ssd DetectionOutput -const std::string DETECTIONOUTPUT_ATTR_ETA = "eta"; -const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION = "shared_location"; -const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID = "background_label_id"; -const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE = "code_type"; -const std::string DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET = "variance_encoded_in_target"; -const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K = "keep_top_k"; -// Refinedet DetectionOutput -const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE = "objectness_score"; -// yolo DetectionOutput -const std::string DETECTIONOUTPUT_ATTR_ClASSES = "classes"; -const std::string DETECTIONOUTPUT_ATTR_BIASES = "biases"; -const std::string DETECTIONOUTPUT_ATTR_RELATIVE = "relative"; -const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD = "objectness_threshold"; -const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD = "class_threshold"; -const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K = "post_top_k"; -const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY = "iou_threshold_decay"; -const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR = "coor_scale_factor"; -const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION = "yolo_version"; - -// DetectionPostprocess -const std::string POSTPROCESS_ATTR_NAME_CLS_NUM = "cls_num"; -const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH = "conf_thresh"; -const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH = "nms_thresh"; -const std::string POSTPROCESS_ATTR_POST_NMS_TOPN = "post_nms_topn"; -const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT = "bbox_reg_weights"; - -// Spatialtransfrom -const std::string SPTIALTF_ATTR_NAME_OUTPUT_H = "output_h"; -const std::string SPTIALTF_ATTR_NAME_OUTPUT_W = "output_w"; -const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE = "border_value"; -const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM = "affine_transform"; - -// Proposa -const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE = "feat_stride"; -const std::string PROPOSAL_ATTR_NAME_BASE_SIZE = "base_size"; -const std::string PROPOSAL_ATTR_NAME_MIN_SIZE = "min_size"; -const std::string PROPOSAL_ATTR_NAME_RATIO = "ratio"; -const std::string PROPOSAL_ATTR_NAME_SCALE = "scale"; -const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN = "pre_nms_topn"; -const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN = "post_nms_topn"; -const std::string PROPOSAL_ATTR_NAME_NMS_THRESH = "nms_thresh"; -const std::string PROPOSAL_ATTR_NAME_TOP_SIZE = "top_size"; -const std::string PROPOSAL_ATTR_IMG_H = "img_h"; -const std::string PROPOSAL_ATTR_IMG_W = "img_w"; -// Softmax -const std::string SOFTMAX_ATTR_AXIS = "axis"; - -// Permute -const std::string PERMUTE_ATTR_ORDER = "order"; -const std::string PERMUTE_ATTR_PERM = "perm"; - -// SSD Normalize -const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; -const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED = "channel_shared"; -const std::string SSDNORMALIZE_ATTR_EPS = "eps"; - -// Flatten -const std::string FLATTEN_ATTR_AXIS = "axis"; -const std::string FLATTEN_ATTR_END_AXIS = "end_axis"; - -// SsdPRIORBOX -const std::string SSD_PRIOR_BOX_ATTR_FLIP = "flip"; -const std::string SSD_PRIOR_BOX_ATTR_CLIP = "clip"; -const std::string SSD_PRIOR_BOX_ATTR_IMG_H = "img_h"; -const std::string SSD_PRIOR_BOX_ATTR_IMG_W = "img_w"; -const std::string SSD_PRIOR_BOX_ATTR_STEP_H = "step_h"; -const std::string SSD_PRIOR_BOX_ATTR_STEP_W = "step_w"; -const std::string SSD_PRIOR_BOX_ATTR_OFFSET = "offset"; -const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE = "min_size"; -const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE = "max_size"; -const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM = "min_size_num"; -const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM = "max_size_num"; -const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO = "aspect_ratio"; -const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; -const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; -const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; - -// RefinedetDetectionOutput -const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; -const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; - -// PRelu -const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; - -// Psroi pooling -const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE = "spatial_scale"; -const std::string PSROIPOOLING_ATTR_OUTPUT_DIM = "output_dim"; -const std::string PSROIPOOLING_ATTR_GROUP_SIZE = "group_size"; - -// Power -const std::string POWER_ATTR_NAME_POWER = "power"; -const std::string POWER_ATTR_NAME_SCALE = "scale"; -const std::string POWER_ATTR_NAME_SHIFT = "shift"; - -// log -const std::string LOG_ATTR_NAME_SCALE = "scale"; -const std::string LOG_ATTR_NAME_SHIFT = "shift"; -const std::string LOG_ATTR_NAME_BASE = "base"; -// Pack -const std::string PACK_ATTR_NAME_NUM = "N"; - -// Unpack -const std::string UNPACK_ATTR_NAME_NUM = "num"; -const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; -// Gathernd -const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; -const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; - -// Argmax -const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; -const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; -const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; -const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; -const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; -const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; -const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; - -// upsample -const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; -const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; - -// Relu -const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; - -// FreeSpaceExtract -const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT = "org_height"; - -// Split -const std::string SPLIT_ATTR_NAME_SLICE_POINT = "slice_point"; -const std::string SPLIT_ATTR_NAME_SIZE_SPLIT = "size_split"; -const std::string SPLIT_ATTR_NAME_NUM_SPLIT = "num_split"; - -// Tvm -const std::string TVM_ATTR_NAME_MAGIC = "tvm_magic"; -const std::string TVM_ATTR_NAME_BLOCKDIM = "tvm_blockdim"; -const std::string TVM_ATTR_NAME_METADATA = "tvm_metadata"; -const std::string TVM_ATTR_NAME_WORKSPACE_TYPE = "tvm_workspace_type"; - -// Squeeze -const std::string SQUEEZE_ATTR_AXIS = "axis"; -const std::string SQUEEZE_ATTR_DIMS = "squeeze_dims"; -const std::string SQUEEZE_OP_NAME = "Squeeze"; - -// Stride slice -const std::string STRIDE_SLICE_ATTR_BEGIN_MASK = "begin_mask"; -const std::string STRIDE_SLICE_ATTR_END_MASK = "end_mask"; -const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK = "ellipsis_mask"; -const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK = "new_axis_mask"; -const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK = "shrink_axis_mask"; - -// Slice -const std::string SLICE_ATTR_NAME_BEGINS = "begins"; -const std::string SLICE_ATTR_NAME_SIZES = "sizes"; - -// Roialign -const std::string ROIALIGN_ATTR_SPATIAL_SCALE = "spatial_scale"; -const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; -const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; -const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; - -// Generate_rpn_proposal -const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK = "post_nms_topk"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE = "rpn_mini_size"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH = "rpn_proposal_nms_thresh"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH = "rpn_proposal_filter_thresh"; -// Decode_bbox -const std::string DECODE_BBOX_ATTR_DECODECLIP = "decodeClip"; - -// Cast -const std::string CAST_ATTR_DSTT = "DstT"; -const std::string CAST_ATTR_SRCT = "SrcT"; -const std::string CAST_ATTR_DST_TYPE = "dst_type"; -const std::string CAST_ATTR_TRUNCATE = "truncate"; - -// Fastrcnnn predications -const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK = "fsr_topk"; -const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD = "fsr_score_thres"; -const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD = "fsr_nms_thres"; -const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES = "fsr_num_classes"; - -// REORG -const std::string REORG_ATTR_STRIDE = "stride"; -const std::string REORG_ATTR_REVERSE = "reverse"; - -// MERGE -const std::string MERGE_DEAD_INDEX = "merge_dead_index"; -const std::string MERGE_PRENODE_FLAG = "merge_prenode_flag"; -const std::string TO_BE_OUTPUT = "to_be_output"; - -// ENTER -const std::string ENTER_ATTR_FRAME_NAME = "frame_name"; -const std::string ENTER_ATTR_CONSTANT_FLAG = "is_constant"; - -// Concatv2 -const std::string CONCAT_V2_ATTR_TIDX = "Tidx"; -const std::string CONCAT_V2_ATTR_N = "N"; -// SUM -const std::string SUM_ATTR_TIDX = "Tidx"; -const std::string SUM_ATTR_AXIS = "axis"; -const std::string SUM_ATTR_KEEP_DIMS = "keep_dims"; - -// ResizeBilinear -const std::string RESIZE_BILINEAR_ATTR_MODE = "mode"; -const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS = "align_corners"; -const std::string RESIZE_BILINEAR_ATTR_HEIGHT = "height"; -const std::string RESIZE_BILINEAR_ATTR_WIDTH = "width"; -const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR = "zoom_factor"; -const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR = "shrink_factor"; -const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN = "pad_begin"; -const std::string RESIZE_BILINEAR_ATTR_PAD_END = "pad_end"; -const std::string RESIZE_BILINEAR_ATTR_ALPHA = "alpha"; -const std::string RESIZE_BILINEAR_ATTR_BETA = "beta"; - -// RetinaNet -const std::string RETINANET_FILTER_BACKGROUND_TRUE = "retina_conv_filter_background"; -const std::string RETINANET_ANCHOR_FUSION = "retina_anchor_fusion"; - -// MatMul -const std::string MATMUL_TRANSPOSE_X = "transposeX"; -const std::string MATMUL_TRANSPOSE_W = "transposeW"; -const std::string MATMUL_HAS_BIAS = "has_bias"; -const std::string MATMUL_ATTR_IS_TRAINING = "matmul_is_training"; - -// Flatten -const std::string FLATTEN_START_AXIS = "start_axis"; -const std::string FLATTEN_END_AXIS = "end_axis"; - -// Reshape -const std::string RESHAPE_ATTR_AXIS = "axis"; -const std::string RESHAPE_ATTR_NUM_AXES = "num_axes"; -const std::string RESHAPE_ATTR_FORMAT = "format"; -const std::string RESHAPE_ATTR_SHAPE = "shape"; -const std::string RESHAPE_ATTR_ALPHA = "alpha"; -const std::string RESHAPE_ATTR_BETA = "beta"; - -// Frameoworkop -const std::string T_IN_DATATYPE = "t_in_datatype"; -const std::string T_OUT_DATATYPE = "t_out_datatype"; -const std::string ATTR_NAME_OUT_N = "out_n"; -const std::string ATTR_NAME_OUT_C = "out_c"; -const std::string ATTR_NAME_OUT_H = "out_h"; -const std::string ATTR_NAME_OUT_W = "out_w"; -const std::string ATTR_PAD_DEPTH_CONV = "pad_depth_conv"; -const std::string ATTR_PAD_CONV = "pad_conv"; - -const std::string ATTR_NAME_BEFORE_PAD = "before_pad"; -const std::string ANN_MEAN_KEEPDIMS = "AnnMeanKeepDims"; -const std::string PAD_ATTR_PADDINGDS = "paddings"; -const std::string PAD_ATTR_CONSTANT_VALUE = "padvalue"; - -// ConvGradFilter -const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape"; -// ConvGradInput -const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; - -// Rnn -const std::string RNN_MODE_STATIC = "rnn_static"; -const std::string MUTI_RNN = "multi_rnn"; -const std::string CNN_RNN = "cnn_rnn"; -const std::string RNN_MODE_ = "rnn_"; - -const std::string CELL_MODE = "mode"; -const std::string LSTM_CELL = "lstm_cell"; -const std::string GRU_CELL = "gru_cell"; -const std::string RNN_HT = "ht"; -const std::string RNN_XT_HT = "xt_ht"; -const std::string RNN_BATCH_SIZE = "batch_size"; -const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; -const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; -const std::string LSTM_ACTIVATE = "lstm_activate"; -const std::string LSTM_OUT_MAP = "lstm_out_map"; -const std::string LSTM_OUT_MODE = "lstm_out_mode"; -const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; -const std::string LSTM_TIME_MAJOR = "lstm_time_major"; -const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; - -// Upsample -const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; - -// PadV2 -const std::string PADV2_ATTR_NAME_MODE = "mode"; -const std::string PADV2_ATTR_NAME_PADS = "paddings"; -const std::string PADV2_ATTR_NAME_T = "T"; -const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; -const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; - -// MirrorPad -const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; -const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; -const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; -const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; -// Filler -const std::string FILLER_TYPE = "filler_type"; -const std::string FILLER_VALUE = "filler_value"; - -// Shufflechannel -const std::string SHUFFLE_CHANNEL_GROUP = "group"; - -// TopKV2 -const std::string TOPKV2_ATTR_K = "k"; - -// Calibaration -const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; -const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; -const std::string PAD_TOP_INDEX = "PAD_TOP_INDEX"; -const std::string PAD_BOTTOM_INDEX = "PAD_BOTTOM_INDEX"; -const std::string PAD_RIGHT_INDEX = "PAD_RIGHT_INDEX"; -const std::string PAD_LEFT_INDEX = "PAD_LEFT_INDEX"; -const std::string QUANTIZE_ALGO_ATTR = "quantize_algo"; -const std::string SCALE_TYPE_ATTR = "scale_type"; - -const std::string QUANTIZE_SCALE_MODE = "quantize_scale_mode"; -const std::string QUANTIZE_SCALE_VALUE = "quantize_scale_value"; -const std::string QUANTIZE_SCALE_OFFSET = "quantize_scale_offset"; -const std::string QUANTIZE_OFFSET_DATA_VALUE = "quantize_offset_data_value"; -const std::string QUANTIZE_OFFSET_DATA_OFFSET = "quantize_offset_data_offset"; -const std::string QUANTIZE_OFFSET_WEIGHT_VALUE = "quantize_offset_weight_value"; -const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET = "quantize_offset_weight_offset"; -const std::string QUANTIZE_OFFSET_PAD_VALUE = "quantize_offset_pad_value"; -const std::string QUANTIZE_OFFSET_PAD_OFFSET = "quantize_offset_pad_offset"; - -const std::string DEQUANTIZE_SCALE_MODE = "dequantize_scale_mode"; -const std::string DEQUANTIZE_SCALE_VALUE = "dequantize_scale_value"; -const std::string DEQUANTIZE_SCALE_OFFSET = "dequantize_scale_offset"; -const std::string DEQUANTIZE_OFFSET_DATA_TYPE = "dequantize_offset_data_value"; -const std::string DEQUANTIZE_OFFSET_DATA_OFFSET = "dequantize_offset_data_offset"; -const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE = "dequantize_offset_weight_value"; -const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET = "dequantize_offset_weight_offset"; -const std::string DEQUANTIZE_OFFSET_PAD_VALUE = "dequantize_offset_pad_value"; -const std::string DEQUANTIZE_OFFSET_PAD_OFFSET = "dequantize_offset_pad_offset"; - -const std::string REQUANTIZE_SCALE_MODE = "requantize_scale_mode"; -const std::string REQUANTIZE_SCALE_VALUE = "requantize_scale_value"; -const std::string REQUANTIZE_SCALE_OFFSET = "requantize_scale_offset"; -const std::string REQUANTIZE_OFFSET_DATA_VALUE = "requantize_offset_data_value"; -const std::string REQUANTIZE_OFFSET_DATA_OFFSET = "requantize_offset_data_offset"; -const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE = "requantize_offset_weight_value"; -const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET = "requantize_offset_weight_offset"; -const std::string REQUANTIZE_OFFSET_PAD_VALUE = "requantize_offset_pad_value"; -const std::string REQUANTIZE_OFFSET_PAD_OFFSET = "requantize_offset_pad_offset"; - -const std::string ATTR_NAME_IS_CONST = "attr_name_is_const"; - -const std::string ATTR_NAME_GROUP = "group"; -const std::string ATTR_NAME_DILATION_SIZE = "dilation_size"; -const std::string ATTR_NAME_EPSILON = "epsilon"; -const std::string ATTR_NAME_POOLING_MODE = "mode"; -const std::string ATTR_NAME_CLASS_NUM = "class_num"; -// model -const std::string ATTR_MODEL_TARGET_TYPE = "target_type"; - -const std::string ATTR_MODEL_STREAM_NUM = "stream_num"; - -const std::string ATTR_MODEL_EVENT_NUM = "event_num"; - -const std::string ATTR_MODEL_HUGE_STREAM_LIST = "huge_stream_list"; - -const std::string ATTR_MODEL_LABEL_NUM = "label_num"; - -const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; - -const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size"; - -const std::string ATTR_MODEL_OUT_NODES_NAME = "attr_model_out_nodes_name"; - -const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; - -const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; - -const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR = "task_gen_weight_addr"; - -const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR = "task_gen_variable_addr"; - -const std::string ATTR_MODEL_VAR_SIZE = "variable_size"; - -const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; - -const std::string ATTR_MODEL_CORE_TYPE = "core_type"; - -const std::string ATTR_MODEL_ATC_VERSION = "atc_version"; - -const std::string ATTR_MODEL_OPP_VERSION = "opp_version"; - -// Public attribute -const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; - -const std::string ATTR_NAME_BYTE_SIZE = "op_byte_size"; - -const std::string ATTR_NAME_FUSION_INFERENCE_ID = "fusion_inference_id"; - -const std::string ATTR_NAME_FUSION_OPDEF = "fusion_opdef"; - -const std::string ATTR_NAME_IO_OP = "io_op"; - -const std::string ATTR_NAME_FUSION_SCOPE = "fusion_scope"; - -const std::string ATTR_NAME_OPATTR = "opattr"; - -const std::string ATTR_NAME_RELUFLAG = "relu_flag"; - -const std::string ATTR_NAME_SEQLEN_INDEX = "seqlen_index"; - -const std::string ATTR_NAME_X_INDEX = "x_index"; - -const std::string ATTR_NAME_CONT_INDEX = "cont_index"; - -const std::string ATTR_NAME_XSTATIC_INDEX = "xstatic_index"; - -const std::string TARGET_TYPE_MINI = "MINI"; - -const std::string TARGET_TYPE_TINY = "TINY"; - -const std::string TARGET_TYPE_LITE = "LITE"; - -// l2_normalize -const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; -const std::string L2_NORMALIZE_ATTR_EPS = "eps"; - -const std::string POOL_PARAMA_ATTR_WINDOW = "window"; -const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; -const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; -const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; -const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; -const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; - -// HCOM -const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; -const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; - -const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; -const std::string HCOM_ATTR_GROUP = "group"; -const std::string HCOM_ATTR_SR_TAG = "sr_tag"; -const std::string HCOM_ATTR_SRC_RANK = "src_rank"; -const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; -const std::string HCOM_ATTR_FUSION = "fusion"; -const std::string HCOM_ATTR_SHAPE = "shape"; -const std::string HCOM_ATTR_DATA_TYPE = "dtype"; - -// SpaceToDepth/DepthToSpace -const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; - -// SparseSoftmaxCrossEntropyWithLogits -const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; - -// MaxPoolGradWithArgmax -const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; - -// AvgPoolGrad -const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; - -// Pad -const std::string ATTR_PAD_FORMAT = "attr_pad_format"; - -// Varible -const std::string VAR_ATTR_FORMAT = "_var_format"; -const std::string VAR_ATTR_NAME = "var_name"; -const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; -const std::string VAR_ATTR_4D_FORMAT = "4D"; -const std::string VAR_ATTR_5D_FORMAT = "5D"; -const std::string VAR_ATTR_DATA_TYPE = "data_format"; -const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; -const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; -const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; -const std::string VAR_ATTR_SHAPE = "shape"; -const std::string HALF_VAR_NAME_END = "_fp16"; -const std::string VAR_ATTR_INITED = "var_is_inited"; - -const std::string VAR_ATTR_CONTAINER = "container"; -const std::string VAR_ATTR_SHARED_NAME = "shared_name"; -const std::string VAR_ATTR_DTYPE = "dtype"; - -const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; -const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; -const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; -const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; -const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; -const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; - -// Assign -const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; -const std::string ASSIGN_VAR_NAME = "_assign_var_name"; - -// space2bacth batch2space -const std::string BATCH_SPACE_ATTR_BLOCK = "block"; -const std::string BATCH_SPACE_ATTR_PADDING = "padding"; - -// depth_to_space space_to_depth -const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; - -// FakeQuantWithMinMaxVars -const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; -const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; - -// mobilenet_ssd_conv_fusion -const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; -const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; -const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; - -// lsh project -const std::string LSH_PROJ_TYPE = "lsh_project_type"; - -// log time stamp -const std::string LOG_TIME_STAMP_LOGID = "logid"; -const std::string LOG_TIME_STAMP_NOTIFY = "notify"; - -// ShapeN -const std::string SHAPEN_ATTR_N = "N"; -const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; -const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; - -// GatherV2 attr def -const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; -const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; -const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; - -// Reshape attr def -const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; -const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; - -// axis attr def -const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; - -const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; - -const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; -const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; - -// For constant folding -const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; - -const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; - -const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC = "continuous_input_alloc"; - -const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; - -const std::string ATTR_NAME_REFERENCE = "reference"; - -const std::string ATTR_NAME_NOTASK = "_no_task"; - -const std::string ATTR_NAME_OUTPUT_REUSE_INPUT = "_output_reuse_input"; - -const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX = "_reuse_input_on_dim_index"; - -const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT = "_no_padding_continuous_input"; - -const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT = "_no_padding_continuous_output"; - -const std::string ATTR_NAME_ATOMIC_INDEX = "atomic_index"; - -// Used for mark the active label list stream of activated node -const std::string ATTR_NAME_ACTIVE_LABEL_LIST = "_active_label_list"; - -// Used for l2cache, true: the memory of all inputs is used for the last time. -const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE = "is_end_of_inputmem_lifecycle"; - -// Multi batch -const std::string ATTR_NAME_PRED_VALUE = "_pred_value"; -const std::string ATTR_NAME_BATCH_NUM = "_batch_num"; -const std::string ATTR_NAME_BATCH_LABEL = "_batch_label"; -const std::string ATTR_NAME_COMBINED_BATCH = "_combined_batch"; - -// Control flow -const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; -const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; -const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; -const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; -const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; -const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; -const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE = "subgraph_first_active"; -const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS = "combined_dynamic_dims"; - -const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; -const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; -const std::string ATTR_NAME_SWITCH_DATA_TYPE = "_switch_data_type"; -const std::string ATTR_NAME_ORIG_NODE_NAME = "_original_node_name"; -const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG = "_cyclic_dependence_flag"; - -const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; - -// Function Op -const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; - -// Used for mark the active node is for loop, type:bool -const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; - -const std::string ATTR_NAME_MEMORY_TYPE_INPUT = "memory_type_input"; - -const std::string ATTR_NAME_MEMORY_TYPE_OUTPUT = "memory_type_output"; - -const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; - -const std::string ATTR_NAME_MEMORY_TYPE_RANGE = "_memory_type_range"; - -const std::string MODEL_ATTR_SESSION_ID = "session_id"; - -// lx fusion -const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; -const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; -const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; -const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; -const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; -const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST = "_input_memory_type"; -const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST = "_output_memory_type"; -const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR = "_l1_fusion_extend_content"; -const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE = "_tensor_actual_size"; -const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1_fuison"; -const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; -const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; -const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; -const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; -const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; -const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; -const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; -const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag"; -const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr"; -const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size"; -const std::string ATTR_NAME_ENGINE_NAME_FOR_LX = "_lxfusion_engine_name"; -const std::string ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX = "_lxfusion_op_kernel_lib_name"; -const std::string ATTR_NAME_NEED_LX_FUSION = "_lx_fusion"; -const std::string ATTR_NAME_OPTIMIZE_GROUP = "_optimize_group"; -const std::string ATTR_NAME_OP_COMPILE_STRATEGY = "_op_compile_strategy"; -const std::string ATTR_NAME_TBE_KERNEL_NAME = "_tbe_kernel_name"; -const std::string ATTR_NAME_TBE_KERNEL_BUFFER = "_tbe_kernel_buffer"; - -// Op debug attrs -const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; -const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode"; - -// Atomic addr clean attrs -const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; -const std::string ATOMIC_ATTR_OUTPUT_INDEX = "atomic_output_index"; -const std::string ATOMIC_ATTR_IS_FUSION_NODE = "is_fusion_node"; -const std::string EXT_ATTR_ATOMIC_WORKSPACE_INFO = "sub_node_workspace_info"; -const std::string EXT_ATTR_ATOMIC_WORKSPACE_OFFSET = "sub_node_workspace_offset"; -const std::string ATOMIC_ATTR_IS_ATOMIC_NODE = "is_atomic_node"; - -// Source/dst format for Op FormatTransfer -const std::string FORMAT_TRANSFER_SRC_FORMAT = "src_format"; -const std::string FORMAT_TRANSFER_DST_FORMAT = "dst_format"; - -// For compile op by ge call -const std::string ATTR_NEED_COMPILE = "_node_need_compile"; - -const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; - -const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS = "_mbatch_origin_input_dims"; - -const std::string ATTR_DYNAMIC_TYPE = "mbatch_dynamic_type"; - -const std::string ATTR_USER_DESIGNEATE_SHAPE_ORDER = "user_designate_shape_order"; - -// For inserted op -const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; - -// For compress weight -const std::string ATTR_NAME_COMPRESS_WEIGHT = "_is_compress_weight"; - -// For data dump -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES = "_datadump_original_op_names"; -const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP = "_datadump_is_multiop"; -const std::string ATTR_NAME_DATA_DUMP_SUB_SPLITER_INDEX = "_datadump_sub_spliter_index"; -const std::string ATTR_NAME_DATA_DUMP_GROUP_OP_NAME = "_datadump_group_op_name"; -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_NAME = "_datadump_origin_name"; -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_output_index"; -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; - -// functional ops attr -const std::string ATTR_NAME_IF_THEN_BRANCH = "then_branch"; -const std::string ATTR_NAME_IF_ELSE_BRANCH = "else_branch"; -const std::string ATTR_NAME_WHILE_COND = "cond"; -const std::string ATTR_NAME_WHILE_BODY = "body"; - -// used for label switch -const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; -const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; -const std::string ATTR_NAME_SUBGRAPH_END_NODE = "_subgraph_end_node"; - -const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; -const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; - -// used for LX tiling -const std::string ATTR_NAME_OP_L1_SPACE = "_l1_space"; -const std::string ATTR_NAME_FUSION_TYPE_LIST = "_fusion_type_list"; -const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST = "_valid_input_shape_list_list"; -const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_list_list"; -const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; -const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_output_offset_list_list"; - -// for unregistered op -const std::string ATTR_NAME_UNREGST_OPPATH = "_unregst_oppath"; -const std::string ATTR_NAME_UNREGST_ATTRLIST = "_unregst_attrlist"; - -// used for Horovod -const std::string ATTR_INTER_EVENT_IDENTIFY = "event_id"; -const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op"; -// used for allreduce tailing optimization -const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group"; -const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node"; - -// dynamic shape attr -const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR = "_alloc_fixed_addr"; -const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX = "_alloc_fixed_addr_index"; - -// op dynamic input -const std::string ATTR_NAME_DYNAMIC_INPUT_START = "_dynamic_input_index_start"; -const std::string ATTR_NAME_DYNAMIC_INPUT_END = "_dynamic_input_index_end"; - -// atc user def dtype&format -const std::string ATTR_ATC_USER_DEFINE_DATATYPE = "_user_defined_data_type"; -const std::string ATTR_ATC_USER_DEFINE_FORMAT = "_user_defined_format"; - -// for fusion op plugin -const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; - -// graph partition for aicpu -const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME = "pld_front_node_engine_name"; -const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME = "end_rear_node_engine_name"; - -// input and output memory type -const std::string ATTR_VARIABLE_PLACEMENT = "_variable_placement"; -const std::string ATTR_INPUT_MEMORY_TYPE = "_input_memory_type"; -const std::string ATTR_OUTPUT_MEMORY_TYPE = "_output_memory_type"; - -// input_output_offset -const std::string ATTR_ZERO_COPY_BASIC_OFFSET = "_zero_copy_basic_offset"; -const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET = "_zero_copy_relative_offset"; -} // namespace ge diff --git a/src/common/graph/ge_attr_value.cc b/src/common/graph/ge_attr_value.cc deleted file mode 100644 index a8490470..00000000 --- a/src/common/graph/ge_attr_value.cc +++ /dev/null @@ -1,1289 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/ge_attr_value.h" -#include "graph/ge_tensor.h" -#include "external/graph/graph.h" -#include "utils/attr_utils.h" -#include "framework/common/debug/ge_log.h" -#include "graph/model_serialize.h" -#include "proto/ge_ir.pb.h" -#include "detail/model_serialize_imp.h" -#include "debug/ge_attr_define.h" -#include "debug/ge_log.h" -#include "debug/ge_util.h" - -using std::map; -using std::string; -using std::vector; - -namespace ge { -NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } - -NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) - : named_attrs_(owner, proto_msg) {} // lint !e1744 - -void NamedAttrs::SetName(const std::string &name) { - auto proto_msg = named_attrs_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_name(name); - } -} - -string NamedAttrs::GetName() const { - auto proto_msg = named_attrs_.GetProtoMsg(); - if (proto_msg != nullptr) { - return proto_msg->name(); - } - return string(); -} - -GeAttrValue NamedAttrs::GetItem(const string &key) const { - GeAttrValue value; - (void)GetAttr(key, value); - return value; -} - -ProtoAttrMapHelper NamedAttrs::MutableAttrMap() { - auto proto_msg = named_attrs_.GetProtoMsg(); - if (proto_msg != nullptr) { - return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), proto_msg->mutable_attr()); - } - return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); -} - -ConstProtoAttrMapHelper NamedAttrs::GetAttrMap() const { - auto proto_msg = named_attrs_.GetProtoMsg(); - if (proto_msg != nullptr) { - return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), &proto_msg->attr()); - } - return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); -} - -class GeAttrValueImp { - public: - static map attr_val_one_type_map_; - static map attr_val_list_type_map_; - - static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::INT val); - static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::FLOAT val); - static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::BOOL val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::STR &val); - static bool SetValue(proto::AttrDef &attr_def, const ConstGeTensorPtr &val); - static bool SetValue(proto::AttrDef &attr_def, const GeTensor &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::TENSOR_DESC &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::BYTES &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::NAMED_ATTRS &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::GRAPH &val); - static bool SetValue(proto::AttrDef &attr_def, const vector &val); - static bool SetValue(proto::AttrDef &attr_def, const vector &val); - static bool SetValue(proto::AttrDef &attr_def, const vector &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_FLOAT &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BOOL &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_STR &val); - static bool SetValue(proto::AttrDef &proto_attr_val, const vector &value); - static bool SetValue(proto::AttrDef &proto_attr_val, const vector &value); - static bool SetValue(proto::AttrDef &attr_def, const vector &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_TENSOR_DESC &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BYTES &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_NAMED_ATTRS &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_GRAPH &val); - - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::INT &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::FLOAT &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BOOL &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::STR &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::TENSOR &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensor &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::TENSOR_DESC &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BYTES &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::NAMED_ATTRS &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::GRAPH &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_INT &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_FLOAT &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_BOOL &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_STR &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_TENSOR &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, vector &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_TENSOR_DESC &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_BYTES &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_NAMED_ATTRS &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_GRAPH &val); - // Value will be moved - static bool SetZeroCopyBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &&buffer); - static bool GetZeroCopyBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &buffer); - // Value will be moved - static bool SetZeroCopyListBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - vector &list_buffer); - static bool GetZeroCopyListBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - vector &list_buffer); - - static bool SetValue(proto::AttrDef &attr_def, const vector> &value); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - vector> &value); - static bool SetValue(proto::AttrDef &attr_def, const vector &value); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - vector &value); - - static bool SetValue(proto::AttrDef &attr_def, const ge::DataType &value); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, ge::DataType &value); -}; - -map GeAttrValueImp::attr_val_one_type_map_ = { - {proto::AttrDef::kI, GeAttrValue::VT_INT}, - {proto::AttrDef::kF, GeAttrValue::VT_FLOAT}, - {proto::AttrDef::kB, GeAttrValue::VT_BOOL}, - {proto::AttrDef::kS, GeAttrValue::VT_STRING}, - {proto::AttrDef::kT, GeAttrValue::VT_TENSOR}, - {proto::AttrDef::kTd, GeAttrValue::VT_TENSOR_DESC}, - {proto::AttrDef::kG, GeAttrValue::VT_GRAPH}, - {proto::AttrDef::kBt, GeAttrValue::VT_BYTES}, - {proto::AttrDef::kFunc, GeAttrValue::VT_NAMED_ATTRS}, - {proto::AttrDef::kListListInt, GeAttrValue::VT_LIST_LIST_INT}, - {proto::AttrDef::kDt, GeAttrValue::VT_DATA_TYPE}, -}; -map GeAttrValueImp::attr_val_list_type_map_ = { - {proto::AttrDef_ListValue_ListValueType_VT_LIST_INT, GeAttrValue::VT_LIST_INT}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT, GeAttrValue::VT_LIST_FLOAT}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_BOOL, GeAttrValue::VT_LIST_BOOL}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_STRING, GeAttrValue::VT_LIST_STRING}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR, GeAttrValue::VT_LIST_TENSOR}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, GeAttrValue::VT_LIST_TENSOR_DESC}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH, GeAttrValue::VT_LIST_GRAPH}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, GeAttrValue::VT_LIST_BYTES}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, GeAttrValue::VT_LIST_NAMED_ATTRS}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE, GeAttrValue::VT_LIST_DATA_TYPE}, -}; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue::GeAttrValue() { value_.InitDefault(); } - -GeAttrValue::GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val) : value_(proto_owner, val) {} - -GeAttrValue::ValueType GeAttrValue::GetValueType() const { - auto proto_msg = value_.GetProtoMsg(); - if (proto_msg != nullptr) { - auto val_case = proto_msg->value_case(); - if (val_case != proto::AttrDef::kList) { - auto it = GeAttrValueImp::attr_val_one_type_map_.find(val_case); - if (it != GeAttrValueImp::attr_val_one_type_map_.end()) { - return it->second; - } - } else { - auto it = GeAttrValueImp::attr_val_list_type_map_.find(proto_msg->list().val_type()); - if (it != GeAttrValueImp::attr_val_list_type_map_.end()) { - return it->second; - } - } - } - return GeAttrValue::VT_NONE; -} - -bool GeAttrValue::IsEmpty() const { return GetValueType() == VT_NONE; } - -GeAttrValue GeAttrValue::Copy() const { - GeAttrValue valueRet; - auto proto_msg = value_.GetProtoMsg(); - auto proto_msg_ret = valueRet.value_.GetProtoMsg(); - if (proto_msg != nullptr && proto_msg_ret != nullptr) { - *proto_msg_ret = *proto_msg; - } - return valueRet; -} - -#define ATTR_VALUE_SET_GET_IMP(type) \ - graphStatus GeAttrValue::SetValue(const type &val) { \ - auto proto_msg = value_.GetProtoMsg(); \ - if (proto_msg) { \ - if (GeAttrValueImp::SetValue(*proto_msg, val)) { \ - return GRAPH_SUCCESS; \ - } \ - } \ - return GRAPH_FAILED; \ - } \ - \ - graphStatus GeAttrValue::GetValue(type &val) const { \ - auto proto_msg = value_.GetProtoMsg(); \ - if (proto_msg) { \ - if (GeAttrValueImp::GetValue(*proto_msg, value_.GetProtoOwner(), val)) { \ - return GRAPH_SUCCESS; \ - } \ - } \ - return GRAPH_FAILED; \ - } - -ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524 -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR_DESC) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::GRAPH) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS) -ATTR_VALUE_SET_GET_IMP(vector) -/*lint -e665*/ -ATTR_VALUE_SET_GET_IMP(vector>) -/*lint +e665*/ -ATTR_VALUE_SET_GET_IMP(vector) // lint !e665 -ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665 - -#undef ATTR_VALUE_SET_GET_IMP - -graphStatus GeAttrValue::MutableTensor(GeTensorPtr &tensor) { return GetValue(tensor); } - -graphStatus GeAttrValue::MutableListTensor(vector &list_tensor) { return GetValue(list_tensor); } - -class AttrUtilsHelper { - public: - inline static bool GetValueCheckType(const proto::AttrDef &attr_def, proto::AttrDef::ValueCase proto_case) { - if (attr_def.value_case() != proto_case) { - GELOGW("Check Type Failed, proto case type %u, expected %u", attr_def.value_case(), proto_case); - return false; - } - return true; - } - - inline static bool GetValueCheckListType( - const proto::AttrDef &attr_def, proto::AttrDef_ListValue_ListValueType proto_list_case, - const std::function item_check_fun) { - if (attr_def.value_case() != proto::AttrDef::kList) { - GELOGW("Check ListType Failed, value_case %u", attr_def.value_case()); - return false; - } - auto &list = attr_def.list(); - if (list.val_type() == proto::AttrDef_ListValue_ListValueType_VT_LIST_NONE) { - return item_check_fun(attr_def); - } - if (list.val_type() != proto_list_case) { - GELOGW("Check ListType Failed, val_type %u, expected %u", list.val_type(), proto_list_case); - return false; - } - return true; - } - - inline static bool SetValueCheckType(proto::AttrDef &attr_def, proto::AttrDef::ValueCase proto_case) { - if (attr_def.value_case() != proto::AttrDef::VALUE_NOT_SET && attr_def.value_case() != proto_case) { - GELOGW("Check Type Failed, proto case type %u, expected %u", attr_def.value_case(), proto_case); - return false; - } - return true; - } - - inline static bool SetValueCheckAndSetListType(proto::AttrDef &attr_def, - proto::AttrDef_ListValue_ListValueType proto_list_case) { - if (attr_def.value_case() != proto::AttrDef::VALUE_NOT_SET && attr_def.value_case() != proto::AttrDef::kList) { - GELOGW("AttrUtils::Check Type Failed, value_case %u", attr_def.value_case()); - return false; - } - auto list = attr_def.mutable_list(); - if (list == nullptr) { - GELOGE(GRAPH_FAILED, "list is nullptr"); - return false; - } - if (list->val_type() != proto::AttrDef_ListValue_ListValueType_VT_LIST_NONE && - list->val_type() != proto_list_case) { - GELOGW("AttrUtils::Check ListType Type Failed, val_type %d, expected %d", static_cast(list->val_type()), - static_cast(proto_list_case)); - return false; - } - list->set_val_type(proto_list_case); - return true; - } - - static bool GetAttrMapItem(const AttrHolder *obj, const string &name, const proto::AttrDef *&attr_def) { - if (obj == nullptr) { - GELOGE(FAILED, "%s obj is nullptr", name.c_str()); - return false; - } - auto attr_map = obj->GetAttrMap().GetProtoMsg(); - if (attr_map == nullptr) { - GELOGE(FAILED, "%s attr map is nullptr", name.c_str()); - return false; - } - auto it = attr_map->find(name); - if (it == attr_map->end()) { - return false; - } - attr_def = &it->second; - return true; - } - - inline static bool MutableAttrMapItem(AttrHolder *obj, const string &name, proto::AttrDef *&attr_def) { - if (obj == nullptr) { - GELOGE(FAILED, " %s obj is nullptr", name.c_str()); - return false; - } - auto attr_map = obj->MutableAttrMap().GetProtoMsg(); - if (attr_map == nullptr) { - GELOGE(FAILED, "%s attr map is nullptr", name.c_str()); - return false; - } - // Get or add - attr_def = &((*attr_map)[name]); - return true; - } -}; - -#define ATTR_VALUE_IMP_SET_ONE(ValType, proto_case, protoItem) \ - bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, ValType value) { \ - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::proto_case)) { \ - return false; \ - } \ - proto_attr_val.set_##protoItem(value); \ - return true; \ - } - -#define ATTR_VALUE_IMP_SET_LIST(ValType, proto_list_case, protoItem) \ - bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, ValType value) { \ - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, \ - proto::AttrDef_ListValue_ListValueType_##proto_list_case)) { \ - return false; \ - } \ - auto list = proto_attr_val.mutable_list(); \ - list->clear_##protoItem(); \ - for (const auto &item : value) { \ - list->add_##protoItem(item); \ - } \ - return true; \ - } - -ATTR_VALUE_IMP_SET_ONE(int64_t, kI, i) -ATTR_VALUE_IMP_SET_ONE(float, kF, f) -ATTR_VALUE_IMP_SET_ONE(const string &, kS, s) -ATTR_VALUE_IMP_SET_ONE(bool, kB, b) - -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_FLOAT, f) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_STRING, s) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_BOOL, b) - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeTensorDesc &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kTd)) { - return false; - } - auto proto_msg = value.tensor_descriptor_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - *proto_attr_val.mutable_td() = *proto_msg; - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_td(); - for (const auto &item : value) { - auto proto_msg = item.tensor_descriptor_.GetProtoMsg(); - if (proto_msg == nullptr) { - proto_attr_val.clear_list(); - return false; - } - *list->add_td() = *proto_msg; - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ConstGeTensorPtr &value) { - if (value) { - return SetValue(proto_attr_val, *value); - } else { - return SetValue(proto_attr_val, GeTensor()); - } -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeTensor &val) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kT)) { - return false; - } - auto proto_msg = val.tensor_def_.GetProtoMsg(); - if (proto_msg == nullptr) { - GELOGE(FAILED, "Proto msg is nullptr"); - return false; - } - *proto_attr_val.mutable_t() = *proto_msg; - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - vector constList(value.size()); - std::copy(value.begin(), value.end(), constList.begin()); - return SetValue(proto_attr_val, constList); -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_t(); - for (const auto &item : value) { - if (item == nullptr) { - GELOGE(GRAPH_FAILED, "AttrUtils::SetListTensor item is nullptr"); - proto_attr_val.clear_list(); - return false; - } - auto proto_msg = item->tensor_def_.GetProtoMsg(); - if (proto_msg == nullptr) { - GELOGE(FAILED, "Proto msg is nullptr"); - proto_attr_val.clear_list(); - return false; - } - *list->add_t() = *proto_msg; - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_t(); - for (const auto &item : value) { - auto proto_msg = item.tensor_def_.GetProtoMsg(); - if (proto_msg == nullptr) { - GELOGE(FAILED, "Proto msg is nullptr"); - proto_attr_val.clear_list(); - return false; - } - *list->add_t() = *proto_msg; - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::BYTES &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { - return false; - } - size_t val_size = value.GetSize(); - proto_attr_val.set_bt(value.GetData(), val_size); - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_bt(); - for (const auto &item : value) { - list->add_bt(item.GetData(), item.GetSize()); - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NAMED_ATTRS &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { - return false; - } - auto proto_msg = value.named_attrs_.GetProtoMsg(); - if (proto_msg == nullptr) { - GELOGE(FAILED, "Proto msg is nullptr"); - return false; - } - *proto_attr_val.mutable_func() = *proto_msg; - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_na(); - for (const auto &item : value) { - auto proto_msg = item.named_attrs_.GetProtoMsg(); - if (proto_msg == nullptr) { - proto_attr_val.clear_list(); - return false; - } - *list->add_na() = *proto_msg; - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::ComputeGraphPtr &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kG)) { - return false; - } - ModelSerializeImp imp; - if (!imp.SerializeGraph(value, proto_attr_val.mutable_g())) { - GELOGE(GRAPH_FAILED, "AttrUtils::SetGraph SerializeGraph Failed"); - proto_attr_val.clear_g(); - return false; - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_g(); - - ModelSerializeImp imp; - for (const auto &item : value) { - if (!imp.SerializeGraph(item, list->add_g())) { - GELOGE(GRAPH_FAILED, "AttrUtils::SetListGraph SerializeGraph"); - proto_attr_val.clear_list(); - return false; - } - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector> &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kListListInt)) { - return false; - } - proto_attr_val.clear_list_list_int(); - auto list_list_int = proto_attr_val.mutable_list_list_int(); - GE_CHECK_NOTNULL_EXEC(list_list_int, return false); - for (auto &list_int : value) { - auto list_item = list_list_int->add_list_list_i(); - GE_CHECK_NOTNULL_EXEC(list_item, return false); - for (auto &int_item : list_int) { - list_item->add_list_i(int_item); - } - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_dt(); - for (const auto &item : value) { - list->add_dt(static_cast(item)); - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::DataType &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kDt)) { - return false; - } - proto_attr_val.set_dt(static_cast(value)); - - return true; -} - -#define ATTR_VALUE_IMP_GET_ONE(ValType, proto_case, protoItem) \ - bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ValType value) { \ - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::proto_case)) { \ - return false; \ - } \ - value = proto_attr_val.protoItem(); \ - return true; \ - } - -#define ListValueItemCheck(protoItem) \ - [](const proto::AttrDef &proto_attr_val) { return proto_attr_val.list().protoItem##_size() > 0; } - -#define ATTR_VALUE_IMP_GET_LIST(ValType, proto_list_case, protoItem) \ - bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, vector &value) { \ - value.clear(); \ - if (!AttrUtilsHelper::GetValueCheckListType( \ - proto_attr_val, proto::AttrDef_ListValue_ListValueType_##proto_list_case, ListValueItemCheck(protoItem))) { \ - return false; \ - } \ - auto &list = proto_attr_val.list(); \ - for (const auto &item : list.protoItem()) { \ - value.push_back(item); \ - } \ - return true; \ - } - -ATTR_VALUE_IMP_GET_ONE(int64_t &, kI, i) -ATTR_VALUE_IMP_GET_ONE(float &, kF, f) -ATTR_VALUE_IMP_GET_ONE(string &, kS, s) -ATTR_VALUE_IMP_GET_ONE(bool &, kB, b) - -ATTR_VALUE_IMP_GET_LIST(int64_t, VT_LIST_INT, i) -ATTR_VALUE_IMP_GET_LIST(float, VT_LIST_FLOAT, f) -ATTR_VALUE_IMP_GET_LIST(string, VT_LIST_STRING, s) -ATTR_VALUE_IMP_GET_LIST(bool, VT_LIST_BOOL, b) - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeTensorDesc &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kTd)) { - return false; - } - auto proto_msg = value.tensor_descriptor_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - *proto_msg = proto_attr_val.td(); - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - if (!AttrUtilsHelper::GetValueCheckListType( - proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, ListValueItemCheck(td))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.td()) { - value.emplace_back(GeTensorDesc()); - auto proto_msg = value.back().tensor_descriptor_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - *proto_msg = item; - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, - GeTensorPtr &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kT)) { - return false; - } - value = std::shared_ptr(new (std::nothrow) - GeTensor(proto_owner, const_cast(proto_attr_val).mutable_t())); - GE_CHK_BOOL_RET_STATUS(value != nullptr, false, "value is nullptr"); - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, - vector &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR, - ListValueItemCheck(t))) { - return false; - } - auto list = const_cast(proto_attr_val).mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - for (auto &item : *(list->mutable_t())) { - std::shared_ptr temp_value = std::shared_ptr(new (std::nothrow) GeTensor(proto_owner, &item)); - GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); - value.push_back(temp_value); - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeAttrValue::BYTES &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { - return false; - } - auto &proto_val = proto_attr_val.bt(); - GE_LOGI_IF(proto_val.size() == 0, "size res is 0."); - value = Buffer::CopyFrom(reinterpret_cast(proto_val.data()), proto_val.size()); - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, - ListValueItemCheck(bt))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.bt()) { - value.push_back(Buffer::CopyFrom((const uint8_t *)item.data(), item.size())); - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - GeAttrValue::NAMED_ATTRS &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { - return false; - } - auto proto_msg = value.named_attrs_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - *proto_msg = proto_attr_val.func(); - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckListType( - proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.na()) { - value.emplace_back(GeAttrValue::NAMED_ATTRS()); - if (value.empty()) { - return false; - } - auto proto_msg = value.back().named_attrs_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - *proto_msg = item; - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ComputeGraphPtr &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kG)) { - return false; - } - ComputeGraphPtr graph = nullptr; - std::shared_ptr graph_def; - graph_def = ComGraphMakeShared(proto_attr_val.g()); - if (graph_def == nullptr) { - GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); - graph_def = nullptr; - return false; // lint !e665 - } else { - ModelSerializeImp imp; - imp.SetProtobufOwner(graph_def); - if (!imp.UnserializeGraph(graph, *graph_def)) { - GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); - return false; - } // lint !e514 - value = graph; - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH, - ListValueItemCheck(g))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.g()) { - std::shared_ptr graph_def; - graph_def = ComGraphMakeShared(item); - if (graph_def == nullptr) { - GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); - graph_def = nullptr; - return false; // lint !e665 - } else { - ComputeGraphPtr graph = nullptr; - ModelSerializeImp imp; - imp.SetProtobufOwner(graph_def); - if (!imp.UnserializeGraph(graph, *graph_def)) { - GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); - return false; - } // lint !e514 - value.push_back(graph); - } - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector> &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kListListInt)) { - return false; - } - - auto &list_listint = proto_attr_val.list_list_int().list_list_i(); - for (auto &list_int : list_listint) { - vector list_item(list_int.list_i().size()); - if (!list_int.list_i().empty()) { - (void)std::copy(list_int.list_i().begin(), list_int.list_i().end(), list_item.begin()); - } - value.push_back(list_item); - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE, - ListValueItemCheck(dt))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.dt()) { - value.emplace_back(static_cast(item)); - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ge::DataType &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kDt)) { - return false; - } - value = static_cast(proto_attr_val.dt()); - return true; -} - -GE_FUNC_HOST_VISIBILITY bool GeAttrValueImp::SetZeroCopyBytes(proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - Buffer &&buffer) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { - return false; - } - auto proto_msg = buffer.data_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - proto_attr_val.set_bt(std::move(*proto_msg->mutable_bt())); - return true; -} - -bool GeAttrValueImp::GetZeroCopyBytes(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, - Buffer &buffer) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { - return false; - } - buffer = Buffer(proto_owner, &const_cast(proto_attr_val)); - return true; -} - -bool GeAttrValueImp::SetZeroCopyListBytes(proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &list_buffer) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_bt(); - for (auto &item : list_buffer) { - auto proto_msg = item.data_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - list->add_bt(std::move(*proto_msg->mutable_bt())); - } - return true; -} - -bool GeAttrValueImp::GetZeroCopyListBytes(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, - vector &list_buffer) { - list_buffer.clear(); - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, - ListValueItemCheck(bt))) { - return false; - } - auto list = const_cast(proto_attr_val).mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - for (auto &item : *(list->mutable_bt())) { - list_buffer.emplace_back(Buffer(proto_owner, &item)); - } - return true; -} - -bool AttrUtils::HasAttr(ConstAttrHolderAdapter &&obj, const string &name) { - if (!obj) { - return false; - } - return obj->HasAttr(name); -} - -#define ATTR_UTILS_SET_IMP(FuncName, Type) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Set##FuncName( \ - AttrHolderAdapter &&obj, const string &name, const Type &value) { \ - proto::AttrDef *proto_attr_val = nullptr; \ - if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \ - return false; \ - } \ - if (!GeAttrValueImp::SetValue(*proto_attr_val, value)) { \ - GELOGW("Set" #FuncName " failed key %s", name.c_str()); \ - return false; \ - } \ - return true; \ - } - -#define ATTR_UTILS_GET_IMP(FuncName, Type) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Get##FuncName(ConstAttrHolderAdapter &&obj, \ - const string &name, Type &value) { \ - const proto::AttrDef *proto_attr_val = nullptr; \ - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \ - return false; \ - } \ - if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value)) { \ - GELOGW("Get" #FuncName " failed key %s", name.c_str()); \ - return false; \ - } \ - return true; \ - } - -#define ATTR_UTILS_SET_GET_IMP(FuncName, Type) \ - ATTR_UTILS_SET_IMP(FuncName, Type) \ - ATTR_UTILS_GET_IMP(FuncName, Type) - -ATTR_UTILS_SET_GET_IMP(Int, int64_t) -ATTR_UTILS_SET_GET_IMP(Float, float) -ATTR_UTILS_SET_GET_IMP(Bool, bool) -ATTR_UTILS_SET_GET_IMP(Str, string) -ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc) -ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr) -ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr) -ATTR_UTILS_SET_IMP(Tensor, GeTensor) -ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS) -ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) -ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) -/*lint -e665*/ -ATTR_UTILS_SET_GET_IMP(ListListInt, vector>) -/*lint +e665*/ - -ATTR_UTILS_SET_GET_IMP(ListInt, vector) -ATTR_UTILS_SET_IMP(ListInt, vector) -ATTR_UTILS_SET_IMP(ListInt, vector) -ATTR_UTILS_SET_GET_IMP(ListFloat, vector) -ATTR_UTILS_SET_GET_IMP(ListBool, vector) -ATTR_UTILS_SET_GET_IMP(ListStr, vector) -ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector) -ATTR_UTILS_SET_IMP(ListTensor, vector) -ATTR_UTILS_SET_IMP(ListTensor, vector) -ATTR_UTILS_SET_IMP(ListTensor, vector) -ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector) -ATTR_UTILS_SET_GET_IMP(ListBytes, vector) -ATTR_UTILS_SET_GET_IMP(ListGraph, vector) -ATTR_UTILS_SET_GET_IMP(ListDataType, vector) // lint !e665 -ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) // lint !e665 - -bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name, - std::initializer_list &&value) { - return SetListTensor(std::move(obj), name, vector(value)); -} - -bool AttrUtils::GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value) { - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - GeTensorPtr tensor; - if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), tensor)) { - return false; - } - value = tensor; - return true; -} - -bool AttrUtils::GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector &value) { - value.clear(); - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - vector tensor; - if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), tensor)) { - return false; - } - value.insert(value.begin(), tensor.begin(), tensor.end()); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::MutableTensor(AttrHolderAdapter &&obj, - const string &name, GeTensorPtr &value) { - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value); -} - -bool AttrUtils::MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector &value) { - value.clear(); - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value); -} - -bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list &&value) { - proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::SetValue(*proto_attr_val, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name, - int32_t &value) { - int64_t int64_val = 0; - if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) { - return false; - } - if (int64_val > INT32_MAX) { - GELOGE(GRAPH_FAILED, "%ld int64_t value cannot cast to int32_t", int64_val); - return false; - } - value = static_cast(int64_val); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name, - uint32_t &value) { - int64_t int64_val = 0; - if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) { - return false; - } - if (int64_val > UINT32_MAX) { - GELOGE(GRAPH_FAILED, "%ld int64_t value cannot cast to uint32_t", int64_val); - return false; - } - value = static_cast(int64_val); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, - const string &name, vector &value) { - value.clear(); - vector int64_list; - if (!GetListInt(std::move(obj), name, int64_list)) { - return false; - } - - for (size_t i = 0; i < int64_list.size(); ++i) { - if (int64_list[i] > INT32_MAX) { - GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]); - return false; - } - } - value.insert(value.begin(), int64_list.begin(), int64_list.end()); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, - const string &name, vector &value) { - value.clear(); - vector int64_list; - if (!GetListInt(std::move(obj), name, int64_list)) { - return false; - } - - for (size_t i = 0; i < int64_list.size(); ++i) { - if (int64_list[i] > UINT32_MAX) { - GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]); - return false; - } - } - value.insert(value.begin(), int64_list.begin(), int64_list.end()); - return true; -} - -bool AttrUtils::SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value) { - if (obj) { - vector bytes_vals; - for (auto &item : value) { - ModelSerialize serialize; - auto buffer = serialize.SerializeOpDesc(item); - if (buffer.GetSize() == 0) { - return false; - } - bytes_vals.push_back(buffer); - } - return SetZeroCopyListBytes(std::move(obj), name, bytes_vals); - } - return false; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListOpDesc(AttrHolderAdapter &&obj, - const string &name, - const vector &value) { - if (obj) { - vector bytes_vals; - for (auto &item : value) { - ModelSerialize serialize; - auto buffer = serialize.SerializeOpDesc(item); - if (buffer.GetSize() == 0) { - return false; - } - bytes_vals.push_back(buffer); - } - return SetZeroCopyListBytes(std::move(obj), name, bytes_vals); - } - return false; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListOpDesc(ConstAttrHolderAdapter &&obj, - const string &name, - vector &value) { - value.clear(); - - vector bytes_vals; - if (!GetZeroCopyListBytes(std::move(obj), name, bytes_vals)) { - return false; - } - for (const auto &item : bytes_vals) { - ModelSerialize serialize; - auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); // lint !e732 - value.push_back(op_desc); - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetZeroCopyBytes(AttrHolderAdapter &&obj, - const string &name, Buffer &&buffer) { - // Value will be moved - proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::SetZeroCopyBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), std::move(buffer)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, - const string &name, Buffer &buffer) { - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::GetZeroCopyBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), buffer); -} - -bool AttrUtils::SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector &list_buffer) { - // Value will be moved - proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::SetZeroCopyListBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), list_buffer); -} - -bool AttrUtils::GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &list_buffer) { - list_buffer.clear(); - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::GetZeroCopyListBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), list_buffer); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) { - if (org_op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "org_op_desc is null"); - return nullptr; - } - std::shared_ptr op_def; - op_def = ComGraphMakeShared(); - if (op_def == nullptr) { - GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); - return nullptr; // lint !e665 - } - ModelSerializeImp imp; - (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); - - imp.SetProtobufOwner(op_def); - OpDescPtr op_desc = nullptr; - GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); - op_desc->extAttrs_ = org_op_desc->extAttrs_; - - // This function may be called by some passes of fusion engine, in this condition, do not need these attribute - if (!op_desc->input_name_idx_.empty()) { - op_desc->input_name_idx_.clear(); - } - if (!op_desc->output_name_idx_.empty()) { - op_desc->output_name_idx_.clear(); - } - if (!op_desc->optional_input_names_.empty()) { - op_desc->optional_input_names_.clear(); - } - - return op_desc; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) { - if (org_op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "org_op_desc is null"); - return nullptr; - } - std::shared_ptr op_def = ComGraphMakeShared(); - if (op_def == nullptr) { - GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); - return nullptr; - } - ModelSerializeImp imp; - (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); - - imp.SetProtobufOwner(op_def); - OpDescPtr op_desc = nullptr; - GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); - - op_desc->extAttrs_ = org_op_desc->extAttrs_; - - op_desc->input_name_idx_.insert(org_op_desc->input_name_idx_.begin(), org_op_desc->input_name_idx_.end()); - op_desc->optional_input_names_.insert(org_op_desc->optional_input_names_.begin(), - org_op_desc->optional_input_names_.end()); - op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); - - op_desc->infer_func_ = org_op_desc->infer_func_; - op_desc->infer_format_func_ = org_op_desc->infer_format_func_; - op_desc->verifier_func_ = org_op_desc->verifier_func_; - - return op_desc; -} -std::string AttrUtils::GetAllAttrsStr(AttrUtils::ConstAttrHolderAdapter &&obj) { - auto holder = obj.get(); - if (holder == nullptr) { - return ""; - } - auto attrs_map = holder->GetAttrMap(); - if (attrs_map.GetProtoMsg() == nullptr) { - return ""; - } - - std::map ordered_attrs; - for (auto &attr : *(attrs_map.GetProtoMsg())) { - ordered_attrs[attr.first] = attr.second.SerializeAsString(); - } - - std::stringstream ss; - for (auto &attr : ordered_attrs) { - ss << attr.first << ":" << attr.second << ";"; - } - return ss.str(); -} -} // namespace ge diff --git a/src/common/graph/ge_tensor.cc b/src/common/graph/ge_tensor.cc deleted file mode 100644 index 65881435..00000000 --- a/src/common/graph/ge_tensor.cc +++ /dev/null @@ -1,1021 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/ge_tensor.h" -#include -#include -#include -#include -#include "debug/ge_attr_define.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/ge_attr_value.h" -#include "graph/model_serialize.h" -#include "proto/ge_ir.pb.h" -#include "utils/attr_utils.h" -#include "utils/ge_ir_utils.h" -#include "utils/tensor_utils.h" -#include "utils/type_utils.h" - -namespace ge { -static const char *const kKeyDataTypeSelfDefined = "__tensor_desc_data_type__"; - -static const std::map kDataTypeMap = { - {DT_UNDEFINED, proto::DT_UNDEFINED}, - {DT_FLOAT, proto::DT_FLOAT}, - {DT_FLOAT16, proto::DT_FLOAT16}, - {DT_INT8, proto::DT_INT8}, - {DT_UINT8, proto::DT_UINT8}, - {DT_INT16, proto::DT_INT16}, - {DT_UINT16, proto::DT_UINT16}, - {DT_INT32, proto::DT_INT32}, - {DT_INT64, proto::DT_INT64}, - {DT_UINT32, proto::DT_UINT32}, - {DT_UINT64, proto::DT_UINT64}, - {DT_BOOL, proto::DT_BOOL}, - {DT_DOUBLE, proto::DT_DOUBLE}, - {DT_DUAL, proto::DT_DUAL}, - {DT_DUAL_SUB_INT8, proto::DT_DUAL_SUB_INT8}, - {DT_DUAL_SUB_UINT8, proto::DT_DUAL_SUB_UINT8}, - {DT_COMPLEX64, proto::DT_COMPLEX64}, - {DT_COMPLEX128, proto::DT_COMPLEX128}, - {DT_QINT8, proto::DT_QINT8}, - {DT_QINT16, proto::DT_QINT16}, - {DT_QINT32, proto::DT_QINT32}, - {DT_QUINT8, proto::DT_QUINT8}, - {DT_QUINT16, proto::DT_QUINT16}, - {DT_RESOURCE, proto::DT_RESOURCE}, - {DT_STRING_REF, proto::DT_STRING_REF}, - {DT_STRING, proto::DT_STRING}, -}; - -static const std::map kDataTypeSelfDefinedMap = { - {DT_DUAL, 13}, {DT_DUAL_SUB_INT8, 14}, {DT_DUAL_SUB_UINT8, 15}, {DT_COMPLEX64, 16}, {DT_COMPLEX128, 17}, - {DT_QINT8, 18}, {DT_QINT16, 19}, {DT_QINT32, 20}, {DT_QUINT8, 21}, {DT_QUINT16, 22}, -}; - -GeShape::GeShape() { shape_def_.InitDefault(); } - -// Default -GeShape::GeShape(std::vector s) : GeShape() { - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto i : s) { - proto_msg->add_dim(i); - } - } -} - -size_t GeShape::GetDimNum() const { - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - if (proto_msg->dim_size() >= 0) { - // check whether contain -2, if true, return -1 - for (auto i : proto_msg->dim()) { - if (i == UNKNOWN_DIM_NUM) { - return 0; - } - } - return proto_msg->dim_size(); - } else { - return 0; - } - } - return 0; -} - -int64_t GeShape::GetDim(size_t idx) const { - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - if (proto_msg->dim_size() > static_cast(idx)) { - return proto_msg->dim(static_cast(idx)); - } - } - return 0; -} - -graphStatus GeShape::SetDim(size_t idx, int64_t value) { - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - auto dims = proto_msg->mutable_dim(); - GE_CHECK_NOTNULL(dims); - if (dims->empty()) { - GELOGE(GRAPH_FAILED, "shape is empty"); - return GRAPH_FAILED; - } - if (static_cast(idx) >= dims->size()) { - GELOGE(GRAPH_FAILED, "idx is out of range"); - return GRAPH_FAILED; - } - proto_msg->set_dim(static_cast(idx), value); - } - return GRAPH_SUCCESS; -} - -std::vector GeShape::GetDims() const { - vector dims; - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto i : proto_msg->dim()) { - dims.push_back(i); - } - } - return dims; -} - -std::string GeShape::ToString() const { - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg == nullptr) { - return ""; - } - - std::stringstream ss; - bool first = true; - for (auto i : proto_msg->dim()) { - if (first) { - first = false; - } else { - ss << ","; - } - ss << i; - } - return ss.str(); -} - -int64_t GeShape::GetShapeSize() const { - int64_t res = 1; - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - if (proto_msg->dim().empty()) { - return 0; - } - for (auto i : proto_msg->dim()) { - // if unknown shape, return -1 - if (i == UNKNOWN_DIM || i == UNKNOWN_DIM_NUM) { - return UNKNOWN_DIM; - } - res *= i; - } - } - return res; -} - -/// -/// @brief Check is unknown shape -/// @return bool -/// /// -bool GeShape::IsUnknownShape() const { - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto i : proto_msg->dim()) { - if (i < 0) { - return true; - } - } - } - return false; -} - -/// -/// @brief Check is a scalar -/// @return bool -/// -bool GeShape::IsScalar() const { - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - return proto_msg->dim().empty(); - } - return false; -} - -const string TENSOR_UTILS_SIZE = "size"; -const string TENSOR_UTILS_WEIGHT_SIZE = "weight_size"; -const string TENSOR_UTILS_REUSE_INPUT = "reuse_input"; -const string TENSOR_UTILS_OUTPUT_TENSOR = "output_tensor"; -const string TENSOR_UTILS_DEVICE_TYPE = "device_type"; -const string TENSOR_UTILS_INPUT_TENSOR = "input_tensor"; -const string TENSOR_UTILS_REAL_DIM_CNT = "real_dim_cnt"; -const string TENSOR_UTILS_REUSE_INPUT_INDEX = "reuse_input_index"; -const string TENSOR_UTILS_DATA_OFFSET = "data_offset"; -const string TENSOR_UTILS_CMPS_SIZE = "cmps_size"; -const string TENSOR_UTILS_CMPS_TAB = "cmps_tab"; -const string TENSOR_UTILS_CMPS_TAB_OFFSET = "cmps_tab_offset"; -const string TENSOR_UTILS_CMPSINFO = "cmps_info"; -const string TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO = "alloffset_quantize_info"; -const string TENSOR_UTILS_RC = "rc"; -const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; -const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; -const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; -const string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; -const string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index"; - -GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} - -GeShape::GeShape(const GeShape &other) : GeShape() { shape_def_.CopyValueFrom(other.shape_def_); } - -GeShape::GeShape(GeShape &&other) : GeShape() { shape_def_.MoveValueFrom(std::move(other.shape_def_)); } - -GeShape &GeShape::operator=(const GeShape &other) { - if (&other != this) { - shape_def_.CopyValueFrom(other.shape_def_); - } - return *this; -} - -GeShape &GeShape::operator=(GeShape &&other) { - if (&other != this) { - shape_def_.CopyValueFrom(std::move(other.shape_def_)); - } - return *this; -} - -GeTensorDesc::GeTensorDesc() { - tensor_descriptor_.InitDefault(); - SetDataType(DT_FLOAT); - Init(); -} - -// Default -GeTensorDesc::GeTensorDesc(GeShape shape, Format format, DataType dt) : GeTensorDesc() { - SetFormat(format); - SetDataType(dt); - ShapeReference() = std::move(shape); -} - -// Default -GeTensorDesc::GeTensorDesc(const GeTensorDesc &desc) : GeTensorDesc() { - tensor_descriptor_.CopyValueFrom(desc.tensor_descriptor_); -} - -// Default -GeTensorDesc::GeTensorDesc(GeTensorDesc &&desc) : GeTensorDesc() { - tensor_descriptor_.MoveValueFrom(std::move(desc.tensor_descriptor_)); -} - -GeTensorDesc::GeTensorDesc(const ProtoMsgOwner &proto_owner, proto::TensorDescriptor *proto_msg) - : tensor_descriptor_(proto_owner, proto_msg) { - if (proto_msg != nullptr && !proto_msg->has_out_attr()) { - proto_msg->set_has_out_attr(true); - - int64_t size = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_SIZE, size); - proto_msg->set_size(size); - - int64_t weight_size = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_WEIGHT_SIZE, weight_size); - proto_msg->set_weight_size(weight_size); - - bool reuse_input = false; - (void)AttrUtils::GetBool(this, TENSOR_UTILS_REUSE_INPUT, reuse_input); - proto_msg->set_reuse_input(reuse_input); - - bool output_tensor = false; - (void)AttrUtils::GetBool(this, TENSOR_UTILS_OUTPUT_TENSOR, output_tensor); - proto_msg->set_output_tensor(output_tensor); - - string device_type = "NPU"; - (void)AttrUtils::GetStr(this, TENSOR_UTILS_DEVICE_TYPE, device_type); - proto_msg->set_device_type(device_type); - - bool input_tensor = false; - (void)AttrUtils::GetBool(this, TENSOR_UTILS_INPUT_TENSOR, input_tensor); - proto_msg->set_input_tensor(input_tensor); - - int64_t real_dim_cnt = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_REAL_DIM_CNT, real_dim_cnt); - proto_msg->set_real_dim_cnt(real_dim_cnt); - - int64_t reuse_input_index = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_REUSE_INPUT_INDEX, reuse_input_index); - proto_msg->set_reuse_input_index(reuse_input_index); - - int64_t data_offset = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_DATA_OFFSET, data_offset); - proto_msg->set_data_offset(data_offset); - - int64_t cmps_size = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_CMPS_SIZE, cmps_size); - proto_msg->set_cmps_size(cmps_size); - - string cmps_tab; - (void)AttrUtils::GetStr(this, TENSOR_UTILS_CMPS_TAB, cmps_tab); - proto_msg->set_cmps_tab(cmps_tab); - - int64_t cmps_tab_offset = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_CMPS_TAB_OFFSET, cmps_tab_offset); - proto_msg->set_cmps_tab_offset(cmps_tab_offset); - } -} - -bool GeTensorDesc::GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const { - const auto &tensor_descriptor = this->tensor_descriptor_.GetProtoMsg(); - const auto &r_tensor_descriptor = r_ge_tensor_desc.tensor_descriptor_.GetProtoMsg(); - if ((tensor_descriptor != nullptr) && (r_tensor_descriptor != nullptr)) { - // Message TensorDescriptor in ge_ir.proto - return ( - IsEqual(tensor_descriptor->name(), r_tensor_descriptor->name(), "TensorDescriptor.name()") && - IsEqual(tensor_descriptor->dtype(), r_tensor_descriptor->dtype(), "TensorDescriptor.dtype()") && - // Message ShapeDef in ge_ir.proto - IsEqual(ToString(tensor_descriptor->shape().dim()), ToString(r_tensor_descriptor->shape().dim()), - "TensorDescriptor.shape().dim()") && - IsEqual(tensor_descriptor->layout(), r_tensor_descriptor->layout(), "TensorDescriptor.layout()") && - IsEqual(tensor_descriptor->has_out_attr(), r_tensor_descriptor->has_out_attr(), - "TensorDescriptor.has_out_attr()") && - IsEqual(tensor_descriptor->size(), r_tensor_descriptor->size(), "TensorDescriptor.size()") && - IsEqual(tensor_descriptor->weight_size(), r_tensor_descriptor->weight_size(), "TensorDescriptor.weight_size()") && - IsEqual(tensor_descriptor->reuse_input(), r_tensor_descriptor->reuse_input(), "TensorDescriptor.reuse_input()") && - IsEqual(tensor_descriptor->output_tensor(), r_tensor_descriptor->output_tensor(), - "TensorDescriptor.output_tensor()") && - IsEqual(tensor_descriptor->device_type(), r_tensor_descriptor->device_type(), "TensorDescriptor.device_type()") && - IsEqual(tensor_descriptor->input_tensor(), r_tensor_descriptor->input_tensor(), - "TensorDescriptor.input_tensor()") && - IsEqual(tensor_descriptor->real_dim_cnt(), r_tensor_descriptor->real_dim_cnt(), - "TensorDescriptor.real_dim_cnt()") && - IsEqual(tensor_descriptor->reuse_input_index(), r_tensor_descriptor->reuse_input_index(), - "TensorDescriptor.reuse_input_index()") && - IsEqual(tensor_descriptor->data_offset(), r_tensor_descriptor->data_offset(), "TensorDescriptor.data_offset()") && - IsEqual(tensor_descriptor->cmps_size(), r_tensor_descriptor->cmps_size(), "TensorDescriptor.cmps_size()") && - IsEqual(tensor_descriptor->cmps_tab(), r_tensor_descriptor->cmps_tab(), "TensorDescriptor.cmps_tab()") && - IsEqual(tensor_descriptor->cmps_tab_offset(), r_tensor_descriptor->cmps_tab_offset(), - "TensorDescriptor.cmps_tab_offset()")); - } else { - return ((tensor_descriptor == nullptr) && (r_tensor_descriptor == nullptr)); - } -} - -bool GeTensorDesc::operator==(const GeTensorDesc &r_ge_tensor_desc) const { - return GeTensorDescAttrsAreEqual(r_ge_tensor_desc); -} - -GeShape &GeTensorDesc::ShapeReference() const { - if (tensor_descriptor_.GetProtoMsg() != nullptr) { - GeShape refShape(tensor_descriptor_.GetProtoOwner(), tensor_descriptor_.GetProtoMsg()->mutable_shape()); - __shape_.RefTo(refShape); - } else { - GeShape refShape(tensor_descriptor_.GetProtoOwner(), nullptr); - __shape_.RefTo(refShape); - } - return __shape_; -} - -void GeTensorDesc::Init() { - SetFormat(FORMAT_ND); - SetOriginFormat(FORMAT_ND); - TensorUtils::SetDeviceType(*this, DeviceType::NPU); - if (tensor_descriptor_.GetProtoMsg() == nullptr) { - GELOGE(GRAPH_FAILED, "ProtoType nullptr."); - return; - } - tensor_descriptor_.GetProtoMsg()->set_has_out_attr(true); -} - -ProtoAttrMapHelper GeTensorDesc::MutableAttrMap() { - if (tensor_descriptor_.GetProtoMsg() != nullptr) { - return ProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), tensor_descriptor_.GetProtoMsg()->mutable_attr()); - } - return ProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), nullptr); -} - -ConstProtoAttrMapHelper GeTensorDesc::GetAttrMap() const { - if (tensor_descriptor_.GetProtoMsg() != nullptr) { - return ConstProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), - tensor_descriptor_.GetProtoMsg()->mutable_attr()); - } - return ConstProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), nullptr); -} - -void GeTensorDesc::Update(GeShape shape, Format format, DataType dt) { - ShapeReference() = std::move(shape); - SetFormat(format); - SetDataType(dt); -} -GeShape GeTensorDesc::GetShape() const { return ShapeReference(); } - -GeShape &GeTensorDesc::MutableShape() { return ShapeReference(); } - -void GeTensorDesc::SetShape(GeShape shape) { ShapeReference() = std::move(shape); } - -// set shape with -2, it stand for unknown shape -void GeTensorDesc::SetUnknownDimNumShape() { SetShape(GeShape({UNKNOWN_DIM_NUM})); } - -// for unknown shape -graphStatus GeTensorDesc::SetShapeRange(const std::vector> &range) { - std::vector> shape_range; - for (const auto &ele : range) { - shape_range.emplace_back(std::vector({ele.first, ele.second})); - } - auto ret = AttrUtils::SetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); - return ret ? GRAPH_SUCCESS : GRAPH_FAILED; -} -graphStatus GeTensorDesc::GetShapeRange(std::vector> &range) const { - std::vector> shape_range; - (void)AttrUtils::GetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); - - for (const auto &ele : shape_range) { - // here must be only two elemenet because pair - if (ele.size() != 2) { - GELOGE(GRAPH_FAILED, "shape_range must contain only 2 value but really is %lu", ele.size()); - return GRAPH_FAILED; - } - std::pair pair({ele[0], ele[1]}); - range.emplace_back(pair); - } - - return GRAPH_SUCCESS; -} - -GeShape GeTensorDesc::GetOriginShape() const { - vector origin_shape; - if (!AttrUtils::GetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape)) { - return GeShape(); - } - return GeShape(origin_shape); -} - -void GeTensorDesc::SetOriginShape(const GeShape &origin_shape) { - std::vector origin_shape_tmp = origin_shape.GetDims(); - (void)AttrUtils::SetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape_tmp); -} - -Format GeTensorDesc::GetFormat() const { - auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - return TypeUtils::SerialStringToFormat(tensor_descriptor_msg->layout()); - } - return FORMAT_RESERVED; -} - -void GeTensorDesc::SetFormat(Format format) { - auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_layout(TypeUtils::FormatToSerialString(format)); - } -} - -void GeTensorDesc::SetName(const std::string &name) { - auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_name(name); - return; - } - GELOGW("[SetName]tensor_descriptor_msg is null."); -} - -const std::string GeTensorDesc::GetName() const { - auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - return tensor_descriptor_msg->name(); - } - GELOGW("[GetName]tensor_descriptor_msg is null."); - return ""; -} - -Format GeTensorDesc::GetOriginFormat() const { - std::string origin_format_str; - if (!AttrUtils::GetStr(this, TENSOR_UTILS_ORIGIN_FORMAT, origin_format_str)) { - // Can not get the certificate and it's not set, return directly - return FORMAT_RESERVED; - } - if (origin_format_str == "RESERVED") { - return FORMAT_RESERVED; - } - return TypeUtils::SerialStringToFormat(origin_format_str); -} - -void GeTensorDesc::SetOriginFormat(Format origin_format) { - std::string origin_format_str = "RESERVED"; - if (origin_format != FORMAT_RESERVED) { - origin_format_str = TypeUtils::FormatToSerialString(origin_format); - } - (void)AttrUtils::SetStr(this, TENSOR_UTILS_ORIGIN_FORMAT, origin_format_str); -} - -DataType GeTensorDesc::GetDataType() const { - auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg == nullptr) { - return DT_UNDEFINED; - } - auto &attr_map = *(tensor_descriptor_msg->mutable_attr()); - // Data type - auto it_data_type = attr_map.find(kKeyDataTypeSelfDefined); - if (it_data_type != attr_map.end()) { - int64_t data_type_proto = it_data_type->second.i(); - for (auto it : kDataTypeSelfDefinedMap) { - if (it.second == data_type_proto) { - return it.first; - } - } - } else { - auto data_type_proto = tensor_descriptor_msg->dtype(); - for (auto it : kDataTypeMap) { - if (it.second == data_type_proto) { - return it.first; - } - } - } - return DT_UNDEFINED; -} - -void GeTensorDesc::SetDataType(DataType dataType) { - auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg == nullptr) { - return; - } - auto &attr_maps = *(tensor_descriptor_msg->mutable_attr()); - (void)attr_maps.erase(kKeyDataTypeSelfDefined); - - // Data type - auto it = kDataTypeMap.find(dataType); - if (it != kDataTypeMap.end()) { - tensor_descriptor_msg->set_dtype(it->second); - return; - } - auto it2 = kDataTypeSelfDefinedMap.find(dataType); - if (it2 != kDataTypeSelfDefinedMap.end()) { - attr_maps[kKeyDataTypeSelfDefined].set_i(it2->second); - } -} - -void GeTensorDesc::SetOriginDataType(DataType origin_data_type) { - std::string origin_data_type_str = "RESERVED"; - if (origin_data_type != DT_UNDEFINED) { - origin_data_type_str = TypeUtils::DataTypeToSerialString(origin_data_type); - } - (void)AttrUtils::SetStr(this, TENSOR_UTILS_ORIGIN_DATA_TYPE, origin_data_type_str); -} - -DataType GeTensorDesc::GetOriginDataType() const { - std::string origin_data_type_str; - if (!AttrUtils::GetStr(this, TENSOR_UTILS_ORIGIN_DATA_TYPE, origin_data_type_str)) { - return DT_UNDEFINED; - } - if (origin_data_type_str == "RESERVED") { - return DT_UNDEFINED; - } - return TypeUtils::SerialStringToDataType(origin_data_type_str); -} - -std::vector GeTensorDesc::GetRefPortIndex() const { - vector ref_port_index; - (void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index); - return ref_port_index; -} - -void GeTensorDesc::SetRefPortByIndex(const std::vector &index) { - (void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index); -} - -graphStatus GeTensorDesc::IsValid() const { - auto dtype = this->GetDataType(); - auto format = this->GetFormat(); - if (dtype == DT_UNDEFINED && format == FORMAT_RESERVED) { - return GRAPH_PARAM_INVALID; - } - return GRAPH_SUCCESS; -} - -GeTensorDesc GeTensorDesc::Clone() const { return *this; } - -GeTensorDesc &GeTensorDesc::operator=(const GeTensorDesc &desc) { - if (&desc != this) { - tensor_descriptor_.CopyValueFrom(desc.tensor_descriptor_); - } - return *this; -} - -GeTensorDesc &GeTensorDesc::operator=(GeTensorDesc &&desc) { - if (&desc != this) { - tensor_descriptor_.CopyValueFrom(std::move(desc.tensor_descriptor_)); - } - return *this; -} - -GeTensor::GeTensor::GeTensor() { - tensor_def_.InitDefault(); - // Default init desc - DescReference() = GeTensorDesc(); -} - -GeTensor::GeTensor(const GeTensorDesc &tensor_desc) : GeTensor() { DescReference() = tensor_desc; } - -GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const vector &data) : GeTensor() { - DescReference() = tensor_desc; - auto proto_msg = tensor_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_data(data.data(), data.size()); - } -} - -GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const uint8_t *data, size_t size) : GeTensor() { - DescReference() = tensor_desc; - auto proto_msg = tensor_def_.GetProtoMsg(); - if (proto_msg != nullptr && data != nullptr) { - proto_msg->set_data(data, size); - } -} - -GeTensor::GeTensor(GeTensorDesc &&tensor_desc, vector &&data) : GeTensor() { - DescReference() = std::move(tensor_desc); - auto proto_msg = tensor_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_data(data.data(), data.size()); - } -} - -GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const Buffer &data) : GeTensor() { - DescReference() = tensor_desc; - auto proto_msg = tensor_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - if (data.size() == 0) { - GELOGI("GetSize res is 0."); - } - if (data.data() == nullptr) { - GELOGI("data addr is null."); - } - proto_msg->set_data(data.GetData(), data.GetSize()); - } -} - -GeTensor::GeTensor(const ProtoMsgOwner &proto_owner, proto::TensorDef *proto_msg) - : tensor_def_(proto_owner, proto_msg) {} - -GeTensorDesc GeTensor::GetTensorDesc() const { return DescReference(); } - -GeTensorDesc &GeTensor::MutableTensorDesc() { return DescReference(); } - -GeTensorDesc &GeTensor::DescReference() const { - if (tensor_def_.GetProtoMsg() != nullptr) { - GeTensorDesc tensor_desc(tensor_def_.GetProtoOwner(), tensor_def_.GetProtoMsg()->mutable_desc()); - __desc_.RefTo(tensor_desc); - } else { - GeTensorDesc tensor_desc(tensor_def_.GetProtoOwner(), nullptr); - __desc_.RefTo(tensor_desc); - } - return __desc_; -} - -void GeTensor::SetTensorDesc(const GeTensorDesc &tensor_desc) { DescReference() = tensor_desc; } - -const Buffer GeTensor::GetData() const { - auto proto_msg = tensor_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - return Buffer(tensor_def_.GetProtoOwner(), proto_msg->mutable_data()); - } - return Buffer(); -} - -Buffer GeTensor::MutableData() { - auto proto_msg = tensor_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - return Buffer(tensor_def_.GetProtoOwner(), proto_msg->mutable_data()); - } - return Buffer(); -} - -graphStatus GeTensor::SetData(vector &&data) { - auto proto_msg = tensor_def_.GetProtoMsg(); - GE_CHECK_NOTNULL(proto_msg); - proto_msg->set_data(data.data(), data.size()); - return GRAPH_SUCCESS; -} - -graphStatus GeTensor::SetData(const vector &data) { - auto proto_msg = tensor_def_.GetProtoMsg(); - GE_CHECK_NOTNULL(proto_msg); - proto_msg->set_data(data.data(), data.size()); - return GRAPH_SUCCESS; -} - -graphStatus GeTensor::SetData(const uint8_t *data, size_t size) { - GE_CHECK_NOTNULL(data); - auto proto_msg = tensor_def_.GetProtoMsg(); - GE_CHECK_NOTNULL(proto_msg); - proto_msg->set_data(data, size); - return GRAPH_SUCCESS; -} - -graphStatus GeTensor::SetData(const Buffer &data) { - auto proto_msg = tensor_def_.GetProtoMsg(); - GE_CHECK_NOTNULL(proto_msg); - if (data.size() == 0) { - GELOGI("GetSize res is 0."); - } - if (data.data() == nullptr) { - GELOGI("data addr is null."); - } - proto_msg->set_data(data.data(), data.size()); - return GRAPH_SUCCESS; -} - -GeTensor GeTensor::Clone() const { - GeTensor tensor; - tensor.tensor_def_.CopyValueFrom(tensor_def_); - return tensor; -} - -GeTensor::GeTensor(const GeTensor &other) { tensor_def_ = other.tensor_def_; } - -GeTensor &GeTensor::operator=(const GeTensor &other) { - if (&other != this) { - tensor_def_ = other.tensor_def_; - } - return *this; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc, - int64_t &size) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - GE_CHECK_NOTNULL(tensor_descriptor_msg); - size = static_cast(tensor_descriptor_msg->size()); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, int64_t size) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_size(size); - } -} - -uint32_t TensorUtils::GetWeightSize(const GeTensorDesc &tensor_desc) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - return static_cast(tensor_descriptor_msg->weight_size()); - } - return 0; -} - -uint32_t TensorUtils::GetWeightSize(const GeTensor &tensor) { return GetWeightSize(tensor.GetTensorDesc()); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t TensorUtils::GetWeightSize(const ConstGeTensorPtr &tensor_ptr) { - if (tensor_ptr == nullptr) { - return 0; - } - return GetWeightSize(*tensor_ptr); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint8_t *TensorUtils::GetWeightAddr(const ConstGeTensorPtr &tensor_ptr, - uint8_t *base) { - if (tensor_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "tensor_ptr is null."); - return nullptr; - } - return GetWeightAddr(*tensor_ptr, base); -} - -uint8_t *TensorUtils::GetWeightAddr(const GeTensor &tensor, uint8_t *base) { - if (base == nullptr) { - GELOGE(GRAPH_FAILED, "base is null."); - return nullptr; - } - int64_t weight_data_offset = 0; - if (GetDataOffset(tensor.GetTensorDesc(), weight_data_offset) != GRAPH_SUCCESS) return nullptr; - - if (weight_data_offset == 0) { - // The weight of offset 0 is still in const op, still get from ATTR_NAME_WEIGHTS. - return const_cast(tensor.GetData().data()); - } - - return base + weight_data_offset; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetWeightSize(GeTensorDesc &tensor_desc, - uint32_t size) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_weight_size(size); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetReuseInput(const GeTensorDesc &tensor_desc, - bool &flag) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - GE_CHECK_NOTNULL(tensor_descriptor_msg); - flag = tensor_descriptor_msg->reuse_input(); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInput(GeTensorDesc &tensor_desc, bool flag) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_reuse_input(flag); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetOutputTensor(const GeTensorDesc &tensor_desc, - bool &flag) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - GE_CHECK_NOTNULL(tensor_descriptor_msg); - flag = tensor_descriptor_msg->output_tensor(); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetOutputTensor(GeTensorDesc &tensor_desc, bool flag) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_output_tensor(flag); - } -} - -static map device_to_str_map{ - {0, "NPU"}, - {1, "CPU"}, -}; -static map str_to_device_map{ - {"NPU", 0}, - {"CPU", 1}, -}; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDeviceType(const GeTensorDesc &tensor_desc, - DeviceType &type) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - GE_CHECK_NOTNULL(tensor_descriptor_msg); - string type_str = tensor_descriptor_msg->device_type(); - type = DeviceType(str_to_device_map[type_str]); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDeviceType(GeTensorDesc &tensor_desc, - DeviceType type) { - auto type_str = device_to_str_map[type]; - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_device_type(type_str); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetInputTensor(const GeTensorDesc &tensor_desc, - bool &flag) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - GE_CHECK_NOTNULL(tensor_descriptor_msg); - flag = tensor_descriptor_msg->input_tensor(); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetInputTensor(GeTensorDesc &tensor_desc, bool flag) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_input_tensor(flag); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRealDimCnt(const GeTensorDesc &tensor_desc, - uint32_t &cnt) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - GE_CHECK_NOTNULL(tensor_descriptor_msg); - cnt = static_cast(tensor_descriptor_msg->real_dim_cnt()); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRealDimCnt(GeTensorDesc &tensor_desc, - uint32_t cnt) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_real_dim_cnt(cnt); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::GetReuseInputIndex(const GeTensorDesc &tensor_desc, uint32_t &idx) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - GE_CHECK_NOTNULL(tensor_descriptor_msg); - - idx = static_cast(tensor_descriptor_msg->reuse_input_index()); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInputIndex(GeTensorDesc &tensor_desc, - uint32_t idx) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_reuse_input_index(idx); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDataOffset(const GeTensorDesc &tensor_desc, - int64_t &offset) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - offset = tensor_descriptor_msg->data_offset(); - return GRAPH_SUCCESS; - } else { - GELOGW("tensor_descriptor_msg is nullptr."); - return GRAPH_FAILED; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDataOffset(GeTensorDesc &tensor_desc, - int64_t offset) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_data_offset(offset); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetCmpsSize(const GeTensorDesc &tensor_desc, - uint32_t &cmp_size) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - cmp_size = static_cast(tensor_descriptor_msg->cmps_size()); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsSize(GeTensorDesc &tensor_desc, - uint32_t cmp_size) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_cmps_size(cmp_size); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetCmpsTab(const GeTensorDesc &tensor_desc, - vector &vec) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - string str = tensor_descriptor_msg->cmps_tab(); - vec.assign(str.begin(), str.end()); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsTab(GeTensorDesc &tensor_desc, - const uint8_t *data, size_t size) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - GE_CHK_BOOL_EXEC(data != nullptr, return, "data is null."); - string str((const char *)data, size); - tensor_descriptor_msg->set_cmps_tab(str); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::GetCmpsTabOffset(const GeTensorDesc &tensor_desc, int64_t &tab_offset) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tab_offset = tensor_descriptor_msg->cmps_tab_offset(); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsTabOffset(GeTensorDesc &tensor_desc, - int64_t tab_offset) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_cmps_tab_offset(tab_offset); - } -} - -graphStatus TensorUtils::GetCmpsInfo(const GeTensorDesc &tensor_desc, CompressInfo &info) { - GeAttrValue attr_value; - if (tensor_desc.GetAttr(TENSOR_UTILS_CMPSINFO, attr_value) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - return attr_value.GetValue(info); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsInfo(GeTensorDesc &tensor_desc, - const CompressInfo &info) { - (void)tensor_desc.SetAttr(TENSOR_UTILS_CMPSINFO, GeAttrValue::CreateFrom(info)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool TensorUtils::HasAlloffsetQuantizeInfo( - const GeTensorDesc &tensor_desc) { - return tensor_desc.HasAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::GetAlloffsetQuantizeInfo(const GeTensorDesc &tensor_desc, AllOffsetQuantizeInfo &info) { - GeAttrValue attr_value; - if (tensor_desc.GetAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO, attr_value) != GRAPH_SUCCESS) { - GELOGW("get attr alloffset_quantize_info fail."); - } - return attr_value.GetValue(info); -} - -void TensorUtils::SetAlloffsetQuantizeInfo(GeTensorDesc &tensor_desc, const AllOffsetQuantizeInfo &info) { - (void)tensor_desc.SetAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO, GeAttrValue::CreateFrom(info)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRC(const GeTensorDesc &tensor_desc, - uint32_t &rc) { - return AttrUtils::GetInt(&tensor_desc, TENSOR_UTILS_RC, rc) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRC(GeTensorDesc &tensor_desc, uint32_t rc) { - (void)AttrUtils::SetInt(&tensor_desc, TENSOR_UTILS_RC, rc); -} -} // namespace ge diff --git a/src/common/graph/graph.cc b/src/common/graph/graph.cc deleted file mode 100644 index fc30e9d6..00000000 --- a/src/common/graph/graph.cc +++ /dev/null @@ -1,384 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "external/graph/graph.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_op_types.h" -#include "graph/model.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" - -using std::map; -using std::pair; -using std::string; -using std::vector; - -namespace ge { -class GraphImpl { - public: - friend class GraphUtils; - GraphImpl(const GraphImpl &) = delete; - GraphImpl &operator=(const GraphImpl &) = delete; - - explicit GraphImpl(const std::string &name) : name_(name) {} - - ~GraphImpl() { - if (IsValid()) { - if (compute_graph_ != nullptr) { - GraphUtils::BreakConnect(compute_graph_->GetAllNodesInfo()); - } - } - for (const auto &it : op_list_) { - Operator op = it.second; - op.BreakConnect(); - } - } - - graphStatus SetInputs(const std::vector &inputs) { - compute_graph_ = GraphUtils::CreateGraphFromOperator(name_, inputs); - GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "Build Graph failed."); - GE_CHK_BOOL_RET_STATUS(inputs.size() != 0, GRAPH_FAILED, "set input NULL."); - compute_graph_->SetInputSize(static_cast(inputs.size())); - return GRAPH_SUCCESS; - } - - graphStatus SetOutputs(const std::vector &outputs) { - if (compute_graph_ == nullptr) { - GELOGE(GRAPH_FAILED, "set ComputeGraph failed."); - return GRAPH_FAILED; - } - if (outputs.empty()) { - GELOGW("set outputs size is 0."); - return GRAPH_SUCCESS; - } - - // Construct special output node - std::vector>> output_indexs; - for (size_t i = 0; i < outputs.size(); ++i) { - output_indexs.emplace_back(outputs[i], std::vector{}); - } - - graphStatus ret = SetOutputs(output_indexs); - return ret; - } - - graphStatus SetOutputs(const std::vector>> &output_indexs) { - if (compute_graph_ == nullptr) { - GELOGE(GRAPH_FAILED, "set ComputeGraph failed."); - return GRAPH_FAILED; - } - if (output_indexs.empty()) { - GELOGW("set outputs size is 0."); - return GRAPH_SUCCESS; - } - - // Construct special output node - std::vector> output_nodes; - for (const auto &item : output_indexs) { - const Operator &output = item.first; - const vector &indexs = item.second; - ge::NodePtr node = compute_graph_->FindNode(output.GetName()); - if (node == nullptr) { - GELOGW("user designated out_node [%s] not exist in graph, will ignored!", output.GetName().c_str()); - continue; - } - - ge::OpDescPtr tmp_op_ptr = node->GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue); - size_t out_size = tmp_op_ptr->GetOutputsSize(); - if (indexs.empty()) { - for (size_t i = 0; i < out_size; ++i) { - output_name_ += output.GetName() + ":" + std::to_string(i) + ";"; - output_nodes.emplace_back(node, i); - } - } else { - for (size_t i = 0; i < indexs.size(); ++i) { - if (indexs[i] >= out_size) { - GELOGW("index[%zu] is not belong to out_node[%s]", indexs[i], output.GetName().c_str()); - } else { - output_name_ += output.GetName() + ":" + std::to_string(i) + ";"; - output_nodes.emplace_back(node, indexs[i]); - } - } - } - } - - // Del last ";" - if (!output_name_.empty()) { - output_name_ = output_name_.substr(0, output_name_.length() - 1); - } - compute_graph_->SetUserDefOutput(output_name_); - compute_graph_->SetOutputSize(static_cast(output_indexs.size())); - compute_graph_->SetGraphOutNodesInfo(output_nodes); - return GRAPH_SUCCESS; - } - - graphStatus SetOutputs(const std::vector> &outputs) { - GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild."); - GE_CHK_BOOL_EXEC_INFO(outputs.size() != 0, return GRAPH_SUCCESS, "set outputs size is 0."); - - // Construct specified output - std::vector> output_nodes; - for (auto item : outputs) { - ge::NodePtr node = compute_graph_->FindNode(item.first.GetName()); - if (node == nullptr) { - GELOGE(GRAPH_FAILED, " Warning, user designated out_node (%s) not exist in graph, this out_node ignored!", - item.first.GetName().c_str()); - return GRAPH_FAILED; - } - ge::OpDescPtr tmp_op_ptr = node->GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue); - size_t out_size = tmp_op_ptr->GetOutputsSize(); - - if (item.second.empty()) { - for (size_t i = 0; i < out_size; ++i) { - output_name_ += item.first.GetName() + ":" + std::to_string(i) + ";"; - output_nodes.push_back(std::make_pair(node, i)); - } - } else { - int32_t index = tmp_op_ptr->GetOutputIndexByName(item.second); - if (index < 0) { - GELOGE(GRAPH_FAILED, - " Warning, user designated out_node (%s):(%s) not exist in graph, this out_node ignored!", - item.first.GetName().c_str(), item.second.c_str()); - return GRAPH_FAILED; - } - output_name_ += item.first.GetName() + ":" + std::to_string(index) + ";"; - output_nodes.push_back(std::make_pair(node, index)); - } - } - // Del last ";" - if (!output_name_.empty()) { - output_name_ = output_name_.substr(0, output_name_.length() - 1); - } - compute_graph_->SetOutputSize(static_cast(outputs.size())); - compute_graph_->SetGraphOutNodesInfo(output_nodes); - GELOGI("********************SetOutputs Success***********************"); - GE_IF_BOOL_EXEC(!output_name_.empty(), GELOGI(" NetOutputs: (%s)", output_name_.c_str())); - - return GRAPH_SUCCESS; - } - - graphStatus SetTargets(const std::vector &targets) { - GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild."); - GE_CHK_BOOL_EXEC_INFO(targets.size() != 0, return GRAPH_SUCCESS, "set targets size is 0."); - - std::vector target_nodes; - for (auto item : targets) { - ge::NodePtr node = compute_graph_->FindNode(item.GetName()); - if (node == nullptr) { - GELOGW(" Warning, user designated target_node (%s) not exist in graph, this target_node ignored!", - item.GetName().c_str()); - continue; - } - target_nodes.push_back(node); - } - compute_graph_->SetGraphTargetNodesInfo(target_nodes); - return GRAPH_SUCCESS; - } - bool IsValid() const { return (compute_graph_ != nullptr); } - - graphStatus AddOp(const ge::Operator &op) { - std::pair::iterator, bool> ret; - ret = op_list_.emplace(std::pair(op.GetName(), op)); - GE_CHK_BOOL_RET_STATUS(ret.second != false, GRAPH_FAILED, "the op have added before, op name:%s.", - op.GetName().c_str()); - return GRAPH_SUCCESS; - } - - graphStatus GetAllOpName(std::vector &op_name) const { - for (const auto &it : op_list_) { - op_name.push_back(it.second.GetName()); - } - return GRAPH_SUCCESS; - } - - graphStatus FindOpByName(const string &name, ge::Operator &op) const { - auto it = op_list_.find(name); - GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str()); - op = it->second; - return GRAPH_SUCCESS; - } - - graphStatus FindOpByType(const string &type, std::vector &ops) const { - for (auto &op : op_list_) { - auto op_type = op.second.GetOpType(); - if (op_type == type) { - ops.push_back(op.second); - continue; - } - if (op_type == ge::FRAMEWORKOP) { - op.second.GetAttr(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, op_type); - if (op_type == type) { - ops.push_back(op.second); - } - } - } - return GRAPH_SUCCESS; - } - - void SetNeedIteration(bool need_iteration) { - if (compute_graph_ == nullptr) { - GELOGE(GRAPH_FAILED, "Set need iteration failed, as compute graph is null."); - return; - } - compute_graph_->SetNeedIteration(need_iteration); - } - - const std::string &GetName() const { return name_; } - - private: - std::string name_; - std::string output_name_; - std::map op_list_; - ComputeGraphPtr compute_graph_{nullptr}; -}; - -Graph::Graph(const std::string &name) { - impl_ = ComGraphMakeShared(name); - if (impl_ == nullptr) { - GELOGW("GraphImpl make shared failed, impl_ is nullptr"); - } -} - -graphStatus Graph::AddOp(const ge::Operator &op) { - GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, "AddOp failed: graph can not be used, impl is nullptr."); - return impl_->AddOp(op); -} - -graphStatus Graph::GetAllOpName(std::vector &op_name) const { - GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, - "GetAllOpName failed: graph can not be used, impl is nullptr."); - return impl_->GetAllOpName(op_name); -} - -graphStatus Graph::FindOpByName(const std::string &name, Operator &op) const { - Operator op_find_op_def("NULL"); - op = op_find_op_def; - GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, - "FindOpByName failed: graph can not be used, impl is nullptr."); - return impl_->FindOpByName(name, op); -} - -graphStatus Graph::FindOpByType(const string &type, std::vector &ops) const { - GE_CHECK_NOTNULL(impl_); - return impl_->FindOpByType(type, ops); -} - -Graph &Graph::SetInputs(const vector &inputs) { - GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetInputs failed: graph can not be used, impl is nullptr.") - GE_CHK_BOOL_EXEC(inputs.size() > 0, return *this, "SetInputs failed: input operator size can not be 0."); - (void)impl_->SetInputs(inputs); - return *this; -} - -Graph &Graph::SetOutputs(const vector &outputs) { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr."); - return *this; - } - (void)impl_->SetOutputs(outputs); - return *this; -} - -Graph &Graph::SetOutputs(const std::vector>> &output_indexs) { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr."); - return *this; - } - (void)impl_->SetOutputs(output_indexs); - return *this; -} - -Graph &Graph::SetOutputs(const std::vector> &outputs) { - GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetOutputs failed: graph can not be used, impl is nullptr.") - (void)impl_->SetOutputs(outputs); - return *this; -} - -Graph &Graph::SetTargets(const vector &targets) { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "SetTargets failed: graph can not be used, impl is nullptr."); - return *this; - } - (void)impl_->SetTargets(targets); - return *this; -} - -bool Graph::IsValid() const { - if (impl_ == nullptr) { - return false; - } - return impl_->IsValid(); -} - -void Graph::SetNeedIteration(bool need_iteration) { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "Set need iteration failed, as impl is null."); - return; - } - impl_->SetNeedIteration(need_iteration); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::GetComputeGraph(const ge::Graph &graph) { - GE_CHK_BOOL_EXEC_NOLOG(graph.IsValid(), return nullptr); - return graph.impl_->compute_graph_; -} - -graphStatus Graph::SaveToFile(const string &file_name) const { - Model model = Model(); - model.SetGraph(*this); - return model.SaveToFile(file_name); -} - -graphStatus Graph::LoadFromFile(const string &file_name) { - Model model = Model(); - graphStatus ret = model.LoadFromFile(file_name); - if (ret != GRAPH_SUCCESS) { - return ret; - } - *this = model.GetGraph(); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string &Graph::GetName() const { return impl_->GetName(); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph -GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) { - GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return Graph("")); - - auto name = compute_graph->GetName(); - auto graph = Graph(name); - - GE_CHK_BOOL_EXEC_NOLOG(graph.impl_ != nullptr, return graph); - graph.impl_->compute_graph_ = compute_graph; - - return graph; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) { - GE_CHECK_NOTNULL(graph.impl_); - GE_CHECK_NOTNULL(graph.impl_->compute_graph_); - - graph.impl_->op_list_.clear(); - for (const auto &node : graph.impl_->compute_graph_->GetDirectNode()) { - graph.impl_->op_list_[node->GetName()] = OpDescUtils::CreateOperatorFromNode(node); - } - return SUCCESS; -} -} // namespace ge diff --git a/src/common/graph/graph.mk b/src/common/graph/graph.mk deleted file mode 100644 index 4ea84919..00000000 --- a/src/common/graph/graph.mk +++ /dev/null @@ -1,294 +0,0 @@ -LOCAL_PATH := $(call my-dir) -include $(LOCAL_PATH)/stub/Makefile -COMMON_LOCAL_SRC_FILES := \ - ./proto/om.proto \ - ./proto/ge_ir.proto \ - ./proto/ge_onnx.proto \ - ./proto/insert_op.proto \ - ./proto/task.proto \ - ./proto/fwk_adapter.proto \ - ./proto/op_mapping_info.proto \ - ./proto/dump_task.proto \ - ./anchor.cc \ - ./ge_attr_value.cc \ - ./attr_value.cc \ - ./buffer.cc \ - ./compute_graph.cc \ - ./graph.cc \ - ./inference_context.cc \ - ./shape_refiner.cc \ - ./format_refiner.cc \ - ./ref_relation.cc \ - ./model.cc \ - ./model_serialize.cc \ - ./node.cc \ - ./op_desc.cc \ - ./operator.cc \ - ./operator_factory.cc \ - ./operator_factory_impl.cc \ - ./ge_attr_define.cc \ - ./ge_tensor.cc \ - ./detail/attributes_holder.cc \ - ./utils/anchor_utils.cc \ - ./utils/tuning_utils.cc \ - ./utils/graph_utils.cc \ - ./utils/ge_ir_utils.cc \ - ./utils/op_desc_utils.cc \ - ./utils/type_utils.cc \ - ./utils/tensor_utils.cc \ - ./tensor.cc \ - ./debug/graph_debug.cc \ - ./opsproto/opsproto_manager.cc \ - ../ops/op_imp.cpp \ - option/ge_context.cc \ - option/ge_local_context.cc \ - ./runtime_inference_context.cc \ - ./utils/node_utils.cc \ - -COMMON_LOCAL_C_INCLUDES := \ - proto/om.proto \ - proto/ge_ir.proto \ - proto_inner/ge_onnx.proto \ - proto/insert_op.proto \ - proto/task.proto \ - proto/fwk_adapter.proto \ - proto/op_mapping_info.proto \ - proto/dump_task.proto \ - inc \ - inc/external \ - inc/external/graph \ - inc/graph \ - inc/common \ - common \ - common/graph \ - third_party/protobuf/include \ - libc_sec/include \ - ops/built-in/op_proto/inc \ - - -#compiler for host -include $(CLEAR_VARS) -LOCAL_MODULE := libgraph - -LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 -LOCAL_CPPFLAGS += -fexceptions - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) - -LOCAL_SHARED_LIBRARIES := \ - libc_sec \ - libprotobuf \ - libslog \ - liberror_manager \ - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_HOST_SHARED_LIBRARY) - -#compiler for host -include $(CLEAR_VARS) -LOCAL_MODULE := stub/libgraph - -LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 -LOCAL_CPPFLAGS += -fexceptions - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := \ - ../../out/graph/lib64/stub/graph.cc \ - ../../out/graph/lib64/stub/operator.cc \ - ../../out/graph/lib64/stub/tensor.cc \ - ../../out/graph/lib64/stub/operator_factory.cc \ - - -LOCAL_SHARED_LIBRARIES := - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_HOST_SHARED_LIBRARY) - -#compiler for host -include $(CLEAR_VARS) -LOCAL_MODULE := fwk_stub/libgraph - -LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 -LOCAL_CPPFLAGS += -fexceptions - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := \ - ../../out/graph/lib64/stub/attr_value.cc \ - ../../out/graph/lib64/stub/graph.cc \ - ../../out/graph/lib64/stub/operator.cc \ - ../../out/graph/lib64/stub/operator_factory.cc \ - ../../out/graph/lib64/stub/tensor.cc \ - ../../out/graph/lib64/stub/inference_context.cc \ - - -LOCAL_SHARED_LIBRARIES := - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_HOST_SHARED_LIBRARY) - -#compiler for device -include $(CLEAR_VARS) -LOCAL_MODULE := libgraph - -LOCAL_CFLAGS += -O2 - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) - -LOCAL_SHARED_LIBRARIES := \ - libc_sec \ - libprotobuf \ - libslog \ - liberror_manager \ - -LOCAL_LDFLAGS := -lrt -ldl - -ifeq ($(device_os),android) -LOCAL_LDFLAGS := -ldl -endif - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_SHARED_LIBRARY) - -#compiler for device -include $(CLEAR_VARS) -LOCAL_MODULE := stub/libgraph - -LOCAL_CFLAGS += -O2 - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := \ - ../../out/graph/lib64/stub/graph.cc \ - ../../out/graph/lib64/stub/operator.cc \ - ../../out/graph/lib64/stub/tensor.cc \ - ../../out/graph/lib64/stub/operator_factory.cc \ - - -LOCAL_SHARED_LIBRARIES := - -LOCAL_LDFLAGS := -lrt -ldl - -ifeq ($(device_os),android) -LOCAL_LDFLAGS := -ldl -endif - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_SHARED_LIBRARY) - -#compiler for device -include $(CLEAR_VARS) -LOCAL_MODULE := fwk_stub/libgraph - -LOCAL_CFLAGS += -O2 - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := \ - ../../out/graph/lib64/stub/attr_value.cc \ - ../../out/graph/lib64/stub/graph.cc \ - ../../out/graph/lib64/stub/operator.cc \ - ../../out/graph/lib64/stub/operator_factory.cc \ - ../../out/graph/lib64/stub/tensor.cc \ - ../../out/graph/lib64/stub/inference_context.cc \ - - -LOCAL_SHARED_LIBRARIES := - -LOCAL_LDFLAGS := -lrt -ldl - -ifeq ($(device_os),android) -LOCAL_LDFLAGS := -ldl -endif - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_SHARED_LIBRARY) - -# compile for ut/st -include $(CLEAR_VARS) -LOCAL_MODULE := libgraph - -LOCAL_CFLAGS += - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) - -LOCAL_SHARED_LIBRARIES := \ - libc_sec \ - libprotobuf \ - libslog \ - liberror_manager \ - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_LLT_SHARED_LIBRARY) - - -#compiler for host static lib -include $(CLEAR_VARS) -LOCAL_MODULE := libgraph - -LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 -LOCAL_CPPFLAGS += -fexceptions - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) - -LOCAL_STATIC_LIBRARIES := \ - libprotobuf \ - -LOCAL_SHARED_LIBRARIES := \ - libc_sec \ - libslog \ - liberror_manager \ - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_HOST_STATIC_LIBRARY) - -#compiler for device static lib -include $(CLEAR_VARS) -LOCAL_MODULE := libgraph - -LOCAL_CFLAGS += -O2 - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) - -LOCAL_STATIC_LIBRARIES := \ - libprotobuf \ - -LOCAL_SHARED_LIBRARIES := \ - libc_sec \ - libslog \ - liberror_manager \ - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_STATIC_LIBRARY) diff --git a/src/common/graph/inference_context.cc b/src/common/graph/inference_context.cc deleted file mode 100644 index ed8193dc..00000000 --- a/src/common/graph/inference_context.cc +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "external/graph/inference_context.h" -#include "debug/ge_util.h" - -namespace ge { -class ShapeAndTypeImpl { - public: - ShapeAndTypeImpl() = default; - ~ShapeAndTypeImpl() = default; - - ShapeAndTypeImpl(const Shape &shape, DataType data_type) : shape_(shape), data_type_(data_type) {} - - Shape shape_; - DataType data_type_ = DT_UNDEFINED; -}; - -class InferenceContextImpl { - public: - InferenceContextImpl() = default; - ~InferenceContextImpl() = default; - - // For deliver to op in pair, help to support dynamic shape - std::vector marks_; - std::vector> input_handle_shapes_and_types_; - std::vector> output_handle_shapes_and_types_; -}; - -ShapeAndType::ShapeAndType() { shape_and_type_impl_ = ComGraphMakeShared(); } - -ShapeAndType::ShapeAndType(const Shape &shape, DataType data_type) { - shape_and_type_impl_ = ComGraphMakeShared(shape, data_type); -} - -void ShapeAndType::SetShape(const Shape &shape) { - if (shape_and_type_impl_ != nullptr) { - shape_and_type_impl_->shape_ = shape; - } -} - -void ShapeAndType::SetType(DataType data_type) { - if (shape_and_type_impl_ != nullptr) { - shape_and_type_impl_->data_type_ = data_type; - } -} - -Shape ShapeAndType::GetShape() const { - if (shape_and_type_impl_ != nullptr) { - return shape_and_type_impl_->shape_; - } - return Shape(); -} - -DataType ShapeAndType::GetDataType() const { - if (shape_and_type_impl_ != nullptr) { - return shape_and_type_impl_->data_type_; - } - return DT_UNDEFINED; -} - -InferenceContext::InferenceContext(std::unique_ptr &impl) { - inference_context_impl_ = std::move(impl); -} - -std::unique_ptr InferenceContext::Create() { - std::unique_ptr impl = - std::unique_ptr(new (std::nothrow) InferenceContextImpl()); - if (impl == nullptr) { - return nullptr; - } - - return std::unique_ptr(new (std::nothrow) InferenceContext(impl)); -} - -void InferenceContext::SetInputHandleShapesAndTypes(std::vector> &&shapes_and_types) { - inference_context_impl_->input_handle_shapes_and_types_.swap(shapes_and_types); -} - -const std::vector> &InferenceContext::GetInputHandleShapesAndTypes() const { - return inference_context_impl_->input_handle_shapes_and_types_; -} - -const std::vector> &InferenceContext::GetOutputHandleShapesAndTypes() const { - return inference_context_impl_->output_handle_shapes_and_types_; -} - -void InferenceContext::SetOutputHandleShapesAndTypes(const std::vector> &shapes_and_types) { - inference_context_impl_->output_handle_shapes_and_types_ = shapes_and_types; -} - -void InferenceContext::SetOutputHandleShapesAndTypes(std::vector> &&shapes_and_types) { - inference_context_impl_->output_handle_shapes_and_types_.swap(shapes_and_types); -} - -void InferenceContext::SetMarks(const std::vector &marks) { inference_context_impl_->marks_ = marks; } - -const std::vector &InferenceContext::GetMarks() const { return inference_context_impl_->marks_; } -} // namespace ge diff --git a/src/common/graph/model.cc b/src/common/graph/model.cc deleted file mode 100644 index a3628204..00000000 --- a/src/common/graph/model.cc +++ /dev/null @@ -1,190 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/model.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "debug/ge_attr_define.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/model_serialize.h" -#include "proto/ge_ir.pb.h" -#include "utils/attr_utils.h" -#include "utils/ge_ir_utils.h" - -using google::protobuf::io::FileInputStream; -using google::protobuf::io::FileOutputStream; -using google::protobuf::io::ZeroCopyInputStream; - -namespace { -const int DEFAULT_VERSION = 1; -const int ACCESS_PERMISSION_BITS = 0400; -} // namespace - -namespace ge { -void Model::Init() { - (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0); - (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0); - (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0); - (void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0); - (void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0); - (void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI); - version_ = 0; -} - -Model::Model() { - attrs_.InitDefault(); - Init(); -} - -Model::Model(const string &name, const string &custom_version) - : name_(name), version_(DEFAULT_VERSION), platform_version_(custom_version) { - attrs_.InitDefault(); - Init(); -} - -string Model::GetName() const { return name_; } - -void Model::SetName(const string &name) { name_ = name; } - -uint32_t Model::GetVersion() const { return version_; } - -string Model::GetPlatformVersion() const { return platform_version_; } - -void Model::SetGraph(const ge::Graph &graph) { graph_ = graph; } - -Graph Model::GetGraph() const { return graph_; } - -graphStatus Model::Save(Buffer &buffer, bool is_dump) const { - ModelSerialize serialize; - buffer = serialize.SerializeModel(*this, is_dump); - return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -void Model::SetAttr(const ProtoAttrMapHelper &attrs) { attrs_ = attrs; } - -graphStatus Model::Load(const uint8_t *data, size_t len, Model &model) { - ModelSerialize serialize; - model = serialize.UnserializeModel(data, len); - return model.IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus Model::SaveToFile(const string &file_name) const { - Buffer buffer; - if ((*this).Save(buffer) != GRAPH_SUCCESS) { - GE_LOGE("save to file fail."); - return GRAPH_FAILED; - } - // Write file - ge::proto::ModelDef ge_proto; - if (buffer.GetData() != nullptr) { - std::string str((const char *)buffer.GetData(), buffer.GetSize()); - if (!ge_proto.ParseFromString(str)) { - return GRAPH_FAILED; - } - char real_path[PATH_MAX] = {0x00}; - if (strlen(file_name.c_str()) >= PATH_MAX) { - return GRAPH_FAILED; - } - if (realpath(file_name.c_str(), real_path) == nullptr) { - GELOGI("file %s does not exit, it will be created.", file_name.c_str()); - } - int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS); - if (fd < 0) { - GELOGE(GRAPH_FAILED, "open file failed, file path [%s], %s ", real_path, strerror(errno)); - return GRAPH_FAILED; - } - bool ret = ge_proto.SerializeToFileDescriptor(fd); - if (!ret) { - GELOGE(GRAPH_FAILED, "SerializeToFileDescriptor failed"); - if (close(fd) != 0) { - GELOGE(GRAPH_FAILED, "close file descriptor fail."); - return GRAPH_FAILED; - } - return GRAPH_FAILED; - } - if (close(fd) != 0) { - GELOGE(GRAPH_FAILED, "close file descriptor fail."); - return GRAPH_FAILED; - } - if (!ret) { - GELOGE(GRAPH_FAILED, "function [SerializeToFileDescriptor] failed"); - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} - -graphStatus Model::Load(ge::proto::ModelDef &model_def) { - ModelSerialize serialize; - *this = serialize.UnserializeModel(model_def); - return this->IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -bool Model::IsValid() const { return graph_.IsValid(); } - -graphStatus Model::LoadFromFile(const string &file_name) { - char real_path[PATH_MAX] = {0x00}; - if (strlen(file_name.c_str()) >= PATH_MAX) { - return GRAPH_FAILED; - } - if (realpath(file_name.c_str(), real_path) == nullptr) { - GELOGE(GRAPH_FAILED, "file %s does not exit, can not load.", file_name.c_str()); - return GRAPH_FAILED; - } - int fd = open(real_path, O_RDONLY); - if (fd < 0) { - GELOGE(GRAPH_FAILED, "open file failed, %s", strerror(errno)); - return GRAPH_FAILED; - } - - ge::proto::ModelDef model_def; - bool ret = model_def.ParseFromFileDescriptor(fd); - if (!ret) { - GELOGE(GRAPH_FAILED, "ParseFromFileDescriptor failed"); - if (close(fd) != 0) { - GELOGE(GRAPH_FAILED, "close file descriptor fail."); - return GRAPH_FAILED; - } - return GRAPH_FAILED; - } - if (close(fd) != 0) { - GELOGE(GRAPH_FAILED, "close file descriptor fail."); - return GRAPH_FAILED; - } - if (!ret) { - GELOGE(GRAPH_FAILED, "function [ParseFromFileDescriptor] failed"); - return GRAPH_FAILED; - } - return Load(model_def); -} - -ProtoAttrMapHelper Model::MutableAttrMap() { return attrs_; } - -ConstProtoAttrMapHelper Model::GetAttrMap() const { - return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); -} -} // namespace ge diff --git a/src/common/graph/model_serialize.cc b/src/common/graph/model_serialize.cc deleted file mode 100644 index 16855fc5..00000000 --- a/src/common/graph/model_serialize.cc +++ /dev/null @@ -1,763 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/model_serialize.h" -#include - -#include -#include - -#include "debug/ge_attr_define.h" -#include "debug/ge_log.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/detail/model_serialize_imp.h" -#include "proto/ge_ir.pb.h" -#include "utils/graph_utils.h" -#include "debug/ge_op_types.h" - -using std::map; -using std::string; - -namespace ge { -bool ModelSerializeImp::ParseNodeIndex(const string &node_index, string &node_name, int32_t &index) { - auto sep = node_index.rfind(":"); - if (sep == string::npos) { - GELOGW("separator is not found in node_index."); - return false; - } - node_name = node_index.substr(0, sep); - auto index_str = node_index.substr(sep + 1); - index = static_cast(std::strtol(index_str.c_str(), nullptr, 10)); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeTensor(const ConstGeTensorPtr &tensor, - proto::TensorDef *tensor_proto) { - GE_CHK_BOOL_EXEC(tensor != nullptr, return false, "tensor is null."); - GE_CHK_BOOL_EXEC(tensor_proto != nullptr, return false, "tensor_proto is null."); - - if (tensor->tensor_def_.GetProtoMsg() != nullptr) { - *tensor_proto = *tensor->tensor_def_.GetProtoMsg(); - return true; - } - return false; -} - -bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_proto) { - GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is null."); - GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); - - op_def_proto->clear_input(); - // Inputs - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - if (in_data_anchor != nullptr) { - auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) { - op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + - std::to_string(peer_out_anchor->GetIdx())); - } else { - op_def_proto->add_input(""); - } - } - } - // Control edge - auto control_anchor = node->GetInControlAnchor(); - if (control_anchor != nullptr) { - auto peer_out_anchors = control_anchor->GetPeerOutControlAnchors(); - for (const auto &peer_out_anchor : peer_out_anchors) { - if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) { - op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":-1"); - } - } - } - return true; -} - -bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { - GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null."); - GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); - if (op_desc->op_def_.GetProtoMsg() != nullptr) { - *op_def_proto = *op_desc->op_def_.GetProtoMsg(); - // Delete unnecessary attr - if (is_dump) { - auto attr = op_def_proto->mutable_attr(); - attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF); - attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF); - attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF); - GE_IF_BOOL_EXEC((op_def_proto->type() == CONSTANT || op_def_proto->type() == CONSTANTOP), - attr->erase(ATTR_NAME_WEIGHTS)); - } - op_def_proto->clear_input_desc(); - op_def_proto->clear_output_desc(); - // Input descs - if (op_desc->GetAllInputsSize() > 0) { - auto size = static_cast(op_desc->GetAllInputsSize()); - for (uint32_t i = 0; i < size; i++) { - auto tensor_desc = op_desc->GetInputDescPtrDfault(i); - if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { - *op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); - } - } - } - // Output descs - if (op_desc->GetOutputsSize() > 0) { - auto size = static_cast(op_desc->GetOutputsSize()); - for (uint32_t i = 0; i < size; i++) { - auto tensor_desc = op_desc->GetOutputDescPtr(i); - if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { - *op_def_proto->add_output_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); - } - } - } - - op_def_proto->set_id(op_desc->GetId()); - for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { - op_def_proto->add_subgraph_name(name); - } - OpDescToAttrDef(op_desc, op_def_proto); - } - return true; -} - -void ModelSerializeImp::OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) { - proto::AttrDef key_in; - proto::AttrDef value_in; - auto op_desc_attr = op_def_proto->mutable_attr(); - if (!op_desc->input_name_idx_.empty()) { - for (auto &item : op_desc->input_name_idx_) { - key_in.mutable_list()->add_s(item.first); - value_in.mutable_list()->add_i(item.second); - } - op_desc_attr->insert({"_input_name_key", key_in}); - op_desc_attr->insert({"_input_name_value", value_in}); - } - proto::AttrDef key_out; - proto::AttrDef value_out; - if (!op_desc->output_name_idx_.empty()) { - for (auto &item : op_desc->output_name_idx_) { - key_out.mutable_list()->add_s(item.first); - value_out.mutable_list()->add_i(item.second); - } - op_desc_attr->insert({"_output_name_key", key_out}); - op_desc_attr->insert({"_output_name_value", value_out}); - } - proto::AttrDef opt_input; - if (!op_desc->optional_input_names_.empty()) { - for (auto &item : op_desc->optional_input_names_) { - opt_input.mutable_list()->add_s(item); - } - op_desc_attr->insert({"_opt_input", opt_input}); - } -} - -bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) { - if (node == nullptr || op_def_proto == nullptr) { - GELOGE(GRAPH_FAILED, "Input Para Node Invalid"); - return false; - } - if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, is_dump)) { - GELOGE(GRAPH_FAILED, "Serialize OpDesc failed"); - return false; - } - if (SerializeEdge(node, op_def_proto)) { - return true; - } else { - return false; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph, - proto::GraphDef *graph_proto, - bool is_dump) { - if (graph == nullptr || graph_proto == nullptr) { - GELOGE(GRAPH_FAILED, "Input para Invalid"); - return false; - } - graph_proto->set_name(graph->GetName()); - // Inputs - for (const auto &input : graph->GetInputNodes()) { - if (input != nullptr) { - graph_proto->add_input(input->GetName() + ":0"); - } - } - // Outputs - for (const auto &output : graph->GetGraphOutNodesInfo()) { - if (output.first != nullptr) { - graph_proto->add_output(output.first->GetName() + ":" + std::to_string(output.second)); - GELOGI("Add output to graph proto, node name:%s, index:%ld", output.first->GetName().c_str(), output.second); - } - } - if (graph->attrs_.GetProtoMsg() != nullptr) { - *graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg(); - } - for (const auto &node : graph->GetDirectNode()) { - if (!SerializeNode(node, graph_proto->add_op(), is_dump)) { - if (node->GetOpDesc() != nullptr) { - GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str()); - } - return false; - } - } - return true; -} - -bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto, bool is_dump) { - if (model_proto == nullptr) { - GELOGE(GRAPH_FAILED, "model_proto para Invalid"); - return false; - } - model_proto->set_name(model.GetName()); - model_proto->set_custom_version(model.GetPlatformVersion()); - model_proto->set_version(model.GetVersion()); - if (model.attrs_.GetProtoMsg()) { - *model_proto->mutable_attr() = *model.attrs_.GetProtoMsg(); - } - auto &graph = model.graph_; - auto compute_graph = GraphUtils::GetComputeGraph(graph); - if (compute_graph == nullptr) { - GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr"); - return false; - } - if (!SerializeGraph(compute_graph, model_proto->add_graph(), is_dump)) { - GELOGE(GRAPH_FAILED, "SerializeGraph fail"); - return false; - } - - for (auto subgraph : compute_graph->GetAllSubgraphs()) { - if (!SerializeGraph(subgraph, model_proto->add_graph(), is_dump)) { - GELOGE(GRAPH_FAILED, "Serialize subgraph failed"); - return false; - } - } - - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeTensor( - GeTensorPtr &tensor, proto::TensorDef &tensor_proto) { - tensor = std::shared_ptr(new (std::nothrow) GeTensor(protobuf_owner_, &tensor_proto)); - if (tensor == nullptr) { - GELOGE(GRAPH_FAILED, "tensor is nullptr"); - return false; - } else { - return true; - } -} - -void ModelSerializeImp::AttrDefToOpDesc(OpDescPtr &op_desc, std::vector &key_in, std::vector &key_out, - std::vector &value_in, std::vector &value_out, - std::vector &opt_input) { - if (!key_in.empty()) { - if (key_in.size() != value_in.size()) { - GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(), - value_in.size()); - } else { - for (uint32_t i = 0; i < key_in.size(); ++i) { - op_desc->input_name_idx_.insert(std::pair(key_in.at(i), value_in.at(i))); - } - } - } - if (!key_out.empty()) { - if (key_out.size() != value_out.size()) { - GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(), - value_out.size()); - } else { - for (uint32_t i = 0; i < key_out.size(); ++i) { - op_desc->output_name_idx_.insert(std::pair(key_out.at(i), value_out.at(i))); - } - } - } - if (!opt_input.empty()) { - for (const auto &i : opt_input) { - op_desc->optional_input_names_.insert(i); - } - } -} - -bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) { - std::vector opt_input; - std::vector key_in; - std::vector value_in; - if (op_def_proto.attr().count("_opt_input") > 0) { - auto &name_list = op_def_proto.attr().at("_opt_input").list(); - for (const auto &item_s : name_list.s()) { - opt_input.push_back(item_s); - } - auto op_desc_attr = op_def_proto.mutable_attr(); - op_desc_attr->erase("_opt_input"); - } - if (op_def_proto.attr().count("_input_name_key") > 0) { - auto &output_name_key_list = op_def_proto.attr().at("_input_name_key").list(); - for (const auto &item_s : output_name_key_list.s()) { - key_in.push_back(item_s); - } - auto op_desc_attr = op_def_proto.mutable_attr(); - op_desc_attr->erase("_input_name_key"); - } - if (op_def_proto.attr().count("_input_name_value") > 0) { - auto &input_name_value_list = op_def_proto.attr().at("_input_name_value").list(); - for (const auto &item_i : input_name_value_list.i()) { - value_in.push_back(static_cast(item_i)); - } - auto op_desc_attr = op_def_proto.mutable_attr(); - op_desc_attr->erase("_input_name_value"); - } - std::vector key_out; - std::vector value_out; - if (op_def_proto.attr().count("_output_name_key") > 0) { - auto &output_name_key_list = op_def_proto.attr().at("_output_name_key").list(); - for (const auto &item_s : output_name_key_list.s()) { - key_out.push_back(item_s); - } - auto op_desc_attr = op_def_proto.mutable_attr(); - op_desc_attr->erase("_output_name_key"); - } - if (op_def_proto.attr().count("_output_name_value") > 0) { - auto &output_name_value_list = op_def_proto.attr().at("_output_name_value").list(); - for (const auto &item_i : output_name_value_list.i()) { - value_out.push_back(static_cast(item_i)); - } - auto op_desc_attr = op_def_proto.mutable_attr(); - op_desc_attr->erase("_output_name_value"); - } - - op_desc = std::shared_ptr(new (std::nothrow) OpDesc(protobuf_owner_, &op_def_proto)); - GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr."); - - // Input tensor - for (auto &input_desc : *op_def_proto.mutable_input_desc()) { - std::shared_ptr temp_value = - std::shared_ptr(new (std::nothrow) GeTensorDesc(protobuf_owner_, &input_desc)); - GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); - op_desc->inputs_desc_.push_back(temp_value); - } - // Output tensor - for (auto &output_desc : *op_def_proto.mutable_output_desc()) { - std::shared_ptr temp_value = - std::shared_ptr(new (std::nothrow) GeTensorDesc(protobuf_owner_, &output_desc)); - GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); - op_desc->outputs_desc_.push_back(temp_value); - } - - op_desc->SetId(op_def_proto.id()); - uint32_t graph_index = 0; - for (const std::string &name : op_def_proto.subgraph_name()) { - op_desc->AddSubgraphName(name); - op_desc->SetSubgraphInstanceName(graph_index++, name); - } - - // insert name index by key and value - AttrDefToOpDesc(op_desc, key_in, key_out, value_in, value_out, opt_input); - - return true; -} - -bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op_def_proto) { - GE_RT_FALSE_CHECK_NOTNULL(graph); - OpDescPtr op_desc = nullptr; - if (!UnserializeOpDesc(op_desc, op_def_proto)) { - GELOGW("UnserializeOpDesc error."); - } - - NodePtr node = graph->AddNode(op_desc, op_desc->GetId()); - GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr."); - - // Inputs - int dst_index = 0; - for (const auto &input : op_def_proto.input()) { - string node_name; - int32_t index = 0; - if (ParseNodeIndex(input, node_name, index)) { - node_input_node_names_.push_back(NodeNameNodeReq{node_name, index, node, dst_index, op_def_proto.name()}); - } - if (index >= 0) { - dst_index++; - } - } - node_map_[op_def_proto.name()] = node; - return true; -} - -bool ModelSerializeImp::HandleNodeNameRef() { - // Edges - for (auto &item : node_input_node_names_) { - auto src_node_it = node_map_.find(item.src_node_name); - if (src_node_it == node_map_.end()) { - GELOGE(GRAPH_FAILED, "cannot find node %s", item.src_node_name.c_str()); - return false; - } - GE_IF_BOOL_EXEC(src_node_it->second == nullptr || item.dst_node == nullptr, continue); - if (item.src_out_index >= 0) { - auto src_anchor = src_node_it->second->GetOutDataAnchor(item.src_out_index); - auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index); - if (src_anchor == nullptr || dst_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "get anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, - item.dst_node_name.c_str(), item.dst_in_index); - return false; - } - GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 - } else { - // Control edge - auto src_anchor = src_node_it->second->GetOutControlAnchor(); - auto dst_anchor = item.dst_node->GetInControlAnchor(); - if (src_anchor != nullptr && dst_anchor != nullptr) { - GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 - } - } - } - // Graph input - for (auto &item : graph_input_node_names_) { - auto node_it = node_map_.find(item.node_name); - if (node_it == node_map_.end()) { - GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str()); - return false; - } - GE_IF_BOOL_EXEC(item.graph == nullptr, continue); - auto ret = item.graph->AddInputNode(node_it->second); - if (ret == nullptr) { - return false; - } - } - // Graph output - for (auto &item : graph_output_node_names_) { - auto node_it = node_map_.find(item.node_name); - if (node_it == node_map_.end()) { - GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str()); - return false; - } - - GE_IF_BOOL_EXEC(item.graph == nullptr, continue); - auto ret = item.graph->AddOutputNodeByIndex(node_it->second, item.index); - GELOGI("node name:%s, item.index:%ld", node_it->second->GetName().c_str(), item.index); - if (ret == nullptr) { - GELOGE(GRAPH_FAILED, "AddOutputNode failed."); - return false; - } - } - node_input_node_names_.clear(); - graph_input_node_names_.clear(); - graph_output_node_names_.clear(); - node_map_.clear(); - return true; -} - -bool ModelSerializeImp::RebuildOwnership(ComputeGraphPtr &compute_graph, map &subgraphs) { - std::queue all_graphs; - all_graphs.emplace(compute_graph); - while (!all_graphs.empty()) { - ComputeGraphPtr graph = all_graphs.front(); - all_graphs.pop(); - - for (const NodePtr &node : graph->GetDirectNode()) { - const OpDescPtr op_desc = node->GetOpDesc(); - for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { - auto it = subgraphs.find(name); - if (it == subgraphs.end()) { - GELOGE(GRAPH_FAILED, "Node:%s, Subgraph:%s not found, num:%zu.", op_desc->GetName().c_str(), name.c_str(), - subgraphs.size()); - return false; - } - - ComputeGraphPtr &subgraph = it->second; - subgraph->SetParentGraph(graph); - subgraph->SetParentNode(node); - compute_graph->AddSubgraph(subgraph->GetName(), subgraph); - all_graphs.emplace(subgraph); - } - } - } - - return true; -} - -bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) { - model.name_ = model_proto.name(); - model.version_ = model_proto.version(); - model.platform_version_ = model_proto.custom_version(); - model.attrs_ = ProtoAttrMapHelper(protobuf_owner_, model_proto.mutable_attr()); - - auto &graphs_proto = *model_proto.mutable_graph(); - if (!graphs_proto.empty()) { - auto &graph_proto = graphs_proto[0]; - ComputeGraphPtr compute_graph_ptr; - if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) { - model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr); - } - - // 0 is main graph, following is subgraph. - map subgraphs; - for (int idx = 1; idx < graphs_proto.size(); ++idx) { - ComputeGraphPtr subgraph; - ModelSerializeImp impl; - if (!impl.UnserializeGraphWithoutEdge(subgraph, graphs_proto[idx])) { - GELOGE(GRAPH_FAILED, "UnserializeGraphWithoutEdge failed"); - return false; - } - - if (!impl.HandleNodeNameRef()) { - GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); - return false; - } - - subgraphs[subgraph->GetName()] = subgraph; - } - - if (!RebuildOwnership(compute_graph_ptr, subgraphs)) { - GELOGE(GRAPH_FAILED, "Rebuild graph ownership failed"); - return false; - } - } - - if (!HandleNodeNameRef()) { - GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); - return false; - } - return true; -} - -bool ModelSerializeImp::UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graph_proto) { - graph = ComGraphMakeShared(graph_proto.name()); - if (graph == nullptr) { - GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); - return false; - } - - // Inputs - for (auto input : graph_proto.input()) { - string node_name; - int32_t index; - if (ParseNodeIndex(input, node_name, index)) { - graph_input_node_names_.push_back(NodeNameGraphReq{node_name, index, graph}); - } - } - // Outputs - for (auto output : graph_proto.output()) { - string node_name; - int32_t index; - if (ParseNodeIndex(output, node_name, index)) { - graph_output_node_names_.push_back(NodeNameGraphReq{node_name, index, graph}); - } - } - graph->attrs_ = ProtoAttrMapHelper(protobuf_owner_, graph_proto.mutable_attr()); - for (auto &op_def_proto : *graph_proto.mutable_op()) { - if (!UnserializeNode(graph, op_def_proto)) { - GELOGE(GRAPH_FAILED, "UnserializeNode fail"); - return false; - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeGraph(ComputeGraphPtr &graph, - proto::GraphDef &graph_proto) { - if (!UnserializeGraphWithoutEdge(graph, graph_proto)) { - GELOGW("UnserializeGraphWithoutEdge fail"); - } - if (!HandleNodeNameRef()) { - GELOGE(GRAPH_FAILED, "Link Anchor or set graph input or output fail"); - return false; - } - return true; -} - -bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf::Message *proto) { - GE_CHK_BOOL_EXEC(data != nullptr, return false, "data is null."); - GE_CHK_BOOL_EXEC(proto != nullptr, return false, "proto is null."); - - google::protobuf::io::CodedInputStream coded_stream(data, len); - // 2048M -1 - coded_stream.SetTotalBytesLimit(INT32_MAX, -1); - if (!proto->ParseFromCodedStream(&coded_stream)) { - GELOGE(GRAPH_FAILED, "ReadProtoFromBinaryFile failed len %zu", len); - return false; - } - return true; -} - -Buffer ModelSerialize::SerializeModel(const Model &model, bool is_dump) { - proto::ModelDef model_def; - ModelSerializeImp imp; - if (!imp.SerializeModel(model, &model_def, is_dump)) { - return Buffer(); - } -#if !defined(__ANDROID__) && !defined(ANDROID) - Buffer buffer(model_def.ByteSizeLong()); -#else - Buffer buffer(model_def.ByteSize()); -#endif - GE_CHK_BOOL_ONLY_LOG(buffer.GetSize() != 0, "get size failed"); - GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); - auto ret = model_def.SerializeToArray(buffer.GetData(), static_cast(buffer.GetSize())); - if (ret != true) { - GELOGW("serialize to array fail."); - } - return buffer; -} - -size_t ModelSerialize::GetSerializeModelSize(const Model &model) { - proto::ModelDef model_def; - ModelSerializeImp imp; - if (!imp.SerializeModel(model, &model_def)) { - return 0; - } -#if !defined(__ANDROID__) && !defined(ANDROID) - return model_def.ByteSizeLong(); -#else - return model_def.ByteSize(); -#endif -} - -Model ModelSerialize::UnserializeModel(const uint8_t *data, size_t len) { - if (data == nullptr) { - GELOGE(GRAPH_FAILED, "data is nullptr"); - return Model(); - } - - std::shared_ptr model_proto_ptr; - model_proto_ptr = ComGraphMakeShared(); - if (model_proto_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed"); - return Model(); - } - - auto &model_proto = *model_proto_ptr; - if (!ReadProtoFromBinaryFile(data, len, &model_proto)) { - GELOGE(GRAPH_FAILED, "ParseFromArray fail"); - return Model(); - } - - Model model; - ModelSerializeImp imp; - imp.SetProtobufOwner(model_proto_ptr); - if (!imp.UnserializeModel(model, model_proto)) { - GELOGE(GRAPH_FAILED, "Unserialize Model fail"); - return Model(); - } - return model; -} - -Model ModelSerialize::UnserializeModel(ge::proto::ModelDef &model_def) { - std::shared_ptr model_def_ptr = ComGraphMakeShared(model_def); - GE_CHK_BOOL_EXEC(model_def_ptr != nullptr, return Model(), "mode_def make shared failed"); - - ModelSerializeImp imp; - imp.SetProtobufOwner(model_def_ptr); - Model model; - if (!imp.UnserializeModel(model, *model_def_ptr)) { - GELOGE(GRAPH_FAILED, "Unserialize Model fail"); - return Model(); - } - return model; -} - -Buffer ModelSerialize::SerializeGraph(const ComputeGraphPtr &graph) { - proto::GraphDef graph_def; - ModelSerializeImp imp; - if (!imp.SerializeGraph(graph, &graph_def)) { - return Buffer(); - } -#if !defined(__ANDROID__) && !defined(ANDROID) - Buffer buffer(graph_def.ByteSizeLong()); -#else - Buffer buffer(graph_def.ByteSize()); -#endif - GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed"); - GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); - auto ret = graph_def.SerializeToArray(buffer.GetData(), static_cast(buffer.GetSize())); - if (ret != true) { - GE_LOGE("serialize to array fail."); - } - - return buffer; -} - -ComputeGraphPtr ModelSerialize::UnserializeGraph(const uint8_t *data, size_t len) { - if (data == nullptr) { - GELOGE(GRAPH_FAILED, "data is nullptr"); - return nullptr; - } - - std::shared_ptr graph_proto_ptr; - graph_proto_ptr = ComGraphMakeShared(); - if (graph_proto_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); - return nullptr; - } - proto::GraphDef &graph_proto = *graph_proto_ptr; - if (!ReadProtoFromBinaryFile(data, len, &graph_proto)) { - GELOGE(GRAPH_FAILED, "ParseFromArray fail"); - return nullptr; - } - - ComputeGraphPtr graph; - ModelSerializeImp imp; - imp.SetProtobufOwner(graph_proto_ptr); - if (!imp.UnserializeGraph(graph, graph_proto)) { - return nullptr; - } - return graph; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer ModelSerialize::SerializeOpDesc(const ConstOpDescPtr &op_desc) { - proto::OpDef op_def; - ModelSerializeImp imp; - if (!imp.SerializeOpDesc(op_desc, &op_def)) { - return Buffer(); - } -#if !defined(__ANDROID__) && !defined(ANDROID) - Buffer buffer(op_def.ByteSizeLong()); -#else - Buffer buffer(op_def.ByteSize()); -#endif - GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed"); - GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); - auto ret = op_def.SerializeToArray(buffer.GetData(), static_cast(buffer.GetSize())); - if (ret != true) { - GE_LOGE("serialize to array fail."); - } - - return buffer; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr ModelSerialize::UnserializeOpDesc(const uint8_t *data, - size_t len) { - if (data == nullptr) { - GELOGE(GRAPH_FAILED, "data is nullptr"); - return nullptr; - } - - std::shared_ptr op_def_ptr; - op_def_ptr = ComGraphMakeShared(); - if (op_def_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); - return nullptr; - } - proto::OpDef &op_def = *op_def_ptr; - if (!ReadProtoFromBinaryFile(data, len, &op_def)) { - GELOGE(GRAPH_FAILED, "ParseFromArray fail"); - return nullptr; - } - - OpDescPtr op_desc; - ModelSerializeImp imp; - imp.SetProtobufOwner(op_def_ptr); - if (!imp.UnserializeOpDesc(op_desc, op_def)) { - GELOGW("UnserializeOpDesc error."); - } - return op_desc; -} -} // namespace ge diff --git a/src/common/graph/module.mk b/src/common/graph/module.mk deleted file mode 100644 index 1e00b7fc..00000000 --- a/src/common/graph/module.mk +++ /dev/null @@ -1,3 +0,0 @@ -LOCAL_PATH := $(call my-dir) - -include $(LOCAL_PATH)/graph.mk diff --git a/src/common/graph/node.cc b/src/common/graph/node.cc deleted file mode 100644 index d33c6008..00000000 --- a/src/common/graph/node.cc +++ /dev/null @@ -1,878 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/node.h" -#include -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "external/graph/operator_factory.h" -#include "framework/common/debug/ge_log.h" -#include "graph/ge_tensor.h" -#include "graph/operator_factory_impl.h" -#include "graph/shape_refiner.h" -#include "utils/ge_ir_utils.h" -#include "utils/node_utils.h" -#include "utils/op_desc_utils.h" -#include "common/util/error_manager/error_manager.h" - -using std::string; -using std::vector; - -namespace ge { -Node::Node(const OpDescPtr &op, const ComputeGraphPtr &owner_graph) - : op_(op), - owner_graph_(owner_graph), - in_data_anchors_(), - out_data_anchors_(), - in_control_anchor_(nullptr), - out_control_anchor_(nullptr), - attrs_(), - has_init_(false) { - anchor_status_updated_ = false; -} - -Node::~Node() { - for (const auto &in_data_anchor : in_data_anchors_) { - if (in_data_anchor != nullptr) { - in_data_anchor->UnlinkAll(); - } - } - for (const auto &out_data_anchor : out_data_anchors_) { - if (out_data_anchor != nullptr) { - out_data_anchor->UnlinkAll(); - } - } - if (in_control_anchor_ != nullptr) { - in_control_anchor_->UnlinkAll(); - } - if (out_control_anchor_ != nullptr) { - out_control_anchor_->UnlinkAll(); - } -} - -graphStatus Node::Init() { - if (has_init_) { - return GRAPH_SUCCESS; - } - GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); - size_t size = op_->GetAllInputsSize(); - for (size_t i = 0; i < size; i++) { - std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), i); - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Current in_data_anchor is null, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - in_data_anchors_.push_back(anchor); - } - size = op_->GetOutputsSize(); - for (size_t i = 0; i < size; i++) { - std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), i); - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Current out_data_anchor is null, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - out_data_anchors_.push_back(anchor); - } - in_control_anchor_ = ComGraphMakeShared(shared_from_this(), -1); - out_control_anchor_ = ComGraphMakeShared(shared_from_this(), -1); - if (in_control_anchor_ == nullptr || out_control_anchor_ == nullptr) { - GELOGE(GRAPH_FAILED, "Current in_control_anchor or out_control_anchor is null, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - has_init_ = true; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string Node::GetName() const { - GE_CHK_BOOL_EXEC(op_ != nullptr, return string(), "original OpDesc is nullptr"); - return op_->GetName(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string Node::GetType() const { - GE_CHK_BOOL_EXEC(op_ != nullptr, return string(), "original OpDesc is nullptr"); - return op_->GetType(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAttrsAreEqual(const Node &r_node) const { - const auto &attr_map = this->attrs_; - const auto &r_attr_map = r_node.attrs_; - // 1.Verify node's map size - if (attr_map.size() != r_attr_map.size()) { - GELOGE(GRAPH_FAILED, "Size of node's attr map verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - // 2.Verify node's map key, verify values is temporarily not implemented - for (const auto &it : attr_map) { - if (r_attr_map.count(it.first) == 0) { - GELOGE(GRAPH_FAILED, "Key of node's attr map verify failed, node name: %s key name: %s.", this->GetName().c_str(), - it.first.c_str()); - return false; - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeMembersAreEqual(const Node &r_node) const { - return ((((this->op_ != nullptr) && (r_node.op_ != nullptr) && (IsEqual(*(this->op_), *(r_node.op_), "node.op_"))) || - ((this->op_ == nullptr) && (r_node.op_ == nullptr))) && - IsEqual(this->has_init_, r_node.has_init_, "node.has_init_") && - IsEqual(this->anchor_status_updated_, r_node.anchor_status_updated_, "node.anchor_status_updated_") && - IsEqual(this->send_event_id_list_, r_node.send_event_id_list_, "node.send_event_id_list_") && - IsEqual(this->recv_event_id_list_, r_node.recv_event_id_list_, "node.recv_event_id_list_")); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(const AnchorPtr &left_anchor, - const AnchorPtr &right_anchor, - size_t i) const { - GE_IF_BOOL_EXEC(left_anchor == nullptr, GELOGE(GRAPH_FAILED, "left_anchor is null."); return false); - GE_IF_BOOL_EXEC(right_anchor == nullptr, GELOGE(GRAPH_FAILED, "right_anchor is null."); return false); - - const auto anchor_peer_size = left_anchor->GetPeerAnchors().size(); - const auto right_anchor_peer_size = right_anchor->GetPeerAnchors().size(); - // Firstly, verify anchor's peer anchors size equal or not - if (anchor_peer_size != right_anchor_peer_size) { - GELOGE(GRAPH_FAILED, - "Size of anchor's peer anchors verify failed, node name: %s " - "anchor_peer_size [%zu] is different form [%zu] at index [%zu].", - this->GetName().c_str(), anchor_peer_size, right_anchor_peer_size, i); - return false; - } - // Secondly, verify anchor's peer anchor owner node equal or not - for (size_t j = 0; j < anchor_peer_size; j++) { - const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); - const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); - if (peer_node == nullptr || r_peer_node == nullptr) { - GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", - this->GetName().c_str(), i, j); - return false; - } - // Determine the connection relationship by linking the node's name - if (peer_node->GetName() != r_peer_node->GetName()) { - GELOGE(GRAPH_FAILED, - "anchor's peer node name verify failed, node name: %s index[%zu]" - "peer node name %s is different from %s at index [%zu].", - this->GetName().c_str(), i, peer_node->GetName().c_str(), r_peer_node->GetName().c_str(), j); - return false; - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeInConnectsAreEqual(const Node &r_node) const { - // 1.Verify all in data and control anchors size - const auto in_data_anchor_size = this->GetAllInDataAnchors().size(); - const auto r_in_data_anchor_size = r_node.GetAllInDataAnchors().size(); - if (in_data_anchor_size != r_in_data_anchor_size) { - GELOGE(GRAPH_FAILED, "Size of node's in data anchors verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - const auto l_in_anchors = this->GetAllInAnchors(); - const auto r_in_anchors = r_node.GetAllInAnchors(); - // Data anchors size equal, all anchors size not equal, means control anchor size not equal - const auto in_control_anchor_size = l_in_anchors.size() - in_data_anchor_size; - const auto r_in_control_anchor_size = r_in_anchors.size() - r_in_data_anchor_size; - if (in_control_anchor_size != r_in_control_anchor_size) { - GELOGE(GRAPH_FAILED, "Size of node's in control anchors verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - // 2.Verify all in data and control anchors connect info - for (size_t i = 0; i < this->GetAllInAnchors().size(); i++) { - // Verify data anchors - if (i < in_data_anchor_size) { - const auto &in_anchor = l_in_anchors.at(i); - const auto &r_in_anchor = r_in_anchors.at(i); - if (!(NodeAnchorIsEqual(in_anchor, r_in_anchor, i))) { - GELOGE(GRAPH_FAILED, "Node's in data control anchor verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - } else { - // Verify control anchors - const auto &in_control_anchor = l_in_anchors.at(i); - const auto &r_in_control_anchor = r_in_anchors.at(i); - if (!(NodeAnchorIsEqual(in_control_anchor, r_in_control_anchor, i - in_data_anchor_size))) { - GELOGE(GRAPH_FAILED, "Node's in control anchor verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeOutConnectsAreEqual(const Node &r_node) const { - // 1.Verify all out data and control anchors size - const auto l_out_data_anchors = this->GetAllOutDataAnchors(); - const auto r_out_data_anchors = r_node.GetAllOutDataAnchors(); - const auto out_data_anchor_size = l_out_data_anchors.size(); - const auto r_out_data_anchor_size = r_out_data_anchors.size(); - if (out_data_anchor_size != r_out_data_anchor_size) { - GELOGE(GRAPH_FAILED, "Size of node's out data anchors verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - const auto l_out_anchors = this->GetAllOutAnchors(); - const auto r_out_anchors = r_node.GetAllOutAnchors(); - // Data anchors size equal, all anchors size not equal, means control anchor size not equal - const auto out_control_anchor_size = l_out_anchors.size() - out_data_anchor_size; - const auto r_out_control_anchor_size = r_out_anchors.size() - r_out_data_anchor_size; - if (out_control_anchor_size != r_out_control_anchor_size) { - GELOGE(GRAPH_FAILED, "Size of node's out control anchors verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - - // 2.Verify all out data and control anchors connect info - for (size_t i = 0; i < this->GetAllOutAnchors().size(); i++) { - // Verify data anchors - if (i < out_data_anchor_size) { - const auto &out_anchor = l_out_data_anchors.at(i); - const auto &r_out_anchor = r_out_data_anchors.at(i); - if (!(NodeAnchorIsEqual(out_anchor, r_out_anchor, i))) { - GELOGE(GRAPH_FAILED, "Node's out data control anchor verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - } else { - // Verify control anchors - const auto &out_control_anchor = l_out_anchors.at(i); - const auto &r_out_control_anchor = r_out_anchors.at(i); - if (!(NodeAnchorIsEqual(out_control_anchor, r_out_control_anchor, i - out_data_anchor_size))) { - GELOGE(GRAPH_FAILED, "Node's out control anchor verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::operator==(const Node &r_node) const { - return (NodeMembersAreEqual(r_node) && NodeAttrsAreEqual(r_node) && NodeInConnectsAreEqual(r_node) && - NodeOutConnectsAreEqual(r_node)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const NodePtr &input_node) { - // This function is deprecated, please use other two overloaded functions - GE_CHECK_NOTNULL(input_node); - // Input_node ---> this - auto out_anchors = input_node->GetAllOutDataAnchors(); - if (out_anchors.size() != 1) { - GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, only support 1", out_anchors.size()); - return GRAPH_PARAM_INVALID; - } - GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); - auto op_desc = input_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - if (op_->AddInputDesc(op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "add input desc failed."); - return GRAPH_FAILED; - } - std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, malloc shared_ptr failed.", out_anchors.size()); - return GRAPH_FAILED; - } - in_data_anchors_.push_back(anchor); - (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const uint32_t &index, - NodePtr input_node) { - GE_CHECK_NOTNULL(input_node); - // Input_node ---> this - auto out_anchors = input_node->GetAllOutDataAnchors(); - if (out_anchors.size() != 1) { - GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, only support 1", out_anchors.size()); - return GRAPH_PARAM_INVALID; - } - - GE_CHECK_NOTNULL(op_); - auto op_desc = input_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - if (op_->AddInputDesc(index, op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "add input desc failed."); - return GRAPH_FAILED; - } - - if (index < GetAllInDataAnchors().size()) { - (void)out_anchors.at(0)->LinkTo(in_data_anchors_[index]); - } else { - std::shared_ptr anchor = - ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, malloc shared_ptr failed.", out_anchors.size()); - return GRAPH_FAILED; - } - in_data_anchors_.push_back(anchor); - (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFromForParse(const NodePtr &input_node) { - // This function is used for ParseWeights. - GE_CHECK_NOTNULL(input_node); - // Input_node ---> this - auto out_anchors = input_node->GetAllOutDataAnchors(); - if (out_anchors.size() != 1) { - GELOGE(GRAPH_PARAM_INVALID, "out_anchor size is:%zu, only support 1", out_anchors.size()); - return GRAPH_PARAM_INVALID; - } - - std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, make anchor failed", out_anchors.size()); - return GRAPH_FAILED; - } - in_data_anchors_.push_back(anchor); - (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const string &name, NodePtr input_node) { - GE_CHECK_NOTNULL(input_node); - // Input_node ---> this - auto out_anchors = input_node->GetAllOutDataAnchors(); - if (out_anchors.size() != 1) { - GELOGE(GRAPH_PARAM_INVALID, "out_anchor size is:%zu, only support 1", out_anchors.size()); - return GRAPH_PARAM_INVALID; - } - - GE_CHECK_NOTNULL(op_); - auto input_op_desc = input_node->GetOpDesc(); - GE_CHECK_NOTNULL(input_op_desc); - auto index = op_->GetInputIndexByName(name); - if (index != -1) { - if (index >= static_cast(in_data_anchors_.size())) { - GELOGE(GRAPH_FAILED, "op %s get input name %s 's index %d is illegal.", op_->GetName().c_str(), name.c_str(), - index); - return GRAPH_FAILED; - } - (void)out_anchors.at(0)->LinkTo(in_data_anchors_[index]); - } else { - std::shared_ptr anchor = - ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "in_data_anchors_size is:%zu, malloc shared_ptr failed.", in_data_anchors_.size()); - return GRAPH_FAILED; - } - in_data_anchors_.push_back(anchor); - (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); - } - if (op_->AddInputDesc(name, input_op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "add input desc failed."); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr Node::GetOwnerComputeGraph() const { - return owner_graph_.lock(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::SetOwnerComputeGraph(const ComputeGraphPtr &graph) { - if (graph == nullptr) { - return GRAPH_PARAM_INVALID; - } - owner_graph_ = graph; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllInDataAnchors() const { - return Vistor(shared_from_this(), in_data_anchors_); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllOutDataAnchors() const { - return Vistor(shared_from_this(), out_data_anchors_); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetAllInDataAnchorsSize() const { - return in_data_anchors_.size(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetAllOutDataAnchorsSize() const { - return out_data_anchors_.size(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllInAnchors() const { - std::vector vec; - // Push back in_data_anchors_ - for (const auto &in_anchor_iter : Vistor(shared_from_this(), in_data_anchors_)) { - auto in_anchor = Anchor::DynamicAnchorCast(in_anchor_iter); - if (in_anchor != nullptr) { - vec.push_back(in_anchor); - } - } - // Push back in_control_anchor_ - if ((in_control_anchor_->GetPeerOutControlAnchors().size() > 0) || - (in_control_anchor_->GetPeerOutDataAnchors().size() > 0)) { - auto in_anchor = Anchor::DynamicAnchorCast(in_control_anchor_); - if (in_anchor != nullptr) { - vec.push_back(in_anchor); - } - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllOutAnchors() const { - std::vector vec; - // Push back out_data_anchors_ - for (const auto &out_anchor_iter : Vistor(shared_from_this(), out_data_anchors_)) { - auto out_anchor = Anchor::DynamicAnchorCast(out_anchor_iter); - if (out_anchor != nullptr) { - vec.push_back(out_anchor); - } - } - // Push back out_control_anchor_ - if (out_control_anchor_->GetPeerInControlAnchors().size() > 0 || - out_control_anchor_->GetPeerInDataAnchors().size() > 0) { - auto out_anchor = Anchor::DynamicAnchorCast(out_control_anchor_); - if (out_anchor != nullptr) { - vec.push_back(out_anchor); - } - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const { - if (idx < 0 || idx >= static_cast(in_data_anchors_.size())) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19019", {"opname", "index", "anchorname", "optype"}, - {GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()}); - GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", GetName().c_str(), idx, - GetType().c_str()); - return nullptr; - } else { - return in_data_anchors_[idx]; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const { - // Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_ - if (idx < -1 || idx >= static_cast(in_data_anchors_.size())) { - GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str()); - return nullptr; - } else { - // Return control anchor - if (idx == -1) { - auto in_anchor = Anchor::DynamicAnchorCast(in_control_anchor_); - return in_anchor; - } - // Return data anchor - return in_data_anchors_[idx]; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const { - // Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_ - if (idx < -1 || idx >= static_cast(out_data_anchors_.size())) { - ErrorManager::GetInstance().ATCReportErrMessage("E19019", {"opname", "index", "anchorname", "optype"}, - { - GetName().c_str(), - std::to_string(idx), - "out_anchor", - GetType().c_str(), - }); - GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx, - GetType().c_str()); - return nullptr; - } else { - // Return control anchor - if (idx == -1) { - auto out_anchor = Anchor::DynamicAnchorCast(out_control_anchor_); - return out_anchor; - } - // Return data anchor - return out_data_anchors_[idx]; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const { - if (idx < 0 || idx >= static_cast(out_data_anchors_.size())) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19019", {"opname", "index", "anchorname", "optype"}, - {GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()}); - GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", GetName().c_str(), idx, - GetType().c_str()); - return nullptr; - } else { - return out_data_anchors_[idx]; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InControlAnchorPtr Node::GetInControlAnchor() const { - return in_control_anchor_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutControlAnchorPtr Node::GetOutControlAnchor() const { - return out_control_anchor_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInNodes() const { - std::vector vec; - for (const auto &in_anchor : in_data_anchors_) { - GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - continue; - } - auto node = out_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - if (in_control_anchor_ != nullptr) { - if (in_control_anchor_->IsPeerOutAnchorsEmpty()) { - return Node::Vistor(shared_from_this(), vec); - } - - auto peer_out_anchors = in_control_anchor_->GetPeerOutDataAnchors(); - for (const auto &out_anchor : peer_out_anchors) { - GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "in_control_anchor_ peer out data anchors is nullptr"); - auto node = out_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - - auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); - for (const auto &out_control_anchor : peer_out_control_anchors) { - GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, - "in_control_anchor_ peer out control anchors is nullptr"); - auto node = out_control_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::IsAllInNodesSeen( - std::unordered_set &nodes_seen) const { - for (const auto &in_anchor : in_data_anchors_) { - GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - continue; - } - auto node = out_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - if ((node->GetType() == NEXTITERATION) || (node->GetType() == REFNEXTITERATION)) { - continue; - } - if (nodes_seen.count(node.get()) == 0) { - return false; - } - } - - if (in_control_anchor_ != nullptr) { - if (in_control_anchor_->IsPeerOutAnchorsEmpty()) { - return true; - } - auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); - for (const auto &out_control_anchor : peer_out_control_anchors) { - GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, "out_control_anchor is nullptr"); - auto node = out_control_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - if ((node->GetType() == NEXTITERATION) || (node->GetType() == REFNEXTITERATION)) { - continue; - } - if (nodes_seen.count(node.get()) == 0) { - return false; - } - } - } - - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInDataNodes() const { - std::vector vec; - for (const auto &in_anchor : in_data_anchors_) { - GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); - auto anchor_ptr = in_anchor->GetPeerOutAnchor(); - if (anchor_ptr == nullptr) { - continue; - } - auto node = anchor_ptr->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInControlNodes() const { - std::vector vec; - if (in_control_anchor_ != nullptr) { - for (const auto &in_anchor : in_control_anchor_->GetPeerOutControlAnchors()) { - GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerOutControlAnchors is nullptr"); - auto node = in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutNodes() const { - std::vector vec; - for (const auto &out_anchor : out_data_anchors_) { - GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); - for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC((peer_in_anchor != nullptr), continue, "GetPeerInDataAnchors is nullptr"); - auto node = peer_in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - if (out_control_anchor_ != nullptr) { - auto peer_in_control_anchors = out_control_anchor_->GetPeerInControlAnchors(); - for (const auto &in_control_anchor : peer_in_control_anchors) { - GE_CHK_BOOL_EXEC(in_control_anchor != nullptr, continue, - "out_control_anchor_ peer in control anchors is nullptr"); - auto node = in_control_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInAllNodes() const { - std::vector vec; - for (const auto &in_node : GetInDataNodes()) { - vec.push_back(in_node); - } - for (const auto &in_control_node : GetInControlNodes()) { - vec.push_back(in_control_node); - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutDataNodes() const { - std::vector vec; - for (const auto &out_anchor : out_data_anchors_) { - GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); - for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "GetPeerInDataAnchors is nullptr"); - auto node = in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetOutDataNodesSize() const { - uint32_t out_nums = 0; - for (const auto &out_anchor : out_data_anchors_) { - GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); - out_nums += out_anchor->GetPeerInDataNodesSize(); - } - return out_nums; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutControlNodes() const { - std::vector vec; - - for (const auto &out_anchor : out_data_anchors_) { - GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); - for (const auto &in_anchor : out_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "GetPeerInControlAnchors is nullptr"); - auto node = in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - - if (out_control_anchor_ != nullptr) { - for (const auto &in_anchor : out_control_anchor_->GetPeerAnchors()) { - GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); - auto node = in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutAllNodes() const { - std::vector vec; - for (const auto &out_anchor : out_data_anchors_) { - GE_CHK_BOOL_EXEC((out_anchor != nullptr), { continue; }, "out_data_anchors_ is nullptr"); - for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC((in_anchor != nullptr), { continue; }, "GetPeerInDataAnchors is nullptr"); - auto node = in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - for (const auto &in_anchor : out_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); - auto node = in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - - if (out_control_anchor_ != nullptr) { - for (const auto &in_anchor : out_control_anchor_->GetPeerAnchors()) { - GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); - auto node = in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - return Node::Vistor(shared_from_this(), vec); -} - -graphStatus Node::InferShapeAndType() const { - Operator op = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); - graphStatus ret = ShapeRefiner::InferShapeAndType(shared_from_this(), op); - return ret; -} - -graphStatus Node::InferOriginFormat() const { - Operator op = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); - // Get infer func and execute - GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); - return op_->CallInferFormatFunc(op); -} -graphStatus Node::Verify() const { - const string data_type = "Data"; - const string aipp_data_type = "AippData"; - const string const_type = "Const"; - const string variable_type = "Variable"; - bool is_unknown_graph = GetOwnerComputeGraph()->GetGraphUnknownFlag(); - GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); - - if (!is_unknown_graph) { - for (const auto &in_anchor_ptr : GetAllInDataAnchors()) { - GE_IF_BOOL_EXEC(in_anchor_ptr == nullptr, GELOGW("in anchor ptr is null"); continue); - bool valid_anchor = - op_->GetType() == data_type || op_->GetType() == aipp_data_type || op_->GetType() == const_type || - op_->GetType() == variable_type || op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || - op_->MutableInputDesc(in_anchor_ptr->GetIdx()) == nullptr || in_anchor_ptr->GetPeerAnchors().size() > 0; - if (!valid_anchor) { - ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"opname", "index"}, - {GetName(), std::to_string(in_anchor_ptr->GetIdx())}); - GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); - return GRAPH_FAILED; - } - } - } - - string frameworkop_type = "FrameworkOp"; - bool need_update_name = op_->GetType() != frameworkop_type && !is_unknown_graph; - if (need_update_name) { - auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_->GetType()); - if (node_op.IsEmpty()) { - GELOGW("get op from OperatorFactory fail. opType: %s", op_->GetType().c_str()); - } else { - GELOGD("get op from OperatorFactory success. opType: %s", op_->GetType().c_str()); - auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); - if (temp_op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "temp op desc is null"); - return GRAPH_FAILED; - } - if (!op_->UpdateInputName(temp_op_desc->GetAllInputName())) { - GELOGW("Verify UpdateInputName failed"); - } - if (!op_->UpdateOutputName(temp_op_desc->GetAllOutputName())) { - GELOGW("Verify UpdateOutputName failed"); - } - } - node_op.BreakConnect(); - } - GE_IF_BOOL_EXEC(is_unknown_graph, return GRAPH_SUCCESS;); - if (op_->CommonVerify() == GRAPH_SUCCESS) { - Operator op_proxy = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); - auto verify_func = op_->GetVerifyFunc(); - if (verify_func == nullptr) { - verify_func = OperatorFactoryImpl::GetVerifyFunc(GetType()); - } - if (verify_func != nullptr) { - return (graphStatus)verify_func(op_proxy); - } - return GRAPH_SUCCESS; - } else { - GELOGE(GRAPH_FAILED, "%s Verify failed.", op_->GetType().c_str()); - return GRAPH_FAILED; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr Node::GetOpDesc() const { return op_; } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::UpdateOpDesc(const OpDescPtr &op_desc) { - GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); - GE_CHK_BOOL_EXEC(op_desc != nullptr, return GRAPH_PARAM_INVALID, "Param OpDesc is nullptr"); - GE_CHK_BOOL_EXEC(op_->GetInputsSize() == op_desc->GetInputsSize(), return GRAPH_PARAM_INVALID, - "Inputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetInputsSize(), - op_desc->GetInputsSize()); - - GE_CHK_BOOL_EXEC(op_->GetOutputsSize() == op_desc->GetOutputsSize(), return GRAPH_PARAM_INVALID, - "Outputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetOutputsSize(), - op_desc->GetOutputsSize()); - op_ = op_desc; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor> -Node::GetInDataNodesAndAnchors() const { - std::vector> vec; - for (const auto &p : in_data_anchors_) { - if (p == nullptr) { - GELOGW("indata anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); - continue; - } - auto anchor_ptr = p->GetPeerOutAnchor(); - if (anchor_ptr == nullptr) { - continue; - } - auto node = anchor_ptr->GetOwnerNode(); - if (node == nullptr) { - GELOGW("src node is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); - continue; - } - vec.push_back(std::make_pair(node, anchor_ptr)); - } - return Node::Vistor>(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor> -Node::GetOutDataNodesAndAnchors() const { - std::vector> vec; - for (const auto &p : out_data_anchors_) { - if (p == nullptr) { - GELOGW("out data anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); - continue; - } - for (const auto &in_anchor : p->GetPeerInDataAnchors()) { - if (in_anchor == nullptr) { - GELOGW("dst in data anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); - continue; - } - auto node = in_anchor->GetOwnerNode(); - if (node == nullptr) { - GELOGW("dst node is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); - continue; - } - vec.push_back(std::make_pair(node, in_anchor)); - } - } - return Node::Vistor>(shared_from_this(), vec); -} -} // namespace ge diff --git a/src/common/graph/op_desc.cc b/src/common/graph/op_desc.cc deleted file mode 100644 index dee0aece..00000000 --- a/src/common/graph/op_desc.cc +++ /dev/null @@ -1,1410 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/op_desc.h" -#include "debug/ge_attr_define.h" -#include "debug/ge_util.h" -#include "external/graph/operator.h" -#include "framework/common/debug/ge_log.h" -#include "common/util/error_manager/error_manager.h" -#include "graph/ge_attr_value.h" -#include "graph/ge_tensor.h" -#include "graph/operator_factory_impl.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/ge_ir_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "proto/ge_ir.pb.h" - -using std::make_pair; -using std::shared_ptr; -using std::string; -using std::vector; - -/*lint -save -e521 -e681 -e732 -e737*/ -namespace ge { -const std::string ATTR_NAME_ID = "id"; - -const std::string ATTR_NAME_STREAM_ID = "stream_id"; - -const std::string ATTR_NAME_INPUT_NAME = "input_name"; - -const std::string ATTR_NAME_SRC_NAME = "src_name"; - -const std::string ATTR_NAME_SRC_INDEX = "src_index"; - -const std::string ATTR_NAME_INPUT = "input"; - -const std::string ATTR_NAME_OUTPUT = "output"; - -const std::string ATTR_NAME_INPUT_DESC = "input_desc"; - -const std::string ATTR_NAME_OUTPUT_DESC = "output_desc"; - -const std::string ATTR_NAME_DST_NAME = "dst_name"; - -const std::string ATTR_NAME_DST_INDEX = "dst_index"; - -const std::string ATTR_NAME_WORKSPACE = "workspace"; - -const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes"; - -const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; - -const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends"; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() { - op_def_.InitDefault(); - if (op_def_.GetProtoMsg() != nullptr) { - op_def_.GetProtoMsg()->set_has_out_attr(true); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::~OpDesc() {} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const std::string &name, const std::string &type) { - op_def_.InitDefault(); - if (op_def_.GetProtoMsg() != nullptr) { - op_def_.GetProtoMsg()->set_has_out_attr(true); - } - SetName(name); - SetType(type); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const ProtoMsgOwner &proto_msg_owner, - ge::proto::OpDef *op_def) - : op_def_(proto_msg_owner, op_def) { - if (op_def != nullptr && !op_def->has_out_attr()) { - op_def->set_has_out_attr(true); - - int64_t id = 0; - (void)AttrUtils::GetInt(this, ATTR_NAME_ID, id); - op_def->set_id(id); - - int64_t stream_id = 0; - (void)AttrUtils::GetInt(this, ATTR_NAME_STREAM_ID, stream_id); - op_def->set_stream_id(stream_id); - - vector input_name; - (void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME, input_name); - for (auto &item : input_name) { - op_def->add_input_name(item); - } - vector src_name; - (void)AttrUtils::GetListStr(this, ATTR_NAME_SRC_NAME, src_name); - for (auto &item : src_name) { - op_def->add_src_name(item); - } - vector src_index; - (void)AttrUtils::GetListInt(this, ATTR_NAME_SRC_INDEX, src_index); - for (auto &item : src_index) { - op_def->add_src_index(item); - } - vector input; - (void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT, input); - for (auto &item : input) { - op_def->add_input_i(item); - } - vector output; - (void)AttrUtils::GetListInt(this, ATTR_NAME_OUTPUT, output); - for (auto &item : output) { - op_def->add_output_i(item); - } - vector dst_name; - (void)AttrUtils::GetListStr(this, ATTR_NAME_DST_NAME, dst_name); - for (auto &item : dst_name) { - op_def->add_dst_name(item); - } - vector dst_index; - (void)AttrUtils::GetListInt(this, ATTR_NAME_DST_INDEX, dst_index); - for (auto &item : dst_index) { - op_def->add_dst_index(item); - } - vector workspace; - (void)AttrUtils::GetListInt(this, ATTR_NAME_WORKSPACE, workspace); - for (auto &item : workspace) { - op_def->add_workspace(item); - } - vector workspace_bytes; - (void)AttrUtils::GetListInt(this, ATTR_NAME_WORKSPACE_BYTES, workspace_bytes); - for (auto &item : workspace_bytes) { - op_def->add_workspace_bytes(item); - } - vector is_input_const; - (void)AttrUtils::GetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const); - for (auto item : is_input_const) { - op_def->add_is_input_const(item); - } - auto input_desc_mutable_list = (*op_def->mutable_attr())[ATTR_NAME_INPUT_DESC].mutable_list(); - if (input_desc_mutable_list != nullptr) { - *op_def->mutable_input_desc() = *(input_desc_mutable_list->mutable_td()); - } - auto output_desc_mutable_list = (*op_def->mutable_attr())[ATTR_NAME_OUTPUT_DESC].mutable_list(); - if (output_desc_mutable_list != nullptr) { - *op_def->mutable_output_desc() = *(output_desc_mutable_list->mutable_td()); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetName() const { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - return proto_msg->name(); - } - return ""; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetName(const std::string &name) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_name(name); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetType() const { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - return proto_msg->type(); - } - return ""; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetType(const string &type) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_type(type); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddInputDesc(const ge::GeTensorDesc &input_desc) { - int index = static_cast(inputs_desc_.size()); - return AddInputDesc("__input" + std::to_string(index), input_desc); -} - -graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_desc) { - graphStatus ret = GRAPH_SUCCESS; - if (index < inputs_desc_.size()) { - // InputsDesc[index] is exist, then update it - ret = UpdateInputDesc(index, input_desc); - } else { - // InputDesc[index] is not exist, then add it - ret = AddInputDesc(input_desc); - } - return ret; -} - -graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { - if (input_name_idx_.find(name) != input_name_idx_.end()) { - GELOGI("input %s is exist, update it", name.c_str()); - graphStatus ret = UpdateInputDesc(name, input_desc); - return ret; - } else { - int index = static_cast(inputs_desc_.size()); - std::shared_ptr in_desc = ComGraphMakeShared(input_desc); - if (in_desc == nullptr) { - GELOGE(GRAPH_FAILED, "AddInputDesc failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - inputs_desc_.push_back(in_desc); - (void)input_name_idx_.insert(make_pair(name, index)); - if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) { - register_input_name_.push_back(name); - } - - return GRAPH_SUCCESS; - } -} - -graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int num, size_t index) { - for (unsigned int i = 0; i < num; i++) { - string input_name = name + std::to_string(i); - GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, - "Add input tensor_desc is existed. name[%s]", input_name.c_str()); - - std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); - if (in_desc == nullptr) { - GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - if (index > inputs_desc_.size()) { - GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); - return GRAPH_FAILED; - } - - (void)inputs_desc_.insert(inputs_desc_.begin() + index + i, in_desc); - - // Update index in input_name_idx - for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { - if (it->second >= (index + i)) { - it->second += 1; - } - } - - (void)input_name_idx_.insert(make_pair(input_name, i + index)); - } - - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::AddOutputDescMiddle(const string &name, const unsigned int num, size_t index) { - for (unsigned int i = 0; i < num; i++) { - string output_name = name + std::to_string(i); - GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED, - "Add input tensor_desc is existed. name[%s]", output_name.c_str()); - - std::shared_ptr out_desc = ComGraphMakeShared(GeTensorDesc()); - if (out_desc == nullptr) { - GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - if (index > outputs_desc_.size()) { - GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); - return GRAPH_FAILED; - } - - (void)outputs_desc_.insert(outputs_desc_.begin() + index + i, out_desc); - - // Update index in input_name_idx - for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) { - if (it->second >= (index + i)) { - it->second += 1; - } - } - - (void)output_name_idx_.insert(make_pair(output_name, i + index)); - } - - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { - for (unsigned int i = 0; i < num; i++) { - string input_name = name + std::to_string(i); - GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, - "Add input tensor_desc is existed. name[%s]", input_name.c_str()); - - std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); - if (in_desc == nullptr) { - GELOGE(GRAPH_FAILED, "AddInputDescForward failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); - - // Update index in input_name_idx - for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { - it->second += 1; - } - - (void)input_name_idx_.insert(make_pair(input_name, 0)); - } - - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int num) { - for (unsigned int i = 0; i < num; i++) { - string output_name = name + std::to_string(i); - GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED, - "Add output tensor_desc is existed. name[%s]", output_name.c_str()); - - std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); - if (in_desc == nullptr) { - GELOGE(GRAPH_FAILED, "AddOutputDescForward failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - (void)outputs_desc_.insert(outputs_desc_.begin(), in_desc); - - // Update index in output_name_idx - for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) { - it->second += 1; - } - (void)output_name_idx_.insert(make_pair(output_name, 0)); - } - - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { - if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED; - (void)optional_input_names_.insert(name); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { - if (index >= inputs_desc_.size()) { - GELOGW("The index is invalid. index[%u]", index); - return GRAPH_FAILED; - } - - inputs_desc_[index] = ComGraphMakeShared(tensor_Desc); - if (inputs_desc_[index] == nullptr) { - GELOGE(GRAPH_FAILED, "UpdateInputDesc failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const { - return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") && - IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && - IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") && - IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && - IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const { - const auto &op_def = this->op_def_.GetProtoMsg(); - const auto &r_op_def = r_op_desc.op_def_.GetProtoMsg(); - if ((op_def != nullptr) && (r_op_def != nullptr)) { - // Message OpDef in ge_ir.proto - return ( - IsEqual(op_def->name(), r_op_def->name(), "OpDef_.name()") && - IsEqual(op_def->type(), r_op_def->type(), "OpDef_.type()") && - IsEqual(ToString(op_def->input()), ToString(r_op_def->input()), "OpDef_.input()") && - IsEqual(op_def->has_out_attr(), r_op_def->has_out_attr(), "OpDef_.has_out_attr()") && - IsEqual(op_def->stream_id(), r_op_def->stream_id(), "OpDef_.stream_id()") && - IsEqual(ToString(op_def->input_name()), ToString(r_op_def->input_name()), "OpDef_.input_name()") && - IsEqual(ToString(op_def->src_name()), ToString(r_op_def->src_name()), "OpDef_.src_name()") && - IsEqual(ToString(op_def->dst_name()), ToString(r_op_def->dst_name()), "OpDef_.dst_name()") && - IsEqual(ToString(op_def->src_index()), ToString(r_op_def->src_index()), "OpDef_.src_index()") && - IsEqual(ToString(op_def->dst_index()), ToString(r_op_def->dst_index()), "OpDef_.dst_index()") && - IsEqual(ToString(op_def->input_i()), ToString(r_op_def->input_i()), "OpDef_.input_i()") && - IsEqual(ToString(op_def->output_i()), ToString(r_op_def->output_i()), "OpDef_.output_i()") && - IsEqual(ToString(op_def->workspace()), ToString(r_op_def->workspace()), "OpDef_.workspace()") && - IsEqual(ToString(op_def->workspace_bytes()), ToString(r_op_def->workspace_bytes()), "OpDef_.workspace_bytes()") && - IsEqual(ToString(op_def->is_input_const()), ToString(r_op_def->is_input_const()), "OpDef_.is_input_const()")); - } else { - return ((op_def == nullptr) && (r_op_def == nullptr)); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescGenTensorDescsAreEqual( - const OpDesc &r_op_desc) const { - // 1.Verify inputs and outputs desc size - const auto inputs_desc_size = this->inputs_desc_.size(); - const auto r_inputs_desc_size = r_op_desc.inputs_desc_.size(); - if (inputs_desc_size != r_inputs_desc_size) { - GELOGE(GRAPH_FAILED, "Size of OpDesc's inputs desc verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - const auto outputs_desc_size = this->outputs_desc_.size(); - const auto r_outputs_desc_size = r_op_desc.outputs_desc_.size(); - if (outputs_desc_size != r_outputs_desc_size) { - GELOGE(GRAPH_FAILED, "Size of OpDesc's outputs desc verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - // 2.Verify all inputs desc equal - for (uint32_t i = 0; i < inputs_desc_size; i++) { - const auto &in_ge_tensor_desc = this->GetInputDesc(i); - const auto &r_in_ge_tensor_desc = r_op_desc.GetInputDesc(i); - // Determine the connection relationship by GeTensorDesc - if (!(in_ge_tensor_desc == r_in_ge_tensor_desc)) { - GELOGE(GRAPH_FAILED, "Link info of OpDesc's inputs desc verify failed, OpDesc name: %s.", - this->GetName().c_str()); - return false; - } - } - // 3.Verify all outputs desc equal - for (uint32_t i = 0; i < outputs_desc_size; i++) { - const auto &out_ge_tensor_desc = this->GetOutputDesc(i); - const auto &r_out_ge_tensor_desc = r_op_desc.GetOutputDesc(i); - if (!(out_ge_tensor_desc == r_out_ge_tensor_desc)) { - GELOGE(GRAPH_FAILED, "Link info of OpDesc's outputs desc verify failed, OpDesc name: %s.", - this->GetName().c_str()); - return false; - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpDesc &r_op_desc) const { - return (OpDescAttrsAreEqual(r_op_desc) && OpDescMembersAreEqual(r_op_desc) && - OpDescGenTensorDescsAreEqual(r_op_desc)); -} - -graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { - auto it = input_name_idx_.find(name); - if (it == input_name_idx_.end()) { - GELOGW("Cann't find the input desc. name[%s]", name.c_str()); - return GRAPH_FAILED; - } - if (it->second >= inputs_desc_.size()) { - GELOGE(GRAPH_FAILED, "[%d] more than size of inputs_desc_", it->second); - return GRAPH_FAILED; - } - GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); - return GRAPH_FAILED); - inputs_desc_[it->second] = ComGraphMakeShared(tensor_Desc); - if (inputs_desc_[it->second] == nullptr) { - GELOGE(GRAPH_FAILED, "UpdateInputDesc failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -bool OpDesc::InputIsSet(const string &name) const { - auto it = input_name_idx_.find(name); - if (it != input_name_idx_.end()) { - GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false); - auto tensor_desc = inputs_desc_[it->second]; - GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false); - auto dims = tensor_desc->GetShape().GetDims(); - if (dims.size() > 0) { - return true; - } - } - return false; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG(index < inputs_desc_.size(), GeTensorDesc()); - return *(inputs_desc_[index].get()); -} - -GeTensorDesc OpDesc::GetInputDesc(const string &name) const { - auto it = input_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc()); - GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc()); - return *(inputs_desc_[it->second].get()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); - if (inputs_desc_[index] == nullptr) { - return nullptr; - } - if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) { - GELOGW("input desc is invalid"); - return nullptr; - } - return inputs_desc_[index]; -} - -GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const { - auto input_name_idx = GetAllInputName(); - auto it = input_name_idx.find(name); - if (it == input_name_idx.end()) { - GELOGW("Failed to get [%s] input desc", name.c_str()); - return nullptr; - } - return MutableInputDesc(it->second); -} - -GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputNames() const { - vector names; - if (input_name_idx_.empty()) { - return OpDesc::Vistor(shared_from_this(), names); - } - for (std::pair input : input_name_idx_) { - names.push_back(input.first); - } - return OpDesc::Vistor(shared_from_this(), names); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpKernelLibName(const std::string &name) { - op_kernel_lib_name_ = name; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpKernelLibName() const { - return op_kernel_lib_name_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(const std::string &name) { - engine_name_ = name; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDesc() const { - vector temp{}; - for (const auto &it : inputs_desc_) { - if (it->IsValid() == GRAPH_SUCCESS) { - temp.push_back(*it); - } else { - GELOGW("this inputDesc is InValid, it won't be return"); - continue; - } - } - return OpDesc::Vistor(shared_from_this(), temp); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDescPtr() const { - vector temp{}; - for (const auto &it : inputs_desc_) { - if (it->IsValid() == GRAPH_SUCCESS) { - temp.push_back(it); - } else { - GELOGW("this inputDesc is InValid, it won't be return"); - continue; - } - } - return OpDesc::Vistor(shared_from_this(), temp); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetInputsSize() const { - // Just return valid inputs size.InValid desc is set in default OPTION_INPUT register. - size_t size = 0; - for (const auto &it : inputs_desc_) { - if (it->IsValid() == GRAPH_SUCCESS) { - size++; - } - } - return size; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetAllInputsSize() const { return inputs_desc_.size(); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) { - int index = static_cast(outputs_desc_.size()); - return AddOutputDesc("__output" + std::to_string(index), output_desc); -} - -graphStatus OpDesc::AddOutputDesc(const string &name, const ge::GeTensorDesc &output_desc) { - GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(name) == output_name_idx_.end()), GRAPH_FAILED, - "Add output tensor_Desc is existed. name[%s]", name.c_str()); - int index = static_cast(outputs_desc_.size()); - - std::shared_ptr tensor = ComGraphMakeShared(output_desc); - if (tensor == nullptr) { - GELOGE(GRAPH_FAILED, "AddOutputDesc failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - outputs_desc_.push_back(tensor); - (void)output_name_idx_.insert(make_pair(name, index)); - if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) { - register_output_name_.push_back(name); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDesc::UpdateOutputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { - GE_CHK_BOOL_RET_STATUS((index < outputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); - - outputs_desc_[index] = ComGraphMakeShared(tensor_Desc); - if (outputs_desc_[index] == nullptr) { - GELOGE(GRAPH_FAILED, "UpdateOutputDesc failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::UpdateOutputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { - auto it = output_name_idx_.find(name); - if (it == output_name_idx_.end()) { - GELOGW("Cann't find the output desc. name[%s]", name.c_str()); - return GRAPH_FAILED; - } - GE_IF_BOOL_EXEC(it->second >= outputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); - return GRAPH_FAILED); - outputs_desc_[it->second] = ComGraphMakeShared(tensor_Desc); - if (outputs_desc_[it->second] == nullptr) { - GELOGE(GRAPH_FAILED, "UpdateOutputDesc failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetOutputDesc(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG(index < outputs_desc_.size(), GeTensorDesc()); - return *(outputs_desc_[index].get()); -} - -GeTensorDesc OpDesc::GetOutputDesc(const string &name) const { - auto it = output_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), GeTensorDesc()); - GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < outputs_desc_.size(), GeTensorDesc()); - return *(outputs_desc_[it->second].get()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS(index < outputs_desc_.size(), nullptr, "Cann't find the output desc %u", index); - return outputs_desc_[index]; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const { - auto it = output_name_idx_.find(name); - if (it == output_name_idx_.end()) { - GELOGW("Failed to get [%s] output desc", name.c_str()); - return nullptr; - } - return MutableOutputDesc(it->second); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { - return static_cast(outputs_desc_.size()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllOutputsDesc() const { - vector temp{}; - for (const auto &it : outputs_desc_) { - temp.push_back(*it); - } - return OpDesc::Vistor(shared_from_this(), temp); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllOutputsDescPtr() const { - return OpDesc::Vistor(shared_from_this(), outputs_desc_); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetOutputsSize() const { return outputs_desc_.size(); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetOutputDescPtr(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast(outputs_desc_.size()), nullptr); - return outputs_desc_[index]; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast(inputs_desc_.size()), nullptr); - if (inputs_desc_[index] == nullptr) { - return nullptr; - } - if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) { - GELOGW("inputsDesc[%u] is InValid", index); - return nullptr; - } else { - return inputs_desc_[static_cast(index)]; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr -OpDesc::GetInputDescPtrDfault(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG((index) < (uint32_t)(inputs_desc_.size()), nullptr); - return inputs_desc_[(int32_t)index]; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const { - auto it = input_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), shared_ptr()); - return inputs_desc_[it->second]; -} - -graphStatus OpDesc::AddRegisterInputName(const std::string &name) { - if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) { - register_input_name_.push_back(name); - } - - return GRAPH_SUCCESS; -} - -vector OpDesc::GetRegisterInputName() const { return register_input_name_; } - -graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) { - if (is_push_back) { - for (unsigned int i = 0; i < num; i++) { - if (AddInputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) return GRAPH_FAILED; - } - } else { - if (AddInputDescForward(name, num) != GRAPH_SUCCESS) return GRAPH_FAILED; - } - if (AddRegisterInputName(name) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index) { - if (AddInputDescMiddle(name, num, index) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::AddRegisterOutputName(const string &name) { - if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) { - register_output_name_.push_back(name); - } - - return GRAPH_SUCCESS; -} - -vector OpDesc::GetRegisterOutputName() const { return register_output_name_; } - -graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int num, bool is_push_back) { - if (is_push_back) { - for (unsigned int i = 0; i < num; i++) { - if (AddOutputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) return GRAPH_FAILED; - } - } else { - if (AddOutputDescForward(name, num) != GRAPH_SUCCESS) return GRAPH_FAILED; - } - - if (AddRegisterOutputName(name) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -bool OpDesc::IsOptionalInput(const string &name) const { - return optional_input_names_.find(name) != optional_input_names_.end(); -} - -bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } - -std::map OpDesc::GetAllInputName() const { return input_name_idx_; } - -std::map OpDesc::GetAllOutputName() { return output_name_idx_; } - -bool OpDesc::UpdateInputName(std::map input_name_idx) { - bool ret = true; - // Use inputDesc_.size() to contain the InValid OptionInput.GetInputsSize() will remove default OptionInput name. - auto input_map_size = inputs_desc_.size(); - auto factory_map_size = input_name_idx.size(); - // It indicates that some inputs have no optionalname. - // The redundant optionalname of factory needs to be deleted and then assigned - if (input_map_size < factory_map_size) { - GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size, - factory_map_size); - for (auto it = input_name_idx.begin(); it != input_name_idx.end();) { - if (it->second >= input_map_size) { - it = input_name_idx.erase(it); - } else { - ++it; - } - } - if (input_name_idx.size() == input_map_size) { - GELOGI("UpdateInputName"); - input_name_idx_ = input_name_idx; - } else { - ret = false; - GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size()); - } - } else if (input_map_size == factory_map_size) { - input_name_idx_ = input_name_idx; - } else { - ret = false; - GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); - } - return ret; -} - -bool OpDesc::UpdateOutputName(std::map output_name_idx) { - size_t output_map_size = GetAllOutputsDescSize(); - size_t factory_map_size = output_name_idx.size(); - if (output_map_size < factory_map_size) { - GELOGI("UpdateOutputName org outputname map size: %zu, factory outputname map size: %zu", output_map_size, - factory_map_size); - for (auto it = output_name_idx.begin(); it != output_name_idx.end();) { - if (it->second >= output_map_size) { - it = output_name_idx.erase(it); - } else { - ++it; - } - } - if (output_name_idx.size() == output_map_size) { - GELOGI("UpdateoutputName"); - output_name_idx_ = output_name_idx; - return true; - } - } else if (output_map_size == factory_map_size) { - output_name_idx_ = output_name_idx; - return true; - } else { - GELOGW("UpdateOutputName org name map size: %zu, factory map size: %zu", output_map_size, factory_map_size); - return false; - } - GELOGW("UpdateOutputName org name map size: %zu, factory map size: %zu", output_map_size, factory_map_size); - return false; -} - -std::function OpDesc::GetInferFunc() const { return infer_func_; } - -std::function OpDesc::GetVerifyFunc() const { return verifier_func_; } - -void OpDesc::AddInferFunc(const std::function &func) { infer_func_ = func; } - -std::function OpDesc::GetInferFormatFunc() const { return infer_format_func_; } - -void OpDesc::AddInferFormatFunc(const std::function &func) { infer_format_func_ = func; } - -void OpDesc::AddVerifierFunc(const std::function &func) { verifier_func_ = func; } - -graphStatus OpDesc::InferShapeAndType() { - if (infer_func_ == nullptr) { - infer_func_ = OperatorFactoryImpl::GetInferShapeFunc(GetType()); - if (infer_func_ == nullptr) { - GELOGW("%s does not have inferfunc_.", GetName().c_str()); - /// The infoshape function has not been added for each operator in the current operator information library. - /// No infoshape added operator skips the call - /// and directly uses the shape information passed down by the upper framework - return GRAPH_SUCCESS; - } - } - Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this()); - graphStatus ret = (graphStatus)infer_func_(op_proxy); - op_proxy.BreakConnect(); - return ret; -} - -graphStatus OpDesc::DefaultInferFormat() { - ge::Format first_none_nd_format = FORMAT_ND; - auto input_descs = GetAllInputsDescPtr(); - auto output_descs = GetAllOutputsDescPtr(); - // Overall input and output,get the first non-nd format - for (const auto &input_desc : input_descs) { - Format origin_format = input_desc->GetOriginFormat(); - if (origin_format != FORMAT_ND) { - first_none_nd_format = origin_format; - break; - } - } - for (const auto &output_desc : output_descs) { - Format origin_format = output_desc->GetOriginFormat(); - if (origin_format != FORMAT_ND) { - first_none_nd_format = origin_format; - break; - } - } - // Refresh all input output format - GELOGD("Default infer format.node[%s], first none nod format is:%d", GetName().c_str(), first_none_nd_format); - - for (const auto &input_desc : input_descs) { - Format origin_format = input_desc->GetOriginFormat(); - GELOGD("Default infer format[in].node[%s].origin format is:%d", GetName().c_str(), origin_format); - if (origin_format == FORMAT_ND) { - input_desc->SetOriginFormat(first_none_nd_format); - input_desc->SetFormat(first_none_nd_format); - } - } - for (const auto &output_desc : output_descs) { - Format origin_format = output_desc->GetOriginFormat(); - GELOGD("Default infer format[out].node[%s].origin format is:%d", GetName().c_str(), origin_format); - if (origin_format == FORMAT_ND) { - output_desc->SetOriginFormat(first_none_nd_format); - output_desc->SetFormat(first_none_nd_format); - } - } - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::OpVerify() { - if (verifier_func_ == nullptr) { - verifier_func_ = OperatorFactoryImpl::GetVerifyFunc(GetType()); - } - if (verifier_func_ != nullptr) { - Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this()); - graphStatus ret = (graphStatus)verifier_func_(op_proxy); - op_proxy.BreakConnect(); - return ret; - } - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::CommonVerify() const { - for (const string &iname : GetAllInputNames()) { - // Checking shape of all inputs - vector ishape = GetInputDescPtr(iname)->GetShape().GetDims(); - for (int64_t dim : ishape) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - dim < -2, ErrorManager::GetInstance().ATCReportErrMessage( - "E19014", {"opname", "value", "reason"}, - {GetName(), "input " + iname + " shape", "contains negative or zero dimension"}); - return GRAPH_FAILED, "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(), - iname.c_str()); - } - } - // Check all attributes defined - const auto &all_attributes = GetAllAttrs(); - for (const auto &name : GetAllAttrNames()) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - all_attributes.find(name) == all_attributes.end(), - ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, - {GetName(), "attribute " + name, "is empty"}); - return GRAPH_FAILED, "operator attribute %s is empty.", name.c_str()); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const { - auto it = input_name_idx_.begin(); - for (; it != input_name_idx_.end(); ++it) { - if (it->second == index) { - break; - } - } - GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), ""); - return it->first; -} - -int OpDesc::GetInputIndexByName(const string &name) const { - auto it_find = input_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1); - return static_cast(it_find->second); -} - -int OpDesc::GetValidInputIndexByName(const string &name) const { - map valid_input_name_idx{}; - uint32_t j = 0; - for (size_t i = 0; i < GetAllInputsSize(); i++) { - if (MutableInputDesc(static_cast(i)) != nullptr) { - auto valid_name = GetInputNameByIndex(static_cast(i)); - GE_CHK_BOOL_RET_STATUS_NOLOG(!valid_name.empty(), -1); - valid_input_name_idx.insert({valid_name, j}); - j++; - } - } - auto it_find = valid_input_name_idx.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != valid_input_name_idx.end(), -1); - return static_cast(it_find->second); -} - -string OpDesc::GetValidInputNameByIndex(uint32_t index) const { - map valid_input_name_idx{}; - uint32_t j = 0; - for (size_t i = 0; i < GetAllInputsSize(); i++) { - if (MutableInputDesc(static_cast(i)) != nullptr) { - auto valid_name = GetInputNameByIndex(static_cast(i)); - GE_CHK_BOOL_RET_STATUS_NOLOG(!valid_name.empty(), ""); - valid_input_name_idx.insert({valid_name, j}); - j++; - } - } - auto it = valid_input_name_idx.begin(); - for (; it != valid_input_name_idx.end(); ++it) { - if (it->second == index) { - break; - } - } - GE_CHK_BOOL_RET_STATUS_NOLOG(it != valid_input_name_idx.end(), ""); - return it->first; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetOutputNameByIndex(uint32_t index) const { - auto it = output_name_idx_.begin(); - for (; it != output_name_idx_.end(); ++it) { - if (it->second == index) { - break; - } - } - GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), ""); - return it->first; -} - -int OpDesc::GetOutputIndexByName(const string &name) const { - auto it_find = output_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != output_name_idx_.end(), -1); - return static_cast(it_find->second); -} - -ProtoAttrMapHelper OpDesc::MutableAttrMap() { - if (op_def_.GetProtoMsg() == nullptr) { - GELOGE(GRAPH_FAILED, "op def get proto msg failed"); - return GeIrProtoHelper(); - } - return ProtoAttrMapHelper(op_def_.GetProtoOwner(), op_def_.GetProtoMsg()->mutable_attr()); -} - -ConstProtoAttrMapHelper OpDesc::GetAttrMap() const { - return ConstProtoAttrMapHelper(op_def_.GetProtoOwner(), &op_def_.GetProtoMsg()->attr()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetId(int64_t id) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_id(id); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetId() const { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - return proto_msg->id(); - } - return 0; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetStreamId(int64_t stream_id) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_stream_id(stream_id); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetStreamId() const { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - return proto_msg->stream_id(); - } - return 0; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputName(const vector &input_name) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_input_name(); - for (auto &item : input_name) { - proto_msg->add_input_name(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetInputName() const { - vector input_name; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->input_name()) { - input_name.push_back(item); - } - } - return input_name; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcName(const vector &src_name) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_src_name(); - for (auto &item : src_name) { - proto_msg->add_src_name(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetSrcName() const { - vector src_name; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->src_name()) { - src_name.push_back(item); - } - } - return src_name; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcIndex(const vector &src_index) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_src_index(); - for (auto &item : src_index) { - proto_msg->add_src_index(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetSrcIndex() const { - vector src_index; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->src_index()) { - src_index.push_back(item); - } - } - return src_index; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputOffset(const vector &input) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_input_i(); - for (auto &item : input) { - proto_msg->add_input_i(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetInputOffset() const { - vector input; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->input_i()) { - input.push_back(item); - } - } - return input; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOutputOffset(const vector &output) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_output_i(); - for (auto &item : output) { - proto_msg->add_output_i(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetOutputOffset() const { - vector output; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->output_i()) { - output.push_back(item); - } - } - return output; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstName(const vector &dst_name) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_dst_name(); - for (auto &item : dst_name) { - proto_msg->add_dst_name(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetDstName() const { - vector dst_name; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->dst_name()) { - dst_name.push_back(item); - } - } - return dst_name; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpInferDepends(const vector &depend_names) { - auto ret = AttrUtils::SetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); - if (ret != true) { - GELOGE(GRAPH_FAILED, "set op_infer_depends fail."); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetOpInferDepends() const { - vector depend_names; - (void)AttrUtils::GetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); - return depend_names; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstIndex(const vector &dst_index) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_dst_index(); - for (auto &item : dst_index) { - proto_msg->add_dst_index(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetDstIndex() const { - vector dst_index; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->dst_index()) { - dst_index.push_back(item); - } - } - return dst_index; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspace(const vector &workspace) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_workspace(); - for (auto &item : workspace) { - proto_msg->add_workspace(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetWorkspace() const { - vector workspace; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->workspace()) { - workspace.push_back(item); - } - } - return workspace; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspaceBytes(const vector &workspace_bytes) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_workspace_bytes(); - for (auto &item : workspace_bytes) { - proto_msg->add_workspace_bytes(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetWorkspaceBytes() const { - vector workspace_bytes; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->workspace_bytes()) { - workspace_bytes.push_back(item); - } - } - return workspace_bytes; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetIsInputConst(const vector &is_input_const) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_is_input_const(); - for (auto item : is_input_const) { - proto_msg->add_is_input_const(item); - } - } - // If comes from ME,which is_input_const exist as attrs, outside no need to check GE_TRAIN flag - auto ret = AttrUtils::SetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const); - if (ret != true) { - GELOGE(GRAPH_FAILED, "set is_input_const fail."); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetIsInputConst() const { - vector is_input_const; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto item : proto_msg->is_input_const()) { - is_input_const.push_back(item); - } - } - return is_input_const; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name, - const int &index) { - if (input_name_idx_.find(name) != input_name_idx_.end()) { - GELOGI("Restore input name index is existed. name[%s]", name.c_str()); - } - (void)input_name_idx_.insert(make_pair(name, index)); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreOutputNameIdx(const string &name, - const int &index) { - if (output_name_idx_.find(name) != output_name_idx_.end()) { - GELOGI("Restore output name index is existed. name[%s]", name.c_str()); - } - (void)output_name_idx_.insert(make_pair(name, index)); - return GRAPH_SUCCESS; -} -graphStatus OpDesc::CallInferFunc(Operator &op) { - if (infer_func_ == nullptr) { - infer_func_ = OperatorFactoryImpl::GetInferShapeFunc(GetType()); - if (infer_func_ == nullptr) { - GELOGW("%s does not have infer func.", GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - } - graphStatus graph_status = (graphStatus)infer_func_(op); - if (graph_status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "%s call infer func. ret: %u", GetName().c_str(), graph_status); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} -graphStatus OpDesc::CallInferFormatFunc(Operator &op) { - if (infer_format_func_ == nullptr) { - infer_format_func_ = OperatorFactoryImpl::GetInferFormatFunc(GetType()); - if (infer_format_func_ == nullptr) { - return DefaultInferFormat(); - } - } - return (graphStatus)infer_format_func_(op); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetSubgraphInstanceName(uint32_t index) const { - if (static_cast(index) >= subgraph_instance_names_.size()) { - return ""; - } - return subgraph_instance_names_.at(index); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector &OpDesc::GetSubgraphInstanceNames() - const { - return subgraph_instance_names_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { - for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { - if (*iter == name) { - *iter = ""; - return; - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { - GELOGI("Add subgraph name is %s", name.c_str()); - auto iter = subgraph_names_to_index_.find(name); - if (iter != subgraph_names_to_index_.end()) { - GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second); - return GRAPH_FAILED; - } - auto size = subgraph_names_to_index_.size(); - subgraph_names_to_index_[name] = size; - subgraph_instance_names_.resize(size + 1); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map &OpDesc::GetSubgraphNameIndexes() - const { - return subgraph_names_to_index_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::SetSubgraphInstanceName(uint32_t index, - const std::string &name) { - GELOGI("Add sub graph instans name is %s, index is %u", name.c_str(), index); - if (index >= subgraph_instance_names_.size()) { - GE_LOGE("The index %u exceeds the max instance coutn %zu", index, subgraph_instance_names_.size()); - return GRAPH_PARAM_INVALID; - } - subgraph_instance_names_[index] = name; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RegisterSubgraphIrName(const string &name, - SubgraphType type) { - subgraph_ir_names_to_type_[name] = type; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map &OpDesc::GetSubgraphIrNames() - const { - return subgraph_ir_names_to_type_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY SubgraphType -OpDesc::GetSubgraphTypeByIrName(const std::string &name) const { - auto iter = subgraph_ir_names_to_type_.find(name); - if (iter == subgraph_ir_names_to_type_.end()) { - return kSubgraphTypeEnd; - } - return iter->second; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDesc::GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const { - for (size_t idx = 0; idx < subgraph_instance_names_.size(); ++idx) { - if (subgraph_instance_names_[idx] != instance_name) { // find subgraph index. - continue; - } - - for (auto name_to_index : subgraph_names_to_index_) { - if (name_to_index.second != idx) { // find subgraph name. - continue; - } - - subgraph_name = name_to_index.first; - return GRAPH_SUCCESS; - } - } - - return GRAPH_PARAM_INVALID; -} - -} // namespace ge diff --git a/src/common/graph/op_imp.cc b/src/common/graph/op_imp.cc deleted file mode 100644 index 9abf242b..00000000 --- a/src/common/graph/op_imp.cc +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "debug/ge_log.h" -#include "debug/ge_util.h" - -using namespace std; - -namespace ge { - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -BroadCastInfer(const function()>& get_in1_shape, const function()>& get_in2_shape, - const function& outShape)>& set_out_shape) { - auto x1_shape = get_in1_shape(); - auto x2_shape = get_in2_shape(); - vector y_shape; - - if (x1_shape.empty()) { - y_shape = x2_shape; - set_out_shape(y_shape); - return GRAPH_SUCCESS; - } - if (x2_shape.empty()) { - y_shape = x1_shape; - set_out_shape(y_shape); - return GRAPH_SUCCESS; - } - - int len_diff = static_cast(x1_shape.size() - x2_shape.size()); - if (len_diff >= 0) { - for (int i = 0; i < len_diff; i++) { - y_shape.push_back(x1_shape[i]); - } - int x2_shape_size = static_cast(x2_shape.size()); - for (int i = 0; i < x2_shape_size; i++) { - bool shapeFlag = - ((x1_shape[i + len_diff] != x2_shape[i]) && (std::min(x1_shape[i + len_diff], x2_shape[i]) != 1)); - if (shapeFlag) { - GE_LOGE("operands could not be broadcast together"); - return GRAPH_FAILED; - } - y_shape.push_back(std::max(x1_shape[i + len_diff], x2_shape[i])); - } - } else { - for (int i = 0; i < -len_diff; i++) { - y_shape.push_back(x2_shape[i]); - } - int x1_shape_size = static_cast(x1_shape.size()); - for (int i = 0; i < x1_shape_size; i++) { - bool shapeFlag = - ((x1_shape[i] != x2_shape[i - len_diff]) && (std::min(x1_shape[i], x2_shape[i - len_diff]) != 1)); - if (shapeFlag) { - GE_LOGE("operands could not be broadcast together"); - return GRAPH_FAILED; - } - y_shape.push_back(std::max(x1_shape[i], x2_shape[i - len_diff])); - } - } - set_out_shape(y_shape); - return GRAPH_SUCCESS; -} - -} // namespace ge diff --git a/src/common/graph/operator.cc b/src/common/graph/operator.cc deleted file mode 100644 index 21554fa1..00000000 --- a/src/common/graph/operator.cc +++ /dev/null @@ -1,1587 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "external/graph/operator.h" -#include "external/graph/operator_factory.h" -#include -#include -#include -#include -#include -#include "./array_ops.h" -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "external/graph/attr_value.h" -#include "external/graph/types.h" -#include "framework/common/debug/ge_log.h" -#include "graph/compute_graph.h" -#include "graph/ge_attr_value.h" -#include "graph/ge_context.h" -#include "graph/ge_tensor.h" -#include "graph/node.h" -#include "graph/op_desc.h" -#include "graph/runtime_inference_context.h" -#include "graph/usr_types.h" -#include "graph/utils/node_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "utils/graph_utils.h" -#include "utils/op_desc_utils.h" -#include "utils/tensor_adapter.h" -#include "utils/tensor_utils.h" -#include "utils/type_utils.h" -#include -#include -#include -#include -#include - -using std::enable_shared_from_this; -using std::make_pair; -using std::shared_ptr; -using std::string; -using std::to_string; -using std::vector; - -/*lint -save -e529 -e728*/ -/*lint -e446 -e732*/ -/*lint -e665*/ -namespace ge { -class OpIO { - public: - OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {} - - ~OpIO() = default; - - string GetName() const { return name_; } - - int GetIndex() const { return index_; } - - OperatorImplPtr GetOwner() const { return owner_; } - - bool operator==(const OpIO &r_value) const { - return (this->name_ == r_value.GetName()) && (this->index_ == r_value.GetIndex()) && - (this->GetOwner() == r_value.GetOwner()); - } - - private: - string name_; - int index_; - std::shared_ptr owner_; -}; - -class TensorTypeImpl { - public: - TensorTypeImpl() = default; - ~TensorTypeImpl() = default; - - std::vector dt_vec_; -}; - -TensorType::TensorType(DataType dt) { - tensor_type_impl_ = ComGraphMakeShared(); - if (tensor_type_impl_ != nullptr) { - tensor_type_impl_->dt_vec_.push_back(dt); - } -} - -TensorType::TensorType(const std::initializer_list &types) { - tensor_type_impl_ = ComGraphMakeShared(); - if (tensor_type_impl_ != nullptr) { - tensor_type_impl_->dt_vec_ = types; - } -} - -class OperatorImpl : public std::enable_shared_from_this { - friend class GraphBuilderImpl; - friend class OpDescUtils; - - public: - explicit OperatorImpl(const string &name, const string &type) : op_desc_(ComGraphMakeShared(name, type)) { - if (op_desc_ == nullptr) { - GELOGW("OpDesc make shared failed"); - } - } - explicit OperatorImpl(const OpDescPtr &op_desc) : op_desc_(op_desc) {} - explicit OperatorImpl(ge::ConstNodePtr node) : node_(std::move(node)) { - if (node_ != nullptr && node_->GetOpDesc() != nullptr) { - op_desc_ = node_->GetOpDesc(); - } - } - ~OperatorImpl() {} - void SetInputImpl(const string &dst_name, const ge::Operator &src_oprt) { - GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty"); - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr."); - GE_CHK_BOOL_EXEC(src_oprt.operator_impl_ != nullptr, return, "operator_impl_ is nullptr."); - GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_ != nullptr, return, "op_desc_ is nullptr."); - - auto src_op_impl = src_oprt.GetOperatorImplPtr(); - GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return, "Src impl is null."); - GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return, "Src impl's opdesc is null."); - GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_->GetOutputsSize() == 1, return, - "The source operator[%s] must has one output", - src_oprt.operator_impl_->op_desc_->GetName().c_str()) - - uint32_t src_index = 0; - string src_name = src_op_impl->op_desc_->GetOutputNameByIndex(src_index); - GE_CHK_BOOL_EXEC(!src_name.empty(), return, "Src output's name is empty."); - - OpIO out_handler(src_name, src_index, src_op_impl); - input_link_.insert(std::make_pair(dst_name, out_handler)); - - int dst_index = op_desc_->GetInputIndexByName(dst_name); - GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), - op_desc_->GetName().c_str()); - - bool is_const = false; - if (src_oprt.GetOpType() == CONSTANT) { - is_const = true; - } - auto is_input_const = op_desc_->GetIsInputConst(); - for (int i = static_cast(is_input_const.size()); i <= dst_index; ++i) { - is_input_const.push_back(false); - } - - is_input_const[dst_index] = is_const; - op_desc_->SetIsInputConst(is_input_const); - - OpIO op_dst(dst_name, dst_index, shared_from_this()); - src_op_impl->UpdateLinkMapImpl(src_name, op_dst); - auto output_desc = src_op_impl->GetOutputDesc(src_name); - auto input_desc = op_desc_->GetInputDesc(dst_name); - if (input_desc.GetFormat() == FORMAT_RESERVED) { - output_desc.SetFormat(FORMAT_ND); - } else { - output_desc.SetFormat(input_desc.GetFormat()); - } - // Fix for linking opdesc - if (op_desc_->UpdateInputDesc(dst_name, output_desc) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Update inputdesc failed,dst name is %s, src name is %s", dst_name.c_str(), - src_name.c_str()); - return; - } - } - - void SetInputImpl(const string &dst_name, const ge::OutHandler &out_handler) { - GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty"); - GE_CHK_BOOL_EXEC(out_handler != nullptr, return, "SetInputImpl faild, out_handler is nullptr."); - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr."); - input_link_.insert(std::make_pair(dst_name, *out_handler)); - - string src_name = out_handler->GetName(); - int dst_index = op_desc_->GetInputIndexByName(dst_name); - GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), - op_desc_->GetName().c_str()); - auto out_op_impl = out_handler->GetOwner(); - GE_CHK_BOOL_EXEC(out_op_impl != nullptr && out_op_impl->GetOpDescImpl() != nullptr, return, - "out_handler invalid. name[%s]", dst_name.c_str()); - bool is_const = false; - if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) { - is_const = true; - } - auto is_input_const = op_desc_->GetIsInputConst(); - for (int i = static_cast(is_input_const.size()); i <= dst_index; ++i) { - is_input_const.push_back(false); - } - is_input_const[dst_index] = is_const; - op_desc_->SetIsInputConst(is_input_const); - - OpIO in_handler(dst_name, dst_index, shared_from_this()); - GE_CHK_BOOL_EXEC(out_op_impl != nullptr, return, "Get out_handler's impl failed."); - - out_op_impl->UpdateLinkMapImpl(src_name, in_handler); - auto src_output_desc = out_op_impl->GetOutputDesc(src_name); - auto dst_input_desc = op_desc_->GetInputDesc(dst_name); - if (dst_input_desc.GetFormat() == FORMAT_RESERVED) { - src_output_desc.SetFormat(FORMAT_ND); - } else { - src_output_desc.SetFormat(dst_input_desc.GetFormat()); - } - GE_CHK_BOOL_EXEC(op_desc_->UpdateInputDesc(dst_name, src_output_desc) == GRAPH_SUCCESS, return, - "Update input desc failed,dst name is %s,src name is %s", dst_name.c_str(), - src_name.c_str()); // fix for linking opdesc - } - - void AddControlInputImp(const ge::Operator &src_oprt) { - if (src_oprt.operator_impl_ == nullptr) { - GELOGE(FAILED, "Src operator impl is nullptr"); - return; - } - for (auto &input : control_input_link_) { - if (input.lock() == src_oprt.operator_impl_) { - return; - } - } - control_input_link_.push_back(src_oprt.operator_impl_); - src_oprt.operator_impl_->control_output_link_.push_back(shared_from_this()); - } - - graphStatus GetInputImpl(const string &dst_name, ge::OpIO &out_handler) { - auto out = input_link_.find(dst_name); - if (out == input_link_.end()) { - return GRAPH_FAILED; - } - out_handler = out->second; - return GRAPH_SUCCESS; - } - - bool InputIsSet(const string &name) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return false, "op_desc_ is nullptr."); - return op_desc_->InputIsSet(name); - } - - string GetName() const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return string(), "op_desc_ is nullptr."); - return op_desc_->GetName(); - } - - GeTensorDesc GetInputDesc(const string &name) const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); - return op_desc_->GetInputDesc(name); - } - - GeTensorDesc GetInputDesc(uint32_t index) const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); - return op_desc_->GetInputDesc(index); - } - - graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GRAPH_FAILED, "op_desc_ is nullptr."); - - return op_desc_->UpdateInputDesc(name, tensor_desc); - } - - OutHandler GetOutput(const string &name) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr."); - - int src_index = op_desc_->GetOutputIndexByName(name); - GE_CHK_BOOL_EXEC(src_index >= 0, return nullptr, "Find src index by name failed. name[%s]", name.c_str()); - shared_ptr output_ptr = ComGraphMakeShared(name, src_index, shared_from_this()); - if (output_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "OpIO make shared failed"); - return nullptr; - } - return output_ptr; - } - - OutHandler GetOutput(uint32_t index) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr."); - - string name = op_desc_->GetOutputNameByIndex(index); - if (name.empty()) { - GELOGE(GRAPH_FAILED, "Find src name by index failed. index[%u]", index); - return nullptr; - } - shared_ptr output_ptr = ComGraphMakeShared(name, index, shared_from_this()); - if (output_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "OpIO make shared failed"); - return nullptr; - } - return output_ptr; - } - - GeTensorDesc GetOutputDesc(const string &name) const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); - - return op_desc_->GetOutputDesc(name); - } - - GeTensorDesc GetOutputDesc(uint32_t index) const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); - - return op_desc_->GetOutputDesc(index); - } - - graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc) { - GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); - - auto res = op_desc_->UpdateOutputDesc(name, tensor_desc); - if (res == GRAPH_SUCCESS) { - for (auto ol : output_links_[name]) { - if (ol.GetOwner() == nullptr) { - GELOGW("%s get owner is nullptr", ol.GetName().c_str()); - continue; - } - GE_CHK_BOOL_RET_STATUS(ol.GetOwner()->UpdateInputDesc(ol.GetName(), tensor_desc) == GRAPH_SUCCESS, GRAPH_FAILED, - "Could not update next operator's input %s.", ol.GetName().c_str()); - } - } - return res; - } - - size_t GetInputsSize() const { - GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0); - return op_desc_->GetInputsSize(); - } - - size_t GetOutputsSize() const { - GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0); - return op_desc_->GetOutputsSize(); - } - - graphStatus SetAttr(const string &name, GeAttrValue &&attr_value) { - GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); - return op_desc_->SetAttr(name, std::move(attr_value)); - } - - graphStatus GetAttr(const string &name, GeAttrValue &attr_value) const { - GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); - return op_desc_->GetAttr(name, attr_value); - } - - OpDescPtr GetOpDescImpl() const { return op_desc_; } - - void UpdateLinkMapImpl(const string &src_name, OpIO &op_dst) { - auto it_find = output_links_.find(src_name); - if (it_find == output_links_.end()) { - std::vector dsts{op_dst}; - output_links_.insert(std::make_pair(src_name, dsts)); - } else { - it_find->second.push_back(op_dst); - } - } - - Operator ToOperator() { return Operator(shared_from_this()); } - - static OpDescPtr GetOpDesc(const Operator &oprt) { - GE_IF_BOOL_EXEC(oprt.operator_impl_ == nullptr, return nullptr); - return oprt.operator_impl_->op_desc_; - } - - void ClearOutputLinks() noexcept { output_links_.clear(); } - - void ClearInputLinks() noexcept { input_link_.clear(); } - - ge::ConstNodePtr GetNode() { return node_; } - - void SetInferenceContext(const InferenceContextPtr &inference_context) { inference_context_ = inference_context; } - - InferenceContextPtr GetInferenceContext() const { return inference_context_; } - - void SubgraphRegister(const std::string &ir_name, bool dynamic) { - op_desc_->RegisterSubgraphIrName(ir_name, dynamic ? kDynamic : kStatic); - } - - void SubgraphCountRegister(const std::string &ir_name, uint32_t count) { - if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kStatic) { - op_desc_->AddSubgraphName(ir_name); - subgraph_names_to_builders_[ir_name] = nullptr; - } else { - for (uint32_t i = 0; i < count; ++i) { - string key_name = ir_name + std::to_string(i); - op_desc_->AddSubgraphName(key_name); - subgraph_names_to_builders_[key_name] = nullptr; - } - } - } - - void SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { - string key_name = ir_name; - if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { - key_name += std::to_string(index); - } - - auto it = subgraph_names_to_builders_.find(key_name); - if (it == subgraph_names_to_builders_.end()) { - GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u.", ir_name.c_str(), index); - return; - } - it->second = builder; - } - - SubgraphBuilder GetSubgraphBuilder(const std::string &ir_name, uint32_t index) const { - string key_name = ir_name; - if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { - key_name += std::to_string(index); - } - - return GetSubgraphBuilder(key_name); - } - - SubgraphBuilder GetSubgraphBuilder(const std::string &name) const { - auto iter = subgraph_names_to_builders_.find(name); - if (iter == subgraph_names_to_builders_.end()) { - GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s", name.c_str()); - return nullptr; - } - - return iter->second; - } - - std::vector GetSubgraphNames() const { - std::vector names; - for (const auto &subgraph_name_to_type : op_desc_->GetSubgraphIrNames()) { - names.emplace_back(subgraph_name_to_type.first); - } - return names; - } - - size_t GetSubgraphNamesCount() const { return op_desc_->GetSubgraphIrNames().size(); } - - OpDescPtr op_desc_ = nullptr; - - private: - ge::ConstNodePtr node_{nullptr}; - ge::InferenceContextPtr inference_context_; - std::map> output_links_{}; - std::map input_link_{}; - std::vector> control_input_link_{}; - std::vector> control_output_link_{}; - std::map subgraph_names_to_builders_; -}; - -// Used to manage OperatorImpl instances created by ge api. -class OperatorKeeper { - private: - OperatorKeeper() = default; - ~OperatorKeeper() { - for (const auto &iter : operators_) { - if (iter) { - iter->ClearInputLinks(); - iter->ClearOutputLinks(); - } - } - } - std::set operators_; - std::mutex mutex_; - - public: - static OperatorKeeper &GetInstance() { - static OperatorKeeper instance; - return instance; - } - void CheckInOperator(const OperatorImplPtr &op_impl) { - if (op_impl) { - std::lock_guard lock(mutex_); - operators_.insert(op_impl); - } - } - void CheckOutOperator(const OperatorImplPtr &op_impl) { - if (op_impl) { - std::lock_guard lock(mutex_); - operators_.erase(op_impl); - } - } -}; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromNode(ge::ConstNodePtr node_ptr) { - ge::OperatorImplPtr operator_impl_ptr = ComGraphMakeShared(node_ptr); - if (operator_impl_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); - return Operator("default"); - } - return operator_impl_ptr->ToOperator(); -} - -Operator::Operator(const std::string &type) { - static uint32_t index = 0; - string name = type + "_" + std::to_string(index++); - operator_impl_ = ComGraphMakeShared(name, type); - if (operator_impl_ == nullptr) { - GELOGW("OperatorImpl make shared failed"); - } - OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromOpDesc(OpDescPtr op_desc) { - shared_ptr operator_impl_ptr; - operator_impl_ptr = ComGraphMakeShared(op_desc); - if (operator_impl_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); - return Operator("default"); - } - OperatorKeeper::GetInstance().CheckInOperator(operator_impl_ptr); - return operator_impl_ptr->ToOperator(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::GetOpDescFromOperator(const Operator &oprt) { - return OperatorImpl::GetOpDesc(oprt); -} - -GE_FUNC_HOST_VISIBILITY Operator::Operator(const string &name, const string &type) { - operator_impl_ = ComGraphMakeShared(name, type); - if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); - return; - } - OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); -} - -Operator::Operator(ge::OperatorImplPtr &&op_impl) { operator_impl_ = std::move(op_impl); } - -bool Operator::IsEmpty() const { - if (operator_impl_ == nullptr) { - return true; - } - return false; -} - -string Operator::GetName() const { - if (operator_impl_ != nullptr) { - return operator_impl_->GetName(); - } - return ""; -} - -GE_FUNC_HOST_VISIBILITY Operator &Operator::SetInput(const string &dst_name, const ge::Operator &src_oprt) { - // Describe the connection relationship between operators, no create action - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); - operator_impl_->SetInputImpl(dst_name, src_oprt); - return *this; -} - -Operator &Operator::SetInput(const string &dst_name, const ge::OutHandler &out_handler) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); - operator_impl_->SetInputImpl(dst_name, out_handler); - return *this; -} - -Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, const std::string &name) { - auto out_handler = src_oprt.GetOutput(name); - GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr."); - (void)SetInput(dst_name, out_handler); - return *this; -} - -Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, uint32_t index) { - auto out_handler = src_oprt.GetOutput(index); - GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr."); - (void)SetInput(dst_name, out_handler); - return *this; -} - -Operator &Operator::AddControlInput(const Operator &src_oprt) { - if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr."); - return *this; - } - operator_impl_->AddControlInputImp(src_oprt); - return *this; -} - -graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const { - GE_CHECK_NOTNULL(operator_impl_); - auto node_ptr = operator_impl_->GetNode(); - if (node_ptr != nullptr) { - // For inner compute graph - auto op_desc = node_ptr->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - auto index = op_desc->GetInputIndexByName(dst_name); - auto in_data_anchor = node_ptr->GetInDataAnchor(index); - GE_CHECK_NOTNULL(in_data_anchor); - auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(out_data_anchor); - auto peer_node = out_data_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(peer_node); - auto peer_op_desc = peer_node->GetOpDesc(); - GE_CHECK_NOTNULL(peer_op_desc); - auto peer_op_type = peer_op_desc->GetType(); - if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) { - auto const_op_impl = ComGraphMakeShared(peer_node); - GE_CHECK_NOTNULL(const_op_impl); - Operator const_op(std::move(const_op_impl)); - return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); - } else if (peer_op_type == DATA) { - auto parent_node = NodeUtils::GetParentInput(peer_node); - while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { - parent_node = NodeUtils::GetParentInput(parent_node); - } - if ((parent_node != nullptr) && - ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { - auto const_op_impl = ComGraphMakeShared(parent_node); - GE_CHECK_NOTNULL(const_op_impl); - Operator const_op(std::move(const_op_impl)); - return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); - } - } - // Try get from runtime inference context - auto session_id = std::to_string(GetContext().SessionId()); - RuntimeInferenceContext *runtime_infer_ctx = nullptr; - if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { - GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); - auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); - if (ret == GRAPH_SUCCESS) { - return GRAPH_SUCCESS; - } - } - } else { - // For outer graph - return GetInputConstDataOut(dst_name, data); - } - auto op_name = operator_impl_->GetName(); - GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str()); - return GRAPH_FAILED; -} -graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const { - ge::OpIO out_handle("", 0, nullptr); - GE_CHECK_NOTNULL(operator_impl_); - if (operator_impl_->GetInputImpl(dst_name, out_handle) != GRAPH_SUCCESS) { - GELOGE(FAILED, "%s get input impl failed", dst_name.c_str()); - return GRAPH_FAILED; - } - if (out_handle.GetOwner() != nullptr && out_handle.GetOwner()->GetOpDescImpl() != nullptr) { - Operator const_op(out_handle.GetOwner()); - const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType(); - if (op_desc_impl_type == CONSTANTOP) { - return const_op.GetAttr(op::Constant::name_attr_value(), data); - } else if (op_desc_impl_type == CONSTANT) { - return const_op.GetAttr(op::Const::name_attr_value(), data); - } - } - return GRAPH_FAILED; -} - -std::shared_ptr Operator::GetNode() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); - return operator_impl_->GetNode(); -} - -TensorDesc Operator::GetInputDesc(const std::string &name) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name)); -} - -void Operator::SetInferenceContext(const InferenceContextPtr &inference_context) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - operator_impl_->SetInferenceContext(inference_context); -} - -InferenceContextPtr Operator::GetInferenceContext() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); - return operator_impl_->GetInferenceContext(); -} -TensorDesc Operator::GetInputDesc(uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(index)); -} - -graphStatus Operator::TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - auto check = operator_impl_->InputIsSet(name); - if (check) tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name)); - return check ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus Operator::UpdateInputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - return operator_impl_->UpdateInputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -OutHandler Operator::GetOutput(const string &name) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); - return operator_impl_->GetOutput(name); -} - -OutHandler Operator::GetOutput(uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); - return operator_impl_->GetOutput(index); -} - -TensorDesc Operator::GetOutputDesc(const std::string &name) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name)); -} - -TensorDesc Operator::GetOutputDesc(uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(index)); -} - -graphStatus Operator::UpdateOutputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - return operator_impl_->UpdateOutputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -TensorDesc Operator::GetDynamicInputDesc(const string &name, uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name + std::to_string(index))); -} - -graphStatus Operator::UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - return operator_impl_->UpdateInputDesc(name + std::to_string(index), - TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -TensorDesc Operator::GetDynamicOutputDesc(const string &name, uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name + std::to_string(index))); -} - -graphStatus Operator::UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - return operator_impl_->UpdateOutputDesc(name + std::to_string(index), - TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -graphStatus Operator::InferShapeAndType() { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); - - return operator_impl_->GetOpDescImpl()->CallInferFunc(*this); -} - -graphStatus Operator::VerifyAllAttr(bool disable_common_verifier) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); - - if (!disable_common_verifier && (graphStatus)Operator::VerifyAll() == GRAPH_FAILED) { - return GRAPH_FAILED; - } else { - return (graphStatus)operator_impl_->GetOpDescImpl()->OpVerify(); - } -} - -GE_FUNC_HOST_VISIBILITY size_t Operator::GetInputsSize() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr"); - return operator_impl_->GetInputsSize(); -} - -GE_FUNC_HOST_VISIBILITY size_t Operator::GetOutputsSize() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr"); - return operator_impl_->GetOutputsSize(); -} - -// According to op get the attrs name and type -namespace { -const std::map kAttrTypesMap = { - {GeAttrValue::VT_NONE, "VT_STRING"}, - {GeAttrValue::VT_STRING, "VT_STRING"}, - {GeAttrValue::VT_FLOAT, "VT_FLOAT"}, - {GeAttrValue::VT_BOOL, "VT_BOOL"}, - {GeAttrValue::VT_INT, "VT_INT"}, - {GeAttrValue::VT_TENSOR_DESC, "VT_TENSOR_DESC"}, - {GeAttrValue::VT_TENSOR, "VT_TENSOR"}, - {GeAttrValue::VT_BYTES, "VT_BYTES"}, - {GeAttrValue::VT_GRAPH, "VT_GRAPH"}, - {GeAttrValue::VT_NAMED_ATTRS, "VT_NAMED_ATTRS"}, - {GeAttrValue::VT_LIST_BASE, "VT_LIST_BASE"}, - {GeAttrValue::VT_LIST_STRING, "VT_LIST_STRING"}, - {GeAttrValue::VT_LIST_FLOAT, "VT_LIST_FLOAT"}, - {GeAttrValue::VT_LIST_BOOL, "VT_LIST_BOOL"}, - {GeAttrValue::VT_LIST_INT, "VT_LIST_INT"}, - {GeAttrValue::VT_LIST_TENSOR_DESC, "VT_LIST_TENSOR_DESC"}, - {GeAttrValue::VT_LIST_TENSOR, "VT_LIST_TENSOR"}, - {GeAttrValue::VT_LIST_BYTES, "VT_LIST_BYTES"}, - {GeAttrValue::VT_GRAPH, "VT_GRAPH"}, - {GeAttrValue::VT_LIST_NAMED_ATTRS, "VT_LIST_NAMED_ATTRS"}, -}; -} // namespace -const std::map Operator::GetAllAttrNamesAndTypes() const { - std::map attr_types; - - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return attr_types, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return attr_types, "GetOpDescImpl is nullptr."); - std::map attr_map = operator_impl_->GetOpDescImpl()->GetAllAttrs(); - - map::iterator iter; - for (iter = attr_map.begin(); iter != attr_map.end(); ++iter) { - string name = iter->first; - GeAttrValue attr_value = iter->second; - - GeAttrValue::ValueType type = attr_value.GetValueType(); - - auto iter2 = kAttrTypesMap.find(type); - if (iter2 != kAttrTypesMap.end()) { - attr_types[name] = iter2->second; - } - } - - return attr_types; -} - -void Operator::InputRegister(const string &name) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - (void)operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); -} - -void Operator::OptionalInputRegister(const string &name) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name, - GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED)); -} - -void Operator::InferFuncRegister(const std::function &func) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddInferFunc(func); -} - -void Operator::InferFormatFuncRegister(const std::function &func) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func); -} - -void Operator::VerifierFuncRegister(const std::function &func) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func); -} - -void Operator::OutputRegister(const string &name) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc()); -} - -void Operator::DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return, - "set int failed"); - (void)operator_impl_->GetOpDescImpl()->AddDynamicInputDesc(name, num, is_push_back); -} - -void Operator::DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index) { - GE_CHK_BOOL_EXEC(!!operator_impl_, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(nullptr != operator_impl_->GetOpDescImpl(), return, "GetOpDescImpl is nullptr."); - operator_impl_->GetOpDescImpl()->AddDynamicInputDescByIndex(name, num, index); -} - -int Operator::GetDynamicInputNum(const string &name) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); - int num = 0; - GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return num, - "Get %s int failed", name.c_str()); - return num; -} - -void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return, - "Set %s int failed", name.c_str()); - (void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back); -} - -int Operator::GetDynamicOutputNum(const string &name) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); - int num = 0; - GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return num, - "Get %s int failed", name.c_str()); - return num; -} - -void Operator::RequiredAttrRegister(const string &name) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - operator_impl_->GetOpDescImpl()->AddRequiredAttr(name); -} - -graphStatus Operator::VerifyAll() { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); - - // Check all inputs defined - for (const string &iname : operator_impl_->GetOpDescImpl()->GetAllInputNames()) { - GE_CHK_BOOL_RET_STATUS(operator_impl_->GetOpDescImpl()->IsOptionalInput(iname) || operator_impl_->InputIsSet(iname), - GRAPH_FAILED, "operator input %s is not linked.", iname.c_str()); - vector ishape = operator_impl_->GetOpDescImpl()->GetInputDesc(iname).GetShape().GetDims(); - for (int64_t dim : ishape) { - GE_CHK_BOOL_RET_STATUS(dim > 0, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", - iname.c_str()); - } - } - // Check all attributes defined - const auto all_attributes = operator_impl_->GetOpDescImpl()->GetAllAttrs(); - for (const auto &name : operator_impl_->GetOpDescImpl()->GetAllAttrNames()) { - GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, - "operator attribute %s is empty.", name.c_str()); - } - - return GRAPH_SUCCESS; -} - -string Operator::GetOpType() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return "Data", "operator impl is nullptr."); - return OperatorImpl::GetOpDesc(*this)->GetType(); -} - -Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt) { - string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index); - return SetInput(dynamic_dst_name, src_oprt); -} - -Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt, - const std::string &name) { - string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index); - return SetInput(dynamic_dst_name, src_oprt, name); -} - -OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; } - -#define OP_ATTR_SET_IMP(ArgType, AttrUtilsFun) \ - Operator &Operator::SetAttr(const string &name, ArgType attr_value) { \ - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ - return *this; \ - } \ - if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ - GELOGW("set attr name %s failed.", name.c_str()); \ - } \ - return *this; \ - } // lint !e665 - -#define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \ - graphStatus Operator::GetAttr(const string &name, ArgType attr_value) const { \ - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ - return GRAPH_FAILED; \ - } \ - if (!AttrUtils::Get##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ - GELOGW("get attr name %s failed.", name.c_str()); \ - return GRAPH_FAILED; \ - } \ - return GRAPH_SUCCESS; \ - } // lint !e665 - -void Operator::BreakConnect() const { - if (operator_impl_ == nullptr) { - GELOGW("operator impl is nullptr."); - return; - } - operator_impl_->ClearInputLinks(); - operator_impl_->ClearOutputLinks(); - OperatorKeeper::GetInstance().CheckOutOperator(operator_impl_); -} - -#define OP_ATTR_REG_IMP(ArgType, AttrUtilsFun) \ - void Operator::AttrRegister(const string &name, ArgType attr_value) { \ - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ - return; \ - } \ - if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ - GELOGW("reg attr name %s failed.", name.c_str()); \ - } \ - } // lint !e665 - -OP_ATTR_SET_IMP(int64_t, Int) -OP_ATTR_SET_IMP(int32_t, Int) -OP_ATTR_SET_IMP(uint32_t, Int) -OP_ATTR_GET_IMP(int64_t &, Int) -OP_ATTR_GET_IMP(int32_t &, Int) -OP_ATTR_GET_IMP(uint32_t &, Int) -OP_ATTR_SET_IMP(const vector &, ListInt) -OP_ATTR_SET_IMP(const vector &, ListInt) -OP_ATTR_SET_IMP(const vector &, ListInt) -OP_ATTR_SET_IMP(std::initializer_list &&, ListInt) -OP_ATTR_GET_IMP(vector &, ListInt) -OP_ATTR_GET_IMP(vector &, ListInt) -OP_ATTR_GET_IMP(vector &, ListInt) -OP_ATTR_GET_IMP(vector> &, ListListInt) -OP_ATTR_SET_IMP(const vector> &, ListListInt) - -OP_ATTR_SET_IMP(float, Float) -OP_ATTR_GET_IMP(float &, Float) -OP_ATTR_SET_IMP(const vector &, ListFloat) -OP_ATTR_GET_IMP(vector &, ListFloat) // lint !e665 - -OP_ATTR_SET_IMP(bool, Bool) -OP_ATTR_GET_IMP(bool &, Bool) -OP_ATTR_SET_IMP(const vector &, ListBool) -OP_ATTR_GET_IMP(vector &, ListBool) // lint !e665 - -OP_ATTR_SET_IMP(const string &, Str) -OP_ATTR_GET_IMP(string &, Str) -OP_ATTR_SET_IMP(const vector &, ListStr) -OP_ATTR_GET_IMP(vector &, ListStr) // lint !e665 - -OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) -OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs) -OP_ATTR_SET_IMP(const vector &, ListNamedAttrs) -OP_ATTR_GET_IMP(vector &, ListNamedAttrs) // lint !e665 - -OP_ATTR_REG_IMP(int64_t, Int) -OP_ATTR_REG_IMP(const vector &, ListInt) -OP_ATTR_REG_IMP(float, Float) -OP_ATTR_REG_IMP(const vector &, ListFloat) -OP_ATTR_REG_IMP(const string &, Str) -OP_ATTR_REG_IMP(const vector &, ListStr) -OP_ATTR_REG_IMP(bool, Bool) -OP_ATTR_REG_IMP(const vector &, ListBool) -OP_ATTR_REG_IMP(const vector> &, ListListInt) -OP_ATTR_REG_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) -OP_ATTR_REG_IMP(const vector &, ListNamedAttrs) - -#undef OP_ATTR_SET_IMP -#undef OP_ATTR_GET_IMP -#undef OP_ATTR_REG_IMP - -Operator &Operator::SetAttr(const string &name, const Tensor &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - GeTensor tensor = TensorAdapter::AsGeTensor(attr_value); - if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { - GELOGW("set attr name %s failed.", name.c_str()); - } - return *this; -} - -Operator &Operator::SetAttr(const string &name, const vector &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - vector val_list; - for (const auto &item : attr_value) { - auto tensor = TensorAdapter::AsGeTensor(item); - val_list.push_back(tensor); - } - if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { - GELOGW("set attr name %s failed.", name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const string &name, Tensor &attr_value) const { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - ConstGeTensorPtr tensor; - if (!AttrUtils::GetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { - GELOGW("get attr name %s failed.", name.c_str()); - return GRAPH_FAILED; - } - attr_value = TensorAdapter::GeTensor2Tensor(tensor); - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetAttr(const string &name, vector &attr_value) const { - attr_value.clear(); - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - vector val_list; - if (!AttrUtils::GetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { - GELOGW("get attr name %s failed.", name.c_str()); - return GRAPH_FAILED; - } - for (auto &tensor : val_list) { - attr_value.push_back(TensorAdapter::GeTensor2Tensor(tensor)); - } - return GRAPH_SUCCESS; -} - -Operator &Operator::SetAttr(const string &name, const OpBytes &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, - Buffer::CopyFrom(attr_value.data(), attr_value.size()))) { - GELOGW("set attr name %s failed.", name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const string &name, OpBytes &attr_value) const { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - Buffer buffer; - if (!AttrUtils::GetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, buffer)) { - GELOGW("get attr name %s failed.", name.c_str()); - return GRAPH_FAILED; - } - attr_value.clear(); - if (buffer.data() == nullptr) { - GELOGE(GRAPH_FAILED, "buffer data is null."); - return GRAPH_FAILED; - } - attr_value.assign(buffer.data(), buffer.data() + buffer.size()); - return GRAPH_SUCCESS; -} - -Operator &Operator::SetAttr(const string &name, ge::AttrValue &&attrValue) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); - (void)operator_impl_->SetAttr(name, std::move(attrValue.impl->geAttrValue_)); - return *this; -} - -graphStatus Operator::GetAttr(const string &name, ge::AttrValue &attrValue) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - return operator_impl_->GetAttr(name, attrValue.impl->geAttrValue_); -} - -Operator &Operator::SetAttr(const string &name, const std::vector &attr_value) { - if (operator_impl_ == nullptr || !operator_impl_->GetOpDescImpl()) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("set attr name %s failed.", name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const string &name, std::vector &attr_value) const { - attr_value.clear(); - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - if (!AttrUtils::GetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("get attr name %s failed.", name.c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -Operator &Operator::SetAttr(const string &name, const ge::DataType &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("set attr name %s failed.", name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const string &name, ge::DataType &attr_value) const { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - if (!AttrUtils::GetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("get attr name %s failed.", name.c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -void Operator::AttrRegister(const string &name, const std::vector &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return; - } - if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("set attr name %s failed.", name.c_str()); - } -} - -void Operator::AttrRegister(const string &name, const ge::DataType &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return; - } - if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("set attr name %s failed.", name.c_str()); - } -} - -void Operator::AttrRegister(const string &name, const Tensor &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return; - } - auto tensor = TensorAdapter::AsGeTensor(attr_value); - if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { - GELOGW("reg attr name %s failed.", name.c_str()); - } -} - -void Operator::AttrRegister(const string &name, const vector &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return; - } - vector val_list; - for (const auto &item : attr_value) { - val_list.push_back(TensorAdapter::AsGeTensor(item)); - } - if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { - GELOGW("reg attr name %s failed.", name.c_str()); - } -} - -void Operator::AttrRegister(const string &name, const OpBytes &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return; - } - if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, - Buffer::CopyFrom(attr_value.data(), attr_value.size()))) { - GELOGW("reg attr name %s failed.", name.c_str()); - } -} - -void Operator::SubgraphRegister(const std::string &name, bool dynamic) { - if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return; - } - operator_impl_->SubgraphRegister(name, dynamic ? kDynamic : kStatic); -} - -void Operator::SubgraphCountRegister(const std::string &name, uint32_t count) { - if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return; - } - operator_impl_->SubgraphCountRegister(name, count); -} - -void Operator::SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { - if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", ir_name.c_str()); - return; - } - operator_impl_->SetSubgraphBuilder(ir_name, index, builder); -} - -std::vector Operator::GetSubgraphNames() const { return operator_impl_->GetSubgraphNames(); } - -SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &ir_name, uint32_t index) const { - if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr."); - return nullptr; - } - return operator_impl_->GetSubgraphBuilder(ir_name, index); -} - -SubgraphBuilder Operator::GetSubgraphBuilder(const string &ir_name) const { - return GetDynamicSubgraphBuilder(ir_name, 0); -} - -Graph Operator::GetSubgraph(const string &name) const { - if (operator_impl_ == nullptr) { - GE_LOGE("Failed to get subgraph %s, the operator impl is null", name.c_str()); - return Graph(""); - } - auto op_desc = OpDescUtils::GetOpDescFromOperator(*this); - if (op_desc == nullptr) { - GE_LOGE("Failed to get subgraph %s, the op_desc is null", name.c_str()); - return Graph(""); - } - const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); - auto iter = subgraph_names_to_index.find(name); - if (iter == subgraph_names_to_index.end()) { - GE_LOGE("Failed to get subgraph %s, the name may be invalid", name.c_str()); - return Graph(""); - } - auto subgraph_instance_name = op_desc->GetSubgraphInstanceName(iter->second); - if (subgraph_instance_name.empty()) { - GE_LOGE("Failed to get subgraph %s index %u, the subgraph may not be added", name.c_str(), iter->second); - return Graph(""); - } - - auto node = operator_impl_->GetNode(); - if (node == nullptr) { - GE_LOGE("Failed to get subgraph %s, the node is null", name.c_str()); - return Graph(""); - } - auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); - if (root_graph == nullptr) { - GE_LOGE("Failed to get subgraph %s, can not find the root graph", name.c_str()); - return Graph(""); - } - auto subgraph = root_graph->GetSubgraph(subgraph_instance_name); - if (subgraph == nullptr) { - GE_LOGE("Failed to get subgraph %s index %u, can not find the instance %s from the root graph", name.c_str(), - iter->second, subgraph_instance_name.c_str()); - return Graph(""); - } - return GraphUtils::CreateGraphFromComputeGraph(subgraph); -} - -Graph Operator::GetDynamicSubgraph(const string &name, uint32_t index) const { - return GetSubgraph(name + std::to_string(index)); -} - -size_t Operator::GetSubgraphNamesCount() const { - if (operator_impl_ == nullptr) { - GE_LOGE("Failed to get subgraph names count, the operator impl is null"); - return 0; - } - return operator_impl_->GetSubgraphNamesCount(); -} - -class GraphBuilderImpl { - public: - explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared(name)) { - if (graph_ == nullptr) { - GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); - return; - } - } - - ~GraphBuilderImpl() {} - - ComputeGraphPtr BuildGraph(const std::vector &inputs) { - std::vector vec_inputs; - for (auto &it : inputs) { - auto src_op_impl = it.operator_impl_; - GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return nullptr, "Operator Impl is null."); - GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return nullptr, "Operator impl's opdesc is null."); - - string type = src_op_impl->op_desc_->GetType(); - auto node_op = ge::OperatorFactory::CreateOperator("node_op", type); - auto tensor_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); - node_op.BreakConnect(); - - GE_CHK_BOOL_EXEC(tensor_desc != nullptr, continue, "tensor_desc is null."); - if ((tensor_desc->GetInputsSize() == 0 && tensor_desc->GetOutputsSize() > 0) || type == DATA || - type == VARIABLE || type == INITDATA || type == GETNEXT) { - vec_inputs.push_back(it.operator_impl_); - } else { - GELOGW("Input operator should be Data, Variable operator or operator that has output but no input."); - } - } - GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr, - "User Input do not include operator such as " - "Data, Variable operator or operator that has output but no input."); - auto ret = WalkAllOperators(vec_inputs); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed."); - - ret = AddEdge(); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "AddEdge failed."); - - return graph_; - } - - const std::map &GetAllNodesInfo() const { return all_nodes_info_; } - - private: - graphStatus WalkAllOperators(const std::vector &vec_ops) { - GE_CHK_BOOL_EXEC(graph_ != nullptr, return GRAPH_FAILED, "graph_ is null.") - std::queue> que; - que.push(vec_ops); - while (!que.empty()) { - auto vec_tem = que.front(); - que.pop(); - for (const auto &op_impl : vec_tem) { - GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.") - GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue, - "This node %s has created.", op_impl->GetName().c_str()) - auto node_ptr = graph_->AddNode(op_impl->op_desc_); - GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed."); - all_nodes_info_.insert(std::make_pair(op_impl, node_ptr)); - - auto &out_links = op_impl->output_links_; - std::vector vec_op_forward{}; - for (const auto &out_link : out_links) { - for (const auto &op_forward : out_link.second) { - vec_op_forward.push_back(op_forward.GetOwner()); - } - } - - auto &out_control_links = op_impl->control_output_link_; - for (const auto &out_link : out_control_links) { - vec_op_forward.push_back(out_link.lock()); - } - que.push(vec_op_forward); - - auto &in_links = op_impl->input_link_; - std::vector vec_op_back_forward{}; - for (const auto &in_link : in_links) { - vec_op_back_forward.push_back(in_link.second.GetOwner()); - } - - auto &in_control_links = op_impl->control_input_link_; - for (const auto &in_link : in_control_links) { - vec_op_back_forward.push_back(in_link.lock()); - } - que.push(vec_op_back_forward); - - if (WalkAllSubgraphs(node_ptr, op_impl) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - } - return MoveSubgraphToRoot(graph_); - } - - graphStatus WalkAllSubgraphs(const NodePtr &node, const OperatorImplPtr &op_impl) { - const string name = node->GetName(); - for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) { - const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first); - GE_CHK_BOOL_EXEC(builder != nullptr, return GRAPH_FAILED, "Node: %s, Get builder failed.", name.c_str()); - - Graph graph = builder(); // Build subgraph from user define builder. - const ComputeGraphPtr &subgraph = GraphUtils::GetComputeGraph(graph); - GE_CHK_BOOL_EXEC(subgraph != nullptr, return GRAPH_FAILED, "Node: %s, Build graph failed.", name.c_str()); - - subgraph->SetParentNode(node); - subgraph->SetParentGraph(graph_); - if (graph_->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - - if (op_impl->op_desc_->SetSubgraphInstanceName(name_idx.second, subgraph->GetName()) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to set subgraph %s index %u", subgraph->GetName().c_str(), name_idx.second); - return GRAPH_FAILED; - } - } - - return GRAPH_SUCCESS; - } - - graphStatus MoveSubgraphToRoot(const ComputeGraphPtr &graph) { - const ComputeGraphPtr &root_graph = GraphUtils::FindRootGraph(graph); - if (root_graph == nullptr) { - GELOGE(GRAPH_FAILED, "Graph: %s, Find root graph failed.", graph->GetName().c_str()); - return GRAPH_FAILED; - } - - if (root_graph == graph) { - auto subgraphs = graph->GetAllSubgraphs(); - for (auto &subgraph : subgraphs) { - if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - } else { - auto subgraphs = graph->GetAllSubgraphs(); - for (auto &subgraph : subgraphs) { - if (root_graph->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - graph->RemoveSubgraph(subgraph->GetName()); - if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - } - - return GRAPH_SUCCESS; - } - - graphStatus AddEdge() { - for (const auto &node_info : all_nodes_info_) { - auto src_op_impl_ptr = node_info.first; - auto src_node_ptr = node_info.second; - - GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue); - auto out_links = src_op_impl_ptr->output_links_; - GE_CHK_BOOL_EXEC(src_op_impl_ptr->op_desc_ != nullptr, return GRAPH_FAILED, - "Src operator impl's op_desc is null."); - auto &op_desc = src_op_impl_ptr->op_desc_; - GE_IF_BOOL_EXEC(op_desc == nullptr, continue); - for (const auto &out : out_links) { - auto src_idx = op_desc->GetOutputIndexByName(out.first); - GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed"); - - auto src_anchor = src_node_ptr->GetOutDataAnchor(src_idx); - GE_CHK_BOOL_EXEC(src_anchor != nullptr, return GRAPH_FAILED, "GetOutDataAnchor failed."); - - for (const auto &dst_opio : out.second) { - auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner()); - GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed."); - - GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); - - auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex()); - GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed."); - - auto ret = GraphUtils::AddEdge(src_anchor, dst_anchor); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED, - "from node[%s][%d] to node[%s][%d]AddEdge failed.", src_node_ptr->GetName().c_str(), - src_anchor->GetIdx(), dst_node_info->second->GetName().c_str(), dst_anchor->GetIdx()); - } - } - auto out_control_anchor = src_node_ptr->GetOutControlAnchor(); - for (const auto &control_out : src_op_impl_ptr->control_output_link_) { - auto dst_node_info = all_nodes_info_.find(control_out.lock()); - if (dst_node_info == all_nodes_info_.end()) { - GELOGE(GRAPH_FAILED, "Find Dst node failed."); - return GRAPH_FAILED; - } - GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); - auto in_control_anchor = dst_node_info->second->GetInControlAnchor(); - auto ret = GraphUtils::AddEdge(out_control_anchor, in_control_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(ret, "AddEdge failed. srcNode %s:%s, dstNode %s:%s", op_desc->GetName().c_str(), - op_desc->GetType().c_str(), dst_node_info->second->GetName().c_str(), - dst_node_info->second->GetType().c_str()); - return ret; - } - } - } - return GRAPH_SUCCESS; - } - - ComputeGraphPtr graph_ = nullptr; - std::map all_nodes_info_{}; -}; - -inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) { - for (const auto &graph : compute_graph->GetAllSubgraphs()) { - std::set node_names; - for (auto const &node : graph->GetDirectNode()) { - auto result = node_names.insert(node->GetName()); - if (!result.second) { - GELOGE(GRAPH_FAILED, "graph %s has same name node%s", graph->GetName().c_str(), node->GetName().c_str()); - return true; - } - } - } - - std::set node_names; - for (auto const &node : compute_graph->GetDirectNode()) { - auto result = node_names.insert(node->GetName()); - if (!result.second) { - GELOGE(GRAPH_FAILED, "graph %s has same name node%s", compute_graph->GetName().c_str(), node->GetName().c_str()); - return true; - } - } - return false; -} - -ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector &inputs) { - auto graph_builder_impl = GraphBuilderImpl(name); - ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs); - GE_CHK_BOOL_EXEC(compute_graph != nullptr, return compute_graph, "Computer graph is nullptr"); - compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo()); - if (HasSameNameNode(compute_graph)) { - GELOGW("Compute do not allow has same name nodes."); - compute_graph = nullptr; - } - - return compute_graph; -} - -void GraphUtils::BreakConnect(const std::map &all_nodes_infos) { - for (const auto &it : all_nodes_infos) { - OperatorImplPtr op_impl = it.first; - if (op_impl == nullptr) { - GELOGW("operator impl is nullptr."); - continue; - } - op_impl->ClearOutputLinks(); - op_impl->ClearInputLinks(); - OperatorKeeper::GetInstance().CheckOutOperator(op_impl); - } -} -} // namespace ge -/*lint +e446 +e732*/ -/*lint +e665*/ diff --git a/src/common/graph/operator_factory.cc b/src/common/graph/operator_factory.cc deleted file mode 100644 index 43d61a7c..00000000 --- a/src/common/graph/operator_factory.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/operator_factory_impl.h" -#include "debug/ge_log.h" - -namespace ge { -Operator OperatorFactory::CreateOperator(const std::string &operator_name, const std::string &operator_type) { - return OperatorFactoryImpl::CreateOperator(operator_name, operator_type); -} - -graphStatus OperatorFactory::GetOpsTypeList(std::vector &all_ops) { - return OperatorFactoryImpl::GetOpsTypeList(all_ops); -} - -bool OperatorFactory::IsExistOp(const string &operator_type) { return OperatorFactoryImpl::IsExistOp(operator_type); } - -OperatorCreatorRegister::OperatorCreatorRegister(const string &operator_type, OpCreator const &op_creator) { - (void)OperatorFactoryImpl::RegisterOperatorCreator(operator_type, op_creator); -} - -InferShapeFuncRegister::InferShapeFuncRegister(const std::string &operator_type, - const InferShapeFunc &infer_shape_func) { - (void)OperatorFactoryImpl::RegisterInferShapeFunc(operator_type, infer_shape_func); -} - -InferFormatFuncRegister::InferFormatFuncRegister(const std::string &operator_type, - const InferFormatFunc &infer_format_func) { - (void)OperatorFactoryImpl::RegisterInferFormatFunc(operator_type, infer_format_func); -} - -VerifyFuncRegister::VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func) { - (void)OperatorFactoryImpl::RegisterVerifyFunc(operator_type, verify_func); -} -} // namespace ge diff --git a/src/common/graph/operator_factory_impl.cc b/src/common/graph/operator_factory_impl.cc deleted file mode 100644 index 026a85bc..00000000 --- a/src/common/graph/operator_factory_impl.cc +++ /dev/null @@ -1,149 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/operator_factory_impl.h" -#include "debug/ge_log.h" -#include "framework/common/debug/ge_log.h" - -namespace ge { -shared_ptr> OperatorFactoryImpl::operator_creators_; -shared_ptr> OperatorFactoryImpl::operator_infershape_funcs_; -shared_ptr> OperatorFactoryImpl::operator_inferformat_funcs_; -shared_ptr> OperatorFactoryImpl::operator_verify_funcs_; - -Operator OperatorFactoryImpl::CreateOperator(const std::string &operator_name, const std::string &operator_type) { - if (operator_creators_ == nullptr) { - return Operator(); - } - auto it = operator_creators_->find(operator_type); - if (it == operator_creators_->end()) { - GELOGW("no OpProto of [%s] registered", operator_type.c_str()); - return Operator(); - } - return it->second(operator_name); -} - -graphStatus OperatorFactoryImpl::GetOpsTypeList(std::vector &all_ops) { - all_ops.clear(); - if (operator_creators_ != nullptr) { - for (auto it = operator_creators_->begin(); it != operator_creators_->end(); ++it) { - all_ops.emplace_back(it->first); - } - } else { - GELOGE(GRAPH_FAILED, "no operator creators found"); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -bool OperatorFactoryImpl::IsExistOp(const string &operator_type) { - if (operator_creators_ == nullptr) { - return false; - } - auto it = operator_creators_->find(operator_type); - if (it == operator_creators_->end()) { - return false; - } - return true; -} - -InferShapeFunc OperatorFactoryImpl::GetInferShapeFunc(const std::string &operator_type) { - if (operator_infershape_funcs_ == nullptr) { - return nullptr; - } - auto it = operator_infershape_funcs_->find(operator_type); - if (it == operator_infershape_funcs_->end()) { - return nullptr; - } - return it->second; -} - -InferFormatFunc OperatorFactoryImpl::GetInferFormatFunc(const std::string &operator_type) { - if (operator_inferformat_funcs_ == nullptr) { - GELOGI("operator_inferformat_funcs_ is null"); - return nullptr; - } - auto it = operator_inferformat_funcs_->find(operator_type); - if (it == operator_inferformat_funcs_->end()) { - return nullptr; - } - return it->second; -} - -VerifyFunc OperatorFactoryImpl::GetVerifyFunc(const std::string &operator_type) { - if (operator_verify_funcs_ == nullptr) { - return nullptr; - } - auto it = operator_verify_funcs_->find(operator_type); - if (it == operator_verify_funcs_->end()) { - return nullptr; - } - return it->second; -} - -graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) { - if (operator_creators_ == nullptr) { - operator_creators_.reset(new (std::nothrow) std::map()); - } - auto it = operator_creators_->find(operator_type); - if (it != operator_creators_->end()) { - return GRAPH_FAILED; - } - (void)operator_creators_->emplace(operator_type, op_creator); - return GRAPH_SUCCESS; -} - -graphStatus OperatorFactoryImpl::RegisterInferShapeFunc(const std::string &operator_type, - InferShapeFunc const infer_shape_func) { - if (operator_infershape_funcs_ == nullptr) { - GELOGI("operator_infershape_funcs_ init"); - operator_infershape_funcs_.reset(new (std::nothrow) std::map()); - } - auto it = operator_infershape_funcs_->find(operator_type); - if (it != operator_infershape_funcs_->end()) { - return GRAPH_FAILED; - } - (void)operator_infershape_funcs_->emplace(operator_type, infer_shape_func); - return GRAPH_SUCCESS; -} - -graphStatus OperatorFactoryImpl::RegisterInferFormatFunc(const std::string &operator_type, - InferFormatFunc const infer_format_func) { - if (operator_inferformat_funcs_ == nullptr) { - GELOGI("operator_inferformat_funcs_ init"); - operator_inferformat_funcs_.reset(new (std::nothrow) std::map()); - } - auto it = operator_inferformat_funcs_->find(operator_type); - if (it != operator_inferformat_funcs_->end()) { - return GRAPH_FAILED; - } - (void)operator_inferformat_funcs_->emplace(operator_type, infer_format_func); - return GRAPH_SUCCESS; -} - -graphStatus OperatorFactoryImpl::RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func) { - if (operator_verify_funcs_ == nullptr) { - GELOGI("operator_verify_funcs_ init"); - operator_verify_funcs_.reset(new (std::nothrow) std::map()); - } - auto it = operator_verify_funcs_->find(operator_type); - if (it != operator_verify_funcs_->end()) { - return GRAPH_FAILED; - } - (void)operator_verify_funcs_->emplace(operator_type, verify_func); - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/src/common/graph/opsproto/opsproto_manager.cc b/src/common/graph/opsproto/opsproto_manager.cc deleted file mode 100644 index d482715b..00000000 --- a/src/common/graph/opsproto/opsproto_manager.cc +++ /dev/null @@ -1,187 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/opsproto_manager.h" -#include -#include -#include -#include -#include -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/debug/ge_log.h" - -namespace ge { -OpsProtoManager *OpsProtoManager::Instance() { - static OpsProtoManager instance; - return &instance; -} - -bool OpsProtoManager::Initialize(const std::map &options) { - std::lock_guard lock(mutex_); - - if (is_init_) { - GELOGI("OpsProtoManager is already initialized."); - return true; - } - - /*lint -e1561*/ - auto proto_iter = options.find("ge.opsProtoLibPath"); - /*lint +e1561*/ - if (proto_iter == options.end()) { - GELOGW("ge.opsProtoLibPath option not set, return."); - return false; - } - - pluginPath_ = proto_iter->second; - LoadOpsProtoPluginSo(pluginPath_); - - is_init_ = true; - - return true; -} - -void OpsProtoManager::Finalize() { - std::lock_guard lock(mutex_); - - if (!is_init_) { - GELOGI("OpsProtoManager is not initialized."); - return; - } - - for (auto handle : handles_) { - if (handle != nullptr) { - if (dlclose(handle) != 0) { - GELOGW("failed to close handle, message: %s", dlerror()); - continue; - } - GELOGI("close opsprotomanager handler success"); - } else { - GELOGW("close opsprotomanager handler failure, handler is nullptr"); - } - } - - is_init_ = false; -} - -static std::vector Split(const std::string &str, char delim) { - std::vector elems; - if (str.empty()) { - elems.emplace_back(""); - return elems; - } - - std::stringstream ss(str); - std::string item; - - while (getline(ss, item, delim)) { - elems.push_back(item); - } - - auto str_size = str.size(); - if (str_size > 0 && str[str_size - 1] == delim) { - elems.emplace_back(""); - } - - return elems; -} - -static void FindParserSo(const std::string &path, std::vector &file_list) { - // Lib plugin path not exist - if (path.empty()) { - GELOGI("realPath is empty"); - return; - } - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path.size() >= PATH_MAX, return, "path is invalid"); - - char resolved_path[PATH_MAX] = {0}; - - // Nullptr is returned when the path does not exist or there is no permission - // Return absolute path when path is accessible - if (realpath(path.c_str(), resolved_path) == nullptr) { - GELOGW("the path [%s] not exsit.", path.c_str()); - return; - } - - struct dirent *dent = nullptr; - DIR *dir = opendir(resolved_path); - // Lib plugin path not exist - if (dir == nullptr) { - GELOGW("Open directory %s failed,maybe it is not exit or not a dir", resolved_path); - return; - } - - while ((dent = readdir(dir)) != nullptr) { - if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) { - continue; - } - std::string name = dent->d_name; - std::string full_name = path + "/" + name; - const std::string so_suff = ".so"; - - if (dent->d_type != DT_DIR && name.size() >= so_suff.size() && - name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) { - file_list.push_back(full_name); - GELOGI("OpsProtoManager Parse full name = %s \n", full_name.c_str()); - } - } - if (closedir(dir) != 0) { - GELOGW("close dir fail."); - } -} - -static void GetPluginSoFileList(const std::string &path, std::vector &file_list) { - // Support multi lib directory with ":" as delimiter - std::vector v_path = Split(path, ':'); - - for (size_t i = 0; i < v_path.size(); ++i) { - FindParserSo(v_path[i], file_list); - GELOGI("OpsProtoManager full name = %s", v_path[i].c_str()); - } -} - -void OpsProtoManager::LoadOpsProtoPluginSo(std::string &path) { - if (path.empty()) { - GELOGE(GRAPH_FAILED, "filePath is invalid. please check your text file %s.", path.c_str()); - return; - } - std::vector file_list; - - // If there is .so file in the lib path - GetPluginSoFileList(path, file_list); - - // Not found any .so file in the lib path - if (file_list.empty()) { - GELOGE(GRAPH_FAILED, "OpsProtoManager can not find any plugin file in pluginPath: %s \n", path.c_str()); - return; - } - // Warning message - GELOGW("The shared library will not be checked. Please ensure that the source of the shared library is trusted."); - - // Load .so file - for (auto elem : file_list) { - void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL); - if (handle == nullptr) { - GELOGW("OpsProtoManager dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); - continue; - } else { - // Close dl when the program exist, not close here - GELOGI("OpsProtoManager plugin load %s success.", elem.c_str()); - handles_.push_back(handle); - } - } -} -} // namespace ge diff --git a/src/common/graph/option/ge_context.cc b/src/common/graph/option/ge_context.cc deleted file mode 100644 index 421e0aff..00000000 --- a/src/common/graph/option/ge_context.cc +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "./ge_context.h" -#include "./ge_global_options.h" -#include "./ge_local_context.h" -#include "framework/common/ge_types.h" -#include "framework/common/debug/ge_log.h" - -namespace ge { -namespace { -const int64_t kMinTrainingTraceJobId = 256; -const int kDecimal = 10; -const char *kHostExecPlacement = "HOST"; -} // namespace -GEContext &GetContext() { - static GEContext ge_context{}; - return ge_context; -} - -graphStatus GEContext::GetOption(const std::string &key, std::string &option) { - return GetThreadLocalContext().GetOption(key, option); -} - -bool GEContext::GetHostExecFlag() { - std::string exec_placement; - if (GetThreadLocalContext().GetOption(GE_OPTION_EXEC_PLACEMENT, exec_placement) != GRAPH_SUCCESS) { - GELOGW("get option OPTION_EXEC_PLACEMENT failed."); - return false; - } - GELOGD("Option ge.exec.placement is %s.", exec_placement.c_str()); - return exec_placement == kHostExecPlacement; -} - -std::map &GetMutableGlobalOptions() { - static std::map global_options{}; - return global_options; -} - -void GEContext::Init() { - string session_id; - (void)GetOption("ge.exec.sessionId", session_id); - try { - session_id_ = static_cast(std::stoi(session_id.c_str())); - } catch (std::invalid_argument &) { - GELOGW("%s transform to int failed.", session_id.c_str()); - } catch (std::out_of_range &) { - GELOGW("%s transform to int failed.", session_id.c_str()); - } - - string device_id; - (void)GetOption("ge.exec.deviceId", device_id); - try { - device_id_ = static_cast(std::stoi(device_id.c_str())); - } catch (std::invalid_argument &) { - GELOGW("%s transform to int failed.", device_id.c_str()); - } catch (std::out_of_range &) { - GELOGW("%s transform to int failed.", device_id.c_str()); - } - - string job_id; - (void)GetOption("ge.exec.jobId", job_id); - std::string s_job_id = ""; - for (auto c : job_id) { - if (c >= '0' && c <= '9') { - s_job_id += c; - } - } - if (s_job_id == "") { - trace_id_ = kMinTrainingTraceJobId; - return; - } - int64_t d_job_id = std::strtoll(s_job_id.c_str(), nullptr, kDecimal); - if (d_job_id < kMinTrainingTraceJobId) { - trace_id_ = d_job_id + kMinTrainingTraceJobId; - } else { - trace_id_ = d_job_id; - } -} - -uint64_t GEContext::SessionId() { return session_id_; } - -uint32_t GEContext::DeviceId() { return device_id_; } - -uint64_t GEContext::TraceId() { return trace_id_; } - -void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; } - -void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } - -} // namespace ge diff --git a/src/common/graph/option/ge_local_context.cc b/src/common/graph/option/ge_local_context.cc deleted file mode 100644 index 82b1cb01..00000000 --- a/src/common/graph/option/ge_local_context.cc +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "./ge_local_context.h" -#include - -namespace ge { -namespace { -thread_local GEThreadLocalContext thread_context; -} - -GEThreadLocalContext &GetThreadLocalContext() { return thread_context; } - -graphStatus GEThreadLocalContext::GetOption(const string &key, string &option) { - auto graph_iter = graph_options_.find(key); - if (graph_iter != graph_options_.end()) { - option = graph_iter->second; - return GRAPH_SUCCESS; - } - auto session_iter = session_options_.find(key); - if (session_iter != session_options_.end()) { - option = session_iter->second; - return GRAPH_SUCCESS; - } - auto global_iter = global_options_.find(key); - if (global_iter != global_options_.end()) { - option = global_iter->second; - return GRAPH_SUCCESS; - } - return GRAPH_PARAM_INVALID; -} - -void GEThreadLocalContext::SetGlobalOption(map options_map) { - global_options_.clear(); - global_options_ = std::move(options_map); -} - -void GEThreadLocalContext::SetSessionOption(map options_map) { - session_options_.clear(); - session_options_ = std::move(options_map); -} - -void GEThreadLocalContext::SetGraphOption(map options_map) { - graph_options_.clear(); - graph_options_ = std::move(options_map); -} -} // namespace ge diff --git a/src/common/graph/ref_relation.cc b/src/common/graph/ref_relation.cc deleted file mode 100644 index 48e136fb..00000000 --- a/src/common/graph/ref_relation.cc +++ /dev/null @@ -1,455 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/ref_relation.h" - -#include -#include - -#include "utils/mem_utils.h" -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "debug/ge_attr_define.h" -#include "graph/ge_error_codes.h" -#include "graph/utils/graph_utils.h" -#include "framework/common/debug/ge_log.h" - -using namespace std; -using namespace ge; -namespace ge { -namespace { -const char *kRefIndex = "_parent_node_index"; -const string kWhile = "While"; -const string kIf = "If"; -const string kCase = "Case"; - -const uint16_t kMaxElementNum = 100; - -std::unordered_set function_op = {kWhile, kIf, kCase}; -} // namespace - -/* Impl */ -class RefRelations::Impl { - public: - graphStatus LookUpRefRelations(const RefCell &key, unordered_set &result) { - unsigned long number = static_cast(reinterpret_cast(key.node.get())); - std::string lookup_key = - key.node_name + std::to_string(key.in_out) + std::to_string(key.in_out_idx) + std::to_string(number); - auto iter = look_up_table_.find(lookup_key); - if (iter != look_up_table_.end()) { - for (auto &c : iter->second) { - result.insert(c); - } - return GRAPH_SUCCESS; - } - GELOGW("can not find any relations! key value of dest relation is %s", lookup_key.c_str()); - return GRAPH_SUCCESS; - }; - graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); - graphStatus Clear() { - GELOGD("Start clear boundary reflections between main graph and sub graph!"); - look_up_table_.clear(); - values_.clear(); - return GRAPH_SUCCESS; - }; - - private: - graphStatus BuildLookUpTables(); - graphStatus BuildRefRelationsForBranch(const NodePtr &root_node, const vector> &classed_data_nodes, - const vector>> &classed_netoutput_nodes, - vector> &node_refs); - graphStatus BuildRefRelationsForWhile(const NodePtr &root_node, const vector> &classed_data_nodes, - const vector>> &classed_netoutput_nodes, - vector> &node_refs); - graphStatus BuildRelationsWithFuncNodeType(const NodePtr &root_node, - const vector> &classed_data_nodes, - const vector>> &classed_netoutput_nodes, - vector> &node_refs); - void GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector &data_nodes, - vector &netoutput_nodes, const std::vector &sub_graph_names, - const std::string &node_type); - - graphStatus GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph); - graphStatus ProcessSubgraphDataNodes(vector &data_nodes, vector> &classed_data_nodes); - graphStatus ProcessSubgraphNetoutput(const vector &netoutput_nodes, - vector>> &classed_netoutput_nodes); - - std::unordered_map> look_up_table_; - std::vector>> values_; -}; - -// Node Level -graphStatus RefRelations::Impl::BuildRefRelationsForBranch( - const NodePtr &root_node, const vector> &classed_data_nodes, - const vector>> &classed_netoutput_nodes, vector> &node_refs) { - GELOGD("Enter BuildRefRelationsForBranch!"); - - size_t ref_i = 0; - for (const auto &ref_i_data_nodes : classed_data_nodes) { - vector in_ref_i_all_refs; - RefCell cell_root; - cell_root.node_name = root_node->GetName(); - cell_root.node = root_node; - cell_root.in_out = NODE_IN; - cell_root.in_out_idx = ref_i; - in_ref_i_all_refs.emplace_back(cell_root); - for (const auto &data : ref_i_data_nodes) { - RefCell cell_in; - RefCell cell_out; - cell_in.node_name = data->GetName(); - cell_in.node = data; - cell_in.in_out = NODE_IN; - cell_in.in_out_idx = 0; - cell_out.node_name = data->GetName(); - cell_out.node = data; - cell_out.in_out = NODE_OUT; - cell_out.in_out_idx = 0; - in_ref_i_all_refs.emplace_back(cell_in); - in_ref_i_all_refs.emplace_back(cell_out); - } - node_refs.emplace_back(in_ref_i_all_refs); - ref_i++; - } - - size_t ref_o = 0; - for (const auto &ref_o_net_nodes : classed_netoutput_nodes) { - vector out_ref_i_all_refs; - RefCell cell_root; - cell_root.node_name = root_node->GetName(); - cell_root.node = root_node; - cell_root.in_out = NODE_OUT; - cell_root.in_out_idx = ref_o; - out_ref_i_all_refs.emplace_back(cell_root); - for (const auto &ele : ref_o_net_nodes) { - RefCell cell_netoutput_in; - cell_netoutput_in.node_name = (ele.first)->GetName(); - cell_netoutput_in.node = ele.first; - cell_netoutput_in.in_out = NODE_IN; - cell_netoutput_in.in_out_idx = ele.second; - out_ref_i_all_refs.emplace_back(cell_netoutput_in); - } - node_refs.emplace_back(out_ref_i_all_refs); - ref_o++; - } - return GRAPH_SUCCESS; -} - -graphStatus RefRelations::Impl::BuildLookUpTables() { - GELOGD("start to build look up table!"); - for (size_t i = 0; i < values_.size(); i++) { - vector> &val = values_[i]; - for (const auto &ele : val) { - for (const auto &ref_cell : ele) { - string key = ref_cell.node_name + std::to_string(ref_cell.in_out) + std::to_string(ref_cell.in_out_idx) + - std::to_string(static_cast(reinterpret_cast(ref_cell.node.get()))); - look_up_table_[key] = ele; - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus RefRelations::Impl::BuildRefRelationsForWhile( - const NodePtr &root_node, const vector> &classed_data_nodes, - const vector>> &classed_netoutput_nodes, vector> &node_refs) { - GELOGD("Enter BuildRefRelations for while op!"); - // data_nodes has been sorted - // for while, input num must be same as output num - auto input_num = root_node->GetAllInDataAnchorsSize(); - NodePtr netoutput = nullptr; - - size_t ref_i = 0; - while (ref_i < input_num) { - auto &ref_i_data_nodes = classed_data_nodes[ref_i]; - auto &ref_i_net_nodes = classed_netoutput_nodes[ref_i]; - - vector ref_i_all_refs; - RefCell cell_root_i; - RefCell cell_root_o; - cell_root_i.node_name = root_node->GetName(); - cell_root_i.node = root_node; - cell_root_i.in_out = NODE_IN; - cell_root_i.in_out_idx = ref_i; - ref_i_all_refs.emplace_back(cell_root_i); - cell_root_o.node_name = root_node->GetName(); - cell_root_o.node = root_node; - cell_root_o.in_out = NODE_OUT; - cell_root_o.in_out_idx = ref_i; - ref_i_all_refs.emplace_back(cell_root_o); - for (const auto &data : ref_i_data_nodes) { - RefCell cell_in; - RefCell cell_out; - cell_in.node_name = data->GetName(); - cell_in.node = data; - cell_in.in_out = NODE_IN; - cell_in.in_out_idx = 0; - cell_out.node_name = data->GetName(); - cell_out.node = data; - cell_out.in_out = NODE_OUT; - cell_out.in_out_idx = 0; - ref_i_all_refs.emplace_back(cell_in); - ref_i_all_refs.emplace_back(cell_out); - } - - for (const auto &ele : ref_i_net_nodes) { - RefCell cell_netoutput_in; - RefCell cell_netoutput_out; - cell_netoutput_in.node_name = (ele.first)->GetName(); - cell_netoutput_in.node = ele.first; - cell_netoutput_in.in_out = NODE_IN; - cell_netoutput_in.in_out_idx = ele.second; - ref_i_all_refs.emplace_back(cell_netoutput_in); - netoutput = ele.first; - } - node_refs.emplace_back(ref_i_all_refs); - ref_i++; - } - /* There exist scene like the follows, it means data0 data1 netoutput 0'th - * and 1'th tensor should be the same addr. - * Data0 Data1 - * \/ - * /\ - * netoutput - */ - if (netoutput == nullptr) { - return GRAPH_SUCCESS; - } - for (const auto &in_anchor : netoutput->GetAllInDataAnchors()) { - auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_data_anchor == nullptr) { - continue; - } - auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); - if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { - GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (netoutput->GetName()).c_str()); - continue; - } - if (peer_out_data_node->GetType() != DATA) { - continue; - } - auto in_data_anchor_idx = in_anchor->GetIdx(); - auto net_in_desc = netoutput->GetOpDesc()->MutableInputDesc(static_cast(in_data_anchor_idx)); - int ref_d = 0; - int ref_n = 0; - (void)AttrUtils::GetInt(peer_out_data_node->GetOpDesc(), kRefIndex, ref_d); - (void)AttrUtils::GetInt(net_in_desc, kRefIndex, ref_n); - - node_refs[ref_d].insert(node_refs[ref_d].end(), node_refs[ref_n].begin(), node_refs[ref_n].end()); - node_refs[ref_n].insert(node_refs[ref_n].end(), node_refs[ref_d].begin(), node_refs[ref_d].end()); - } - - return GRAPH_SUCCESS; -} -// build ref relations according to diff func op type -graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType( - const NodePtr &root_node, const vector> &classed_data_nodes, - const vector>> &classed_netoutput_nodes, vector> &node_refs) { - // data_nodes has been sorted - auto node_type = root_node->GetType(); - - auto status = GRAPH_SUCCESS; - if (node_type != kWhile) { - status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); - } else { - status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); - } - return status; -} - -void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector &data_nodes, - vector &netoutput_nodes, - const std::vector &sub_graph_names, - const std::string &node_type) { - int sub_graph_idx = 0; - for (const auto &name : sub_graph_names) { - auto sub_graph = root_graph.GetSubgraph(name); - if (sub_graph == nullptr) { - GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str()); - continue; - } - for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { - auto sub_graph_node_type = sub_graph_node->GetType(); - - if (sub_graph_node_type == DATA) { - data_nodes.emplace_back(sub_graph_node); - } else if (sub_graph_node_type == NETOUTPUT) { - // if while, the first subgraph must be cond subgraph. - // There is no meaning for refs ,so continue - if (node_type == kWhile && sub_graph_idx == 0) { - continue; - } - netoutput_nodes.emplace_back(sub_graph_node); - } - continue; - } - sub_graph_idx++; - } -} - -graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) { - auto parent_graph_ptr = graph.GetParentGraph(); - if (parent_graph_ptr == nullptr) { - root_graph = graph; - return GRAPH_SUCCESS; - } - auto root_graph_ptr = GraphUtils::FindRootGraph(parent_graph_ptr); - if (root_graph_ptr == nullptr) { - GE_LOGE("Get null root graph"); - return GRAPH_PARAM_INVALID; - } - root_graph = *root_graph_ptr; - return GRAPH_SUCCESS; -} - -graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector &data_nodes, - vector> &classed_data_nodes) { - GELOGD("start to process subgraph data nodes!"); - int max_ref_idx = 0; - for (const auto &e : data_nodes) { - int i; - bool is_exist = true; - is_exist = AttrUtils::GetInt(e->GetOpDesc(), kRefIndex, i); - if (!is_exist) { - GELOGE(GRAPH_FAILED, "Invalid SubGraph NetOutput node[%s].no attr %s", e->GetName().c_str(), kRefIndex); - return GRAPH_FAILED; - } - max_ref_idx = (i > max_ref_idx) ? i : max_ref_idx; - } - - while (!data_nodes.empty()) { - auto data = data_nodes.back(); - data_nodes.pop_back(); - int ref_idx = 0; - (void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx); - if (ref_idx >= static_cast(classed_data_nodes.size())) { - return GRAPH_FAILED; - } - classed_data_nodes[ref_idx].emplace_back(data); - } - return GRAPH_SUCCESS; -} - -graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( - const vector &netoutput_nodes, vector>> &classed_netoutput_nodes) { - GELOGD("[RefRelations]Start to process subgraph netoutput!"); - for (const auto &sub_netoutput_node : netoutput_nodes) { - auto op_desc = sub_netoutput_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchors()) { - auto in_desc = op_desc->MutableInputDesc(in_data_anchor->GetIdx()); - if (in_desc == nullptr) { - GELOGE(GRAPH_FAILED, "Invalid NetOutput node [%s] idx [%lu], no tensor on it", - sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx()); - return GRAPH_FAILED; - } - int ref_o; - if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) { - if (ref_o >= static_cast(classed_netoutput_nodes.size())) { - return GRAPH_FAILED; - } - classed_netoutput_nodes[ref_o].emplace_back( - std::pair({sub_netoutput_node, static_cast(in_data_anchor->GetIdx())})); - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { - GELOGD("Start to build ref relations!"); - /* First Step: Get root graph */ - ge::ComputeGraph &root_graph = graph; - auto status = GetRootGraph(graph, root_graph); - if (status != GRAPH_SUCCESS) { - return status; - } - - for (const auto &node : graph.GetAllNodes()) { - auto node_type = node->GetType(); - std::vector ref_nodes; - auto op_desc = node->GetOpDesc(); - auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); - if (sub_graph_names.empty()) { - continue; - } - vector data_nodes; - vector netoutput_nodes; - // Get data and netoutput of sub_graph - GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type); - size_t max_elem_num = (data_nodes.size() > kMaxElementNum) ? data_nodes.size() : kMaxElementNum; - vector> classed_data_nodes(max_elem_num); // according to ref_idx - vector>> classed_netoutput_nodes(max_elem_num); // according to ref_idx - status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "classfy data nodes failed!"); - return status; - } - - // for netoutput - // check netoutput - // here main graph output number must be the same as every sub_graph netoutput node - // key: netoutput node_ptr , - status = ProcessSubgraphNetoutput(netoutput_nodes, classed_netoutput_nodes); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "process netoutput failed!"); - return status; - } - - vector> node_refs; - status = BuildRelationsWithFuncNodeType(node, classed_data_nodes, classed_netoutput_nodes, node_refs); - if (status != GRAPH_SUCCESS) { - GELOGE(status, "BuildRelationsWithFuncNodeType Failed! Node is [%s]!", node->GetName().c_str()); - return status; - } - if (!node_refs.empty()) { - values_.push_back(node_refs); - } - } - /* Seconde Step: generate map */ - status = BuildLookUpTables(); - if (status != GRAPH_SUCCESS) { - GELOGE(status, "Build look up tables failed!"); - return status; - } - return GRAPH_SUCCESS; -} - -/* Ref Relations Interface */ -RefRelations::RefRelations() { - impl_ = MakeShared(); - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "MakeShared failed!"); - return; - } -} - -graphStatus RefRelations::LookUpRefRelations(const RefCell &key, unordered_set &result) { - GE_CHECK_NOTNULL(impl_); - return impl_->LookUpRefRelations(key, result); -} - -graphStatus RefRelations::BuildRefRelations(ge::ComputeGraph &root_graph) { - GE_CHECK_NOTNULL(impl_); - return impl_->BuildRefRelations(root_graph); -} - -graphStatus RefRelations::Clear() { - GE_CHECK_NOTNULL(impl_); - return impl_->Clear(); -} -} // namespace ge \ No newline at end of file diff --git a/src/common/graph/runtime_inference_context.cc b/src/common/graph/runtime_inference_context.cc deleted file mode 100644 index 361d893c..00000000 --- a/src/common/graph/runtime_inference_context.cc +++ /dev/null @@ -1,129 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/runtime_inference_context.h" -#include "graph/utils/tensor_adapter.h" -#include -#include "framework/common/debug/ge_log.h" - -namespace ge { -std::map> RuntimeInferenceContext::contexts_; -std::mutex RuntimeInferenceContext::ctx_mu_; - -graphStatus RuntimeInferenceContext::CreateContext(const std::string &context_id) { - GELOGI("To create context. session id = %s", context_id.c_str()); - auto ctx = std::unique_ptr(new (std::nothrow) RuntimeInferenceContext()); - if (ctx == nullptr) { - GELOGE(GRAPH_FAILED, "Failed to create instance of RuntimeInferenceContext. context_id = %s", context_id.c_str()); - return GRAPH_FAILED; - } - - std::lock_guard lk(ctx_mu_); - auto emplace_ret = contexts_.emplace(context_id, std::move(ctx)); - if (!emplace_ret.second) { - GELOGE(GRAPH_FAILED, "Old context not destroyed"); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -void RuntimeInferenceContext::DestroyContext(const std::string &context_id) { - GELOGI("To destroy context. session id = %s", context_id.c_str()); - std::lock_guard lk(ctx_mu_); - contexts_.erase(context_id); -} - -graphStatus RuntimeInferenceContext::GetContext(const std::string &context_id, RuntimeInferenceContext **ctx) { - std::lock_guard lk(ctx_mu_); - auto it = contexts_.find(context_id); - if (it != contexts_.end()) { - *ctx = it->second.get(); - return GRAPH_SUCCESS; - } - - GELOGD("Runtime inference context not created. session id = %s", context_id.c_str()); - return GRAPH_FAILED; -} - -graphStatus RuntimeInferenceContext::SetTensor(int64_t node_id, int output_id, Tensor &&tensor) { - std::lock_guard lk(mu_); - auto &output_tensors = tensors_[node_id]; - if (static_cast(output_id) >= output_tensors.size()) { - output_tensors.resize(output_id + 1); - } - - GELOGD("Set tensor for node_id = %ld, output_id = %d", node_id, output_id); - output_tensors[output_id] = std::move(tensor); - - auto &output_ge_tensors = ge_tensors_[node_id]; - if (static_cast(output_id) >= output_ge_tensors.size()) { - output_ge_tensors.resize(output_id + 1); - } - - GELOGD("Set ge tensor for node_id = %ld, output_id = %d", node_id, output_id); - output_ge_tensors[output_id] = TensorAdapter::AsGeTensorPtr(tensor); - return GRAPH_SUCCESS; -} - -graphStatus RuntimeInferenceContext::GetTensor(int64_t node_id, int output_id, Tensor &tensor) { - if (output_id < 0) { - GELOGE(GRAPH_PARAM_INVALID, "Invalid output index: %d", output_id); - return GRAPH_PARAM_INVALID; - } - - std::lock_guard lk(mu_); - auto iter = tensors_.find(node_id); - if (iter == tensors_.end()) { - GELOGE(INTERNAL_ERROR, "Node not register. Id = %ld", node_id); - return INTERNAL_ERROR; - } - - auto &output_tensors = iter->second; - if (static_cast(output_id) >= output_tensors.size()) { - GELOGE(GRAPH_FAILED, "Node output is not registered. node_id = %ld, output index = %d", node_id, output_id); - return GRAPH_FAILED; - } - - GELOGD("Get tensor for node_id = %ld, output_id = %d", node_id, output_id); - tensor = output_tensors[output_id]; - return GRAPH_SUCCESS; -} - -graphStatus RuntimeInferenceContext::GetTensor(int64_t node_id, int output_id, GeTensorPtr &tensor) { - if (output_id < 0) { - GELOGE(GRAPH_PARAM_INVALID, "Invalid output index: %d", output_id); - return GRAPH_PARAM_INVALID; - } - - std::lock_guard lk(mu_); - auto iter = ge_tensors_.find(node_id); - if (iter == ge_tensors_.end()) { - GELOGE(INTERNAL_ERROR, "Node not register. Id = %ld", node_id); - return INTERNAL_ERROR; - } - - auto &output_tensors = iter->second; - if (static_cast(output_id) >= output_tensors.size()) { - GELOGE(GRAPH_FAILED, "Node output is not registered. node_id = %ld, output index = %d", node_id, output_id); - return GRAPH_FAILED; - } - - GELOGD("Get ge tensor for node_id = %ld, output_id = %d", node_id, output_id); - tensor = output_tensors[output_id]; - return GRAPH_SUCCESS; -} -} // namespace ge \ No newline at end of file diff --git a/src/common/graph/shape_refiner.cc b/src/common/graph/shape_refiner.cc deleted file mode 100644 index 17423da4..00000000 --- a/src/common/graph/shape_refiner.cc +++ /dev/null @@ -1,688 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/shape_refiner.h" - -#include -#include -#include -#include -#include -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/graph_utils.h" - -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "external/graph/operator.h" -#include "external/graph/operator_factory.h" -#include "framework/common/debug/ge_log.h" -#include "graph/compute_graph.h" -#include "utils/node_utils.h" -#include "utils/op_desc_utils.h" -#include "utils/tensor_utils.h" -#include "utils/type_utils.h" - -namespace ge { -namespace { -const uint32_t kWhileBodySubGraphIdx = 1; - -graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) { - GELOGD("Enter reverse brush while body subgraph process!"); - - auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx); - if (sub_graph_body == nullptr) { - GELOGE(GRAPH_FAILED, "Get while body graph failed!"); - return GRAPH_FAILED; - } - - for (const auto &node_sub : sub_graph_body->GetAllNodes()) { - for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) { - auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i); - GE_IF_BOOL_EXEC(input_desc == nullptr, - GELOGW("Get null input by index %zu from node %s ", i, node_sub->GetName().c_str()); - continue); - (void)input_desc->SetUnknownDimNumShape(); - } - for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) { - auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i); - (void)output_desc->SetUnknownDimNumShape(); - } - } - - return GRAPH_SUCCESS; -} - -graphStatus UpdataOutputForMultiBatcch(const ConstNodePtr &node, - std::vector> &ref_out_tensors) { - // check sub_graph shape. Get max for update. - for (size_t i = 0; i < ref_out_tensors.size(); ++i) { - if (ref_out_tensors[i].empty()) { - continue; - } - - int64_t max_size = 0; - size_t max_shape_index = 0; - auto &ref_out_tensor = ref_out_tensors[i].at(0); - const auto &ref_out_tensor_shape = ref_out_tensor.MutableShape(); - for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) { - auto &tensor = ref_out_tensors[i].at(j); - if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { - GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); - return GRAPH_FAILED; - } - - auto shape = tensor.MutableShape(); - if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { - GELOGE(GRAPH_FAILED, "node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", - node->GetName().c_str(), i, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); - return GRAPH_FAILED; - } - - int64_t size = 1; - for (auto dim : shape.GetDims()) { - if (INT64_MAX / dim < size) { - GELOGE(PARAM_INVALID, "The shape size overflow"); - return PARAM_INVALID; - } - size *= dim; - } - - if (size > max_size) { - max_size = size; - max_shape_index = j; - } - } - - (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index)); - } - - return GRAPH_SUCCESS; -} - -graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node, - std::vector> &ref_out_tensors) { - GELOGD("Enter update parent node shape for class branch op process"); - if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { - return UpdataOutputForMultiBatcch(node, ref_out_tensors); - } - - // check sub_graph shape.If not same ,do unknown shape process - for (size_t i = 0; i < ref_out_tensors.size(); i++) { - if (ref_out_tensors[i].empty()) { - continue; - } - auto ref_out_tensor = ref_out_tensors[i].at(0); - ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); - for (auto &tensor : ref_out_tensors[i]) { - if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { - GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); - return GRAPH_FAILED; - } - auto shape = tensor.MutableShape(); - if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { - GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, - shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); - ref_out_tensor_shape = GeShape(UNKNOWN_RANK); - break; - } - for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { - if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { - continue; - } - GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, - j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); - (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); - } - } - (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); - } - return GRAPH_SUCCESS; -} - -graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector> &ref_data_tensors, - std::vector> &ref_out_tensors) { - GELOGD("Enter update parent node shape for class while op process"); - if (ref_data_tensors.size() != ref_out_tensors.size()) { - GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(), - ref_data_tensors.size(), ref_out_tensors.size()); - return GRAPH_FAILED; - } - for (size_t i = 0; i < ref_data_tensors.size(); i++) { - if (ref_out_tensors[i].size() != 1) { - GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!"); - return GRAPH_FAILED; - } - } - bool is_need_reverse_brush = false; - // check input and output - for (size_t i = 0; i < ref_out_tensors.size(); i++) { - if (ref_out_tensors[i].empty()) { - continue; - } - auto ref_out_tensor = ref_out_tensors[i].at(0); - auto tmp_shape = ref_out_tensor.MutableShape(); - // ref_i's data and output tensor shape should be same - for (auto &tensor : ref_data_tensors[i]) { - if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { - GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str()); - return GRAPH_FAILED; - } - auto shape = tensor.MutableShape(); - if (shape.GetDims() != tmp_shape.GetDims()) { - ref_out_tensor.SetUnknownDimNumShape(); - is_need_reverse_brush = true; - break; - } - } - (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); - } - // reverse refresh while body shape - if (is_need_reverse_brush) { - return ReverseBrushWhileBodySubGraph(node); - } - return GRAPH_SUCCESS; -} - -graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { - auto op_desc = node->GetOpDesc(); - auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); - if (sub_graph_names.empty()) { - return GRAPH_SUCCESS; - } - - auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); - for (const auto &name : sub_graph_names) { - if (name.empty()) { - GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); - continue; - } - auto sub_graph = root_graph->GetSubgraph(name); - if (sub_graph == nullptr) { - GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); - return GRAPH_FAILED; - } - for (const auto &node_sub : sub_graph->GetDirectNode()) { - if (node_sub->GetType() != DATA) { - continue; - } - int ref_i; - auto data_opdesc = node_sub->GetOpDesc(); - if (data_opdesc == nullptr) { - GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), - node->GetName().c_str()); - return GRAPH_FAILED; - } - if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { - GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), - node->GetName().c_str()); - return GRAPH_FAILED; - } - if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { - continue; - } - auto input_desc = op_desc->MutableInputDesc(ref_i); - if (input_desc == nullptr) { - GE_LOGE( - "The ref index(%d) on the data %s on the sub graph %s " - "parent node %s are incompatible, inputs num %u", - ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); - return GRAPH_FAILED; - } - GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), - node->GetName().c_str()); - auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); - - if (ret != GRAPH_SUCCESS) { - GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s", - node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); - return ret; - } - ret = data_opdesc->UpdateOutputDesc(0, *input_desc); - if (ret != GRAPH_SUCCESS) { - GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s", - node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); - return ret; - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr &sub_graph, NodePtr &netoutput, - const ConstNodePtr &node, - std::vector> &ref_data_tensors) { - auto sub_nodes = sub_graph->GetDirectNode(); - for (size_t i = sub_nodes.size(); i > 0; --i) { - auto sub_node = sub_nodes.at(i - 1); - if (sub_node->GetType() == NETOUTPUT) { - netoutput = sub_node; - } - if (sub_node->GetType() == DATA) { - if (sub_node->GetOpDesc() == nullptr) { - return GRAPH_FAILED; - } - - int ref_i; - if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { - GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); - return GRAPH_FAILED; - } - if (ref_i < 0 || static_cast(ref_i) >= node->GetAllInDataAnchorsSize()) { - GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(), - ref_i, node->GetAllInDataAnchorsSize()); - return GRAPH_FAILED; - } - ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0)); - } - } - return GRAPH_SUCCESS; -} - -graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { - auto op_desc = node->GetOpDesc(); - auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); - if (sub_graph_names.empty()) { - return GRAPH_SUCCESS; - } - - std::vector> ref_data_tensors(node->GetAllInDataAnchorsSize()); - std::vector> ref_out_tensors(node->GetAllOutDataAnchorsSize()); - auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); - - for (const auto &name : sub_graph_names) { - if (name.empty()) { - GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); - continue; - } - auto sub_graph = root_graph->GetSubgraph(name); - if (sub_graph == nullptr) { - GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); - return GRAPH_FAILED; - } - NodePtr netoutput = nullptr; - auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); - if (ret != GRAPH_SUCCESS) { - return ret; - } - if (netoutput == nullptr) { - GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); - return GRAPH_FAILED; - } - auto netoutput_opdesc = netoutput->GetOpDesc(); - if (netoutput_opdesc == nullptr) { - GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(), - node->GetName().c_str()); - return GRAPH_FAILED; - } - for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) { - auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx()); - if (edge_desc == nullptr) { - GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", name.c_str(), - node->GetName().c_str(), edge_anchor->GetIdx()); - return GRAPH_FAILED; - } - GELOGI("Netoutput in anchor index is %zu, input tensor dim is %zu", edge_anchor->GetIdx(), - edge_desc->GetShape().GetDimNum()); - int ref_i; - if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { - // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. - continue; - } - GELOGI("Parent node index of edge desc is %d", ref_i); - if (ref_i < 0 || static_cast(ref_i) >= node->GetAllOutDataAnchorsSize()) { - return GRAPH_FAILED; - } - ref_out_tensors[ref_i].emplace_back(*edge_desc); - } - } - - if (node->GetType() == WHILE) { - return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors); - } - return UpdateParentNodeForBranch(node, ref_out_tensors); -} - -string Serial(const vector &dims) { - string serial_string; - serial_string += "["; - for (int64_t dim : dims) { - serial_string += std::to_string(dim) + " "; - } - serial_string += "]"; - return serial_string; -} - -graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) { - GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); - GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); - for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) { - auto in_idx = in_anchor->GetIdx(); - auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_data_anchor == nullptr) { - continue; - } - auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); - if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { - continue; - } - int peer_out_idx = peer_out_data_anchor->GetIdx(); - auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast(peer_out_idx)); - - // check shape and dtype continuity. do not stop process - auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast(in_idx)); - if (in_desc == nullptr) { - continue; - } - auto in_shape = in_desc->GetShape().GetDims(); - auto in_dtype = in_desc->GetDataType(); - auto peer_out_shape = peer_out_desc->GetShape().GetDims(); - auto peer_out_dtype = peer_out_desc->GetDataType(); - if (peer_out_dtype != in_dtype) { - GELOGW( - "current node [%s] [%d]\'th out_dtype is [%s].peer output node [%s] [%d]\'th " - "output_dtype is [%s].The two dtype should be same! Please check graph and fix it", - node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(), - peer_out_data_node->GetName().c_str(), peer_out_idx, TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str()); - } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) { - string in_shape_str = Serial(in_shape); - string peer_out_shape_str = Serial(peer_out_shape); - GELOGW( - "current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th " - "input_shape is [%s].The two shape should be same! Please check graph and fix it", - node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), peer_out_data_node->GetName().c_str(), peer_out_idx, - peer_out_shape_str.c_str()); - } - // refresh current node input desc - in_desc->SetOriginShape(peer_out_desc->GetOriginShape()); - in_desc->SetShape(peer_out_desc->GetShape()); - in_desc->SetDataType(peer_out_desc->GetDataType()); - in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType()); - std::vector> shape_range; - (void)peer_out_desc->GetShapeRange(shape_range); - in_desc->SetShapeRange(shape_range); - ge::TensorUtils::SetRealDimCnt(*in_desc, static_cast(peer_out_desc->GetShape().GetDims().size())); - } - return GRAPH_SUCCESS; -} -} // namespace -void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { - if (!IsLogEnable(GE, DLOG_DEBUG)) { - return; - } - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "node is null"); - return; - } - ge::OpDescPtr op_desc = node->GetOpDesc(); - GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); - std::string str; - if (op_desc->GetInputsSize() != 0) { - std::string input_desc_str = "input shape: "; - for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { - input_desc_str += "["; - for (int64_t dim : input_desc->GetShape().GetDims()) { - input_desc_str += std::to_string(dim) + " "; - } - input_desc_str += "]"; - input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) + ":" + - TypeUtils::FormatToSerialString(input_desc->GetFormat()) + " "; - } - str += input_desc_str; - - input_desc_str = "input origin shape: "; - for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { - input_desc_str += "["; - for (int64_t dim : input_desc->GetOriginShape().GetDims()) { - input_desc_str += std::to_string(dim) + " "; - } - input_desc_str += "]"; - input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) + ":" + - TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) + " "; - } - str += input_desc_str; - } - - if (op_desc->GetAllOutputsDescSize() != 0) { - std::string output_desc_str = "output shape: "; - for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { - if (output_desc == nullptr) { - continue; - } - output_desc_str += "["; - for (int64_t dim : output_desc->GetShape().GetDims()) { - output_desc_str += std::to_string(dim) + " "; - } - output_desc_str += "]"; - output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) + ":" + - TypeUtils::FormatToSerialString(output_desc->GetFormat()) + " "; - } - str += output_desc_str; - - output_desc_str = "output origin shape: "; - for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { - if (output_desc == nullptr) { - continue; - } - output_desc_str += "["; - for (int64_t dim : output_desc->GetOriginShape().GetDims()) { - output_desc_str += std::to_string(dim) + " "; - } - output_desc_str += "]"; - output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) + ":" + - TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) + " "; - } - str += output_desc_str; - } - GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), str.c_str()); -} - -graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) { - return InferShapeAndType(node, op, true); -} -graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) { - auto op_desc = node->GetOpDesc(); - const auto &op_type = op_desc->GetType(); - - graphStatus ret; - if (before_subgraph) { - ret = UpdateSubGraphDataNodes(node); - if (ret != GRAPH_SUCCESS) { - return ret; - } - } - // Get infer func and execute - ret = op_desc->CallInferFunc(op); - if (ret == GRAPH_PARAM_INVALID) { - // Op ir no infer func, try to get infer func from operator factory - auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); - if (node_op.IsEmpty()) { - GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); - return ret; - } - - GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); - auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); - node_op.BreakConnect(); - if (temp_op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "temp op desc is null"); - return GRAPH_FAILED; - } - if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { - GELOGW("InferShapeAndType UpdateInputName failed"); - for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { - if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { - break; - } - return GRAPH_SUCCESS; - } - } - if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { - GELOGW("InferShapeAndType UpdateOutputName failed"); - } - op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); - ret = op_desc->CallInferFunc(op); - GELOGI("op CallInferFunc second. ret: %u", ret); - } - if (ret != GRAPH_SUCCESS) { - return ret; - } - - if (!before_subgraph) { - return UpdateParentNodeOutTensor(node); - } - return GRAPH_SUCCESS; -} - -InferenceContextPtr CreateInferenceContext(const std::unordered_map &context_map, - const NodePtr &node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "node is null"); - return nullptr; - } - InferenceContextPtr inference_context = std::shared_ptr(InferenceContext::Create()); - if (inference_context == nullptr) { - GELOGE(GRAPH_FAILED, "Failed to alloc InferenceContext"); - return nullptr; - } - - auto all_in_data_anchors = node->GetAllInDataAnchors(); - std::vector> input_shapes_and_types(all_in_data_anchors.size()); - std::vector marks; - - bool has_input_shapes_and_types = false; - for (const auto &in_anchor : all_in_data_anchors) { - const auto &out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - continue; - } - - auto input_node = out_anchor->GetOwnerNode(); - if (input_node == nullptr) { - continue; - } - - auto iter = context_map.find(input_node); - if (iter != context_map.end()) { - const auto &src_context = iter->second; - GE_IF_BOOL_EXEC(src_context == nullptr, GELOGE(GRAPH_FAILED, "src_context is null."); return nullptr); - GELOGD("node:%s get %ld marks from node:%s", node->GetName().c_str(), src_context->GetMarks().size(), - input_node->GetName().c_str()); - for (auto mark : src_context->GetMarks()) { - marks.push_back(mark); - } - auto output_idx = out_anchor->GetIdx(); - auto input_idx = in_anchor->GetIdx(); - auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes(); - if (output_idx < static_cast(output_shape_and_type.size())) { - GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx, - node->GetName().c_str(), input_idx); - input_shapes_and_types[input_idx] = output_shape_and_type[output_idx]; - has_input_shapes_and_types = true; - } else { - GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx, - output_shape_and_type.size()); - } - } - } - - if (has_input_shapes_and_types) { - inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types)); - } - inference_context->SetMarks(marks); - - return inference_context; -} - -namespace { -thread_local std::unordered_map context_map; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ShapeRefiner::ClearContextMap() { context_map.clear(); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) { - return InferShapeAndType(node, true); -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node, - bool before_subgraph) { - GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); - bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); - auto opdesc = node->GetOpDesc(); - GE_IF_BOOL_EXEC(opdesc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); - // some op can not infershape twice such as aipp - bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified"); - if (need_update_input) { - auto status = UpdateOpInputDesc(node); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "update op input_desc failed!"); - return status; - } - } - - if (node->Verify() != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str()); - return GRAPH_FAILED; - } - PrintInOutTensorShape(node, "before_infershape"); - Operator op = OpDescUtils::CreateOperatorFromNode(node); - - if (!is_unknown_graph) { - auto inference_context = CreateInferenceContext(context_map, node); - if (inference_context == nullptr) { - GELOGE(GRAPH_FAILED, "inference context is null"); - return GRAPH_FAILED; - } - GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); - op.SetInferenceContext(inference_context); - } - - graphStatus status = InferShapeAndType(node, op, before_subgraph); - if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { - if (is_unknown_graph) { - PrintInOutTensorShape(node, "after_infershape when running"); - return GRAPH_SUCCESS; - } - auto op_desc = node->GetOpDesc(); - for (const auto &out_anchor : node->GetAllOutDataAnchors()) { - auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); - ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetShape().GetDims().size())); - output_tensor->SetOriginShape(output_tensor->GetShape()); - output_tensor->SetOriginDataType(output_tensor->GetDataType()); - - GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", - node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), - TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), - TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); - } - } else { - GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str()); - return GRAPH_FAILED; - } - if (!is_unknown_graph) { - auto ctx_after_infer = op.GetInferenceContext(); - if (ctx_after_infer != nullptr) { - GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); - if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { - GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), - ctx_after_infer->GetMarks().size()); - (void)context_map.emplace(node, ctx_after_infer); - } - } - } - PrintInOutTensorShape(node, "after_infershape"); - - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/src/common/graph/stub/Makefile b/src/common/graph/stub/Makefile deleted file mode 100644 index f339fa33..00000000 --- a/src/common/graph/stub/Makefile +++ /dev/null @@ -1,6 +0,0 @@ -inc_path := $(shell pwd)/metadef/inc/external/ -out_path := $(shell pwd)/out/graph/lib64/stub/ -stub_path := $(shell pwd)/metadef/graph/stub/ - -mkdir_stub := $(shell mkdir -p $(out_path)) -graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) diff --git a/src/common/graph/stub/gen_stubapi.py b/src/common/graph/stub/gen_stubapi.py deleted file mode 100644 index 7263ff17..00000000 --- a/src/common/graph/stub/gen_stubapi.py +++ /dev/null @@ -1,578 +0,0 @@ -import os -import re -import sys -import logging - -logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s', - level=logging.INFO) - -""" - this attr is used for symbol table visible -""" -GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY' - -""" - generate stub func body by return type -""" -RETURN_STATEMENTS = { - 'graphStatus': ' std::cout << "[ERROR]: stub library libgraph or libge_compiler cannot be used for execution, please check your "\n ' - ' << "environment variables and compilation options to make sure you use the correct library."\n' - ' << std::endl;\n' - ' return ACL_ERROR_COMPILING_STUB_MODE;', - 'Status': ' return SUCCESS;', - 'Graph': ' return Graph();', - 'Graph&': ' return *this;', - 'Format': ' return Format();', - 'Format&': ' return *this;', - 'Shape': ' return Shape();', - 'Shape&': ' return *this;', - 'TensorDesc': ' return TensorDesc();', - 'TensorDesc&': ' return *this;', - 'Tensor': ' return Tensor();', - 'Tensor&': ' return *this;', - 'Operator': ' return Operator();', - 'Operator&': ' return *this;', - 'Ptr': ' return nullptr;', - 'std::string': ' return "";', - 'std::string&': ' return "";', - 'string': ' return "";', - 'int': ' return 0;', - 'DataType': ' return DT_FLOAT;', - 'InferenceContextPtr': ' return nullptr;', - 'SubgraphBuilder': ' return nullptr;', - 'OperatorImplPtr': ' return nullptr;', - 'OutHandler': ' return nullptr;', - 'std::vector': ' return {};', - 'std::vector': ' return {};', - 'std::map': ' return {};', - 'uint32_t': ' return 0;', - 'int64_t': ' return 0;', - 'uint64_t': ' return 0;', - 'size_t': ' return 0;', - 'float': ' return 0.0f;', - 'bool': ' return false;', -} - -""" - max code len per line in hua_wei software programming specifications -""" -max_code_len_per_line = 100 - -""" - white_list_for_debug, include_dir_key_words is to - determines which header files to generate cc files from - when DEBUG on -""" -white_list_for_debug = ["attr_value.h", "operator.h", "tensor.h", "graph.h", "operator_factory.h", "inference_context.h", - "ge_ir_build.h", "ge_api.h", "ascend_string.h", "gnode.h"] -include_dir_key_words = ["ge", "graph"] -DEBUG = True - - -def need_generate_func(func_line): - """ - :param func_line: - :return: - """ - if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \ - or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"): - return False - return True - - -def file_endswith_white_list_suffix(file): - """ - :param file: - :return: - """ - if DEBUG: - for suffix in white_list_for_debug: - if file.endswith(suffix): - return True - return False - else: - return True - - -""" - belows are patterns used for analyse .h file -""" -# pattern function -pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after -([a-zA-Z~_] # void int likely -.* -[)] #we find ) -(?!.*{) # we do not want the case int abc() const -.*) -(;.*) #we want to find ; and after for we will replace these later -\n$ -""", re.VERBOSE | re.MULTILINE | re.DOTALL) - -# pattern comment -pattern_comment = re.compile(r'^\s*//') -pattern_comment_2_start = re.compile(r'^\s*/[*]') -pattern_comment_2_end = re.compile(r'[*]/\s*$') -# pattern define -pattern_define = re.compile(r'^\s*#define') -pattern_define_return = re.compile(r'\\\s*$') -# blank line -pattern_blank_line = re.compile(r'^\s*$') -# virtual,explicit,friend,static -pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)') -# lead space -pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]') -# functions will have patterns such as func ( or func( -# but operator is an exception; the class name is preceded by an operator, and the above mode does not exist -# format like :"operator = ()" -pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]') -# template -pattern_template = re.compile(r'^\s*template') -pattern_template_end = re.compile(r'>\s*$') -# namespace -pattern_namespace = re.compile(r'namespace.*{') -# class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with -pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+ 0 and not friend_match: - line, func_name = self.handle_class_member_func(line, template_string) - # Normal functions - else: - line, func_name = self.handle_normal_func(line, template_string) - - need_generate = need_generate_func(line) - # func body - line += self.implement_function(line) - # comment - line = self.gen_comment(start_i) + line - # write to out file - self.write_func_content(line, func_name, need_generate) - # next loop - self.line_index += 1 - - logging.info('Added %s functions', len(self.func_list_exist)) - logging.info('Successfully converted,please see ' + self.output_file) - - def handle_func1(self, line): - """ - :param line: - :return: - """ - find1 = re.search('[(]', line) - if not find1: - self.line_index += 1 - return "continue", line, None - find2 = re.search('[)]', line) - start_i = self.line_index - space_match = pattern_leading_space.search(line) - # deal with - # int abc(int a, - # int b) - if find1 and (not find2): - self.line_index += 1 - line2 = self.input_content[self.line_index] - if space_match: - line2 = re.sub('^' + space_match.group(1), '', line2) - line += line2 - while self.line_index < len(self.input_content) and (not re.search('[)]', line2)): - self.line_index += 1 - line2 = self.input_content[self.line_index] - line2 = re.sub('^' + space_match.group(1), '', line2) - line += line2 - - match_start = pattern_start.search(self.input_content[self.line_index]) - match_end = pattern_end.search(self.input_content[self.line_index]) - if match_start: # like ) { or ) {} int the last line - if not match_end: - self.stack.append('normal_now') - ii = start_i - while ii <= self.line_index: - ii += 1 - self.line_index += 1 - return "continue", line, start_i - logging.info("line[%s]", line) - # ' int abc();'->'int abc()' - (line, match) = pattern_func.subn(r'\2\n', line) - logging.info("line[%s]", line) - # deal with case: - # 'int \n abc(int a, int b)' - if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]): - line = self.input_content[start_i - 1] + line - line = line.lstrip() - if not match: - self.line_index += 1 - return "continue", line, start_i - return "pass", line, start_i - - def handle_stack(self, match_start): - """ - :param match_start: - :return: - """ - line = self.input_content[self.line_index] - match_end = pattern_end.search(line) - if match_start: - self.stack.append('normal_now') - if match_end: - top_status = self.stack.pop() - if top_status == 'namespace_now': - self.output_fd.write(line + '\n') - elif top_status == 'class_now': - self.stack_class.pop() - self.stack_template.pop() - if match_start or match_end: - self.line_index += 1 - return "continue" - - if len(self.stack) > 0 and self.stack[-1] == 'normal_now': - self.line_index += 1 - return "continue" - return "pass" - - def handle_class(self, template_string, line, match_start, match_class): - """ - :param template_string: - :param line: - :param match_start: - :param match_class: - :return: - """ - if match_class: # we face a class - self.stack_template.append(template_string) - self.stack.append('class_now') - class_name = match_class.group(3) - - # class template specializations: class A > - if '<' in class_name: - k = line.index('<') - fit = 1 - for ii in range(k + 1, len(line)): - if line[ii] == '<': - fit += 1 - if line[ii] == '>': - fit -= 1 - if fit == 0: - break - class_name += line[k + 1:ii + 1] - logging.info('class_name[%s]', class_name) - self.stack_class.append(class_name) - while not match_start: - self.line_index += 1 - line = self.input_content[self.line_index] - match_start = pattern_start.search(line) - self.line_index += 1 - return "continue" - return "pass" - - def handle_template(self): - line = self.input_content[self.line_index] - match_template = pattern_template.search(line) - template_string = '' - if match_template: - match_template_end = pattern_template_end.search(line) - template_string = line - while not match_template_end: - self.line_index += 1 - line = self.input_content[self.line_index] - template_string += line - match_template_end = pattern_template_end.search(line) - self.line_index += 1 - return template_string - - def handle_namespace(self): - line = self.input_content[self.line_index] - match_namespace = pattern_namespace.search(line) - if match_namespace: # we face namespace - self.output_fd.write(line + '\n') - self.stack.append('namespace_now') - self.line_index += 1 - - def handle_normal_func(self, line, template_string): - template_line = '' - self.stack_template.append(template_string) - if self.stack_template[-1] != '': - template_line = re.sub(r'\s*template', 'template', self.stack_template[-1]) - # change '< class T = a, class U = A(3)>' to '' - template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) - template_line = re.sub(r'\s*=.*,', ',', template_line) - template_line = re.sub(r'\s*=.*', '', template_line) - line = re.sub(r'\s*=.*,', ',', line) - line = re.sub(r'\s*=.*\)', ')', line) - line = template_line + line - self.stack_template.pop() - func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() - logging.info("line[%s]", line) - logging.info("func_name[%s]", func_name) - return line, func_name - - def handle_class_member_func(self, line, template_string): - template_line = '' - x = '' - if template_string != '': - template_string = re.sub(r'\s*template', 'template', template_string) - template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string) - template_string = re.sub(r'\s*=.*,', ',', template_string) - template_string = re.sub(r'\s*=.*', '', template_string) - if self.stack_template[-1] != '': - if not (re.search(r'<\s*>', stack_template[-1])): - template_line = re.sub(r'^\s*template', 'template', stack_template[-1]) - if not (re.search(r'<.*>', self.stack_class[-1])): - # for x we get like template -> - x = re.sub(r'template\s*<', '<', template_line) # remove template -> - x = re.sub(r'\n', '', x) - x = re.sub(r'\s*=.*,', ',', x) - x = re.sub(r'\s*=.*\>', '>', x) - x = x.rstrip() # remove \n - x = re.sub(r'(class|typename)\s+|(|\s*class)', '', - x) # remove class,typename -> - x = re.sub(r'<\s+', '<', x) - x = re.sub(r'\s+>', '>', x) - x = re.sub(r'\s+,', ',', x) - x = re.sub(r',\s+', ', ', x) - line = re.sub(r'\s*=\s+0', '', line) - line = re.sub(r'\s*=\s+.*,', ',', line) - line = re.sub(r'\s*=\s+.*\)', ')', line) - logging.info("x[%s]\nline[%s]", x, line) - # if the function is long, void ABC::foo() - # breaks into two lines void ABC::\n foo() - temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1) - if len(temp_line) > max_code_len_per_line: - line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1) - else: - line = temp_line - logging.info("line[%s]", line) - # add template as the above if there is one - template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) - template_line = re.sub(r'\s*=.*,', ',', template_line) - template_line = re.sub(r'\s*=.*', '', template_line) - line = template_line + template_string + line - func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() - logging.info("line[%s]", line) - logging.info("func_name[%s]", func_name) - return line, func_name - - def write_func_content(self, content, func_name, need_generate): - if not (func_name in self.func_list_exist) and need_generate: - self.output_fd.write(content) - self.func_list_exist.append(func_name) - logging.info('add func:[%s]', func_name) - - def gen_comment(self, start_i): - comment_line = '' - # Function comments are on top of function declarations, copy them over - k = start_i - 1 # one line before this func start - if pattern_template.search(self.input_content[k]): - k -= 1 - if pattern_comment_2_end.search(self.input_content[k]): - comment_line = self.input_content[k].lstrip() - while not pattern_comment_2_start.search(self.input_content[k]): - k -= 1 - comment_line = self.input_content[k].lstrip() + comment_line - else: - for j in range(k, 0, -1): - c_line = self.input_content[j] - if pattern_comment.search(c_line): - c_line = re.sub(r'\s*//', '//', c_line) - comment_line = c_line + comment_line - else: - break - return comment_line - - @staticmethod - def implement_function(func): - function_def = '' - function_def += '{\n' - - all_items = func.split() - start = 0 - return_type = all_items[start] - if return_type == "const": - start += 1 - return_type = all_items[start] - if return_type.startswith(('std::map', 'std::set', 'std::vector')): - return_type = "std::map" - if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): - return_type = "Ptr" - if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): - return_type += "&" - if RETURN_STATEMENTS.__contains__(return_type): - function_def += RETURN_STATEMENTS[return_type] - else: - logging.warning("Unhandled return type[%s]", return_type) - - function_def += '\n' - function_def += '}\n' - function_def += '\n' - return function_def - - -def collect_header_files(path): - """ - :param path: - :return: - """ - header_files = [] - shared_includes_content = [] - for root, dirs, files in os.walk(path): - files.sort() - for file in files: - if file.find("git") >= 0: - continue - if not file.endswith('.h'): - continue - file_path = os.path.join(root, file) - file_path = file_path.replace('\\', '/') - header_files.append(file_path) - include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:]) - shared_includes_content.append(include_str) - # for acl error code - shared_includes_content.append('#include \n') - shared_includes_content.append('const int ACL_ERROR_COMPILING_STUB_MODE = 100039;\n') - return header_files, shared_includes_content - - -def generate_stub_file(inc_dir, out_cc_dir): - """ - :param inc_dir: - :param out_cc_dir: - :return: - """ - target_header_files, shared_includes_content = collect_header_files(inc_dir) - for header_file in target_header_files: - if not file_endswith_white_list_suffix(header_file): - continue - cc_file = re.sub('.h*$', '.cc', header_file) - h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content) - h_2_cc.h2cc() - - -def gen_code(inc_dir, out_cc_dir): - """ - :param inc_dir: - :param out_cc_dir: - :return: - """ - if not inc_dir.endswith('/'): - inc_dir += '/' - if not out_cc_dir.endswith('/'): - out_cc_dir += '/' - for include_dir_key_word in include_dir_key_words: - generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir) - - -if __name__ == '__main__': - inc_dir = sys.argv[1] - out_cc_dir = sys.argv[2] - gen_code(inc_dir, out_cc_dir) diff --git a/src/common/graph/tensor.cc b/src/common/graph/tensor.cc deleted file mode 100644 index 1f30c876..00000000 --- a/src/common/graph/tensor.cc +++ /dev/null @@ -1,704 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "external/graph/tensor.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/ge_tensor.h" -#include "securec.h" -#include "utils/attr_utils.h" -#include "utils/tensor_adapter.h" -#include "utils/tensor_utils.h" -#include "utils/type_utils.h" - -namespace { -/// Extra 8 bytes store pointer of string -/// Extra 1 byte store '\0' -const int EXTRA_STORE_POINTER_FOR_STRING = 8; -const int EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL = 9; -const int64_t UNKNOWN_DIM_SIZE = -1; -} // namespace - -namespace ge { -// If not overflow return true -static bool Int64MulNotOverflow(int64_t a, int64_t b) { - if (a > 0) { - if (b > 0) { - if (a > (INT64_MAX / b)) { - return false; - } - } else { - if (b < (INT64_MIN / a)) { - return false; - } - } - } else { - if (b > 0) { - if (a < (INT64_MIN / b)) { - return false; - } - } else { - if ((a != 0) && (b < (INT64_MAX / a))) { - return false; - } - } - } - return true; -} - -class TensorDescImpl { - public: - TensorDescImpl() = default; - ~TensorDescImpl() = default; - TensorDescImpl(const Shape &shape, Format format, DataType dt) : shape_(shape), format_(format), data_type_(dt) {} - - Shape shape_; - std::vector> range_; - Format format_ = FORMAT_ND; - Format origin_format_ = FORMAT_ND; - DataType data_type_ = DT_FLOAT; - Shape origin_shape_; - int64_t size_ = 0; - int64_t real_dim_cnt_ = 0; - std::string name_; -}; - -class TensorImpl { - public: - TensorImpl() = default; - ~TensorImpl() = default; - - explicit TensorImpl(const TensorDesc &tensor_desc) : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)) {} - TensorImpl(const TensorDesc &tensor_desc, const std::vector &data) - : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), data) {} - TensorImpl(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) - : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), data, size) {} - TensorImpl(TensorDesc &&tensor_desc, std::vector &&data) - : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), std::move(data)) {} - - GeTensor ge_tensor; -}; - -class ShapeImpl { - public: - ShapeImpl() = default; - ~ShapeImpl() = default; - explicit ShapeImpl(const std::vector &dims) { - bool is_unknown_dim_num = false; - for (const auto &dim : dims) { - if (dim == UNKNOWN_DIM_NUM) { - is_unknown_dim_num = true; - break; - } - } - dims_ = is_unknown_dim_num ? std::vector({UNKNOWN_DIM_NUM}) : dims; - } - - std::vector dims_; -}; - -Shape::Shape() { impl_ = ComGraphMakeShared(); } - -Shape::Shape(const std::vector &dims) { impl_ = ComGraphMakeShared(dims); } - -size_t Shape::GetDimNum() const { - if (impl_ != nullptr) { - for (auto i : impl_->dims_) { - if (i == UNKNOWN_DIM_NUM) { - return 0; - } - } - return impl_->dims_.size(); - } - return 0; -} - -int64_t Shape::GetDim(size_t idx) const { - if (impl_ != nullptr) { - if (idx >= impl_->dims_.size()) { - return 0; - } - return impl_->dims_[idx]; - } - return 0; -} - -graphStatus Shape::SetDim(size_t idx, int64_t value) { - if (impl_ != nullptr) { - if (idx >= impl_->dims_.size()) { - return GRAPH_FAILED; - } - impl_->dims_[idx] = value; - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -std::vector Shape::GetDims() const { - vector dims; - if (impl_ != nullptr) { - return impl_->dims_; - } - return dims; -} - -int64_t Shape::GetShapeSize() const { - if (impl_ != nullptr) { - if (impl_->dims_.empty()) { - return 0; - } - int64_t size = 1; - for (auto i : impl_->dims_) { - if (i == UNKNOWN_DIM_NUM || i == UNKNOWN_DIM) { - return UNKNOWN_DIM_SIZE; - } - - if (!Int64MulNotOverflow(size, i)) { - GELOGE(GRAPH_FAILED, "mul overflow: %ld, %ld", size, i); - size = 0; - return size; - } - size *= i; - } - return size; - } - return 0; -} - -TensorDesc::TensorDesc() { - impl = ComGraphMakeShared(); // lint !e665 -} - -TensorDesc::TensorDesc(Shape shape, Format format, DataType dt) { - impl = ComGraphMakeShared(shape, format, dt); // lint !e665 - SetRealDimCnt(shape.GetDimNum()); -} - -TensorDesc::TensorDesc(const TensorDesc &desc) { - // Copy - impl = ComGraphMakeShared(); // lint !e665 - if (desc.impl != nullptr && impl != nullptr) { - *impl = *desc.impl; - } -} - -TensorDesc::TensorDesc(TensorDesc &&desc) { - // Move - impl = std::move(desc.impl); -} - -TensorDesc &TensorDesc::operator=(const TensorDesc &desc) { - // Copy - if (&desc != this) { - impl = ComGraphMakeShared(); - if (desc.impl != nullptr && impl != nullptr) { - *impl = *desc.impl; - } - } - return *this; -} - -TensorDesc &TensorDesc::operator=(TensorDesc &&desc) { - if (&desc != this) { - impl = std::move(desc.impl); - } - return *this; -} - -void TensorDesc::Update(const Shape &shape, Format format, DataType dt) { - if (impl != nullptr) { - impl->shape_ = shape; - impl->format_ = format; - impl->data_type_ = dt; - } -} - -Shape TensorDesc::GetShape() const { - if (impl != nullptr) { - return impl->shape_; - } - return Shape(); -} - -void TensorDesc::SetShape(const Shape &shape) { - if (impl != nullptr) { - impl->shape_ = shape; - } -} - -// set shape with -2, it stand for unknown shape -graphStatus TensorDesc::SetUnknownDimNumShape() { - if (impl != nullptr) { - impl->shape_ = Shape({UNKNOWN_DIM_NUM}); - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Set unknown shape failed,because no impl class!"); - return GRAPH_FAILED; -} - -// for unknown shape -graphStatus TensorDesc::SetShapeRange(const std::vector> &range) { - if (impl != nullptr) { - impl->range_ = range; - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "SetShapeRange failed!impl is nullptr!"); - return GRAPH_FAILED; -} -graphStatus TensorDesc::GetShapeRange(std::vector> &range) const { - if (impl != nullptr) { - range = impl->range_; - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "impl is nullptr!"); - return GRAPH_FAILED; -} - -Shape TensorDesc::GetOriginShape() const { - if (impl != nullptr) { - return impl->origin_shape_; - } - return Shape(); -} - -void TensorDesc::SetOriginShape(const Shape &origin_shape) { - if (impl != nullptr) { - impl->origin_shape_ = origin_shape; - } -} - -Format TensorDesc::GetFormat() const { - if (impl != nullptr) { - return impl->format_; - } - return FORMAT_RESERVED; -} - -void TensorDesc::SetFormat(Format format) { - if (impl != nullptr) { - impl->format_ = format; - } -} - -Format TensorDesc::GetOriginFormat() const { - if (impl != nullptr) { - return impl->origin_format_; - } - return FORMAT_RESERVED; -} - -void TensorDesc::SetOriginFormat(Format origin_format) { - if (impl != nullptr) { - impl->origin_format_ = origin_format; - } -} - -DataType TensorDesc::GetDataType() const { - if (impl != nullptr) { - return impl->data_type_; - } - return DT_UNDEFINED; -} - -void TensorDesc::SetDataType(DataType dt) { - if (impl != nullptr) { - impl->data_type_ = dt; - } -} - -void TensorDesc::SetSize(int64_t size) { - if (impl != nullptr) { - impl->size_ = size; - } -} - -int64_t TensorDesc::GetSize() const { - if (impl != nullptr) { - return impl->size_; - } - return 0; -} - -void TensorDesc::SetRealDimCnt(const int64_t real_dim_cnt) { - if (impl != nullptr) { - impl->real_dim_cnt_ = real_dim_cnt; - } -} - -int64_t TensorDesc::GetRealDimCnt() const { - if (impl != nullptr) { - return impl->real_dim_cnt_; - } - return 0; -} - -std::string TensorDesc::GetName() const { - if (impl != nullptr) { - return impl->name_; - } - return ""; -} - -void TensorDesc::SetName(const std::string &name) { - if (impl != nullptr) { - impl->name_ = name; - } -} - -Tensor::Tensor() { impl = ComGraphMakeShared(); } - -Tensor::Tensor(const TensorDesc &tensor_desc) { - impl = ComGraphMakeShared(tensor_desc); // lint !e665 -} - -Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector &data) { - uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); - DataType data_type = tensor_desc.GetDataType(); - uint32_t type_length; - bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); - if (!ret) { - GELOGW("datatype %d is not found.", data_type); - } - - auto data_size = data.size(); - if (ret && (shape_size || (data_size != type_length))) { - if (type_length != 0 && UINT64_MAX / type_length < shape_size) { - GELOGW("mul overflow: %lu, %u", shape_size, type_length); - } else { - if (shape_size * type_length != data_size) { - GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, - data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); - } - } - } - impl = ComGraphMakeShared(tensor_desc, data); // lint !e665 -} - -Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) { - uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); - DataType data_type = tensor_desc.GetDataType(); - uint32_t type_length; - bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); - if (!ret) { - GELOGW("datatype %d is not found.", data_type); - } - if (ret && (shape_size || (size != type_length))) { - if (type_length != 0 && UINT64_MAX / type_length < shape_size) { - GELOGW("mul overflow: %lu, %u", shape_size, type_length); - } else { - if (shape_size * type_length != size) { - GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, - size, TypeUtils::DataTypeToSerialString(data_type).c_str()); - } - } - } - - impl = ComGraphMakeShared(tensor_desc, data, size); // lint !e665 -} - -Tensor::Tensor(TensorDesc &&tensor_desc, std::vector &&data) { - uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); - DataType data_type = tensor_desc.GetDataType(); - uint32_t type_length; - bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); - if (!ret) { - GELOGW("datatype %d is not found.", data_type); - } - - auto data_size = data.size(); - if (ret && (shape_size || (data_size != type_length))) { - if (type_length != 0 && UINT64_MAX / type_length < shape_size) { - GELOGW("mul overflow: %lu, %u", shape_size, type_length); - } else { - if (shape_size * type_length != data_size) { - GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, - data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); - } - } - } - impl = ComGraphMakeShared(std::move(tensor_desc), std::move(data)); // lint !e665 -} - -TensorDesc Tensor::GetTensorDesc() const { - if (impl != nullptr) { - return TensorAdapter::GeTensorDesc2TensorDesc(impl->ge_tensor.MutableTensorDesc()); - } - return TensorDesc(); -} - -graphStatus Tensor::SetTensorDesc(const TensorDesc &tensor_desc) { - if (impl != nullptr) { - impl->ge_tensor.SetTensorDesc(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -const uint8_t *Tensor::GetData() const { - if (impl != nullptr) { - return impl->ge_tensor.GetData().data(); - } - return nullptr; -} - -uint8_t *Tensor::GetData() { - if (impl != nullptr) { - return impl->ge_tensor.MutableData().data(); - } - return nullptr; -} - -size_t Tensor::GetSize() const { - if (impl != nullptr) { - return impl->ge_tensor.GetData().size(); - } - return 0; -} - -graphStatus Tensor::SetData(std::vector &&data) { - if (impl != nullptr) { - (void)impl->ge_tensor.SetData(data); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::SetData(const std::vector &data) { - if (impl != nullptr) { - (void)impl->ge_tensor.SetData(data); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::SetData(const uint8_t *data, size_t size) { - if (impl != nullptr) { - (void)impl->ge_tensor.SetData(data, size); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::SetData(const std::string &data) { - if (impl != nullptr && (!data.empty())) { - /// Extra 8 bytes store pointer of string - /// Extra 1 byte store '\0' - size_t total_size = data.size() + EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL; - std::unique_ptr buff(new (std::nothrow) char[total_size]()); - if (buff == nullptr) { - GELOGE(GRAPH_FAILED, "allocate string raw data buff failed"); - return GRAPH_FAILED; - } - uint64_t *p = reinterpret_cast(buff.get()); - // Front 8 bytes store pointer of string - char *raw_data = buff.get() + EXTRA_STORE_POINTER_FOR_STRING; - p[0] = reinterpret_cast(raw_data); - int32_t memcpy_ret = memcpy_s(raw_data, total_size - EXTRA_STORE_POINTER_FOR_STRING, data.c_str(), data.size() + 1); - GE_CHK_BOOL_RET_STATUS(memcpy_ret == EOK, GRAPH_FAILED, "copy data failed"); - (void)impl->ge_tensor.SetData(reinterpret_cast(buff.get()), total_size); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} -graphStatus Tensor::SetData(const std::vector &data) { - if (impl != nullptr) { - if (data.empty()) { - GELOGE(GRAPH_FAILED, "there is no data, please check the input variable"); - return GRAPH_FAILED; - } - size_t total_size = 0; - for (auto str : data) { - /// Extra 8 bytes store pointer of each string - /// Extra 1 byte store '\0' - total_size += (str.size() + EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL); - } - std::unique_ptr buff(new (std::nothrow) char[total_size]); - if (buff == nullptr) { - GELOGE(GRAPH_FAILED, "allocate string raw data buff failed"); - return GRAPH_FAILED; - } - uint64_t *p = reinterpret_cast(buff.get()); - // Front some bytes store pointer of each string - char *raw_data = buff.get() + data.size() * sizeof(uint64_t); - uint64_t ptr_size = data.size() * sizeof(uint64_t); - for (size_t i = 0; i < data.size(); ++i) { - p[i] = reinterpret_cast(raw_data); - if (total_size < ptr_size) { - GELOGE(GRAPH_FAILED, "Subtraction invalid, total_size: %zu, ptr_size: %lu", total_size, ptr_size); - return GRAPH_FAILED; - } - int32_t memcpy_ret = memcpy_s(raw_data, total_size - ptr_size, data[i].c_str(), data[i].size() + 1); - GE_CHK_BOOL_RET_STATUS(memcpy_ret == EOK, GRAPH_FAILED, "copy data failed"); - raw_data += (data[i].size() + 1); - ptr_size += (data[i].size() + 1); - } - - (void)impl->ge_tensor.SetData(reinterpret_cast(buff.get()), total_size); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::IsValid() { - uint64_t shape_size = GetTensorDesc().GetShape().GetShapeSize(); - DataType data_type = GetTensorDesc().GetDataType(); - uint32_t type_length; - bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); - if (!ret) { - GELOGW("datatype %d is not found.", data_type); - return GRAPH_SUCCESS; - } - - size_t data_size = GetSize(); - if (data_type != DT_STRING) { - if (shape_size || (data_size != type_length)) { - if (type_length != 0 && UINT64_MAX / type_length < shape_size) { - GELOGW("mul overflow: %lu, %u", shape_size, type_length); - } else { - if (shape_size * type_length != data_size) { - GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, - data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); - return GRAPH_FAILED; - } - } - } - } - - return GRAPH_SUCCESS; -} - -Tensor Tensor::Clone() const { - Tensor tensor; - if (impl != nullptr && tensor.impl != nullptr) { - tensor.impl->ge_tensor = impl->ge_tensor.Clone(); - } - return tensor; -} - -GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_desc) { - GeTensorDesc ge_tensor_desc(GeShape(tensor_desc.GetShape().GetDims()), tensor_desc.GetFormat(), - tensor_desc.GetDataType()); - ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); - ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); - ge_tensor_desc.SetName(tensor_desc.GetName()); - std::vector> shape_range; - auto status = tensor_desc.GetShapeRange(shape_range); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Get shape range failed!"); - return ge_tensor_desc; - } - status = ge_tensor_desc.SetShapeRange(shape_range); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set shape range failed!"); - return ge_tensor_desc; - } - auto size = tensor_desc.GetSize(); - TensorUtils::SetSize(ge_tensor_desc, size); - - auto real_dim_cnt = static_cast(tensor_desc.GetRealDimCnt()); - TensorUtils::SetRealDimCnt(ge_tensor_desc, real_dim_cnt); - return ge_tensor_desc; -} - -TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_desc) { - TensorDesc tensor_desc(Shape(ge_tensor_desc.GetShape().GetDims()), ge_tensor_desc.GetFormat(), - ge_tensor_desc.GetDataType()); - tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); - tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); - tensor_desc.SetName(ge_tensor_desc.GetName()); - std::vector> shape_range; - auto status = ge_tensor_desc.GetShapeRange(shape_range); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Get shape range failed!"); - return tensor_desc; - } - status = tensor_desc.SetShapeRange(shape_range); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set shape range failed!"); - return tensor_desc; - } - int64_t size = 0; - (void)TensorUtils::GetSize(ge_tensor_desc, size); - tensor_desc.SetSize(size); - - uint32_t real_dim_cnt = 0; - (void)TensorUtils::GetRealDimCnt(ge_tensor_desc, real_dim_cnt); - tensor_desc.SetRealDimCnt(real_dim_cnt); - return tensor_desc; -} - -GeTensorPtr TensorAdapter::Tensor2GeTensor(const Tensor &tensor) { - GeTensorPtr ge_tensor; - if (tensor.impl != nullptr) { - ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor.Clone()); // lint !e665 - } - return ge_tensor; -} - -Tensor TensorAdapter::GeTensor2Tensor(const ConstGeTensorPtr &ge_tensor) { - Tensor tensor; - if (ge_tensor != nullptr && tensor.impl != nullptr) { - tensor.impl->ge_tensor = ge_tensor->Clone(); - } - return tensor; -} - -ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) { - GeTensorPtr ge_tensor; - if (tensor.impl != nullptr) { - ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor); // lint !e665 - } - return ge_tensor; -} - -GeTensorPtr TensorAdapter::AsGeTensorPtr(Tensor &tensor) { - GeTensorPtr ge_tensor; - if (tensor.impl != nullptr) { - ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor); // lint !e665 - } - return ge_tensor; -} - -const GeTensor TensorAdapter::AsGeTensor(const Tensor &tensor) { - if (tensor.impl != nullptr) { - return tensor.impl->ge_tensor; - } - return GeTensor(); -} - -GeTensor TensorAdapter::AsGeTensor(Tensor &tensor) { - if (tensor.impl != nullptr) { - return tensor.impl->ge_tensor; - } - return GeTensor(); -} - -const Tensor TensorAdapter::AsTensor(const GeTensor &ge_tensor) { - Tensor tensor; - if (tensor.impl != nullptr) { - tensor.impl->ge_tensor = ge_tensor; - } - return tensor; -} - -Tensor TensorAdapter::AsTensor(GeTensor &ge_tensor) { - Tensor tensor; - if (tensor.impl != nullptr) { - tensor.impl->ge_tensor = ge_tensor; - } - return tensor; -} -} // namespace ge diff --git a/src/common/graph/utils/anchor_utils.cc b/src/common/graph/utils/anchor_utils.cc deleted file mode 100644 index 5a042283..00000000 --- a/src/common/graph/utils/anchor_utils.cc +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/anchor_utils.h" -#include -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" - -namespace ge { -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Format AnchorUtils::GetFormat(const DataAnchorPtr &data_anchor) { - if (data_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "The input data anchor is invalid."); - return FORMAT_RESERVED; - } - return data_anchor->format_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus AnchorUtils::SetFormat(const DataAnchorPtr &data_anchor, - Format data_format) { - if ((data_anchor == nullptr) || (data_format == FORMAT_RESERVED)) { - GELOGE(GRAPH_FAILED, "The input data anchor or input data format is invalid ."); - return GRAPH_FAILED; - } - data_anchor->format_ = data_format; - return GRAPH_SUCCESS; -} - -// Get anchor status -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorStatus AnchorUtils::GetStatus(const DataAnchorPtr &data_anchor) { - if (data_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "The input data anchor is invalid."); - return ANCHOR_RESERVED; - } - return data_anchor->status_; -} - -// Set anchor status -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus AnchorUtils::SetStatus(const DataAnchorPtr &data_anchor, - AnchorStatus anchor_status) { - if ((data_anchor == nullptr) || (anchor_status == ANCHOR_RESERVED)) { - GELOGE(GRAPH_FAILED, "The input data anchor or input data format is invalid ."); - return GRAPH_FAILED; - } - data_anchor->status_ = anchor_status; - return GRAPH_SUCCESS; -} - -bool AnchorUtils::HasControlEdge(const AnchorPtr &anchor) { - auto control_anchor = Anchor::DynamicAnchorCast(anchor); - if (control_anchor != nullptr) { - return (control_anchor->GetPeerAnchors().size() != 0); - } - - auto data_anchor = Anchor::DynamicAnchorCast(anchor); - if (data_anchor) { - for (const auto &peer : data_anchor->GetPeerAnchors()) { - auto peer_cast = Anchor::DynamicAnchorCast(peer); - if (peer_cast) { - return true; - } - } - return false; - } - GELOGE(GRAPH_FAILED, "the anchor is neither control anchor nor data anchor"); - return false; -} - -bool AnchorUtils::IsControlEdge(const AnchorPtr &src, const AnchorPtr &dst) { - GE_CHK_BOOL_EXEC(src != nullptr, return false, "src is null."); - GE_CHK_BOOL_RET_STATUS_NOLOG(src->IsLinkedWith(dst), false); - auto src_control_anchor = Anchor::DynamicAnchorCast(src); - auto dst_control_anchor = Anchor::DynamicAnchorCast(dst); - return (src_control_anchor || dst_control_anchor); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int AnchorUtils::GetIdx(const AnchorPtr &anchor) { - // Check if it can add edge between DataAnchor - auto data_anchor = Anchor::DynamicAnchorCast(anchor); - if (data_anchor != nullptr) { - return data_anchor->GetIdx(); - } - // Check if it can add edge between ControlAnchor - auto control_anchor = Anchor::DynamicAnchorCast(anchor); - if (control_anchor != nullptr) { - return control_anchor->GetIdx(); - } - return -1; -} -} // namespace ge diff --git a/src/common/graph/utils/ge_ir_utils.cc b/src/common/graph/utils/ge_ir_utils.cc deleted file mode 100644 index f238c6e8..00000000 --- a/src/common/graph/utils/ge_ir_utils.cc +++ /dev/null @@ -1,1178 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/utils/ge_ir_utils.h" -#include -#include "framework/common/debug/ge_log.h" - -namespace { -const char *const kControlAnchorIndex = ":-1"; -const char *const kNodeTypeForSubgraph = "subgraph"; -const char *const kPrefixForInputDesc = "input_desc_attr_"; -const char *const kPrefixForOutputDesc = "output_desc_attr_"; -const char *const kDumpGEGraph = "DUMP_GE_GRAPH"; -const int8_t kMaxRecursionDepth = 10; -const char *const kDumpGeGraph = std::getenv(kDumpGEGraph); -const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, nullptr, 10) : ge::OnnxUtils::NO_DUMP; -const int64_t kInputPrefixLength = 5; -const int64_t kOutputPrefixLength = 6; -using AttrDefPair = ::google::protobuf::MapPair; -} // namespace - -namespace ge { -// Part 1: from IR convert to ONNX Protobuf -static const std::map kGeDataTypeToOnnxMap = { - {DT_INT64, onnx::TensorProto_DataType_INT64}, {DT_UINT64, onnx::TensorProto_DataType_UINT64}, - {DT_FLOAT, onnx::TensorProto_DataType_FLOAT}, {DT_INT32, onnx::TensorProto_DataType_INT32}, - {DT_UINT32, onnx::TensorProto_DataType_UINT32}, {DT_INT8, onnx::TensorProto_DataType_INT8}, - {DT_UINT8, onnx::TensorProto_DataType_UINT8}, {DT_INT16, onnx::TensorProto_DataType_INT16}, - {DT_UINT16, onnx::TensorProto_DataType_UINT16}, {DT_FLOAT16, onnx::TensorProto_DataType_FLOAT16}, - {DT_DOUBLE, onnx::TensorProto_DataType_DOUBLE}, {DT_BOOL, onnx::TensorProto_DataType_BOOL}, -}; - -onnx::TensorProto_DataType OnnxUtils::EncodeDataType(DataType data_type) { - auto it = kGeDataTypeToOnnxMap.find(data_type); - if (it != kGeDataTypeToOnnxMap.end()) { - return it->second; - } else { - GELOGW("EncodeDataType: datatype not support %u", data_type); - return onnx::TensorProto_DataType_UNDEFINED; - } -} - -void OnnxUtils::AddAttrProtoFromAttribute(const std::pair &string_attr_value, - onnx::NodeProto *node_proto) { - if (node_proto == nullptr) { - GELOGE(FAILED, "Node proto is nullptr."); - return; - } - auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - GELOGE(GRAPH_FAILED, "attr is nullptr."); - return; - } - auto attr_name = string_attr_value.first; - attr->set_name(attr_name); - auto attr_value = string_attr_value.second; - auto value_type = attr_value.GetValueType(); - switch (value_type) { - case GeAttrValue::VT_FLOAT: { - GeAttrValue::FLOAT data_f = 0; - (void)attr_value.GetValue(data_f); - attr->set_f(data_f); - attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); - break; - } - case GeAttrValue::VT_LIST_FLOAT: { - GeAttrValue::LIST_FLOAT data_fs = {}; - (void)attr_value.GetValue(data_fs); - attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); - for (auto &v : data_fs) { - attr->add_floats(v); - } - break; - } - case GeAttrValue::VT_INT: { - GeAttrValue::INT data_i = 0; - (void)attr_value.GetValue(data_i); - attr->set_type(onnx::AttributeProto_AttributeType_INT); - attr->set_i(data_i); - break; - } - case GeAttrValue::VT_LIST_INT: { - GeAttrValue::LIST_INT data_is = {}; - (void)attr_value.GetValue(data_is); - attr->set_type(onnx::AttributeProto_AttributeType_INTS); - for (auto &v : data_is) { - attr->add_ints(v); - } - break; - } - case GeAttrValue::VT_STRING: { - GeAttrValue::STR data_s; - (void)attr_value.GetValue(data_s); - attr->set_type(onnx::AttributeProto_AttributeType_STRING); - attr->set_s(data_s); - break; - } - case GeAttrValue::VT_LIST_STRING: { - GeAttrValue::LIST_STR data_ss = {}; - (void)attr_value.GetValue(data_ss); - attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); - for (auto &v : data_ss) { - attr->add_strings(v); - } - break; - } - default: - GELOGW("GeAttrValue ValueType: %u is not supported for now", value_type); - break; - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, - void *data) { - if (node_proto == nullptr) { - GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); - return; - } - auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - GELOGE(GRAPH_FAILED, "attr is nullptr."); - return; - } - attr->set_name(name); - switch (type) { - case onnx::AttributeProto_AttributeType_FLOAT: - attr->set_f((*(static_cast(data)))); - attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); - break; - - case onnx::AttributeProto_AttributeType_FLOATS: - attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); - for (auto &v : (*(static_cast *>(data)))) { - attr->add_floats(v); - } - break; - - case onnx::AttributeProto_AttributeType_INT: - attr->set_type(onnx::AttributeProto_AttributeType_INT); - attr->set_i((*(static_cast(data)))); - break; - - case onnx::AttributeProto_AttributeType_INTS: - attr->set_type(onnx::AttributeProto_AttributeType_INTS); - for (auto &v : *(static_cast *>(data))) { - attr->add_ints(v); - } - break; - - case onnx::AttributeProto_AttributeType_STRING: - attr->set_type(onnx::AttributeProto_AttributeType_STRING); - attr->set_s((*(static_cast(data)))); - break; - - case onnx::AttributeProto_AttributeType_STRINGS: - attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); - for (auto &v : *(static_cast *>(data))) { - attr->add_strings(v); - } - break; - - default: - GELOGW("AttributeProto AttributeType: %u is not supported for now", type); - break; - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, - ::google::protobuf::RepeatedField<::google::protobuf::int64> data) { - if (node_proto == nullptr) { - GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); - return; - } - if (!data.empty()) { - auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - GELOGE(GRAPH_FAILED, "attr is nullptr."); - return; - } - attr->set_name(name); - for (auto &v : data) { - attr->add_ints(v); - } - attr->set_type(type); - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, - ::google::protobuf::RepeatedField data) { - if (node_proto == nullptr) { - GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str()); - return; - } - if (!data.empty()) { - auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - GELOGE(GRAPH_FAILED, "attr is nullptr."); - return; - } - attr->set_name(name); - for (auto &v : data) { - attr->add_ints(static_cast(v)); - } - attr->set_type(type); - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, - ::google::protobuf::RepeatedField data) { - if (node_proto == nullptr) { - GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); - return; - } - if (!data.empty()) { - auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - GELOGE(GRAPH_FAILED, "attr is nullptr."); - return; - } - attr->set_name(name); - for (auto &v : data) { - attr->add_floats(v); - } - attr->set_type(type); - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, - ::google::protobuf::RepeatedPtrField<::std::string> data) { - if (node_proto == nullptr) { - GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str()); - return; - } - if (!data.empty()) { - auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - GELOGE(GRAPH_FAILED, "attr is nullptr."); - return; - } - attr->set_name(name); - for (auto &v : data) { - attr->add_strings(v); - } - attr->set_type(type); - } -} - -void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc) { - if (node_proto == nullptr || op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "node_proto or op_desc is nullptr"); - return; - } - // Input describes - auto size_in = op_desc->GetAllInputsSize(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_nums", &size_in); - if (size_in > 0) { - for (uint32_t i = 0; i < size_in; i++) { - auto input_desc = op_desc->GetInputDescPtrDfault(i); - if (input_desc != nullptr) { - auto data_type = TypeUtils::DataTypeToSerialString(input_desc->GetDataType()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_dtype:" + std::to_string(i), - &data_type); - auto data_type_origin = TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "input_desc_origin_dtype:" + std::to_string(i), &data_type_origin); - auto dims = input_desc->GetShape().GetDims(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_desc_shape:" + std::to_string(i), - &dims); - auto dims_origin = input_desc->GetOriginShape().GetDims(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, - "input_desc_origin_shape:" + std::to_string(i), &dims_origin); - auto layout = TypeUtils::FormatToSerialString(input_desc->GetFormat()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_layout:" + std::to_string(i), - &layout); - auto layout_origin = TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "input_desc_origin_layout:" + std::to_string(i), &layout_origin); - auto tensor_descriptor = input_desc->tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor != nullptr) { - auto size = tensor_descriptor->size(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_size:" + std::to_string(i), - &size); - auto weight_size = tensor_descriptor->weight_size(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "input_desc_weight_size:" + std::to_string(i), &weight_size); - auto reuse_input = tensor_descriptor->reuse_input(); - auto reuse_input_int = static_cast(reuse_input); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "input_desc_reuse_input:" + std::to_string(i), &reuse_input_int); - auto output_tensor = tensor_descriptor->output_tensor(); - auto output_tensor_int = static_cast(output_tensor); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "input_desc_output_tensor:" + std::to_string(i), &output_tensor_int); - auto device_type = tensor_descriptor->device_type(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "input_desc_device_type:" + std::to_string(i), &device_type); - auto input_tensor = tensor_descriptor->input_tensor(); - auto input_tensor_int = static_cast(input_tensor); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "input_desc_input_tensor:" + std::to_string(i), &input_tensor_int); - auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "input_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); - auto data_offset = tensor_descriptor->data_offset(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "input_desc_data_offset:" + std::to_string(i), &data_offset); - auto cmps_size = tensor_descriptor->cmps_size(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_cmps_size:" + std::to_string(i), - &cmps_size); - auto cmps_tab = tensor_descriptor->cmps_tab(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "input_desc_cmps_tab:" + std::to_string(i), &cmps_tab); - auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset); - const auto &tensor_desc_map = tensor_descriptor->attr(); - std::string suffix = ":" + std::to_string(i); - AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForInputDesc, suffix); - } else { - GELOGW("Tensor descriptor is nullptr"); - continue; - } - } else { - GELOGW("Input desc is nullptr"); - continue; - } - } - } - // Output describes - auto size_out = op_desc->GetOutputsSize(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_nums", &size_out); - if (size_out > 0) { - for (uint32_t i = 0; i < size_out; i++) { - auto output_desc = op_desc->GetOutputDescPtr(i); - if (output_desc != nullptr) { - auto data_type = TypeUtils::DataTypeToSerialString(output_desc->GetDataType()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_dtype:" + std::to_string(i), - &data_type); - auto origin_data_type = TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "output_desc_origin_dtype:" + std::to_string(i), &origin_data_type); - auto dims = output_desc->GetShape().GetDims(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_desc_shape:" + std::to_string(i), - &dims); - auto dims_origin = output_desc->GetOriginShape().GetDims(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, - "output_desc_origin_shape:" + std::to_string(i), &dims_origin); - auto layout = TypeUtils::FormatToSerialString(output_desc->GetFormat()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_layout:" + std::to_string(i), - &layout); - auto layout_origin = TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "output_desc_origin_layout:" + std::to_string(i), &layout_origin); - auto tensor_descriptor = output_desc->tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor != nullptr) { - auto size = tensor_descriptor->size(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_size:" + std::to_string(i), - &size); - auto weight_size = tensor_descriptor->weight_size(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "output_desc_weight_size:" + std::to_string(i), &weight_size); - auto device_type = tensor_descriptor->device_type(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "output_desc_device_type:" + std::to_string(i), &device_type); - auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); - const auto &tensor_desc_map = tensor_descriptor->attr(); - std::string suffix = ":" + std::to_string(i); - AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForOutputDesc, suffix); - } else { - GELOGW("Tensor descriptor is nullptr"); - continue; - } - } else { - GELOGW("Output desc is nullptr"); - continue; - } - } - } -} - -void OnnxUtils::AddAttrProtoForAttrsFromAttrMap( - const ::google::protobuf::Map &attr_map, onnx::NodeProto *node_proto, - const std::string &prefix, const std::string &suffix) { - for (const auto &item : attr_map) { - auto attr_name = item.first; - auto attr_def = item.second; - auto attr_type = attr_def.value_case(); - if (attr_type == ge::proto::AttrDef::kT) { - const auto &tensor_def = attr_def.t(); - const auto &tensor_desc = tensor_def.desc(); - auto data_type = ge::proto::DataType_Name(tensor_desc.dtype()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_dtype" + suffix, - &data_type); - auto dims = tensor_desc.shape().dim(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + "_desc_shape" + suffix, - dims); - auto layout = tensor_desc.layout(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_layout" + suffix, - &layout); - auto device_type = tensor_desc.device_type(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - prefix + attr_name + "_desc_device_type" + suffix, &device_type); - if (kDumpLevel == DUMP_ALL) { - auto data = tensor_def.data(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_data" + suffix, - &data); - } - } - if (attr_type == ge::proto::AttrDef::kS) { - if (kDumpLevel == DUMP_ALL) { - auto str_value = attr_def.s(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value); - } - } - if (attr_type == ge::proto::AttrDef::kI) { - auto int_value = attr_def.i(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); - } - if (attr_type == ge::proto::AttrDef::kF) { - auto float_value = attr_def.f(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value); - } - if (attr_type == ge::proto::AttrDef::kB) { - auto int_value = static_cast(attr_def.b()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); - } - if (attr_type == ge::proto::AttrDef::kList) { - const auto &list_value = attr_def.list(); - auto list_value_type = list_value.val_type(); - if (list_value_type == - ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) { - if (kDumpLevel == DUMP_ALL) { - const auto &strings = list_value.s(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings); - } - } - if (list_value_type == - ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) { - const auto &floats = list_value.f(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats); - } - if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) { - const auto &ints = list_value.i(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints); - } - if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) { - const auto &bools = list_value.b(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools); - } - } - } -} - -void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "node is nullptr"); - return; - } - // 1.Attributes added from node's methods - auto send_list = node->send_event_id_list_; - if (!send_list.empty()) { - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "send_event_id_list", &send_list); - } - auto recv_list = node->recv_event_id_list_; - if (!recv_list.empty()) { - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "recv_event_id_list", &recv_list); - } - auto op_desc = node->op_; - if (op_desc != nullptr) { - // for input_name_idx_ in opdesc - auto input_name_2_indexs = op_desc->GetAllInputName(); - ::google::protobuf::RepeatedPtrField<::std::string> input_names; - ::google::protobuf::RepeatedField<::google::protobuf::int64> input_indexes; - for (const auto &input_name_2_index : input_name_2_indexs) { - std::string input_name = input_name_2_index.first; - input_names.Add(std::move(input_name)); - input_indexes.Add(input_name_2_index.second); - } - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "_input_name_key", input_names); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "_input_name_value", input_indexes); - // 2.Attributes added from node's op_(message OpDef) - // Input and out describes - AddAttrProtoForOpInAndOutDesc(node_proto, op_desc); - // Others - auto op_def = op_desc->op_def_.GetProtoMsg(); - if (op_def != nullptr) { - auto id = op_def->id(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "id", &id); - auto stream_id = op_def->stream_id(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "stream_id", &stream_id); - const auto &input_name = op_def->input_name(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "input_name", input_name); - const auto &src_name = op_def->src_name(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "src_name", src_name); - const auto &src_index = op_def->src_index(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "src_index", src_index); - const auto &dst_name = op_def->dst_name(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "dst_name", dst_name); - const auto &dst_index = op_def->dst_index(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "dst_index", dst_index); - const auto &input_i = op_def->input_i(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_i", input_i); - const auto &output_i = op_def->output_i(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_i", output_i); - const auto &workspace = op_def->workspace(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace", workspace); - const auto &workspace_bytes = op_def->workspace_bytes(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes); - const auto &is_input_const = op_def->is_input_const(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const); - const auto &op_def_attr_map = op_def->attr(); - AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto); - } else { - GELOGE(FAILED, "Opdef is nullptr"); - return; - } - } else { - GELOGE(FAILED, "Opdesc is nullptr"); - return; - } -} - -bool OnnxUtils::EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto) { - if ((node == nullptr) || (node_proto == nullptr)) { - GELOGE(GRAPH_FAILED, "EncodeOpDesc: Input Para Node Invalid"); - return false; - } - - // 2.Encode map attrs_ to AttributeProto - for (auto &node_attr : node->attrs_) { - AddAttrProtoFromAttribute(node_attr, node_proto); - } - // 3.Encode ge::Node members to AttributeProto - AddAttrProtoFromNodeMembers(node, node_proto); - return true; -} - -void OnnxUtils::EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto) { - if ((node == nullptr) || (node_proto == nullptr)) { - GELOGE(GRAPH_FAILED, "EncodeNodeLinkForNetronVisual: Input Para Node Invalid"); - return; - } - const auto &node_name = node->GetName(); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - if ((out_data_anchor != nullptr) && (!out_data_anchor->GetPeerInDataAnchors().empty())) { - node_proto->add_output(node_name + ":" + std::to_string(out_data_anchor->GetIdx())); - } - } - auto out_control_anchor = node->GetOutControlAnchor(); - if ((out_control_anchor != nullptr) && (!out_control_anchor->GetPeerInControlAnchors().empty())) { - node_proto->add_output(node_name + kControlAnchorIndex); - } -} - -bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) { - if ((node == nullptr) || (node_proto == nullptr)) { - GELOGE(GRAPH_FAILED, "EncodeNodeLink: Input Para Node Invalid"); - return false; - } - node_proto->clear_input(); - // 1. Add input by in data edge - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) { - node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + - std::to_string(peer_out_anchor->GetIdx())); - } else { - // Add "" input - node_proto->add_input(""); - } - } - - // 2. Add input by in control edge - auto in_control_anchor = node->GetInControlAnchor(); - if (in_control_anchor != nullptr) { - auto peer_out_anchors = in_control_anchor->GetPeerOutControlAnchors(); - for (const auto &peer_out_anchor : peer_out_anchors) { - if (peer_out_anchor->GetOwnerNode()) { - node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex); - } - } - } else { - GELOGE(FAILED, "Incontrol anchor is nullptr"); - return false; - } - - // 3. Add output for Netron visual support - EncodeNodeLinkForNetronVisual(node, node_proto); - return true; -} - -bool OnnxUtils::EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto) { - if ((node == nullptr) || (node_proto == nullptr)) { - GELOGE(GRAPH_FAILED, "EncodeNode: Input Para Node Invalid"); - return false; - } - // 1. Encode name and type - node_proto->set_name(node->GetName()); - /// Netron believes that some operators, such as the activation operator of softplus, only have one input, - /// while the link relation of control anchor may exist in ge, resulting in two inputs. Therefore, "ge:" prefix - /// is added to correctly display the link relation at the expense of some color features - node_proto->set_op_type("ge:" + node->GetType()); - - if (kDumpLevel != DUMP_WITH_OUT_DESC) { - // 2.for attr - if (!EncodeNodeDesc(node, node_proto)) { - GELOGE(GRAPH_FAILED, "Encode NodeDesc: %s failed", node->GetName().c_str()); - return false; - } - } - // 3.for link info - return EncodeNodeLink(node, node_proto); -} - -void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type) { - if ((node == nullptr) || (tensor_type == nullptr)) { - GELOGE(GRAPH_FAILED, "EncodeTypeProtoTensorType: Input Para Node or tensor_type Invalid"); - return; - } - const auto &op_desc = node->GetOpDesc(); - if (op_desc != nullptr) { - uint32_t size_out = static_cast(op_desc->GetOutputsSize()); - if (size_out > 0) { - for (uint32_t i = 0; i < size_out; i++) { - const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i); - if (ge_tensor != nullptr) { - auto ge_data_type = ge_tensor->GetDataType(); - auto onnx_data_type = EncodeDataType(ge_data_type); - tensor_type->set_elem_type(onnx_data_type); - onnx::TensorShapeProto *shape = tensor_type->mutable_shape(); - if (shape != nullptr) { - for (auto d : ge_tensor->GetShape().GetDims()) { - auto dim = shape->add_dim(); - dim->set_dim_value(d); - } - } else { - GELOGW("Shape is nullptr"); - continue; - } - } else { - GELOGW("Ge tensor is nullptr"); - continue; - } - } - } - } else { - GELOGW("OpDesc Is Empty, nodeName %s nodeType %s", node->GetName().c_str(), node->GetType().c_str()); - return; - } -} - -void OnnxUtils::EncodeValueInfo(const NodePtr &node, onnx::ValueInfoProto *value_info_proto) { - if ((node == nullptr) || (value_info_proto == nullptr)) { - GELOGE(GRAPH_FAILED, "EncodeValueInfo: Input Para Node or value_info_proto Invalid"); - return; - } - value_info_proto->set_name(node->GetName()); - onnx::TypeProto *t = value_info_proto->mutable_type(); - onnx::TypeProto_Tensor *tensor_type = t->mutable_tensor_type(); - EncodeTypeProtoTensorType(node, tensor_type); -} - -bool OnnxUtils::EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto) { - if ((graph == nullptr) || (graph_proto == nullptr)) { - GELOGE(GRAPH_FAILED, "EncodeGraph: Input para Invalid"); - return false; - } - graph_proto->set_name(graph->GetName()); - // 1. Add graph inputs - for (const auto &input : graph->GetInputNodes()) { - auto value_info_proto = graph_proto->add_input(); - EncodeValueInfo(input, value_info_proto); - } - // 2. Add graph outputs - for (const auto &output : graph->GetOutputNodes()) { - auto value_info_proto = graph_proto->add_output(); - EncodeValueInfo(output, value_info_proto); - } - // 3. Add nodes - for (const auto &node : graph->GetDirectNode()) { - if (!EncodeNode(node, graph_proto->add_node())) { - GELOGW("EncodeNode failed"); - continue; - } - } - return true; -} - -bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto) { - model_proto.set_model_version(model.GetVersion()); - model_proto.set_ir_version(onnx::IR_VERSION); - model_proto.set_producer_name(model.GetName()); - auto &graph = model.graph_; - auto compute_graph = GraphUtils::GetComputeGraph(graph); - if (compute_graph == nullptr) { - GELOGE(GRAPH_FAILED, "GetComputeGraph: return nullptr"); - return false; - } - auto graph_proto = model_proto.mutable_graph(); - if (graph_proto == nullptr) { - GELOGE(GRAPH_FAILED, "mutable_graph: %s return nullptr", compute_graph->GetName().c_str()); - return false; - } - if (!EncodeGraph(compute_graph, graph_proto)) { - GELOGE(GRAPH_FAILED, "EncodeGraph: %s fail", compute_graph->GetName().c_str()); - return false; - } - - // For subgraphs: a subgraph is represented by a node - for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) { - if (sub_compute_graph != nullptr) { - auto node_proto = graph_proto->add_node(); - if (node_proto == nullptr) { - GELOGW("Node proto is nullptr"); - continue; - } - node_proto->set_name(sub_compute_graph->GetName()); - node_proto->set_op_type(kNodeTypeForSubgraph); - auto attr = node_proto->add_attribute(); - attr->set_name("graph"); - attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); - auto sub_graph_proto = attr->mutable_g(); - if (sub_graph_proto == nullptr) { - GELOGW("Sub graph proto is nullptr"); - continue; - } - if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) { - GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str()); - continue; - } - } else { - GELOGW("Graph: %s subgraph is nullptr, skip EncodeGraph", compute_graph->GetName().c_str()); - continue; - } - } - return true; -} - -// Part 2: from ONNX Protobuf convert to IR -static std::map onnxDataTypeToGeMap = { - {onnx::TensorProto_DataType_INT64, DT_INT64}, {onnx::TensorProto_DataType_UINT64, DT_UINT64}, - {onnx::TensorProto_DataType_FLOAT, DT_FLOAT}, {onnx::TensorProto_DataType_INT32, DT_INT32}, - {onnx::TensorProto_DataType_UINT32, DT_UINT32}, {onnx::TensorProto_DataType_INT8, DT_INT8}, - {onnx::TensorProto_DataType_UINT8, DT_UINT8}, {onnx::TensorProto_DataType_INT16, DT_INT16}, - {onnx::TensorProto_DataType_UINT16, DT_UINT16}, {onnx::TensorProto_DataType_FLOAT16, DT_FLOAT16}, - {onnx::TensorProto_DataType_DOUBLE, DT_DOUBLE}, {onnx::TensorProto_DataType_BOOL, DT_BOOL}, -}; - -ge::DataType OnnxUtils::DecodeDataType(onnx::TensorProto_DataType data_type) { - auto it = onnxDataTypeToGeMap.find(data_type); - if (it != onnxDataTypeToGeMap.end()) { - return it->second; - } else { - GELOGW("DecodeDataType: datatype not support %u", data_type); - return ge::DT_UNDEFINED; - } -} - -bool OnnxUtils::ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index) { - auto sep = node_name_index.rfind(':'); - if (sep == std::string::npos) { - return false; - } - node_name = node_name_index.substr(0, sep); - auto index_str = node_name_index.substr(sep + 1); - index = static_cast(std::strtol(index_str.c_str(), nullptr, 10)); - return true; -} - -bool OnnxUtils::DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr) { - if (node_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp: node_ptr is nullptr"); - return false; - } - // Data edge - if (item.src_out_index >= 0) { - auto src_anchor = node_ptr->GetOutDataAnchor(item.src_out_index); - auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index); - if ((src_anchor == nullptr) || (dst_anchor == nullptr)) { - GELOGE(GRAPH_FAILED, "Get data anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, - item.dst_node_name.c_str(), item.dst_in_index); - return false; - } - if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Data Anchor: src_anchor->LinkTo(dst_anchor) failed"); - return false; - } - // Control edge - } else { - auto src_anchor = node_ptr->GetOutControlAnchor(); - auto dst_anchor = item.dst_node->GetInControlAnchor(); - if ((src_anchor == nullptr) || (dst_anchor == nullptr)) { - GELOGE(GRAPH_FAILED, "Get control anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, - item.dst_node_name.c_str(), item.dst_in_index); - return false; - } - if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Control Anchor: src_anchor->LinkTo(dst_anchor) failed"); - return false; - } - } - return true; -} - -bool OnnxUtils::DecodeNodeLink(const std::vector &node_proto_vector, - const std::map &node_map) { - for (const auto &node_proto : node_proto_vector) { - const auto &node_name = node_proto.name(); - auto dst_node = node_map.find(node_name); - if ((dst_node == node_map.end()) || (dst_node->second == nullptr)) { - GELOGE(GRAPH_FAILED, "destination node: %s find failed or is nullptr", node_name.c_str()); - return false; - } - int32_t dst_index = 0; - for (const auto &input : node_proto.input()) { - std::string input_node_name; - int32_t index = 0; - if (ParseNameIndex(input, input_node_name, index)) { - auto item = NodeLinkInfo{input_node_name, index, dst_node->second, dst_index, node_proto.name()}; - auto src_node = node_map.find(input_node_name); - if (src_node == node_map.end()) { - GELOGE(GRAPH_FAILED, "find src node: %s failed", input_node_name.c_str()); - return false; - } - auto node_ptr = src_node->second; - if (node_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "src node: %s is nullptr", input_node_name.c_str()); - return false; - } - if (!DecodeNodeLinkImp(item, node_ptr)) { - GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp node: %s failed", input_node_name.c_str()); - return false; - } - } - if (index >= 0) { - dst_index++; - } - } - } - return true; -} - -void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &strings) { - if (attr_proto.type() != onnx::AttributeProto_AttributeType_STRINGS) { - GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - return; - } - for (int i = 0; i < attr_proto.strings_size(); i++) { - strings.push_back(attr_proto.strings(i)); - } -} - -void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value) { - if (attr_proto.type() != onnx::AttributeProto_AttributeType_STRING) { - GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - return; - } - value = attr_proto.s(); -} - -void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &ints) { - if (attr_proto.type() != onnx::AttributeProto_AttributeType_INTS) { - GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - return; - } - for (int i = 0; i < attr_proto.ints_size(); i++) { - ints.push_back(attr_proto.ints(i)); - } -} - -void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value) { - if (attr_proto.type() != onnx::AttributeProto_AttributeType_INT) { - GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - return; - } - value = attr_proto.i(); -} - -void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_input_desc, int32_t index, - OpDescPtr &op_desc) { - if (op_desc->MutableInputDesc(static_cast(index)) == nullptr) { - GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableInputDesc(static_cast(index)) is nullptr", - op_desc->GetName().c_str(), attr_name_for_input_desc.c_str()); - return; - } - if (attr_name_for_input_desc == "input_desc_dtype") { - auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); - op_desc->MutableInputDesc(static_cast(index))->SetDataType(data_type); - } else if (attr_name_for_input_desc == "input_desc_shape") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - GeShape ge_shape(ints); - op_desc->MutableInputDesc(static_cast(index))->SetShape(ge_shape); - } else if (attr_name_for_input_desc == "input_desc_layout") { - auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); - op_desc->MutableInputDesc(static_cast(index))->SetFormat(data_format); - } else if (attr_name_for_input_desc == "input_desc_origin_shape") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - GeShape ge_shape(ints); - op_desc->MutableInputDesc(static_cast(index))->SetOriginShape(ge_shape); - } else if (attr_name_for_input_desc == "input_desc_origin_layout") { - auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); - op_desc->MutableInputDesc(static_cast(index))->SetOriginFormat(data_format); - } else if (attr_name_for_input_desc == "input_desc_size") { - int64_t input_size = 0; - auto tensor_descriptor = op_desc->MutableInputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); - DecodeAttribute(attr_proto, input_size); - tensor_descriptor->set_size(input_size); - } else if (attr_name_for_input_desc == "input_desc_data_offset") { - auto tensor_descriptor = op_desc->MutableInputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); - int64_t offset = 0; - DecodeAttribute(attr_proto, offset); - tensor_descriptor->set_data_offset(offset); - } else { - return; - } -} - -void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_output_desc, int32_t index, - OpDescPtr &op_desc) { - if (op_desc->MutableOutputDesc(static_cast(index)) == nullptr) { - GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableOutputDesc(static_cast(index)) is nullptr", - op_desc->GetName().c_str(), attr_name_for_output_desc.c_str()); - return; - } - if (attr_name_for_output_desc == "output_desc_dtype") { - auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); - op_desc->MutableOutputDesc(static_cast(index))->SetDataType(data_type); - } else if (attr_name_for_output_desc == "output_desc_shape") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - GeShape ge_shape(ints); - op_desc->MutableOutputDesc(static_cast(index))->SetShape(ge_shape); - } else if (attr_name_for_output_desc == "output_desc_layout") { - auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); - op_desc->MutableOutputDesc(static_cast(index))->SetFormat(data_format); - } else if (attr_name_for_output_desc == "output_desc_origin_shape") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - GeShape ge_shape(ints); - op_desc->MutableOutputDesc(static_cast(index))->SetOriginShape(ge_shape); - } else if (attr_name_for_output_desc == "output_desc_origin_layout") { - auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); - op_desc->MutableOutputDesc(static_cast(index))->SetOriginFormat(data_format); - } else if (attr_name_for_output_desc == "output_desc_size") { - int64_t output_size = 0; - auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); - DecodeAttribute(attr_proto, output_size); - tensor_descriptor->set_size(output_size); - } else if (attr_name_for_output_desc == "output_desc_data_offset") { - auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); - int64_t offset = 0; - DecodeAttribute(attr_proto, offset); - tensor_descriptor->set_data_offset(offset); - } else { - return; - } -} - -void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_input_output_desc, int32_t index, - OpDescPtr &op_desc) { - if (op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "op_desc is nullptr"); - return; - } - if (attr_name_for_input_output_desc.substr(0, kInputPrefixLength) == "input") { - DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); - } else if (attr_name_for_input_output_desc.substr(0, kOutputPrefixLength) == "output") { - DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); - } else { - return; - } -} - -void OnnxUtils::DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def) { - auto attr_map = op_def.mutable_attr(); - const auto &attr_name = attr_proto.name(); - ge::proto::AttrDef op_attr; - int64_t value = 0; - DecodeAttribute(attr_proto, value); - op_attr.set_i(value); - attr_map->insert(AttrDefPair(attr_name, op_attr)); -} - -void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) { - if (op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr"); - return; - } - const auto &attr_name = attr_proto.name(); - std::string attr_name_for_input_output_desc; - int32_t index = 0; - if (!ParseNameIndex(attr_name, attr_name_for_input_output_desc, index)) { - if (attr_name == "id") { - op_desc->SetId(attr_proto.i()); - } else if (attr_name == "stream_id") { - op_desc->SetStreamId(attr_proto.i()); - } else if (attr_name == "src_name") { - std::vector strings; - DecodeAttribute(attr_proto, strings); - op_desc->SetSrcName(strings); - } else if (attr_name == "dst_name") { - std::vector strings; - DecodeAttribute(attr_proto, strings); - op_desc->SetDstName(strings); - } else if (attr_name == "src_index") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - op_desc->SetSrcIndex(ints); - } else if (attr_name == "dst_index") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - op_desc->SetDstIndex(ints); - } else if (attr_name == "fusion_scope") { - DecodeNodeAttributeForOpDef(attr_proto, *op_desc->op_def_.GetProtoMsg()); - } else if (attr_name == "input_i") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - op_desc->SetInputOffset(ints); - } else if (attr_name == "output_i") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - op_desc->SetOutputOffset(ints); - } else { - return; - } - // Update input and output desc - } else { - DecodeNodeAttributeForOpInAndOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); - } -} - -bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_desc) { - if (op_desc == nullptr || node_proto == nullptr) { - GELOGE(GRAPH_FAILED, " Op_desc is nullptr or node_proto is nullptr"); - return false; - } - // 1. Decode node_proto name and type - op_desc->SetName(node_proto->name()); - const auto &node_type_with_ge_prefix = node_proto->op_type(); - auto sep = node_type_with_ge_prefix.find(':'); - if (sep == std::string::npos) { - return false; - } - auto node_type = node_type_with_ge_prefix.substr(sep + 1); - op_desc->SetType(node_type); - // 2. Add empty input and output desc - for (const auto &attr : node_proto->attribute()) { - if (attr.name() == "input_desc_nums") { - auto size_in = attr.i(); - for (int64_t i = 0; i < size_in; i++) { - GeTensorDesc ge_tensor_desc; - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed."); - } - } - if (attr.name() == "output_desc_nums") { - auto size_out = attr.i(); - for (int64_t i = 0; i < size_out; i++) { - GeTensorDesc ge_tensor_desc; - GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed."); - } - } - } - // 3.Decode node_proto attributes - for (int i = 0; i < node_proto->attribute_size(); i++) { - DecodeNodeAttributeForOpDesc(node_proto->attribute(i), op_desc); - } - return true; -} - -bool OnnxUtils::DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph) { - if (recursion_depth > kMaxRecursionDepth) { - GELOGE(GRAPH_FAILED, "DecodeGraph: recursion depth is too large, abort"); - return false; - } - - graph = ComGraphMakeShared(graph_proto.name()); - GE_CHK_BOOL_EXEC(graph != nullptr, return false, "ComputeGraph make shared failed"); - /// 1. Decode all nodes first, node should include input - /// and output nodes and nodes which represent sub graphs - std::map node_map; - std::vector node_proto_vector; - for (const auto &node_proto : graph_proto.node()) { - // a. nodes represent sub graphs - if (node_proto.op_type() == kNodeTypeForSubgraph) { - ComputeGraphPtr compute_graph; - // in this case, node only have one attr, whose type is AttributeProto_AttributeType_GRAPH - const auto &node_attr = node_proto.attribute(0); - if ((node_attr.type() == onnx::AttributeProto_AttributeType_GRAPH) && - DecodeGraph(recursion_depth + 1, node_attr.g(), compute_graph)) { - (void)graph->AddSubGraph(compute_graph); - } else { - GELOGE(GRAPH_FAILED, "Decode sub graph %s failed with node type:%d", node_proto.name().c_str(), - node_attr.type()); - return false; - } - // b. direct nodes in graph - } else { - node_proto_vector.push_back(node_proto); - OpDescPtr op_desc = ComGraphMakeShared(); - // b.1 For node desc - if (!DecodeNodeDesc(&node_proto, op_desc)) { - GELOGE(GRAPH_FAILED, "Decode node desc %s failed ", node_proto.name().c_str()); - return false; - } - auto node = graph->AddNode(op_desc); - node_map.insert(std::make_pair(node_proto.name(), node)); - } - } - /// We get all nodes in graph here - /// b.2 For node link - if (!DecodeNodeLink(node_proto_vector, node_map)) { - GELOGE(GRAPH_FAILED, "Decode node link failed"); - return false; - } - - // 2. Add inputs nodes for graph - for (const auto &input : graph_proto.input()) { - const auto &input_node_name = input.name(); - auto input_node_item = node_map.find(input_node_name); - if (input_node_item == node_map.end()) { - GELOGE(GRAPH_FAILED, "cannot find graph's input node %s in node_", input_node_name.c_str()); - return false; - } - auto ret = graph->AddInputNode(input_node_item->second); - GE_CHK_BOOL_EXEC(ret != nullptr, continue, "Add inputnode failed"); - } - // 3. Add outputs nodes for graph - for (const auto &output : graph_proto.output()) { - const auto &output_node_name = output.name(); - auto output_node_item = node_map.find(output_node_name); - if (output_node_item == node_map.end()) { - GELOGE(GRAPH_FAILED, "cannot find graph's output node %s in node_", output_node_name.c_str()); - return false; - } - auto ret = graph->AddOutputNode(output_node_item->second); - if (ret == nullptr) { - GELOGW("Add outputnode failed,out put node is %s", output_node_name.c_str()); - continue; - } - } - return true; -} - -bool OnnxUtils::ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model) { - model.name_ = model_proto.producer_name(); - model.version_ = static_cast(model_proto.model_version()); - - auto &graph_proto = model_proto.graph(); - ComputeGraphPtr compute_graph; - // 0 means recursion depth, father call - if (!DecodeGraph(0, graph_proto, compute_graph)) { - GELOGE(GRAPH_FAILED, "Decode compute graph from graph_proto failed"); - return false; - } - model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph); - return true; -} -} // namespace ge diff --git a/src/common/graph/utils/ge_ir_utils.h b/src/common/graph/utils/ge_ir_utils.h deleted file mode 100644 index 9b16be18..00000000 --- a/src/common/graph/utils/ge_ir_utils.h +++ /dev/null @@ -1,206 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ -#define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "proto/ge_ir.pb.h" -#include "proto/onnx.pb.h" - -namespace ge { -const int kOffsetToString = 2; - -/// -/// @ingroup ge_ir_utils -/// @brief RepeatedField->String -/// @param [in] const rpd_field RepeatedField -/// @return String -/// -template -const std::string ToString(const google::protobuf::RepeatedField &rpd_field) { - std::stringstream ss; - ss << "["; - for (const T &x : rpd_field) { - ss << x; - ss << ", "; - } - std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString); - str_ret += "]"; - return str_ret; -} - -/// -/// @ingroup ge_ir_utils -/// @brief RepeatedPtrField->String -/// @param [in] const rpd_field RepeatedPtrField -/// @return String -/// -template -const std::string ToString(const google::protobuf::RepeatedPtrField &rpd_ptr_field) { - std::stringstream ss; - ss << "["; - for (const T &x : rpd_ptr_field) { - ss << x; - ss << ", "; - } - std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString); - str_ret += "]"; - return str_ret; -} - -/// -/// @ingroup ge_ir_utils -/// @brief check, if not equal, log with tag -/// @param [in] const left_value, right_value reference, log_info_tag -/// @return bool -/// -template -bool IsEqual(const T &l_value, const T &r_value, const std::string &log_info_tag) { - if (l_value == r_value) { - return true; - } else { - GELOGE(GRAPH_FAILED, "Check failed with %s", log_info_tag.c_str()); - return false; - } -} - -class OnnxUtils { - public: - enum DumpLevel { NO_DUMP = 0, DUMP_ALL = 1, DUMP_WITH_OUT_DATA = 2, DUMP_WITH_OUT_DESC = 3, DUMP_LEVEL_END }; - - static bool ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto); - - static bool ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model); - - private: - // Part 1: from IR convert to ONNX Protobuf - static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, - const std::string &name, void *data); - - static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, - const std::string &name, ::google::protobuf::RepeatedField<::google::protobuf::int64> data); - - static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, - const std::string &name, ::google::protobuf::RepeatedField data); - - static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, - const std::string &name, ::google::protobuf::RepeatedField data); - - static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, - const std::string &name, ::google::protobuf::RepeatedPtrField<::std::string> data); - - static void AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto); - - static void AddAttrProtoFromAttribute(const std::pair &string_attr_value, - onnx::NodeProto *node_proto); - - static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc); - - static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map &attr_map, - onnx::NodeProto *node_proto, const std::string &prefix = "", - const std::string &suffix = ""); - - static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto); - - static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type); - - static void EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto); - - static bool EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto); - - static bool EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto); - - static bool EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto); - - static void EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type); - - static void EncodeValueInfo(const NodePtr &n, onnx::ValueInfoProto *v); - - static bool EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto); - - /// Part 2: from ONNX Protobuf convert to IR - /// Describes node's link relationships - struct NodeLinkInfo { - std::string src_node_name; - int32_t src_out_index; - NodePtr dst_node; - int32_t dst_in_index; - std::string dst_node_name; - }; - - // Parse node name and index - static bool ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index); - - static ge::DataType DecodeDataType(onnx::TensorProto_DataType data_type); - - static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &strings); - - static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &ints); - - static void DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value); - - static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value); - - static void DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_output_desc, int32_t index, - OpDescPtr &op_desc); - - static void DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_input_desc, int32_t index, - OpDescPtr &op_desc); - - static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_input_output_desc, int32_t index, - OpDescPtr &op_desc); - - static void DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def); - - static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc); - - static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr); - - static bool DecodeNodeLink(const std::vector &node_proto_vector, - const std::map &node_map); - - static bool DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &node); - - static bool DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph); -}; -} // namespace ge - -#endif // COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ diff --git a/src/common/graph/utils/graph_utils.cc b/src/common/graph/utils/graph_utils.cc deleted file mode 100644 index c741a316..00000000 --- a/src/common/graph/utils/graph_utils.cc +++ /dev/null @@ -1,2767 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/graph_utils.h" - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "./ge_context.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "proto/ge_ir.pb.h" -#include "utils/attr_utils.h" -#include "utils/ge_ir_utils.h" -#include "utils/node_utils.h" -#include "debug/ge_op_types.h" -#include "external/ge/ge_api_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" - -using google::protobuf::io::FileOutputStream; - -namespace ge { -enum DumpGraphLevel { - kDumpLevel1 = 1, - kDumpLevel2 = 2, - kDumpLevel3 = 3, - kDumpLevelOther, -}; - -namespace { -const int32_t kBaseOfIntegerValue = 10; -#ifdef FMK_SUPPORT_DUMP -const char *const kDumpGeGraph = "DUMP_GE_GRAPH"; -const int kDumpGraphIndexWidth = 5; -#endif -const char *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; -const char *const kDumpStrBuild = "Build"; -const char *const kDumpStrPartition = "partition"; -const char *const kDumpStrOptimizeSubgraph = "OptimizeSubGraph"; -const char *const kDumpStrSubgraphFunc = "sub_graph"; -const char *const kDumpStrAicpu = "Aicpu"; -}; // namespace - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, - const InDataAnchorPtr &dst) { - if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Add edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const AnchorPtr &src, - const AnchorPtr &dst) { - OutDataAnchorPtr src_data = Anchor::DynamicAnchorCast(src); - InDataAnchorPtr dst_data = Anchor::DynamicAnchorCast(dst); - OutControlAnchorPtr src_control = Anchor::DynamicAnchorCast(src); - InControlAnchorPtr dst_control = Anchor::DynamicAnchorCast(dst); - if ((src_data != nullptr) && (dst_data != nullptr) && (src_data->LinkTo(dst_data) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - if ((src_data != nullptr) && (dst_control != nullptr) && (src_data->LinkTo(dst_control) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - if ((src_control != nullptr) && (dst_control != nullptr) && (src_control->LinkTo(dst_control) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - if ((src_control != nullptr) && (dst_data != nullptr) && (src_control->LinkTo(dst_data) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Add edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, - const Format &src_format, - const InDataAnchorPtr &dst, - const Format &dst_format) { - if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { - auto ret = AnchorUtils::SetFormat(src, src_format); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set format failed, format is %d", static_cast(src_format)); - return ret; - } - ret = AnchorUtils::SetFormat(dst, dst_format); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set format failed,format is %d", static_cast(dst_format)); - return ret; - } - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Add edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutControlAnchorPtr &src, - const InControlAnchorPtr &dst) { - if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Add edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, - const InControlAnchorPtr &dst) { - if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Add edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutDataAnchorPtr &src, - const InDataAnchorPtr &dst) { - if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Remove edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const AnchorPtr &src, - const AnchorPtr &dst) { - if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Remove edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutControlAnchorPtr &src, - const InControlAnchorPtr &dst) { - if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Remove edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutDataAnchorPtr &src, - const InControlAnchorPtr &dst) { - if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Remove edge Failed."); - return GRAPH_FAILED; -} - -graphStatus GraphUtils::ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, - const InDataAnchorPtr &new_dst) { - if (RemoveEdge(src, dst) == GRAPH_SUCCESS && AddEdge(src, new_dst) == GRAPH_SUCCESS) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Replace edge dst Failed."); - return GRAPH_FAILED; -} - -graphStatus GraphUtils::ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst, - const InControlAnchorPtr &new_dst) { - if (RemoveEdge(src, dst) == GRAPH_SUCCESS && AddEdge(src, new_dst) == GRAPH_SUCCESS) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Replace edge dst Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertNodeBetweenDataAnchors( - const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, const NodePtr &new_node) { - GE_CHECK_NOTNULL(src); - GE_CHECK_NOTNULL(dst); - GE_CHECK_NOTNULL(new_node); - - InDataAnchorPtr node_in_anchor = new_node->GetInDataAnchor(0); - GE_CHK_BOOL_RET_STATUS(node_in_anchor != nullptr, GRAPH_FAILED, "this node has not inDataAnchor"); - OutDataAnchorPtr node_out_anchor = new_node->GetOutDataAnchor(0); - GE_CHK_BOOL_RET_STATUS(node_out_anchor != nullptr, GRAPH_FAILED, "this node has not outDataAnchor"); - GE_CHK_STATUS_RET(src->ReplacePeer(dst, node_in_anchor, node_out_anchor), "ReplacePeer Failed"); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node) { - GE_CHECK_NOTNULL(compute_graph); - if (remove_node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should not be null."); - return GRAPH_FAILED; - } - - // Check if this node is belong to this compute graph, maybe a little slow - const auto &all_nodes_in_graph = compute_graph->GetDirectNode(); - if (std::find(all_nodes_in_graph.begin(), all_nodes_in_graph.end(), remove_node) == all_nodes_in_graph.end()) { - GELOGE(GRAPH_FAILED, "Can not find node %s in graph %s.", remove_node->GetName().c_str(), - compute_graph->GetName().c_str()); - return GRAPH_FAILED; - } - // Find all subgraph of this node - const auto &root_graph = GraphUtils::FindRootGraph(compute_graph); - std::vector subgraphs; - std::vector all_nodes; - std::deque candidates; - NodePtr remove_node_new = remove_node; - candidates.emplace_back(remove_node_new); - while (!candidates.empty()) { - const NodePtr node = candidates.front(); - all_nodes.emplace_back(node); - candidates.pop_front(); - - OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - continue; - } - - const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); - for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { - auto subgraph = root_graph->GetSubgraph(*name_iter); - if (subgraph != nullptr) { - subgraphs.emplace_back(subgraph); - candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); - } - } - } - // Remove all subgraph - for (const auto &remove_graph : subgraphs) { - if (root_graph->RemoveSubGraph(remove_graph) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Remove subgraph failed, sub graph name is %s, compute graph is %s.", - remove_node->GetName().c_str(), compute_graph->GetName().c_str()); - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node) { - GE_CHECK_NOTNULL(compute_graph); - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should not be null."); - return GRAPH_FAILED; - } - - // If the node save as input node, delete it - (void)compute_graph->RemoveInputNode(node); - - // If the node save as output node, delete it - (void)compute_graph->RemoveOutputNode(node); - - // If the node has sub-graphs, delete them - auto ret = RemoveSubgraphRecursively(compute_graph, node); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Remove subgraph recursively failed."); - return GRAPH_FAILED; - } - - auto iter = find(compute_graph->nodes_.begin(), compute_graph->nodes_.end(), node); - if (iter != compute_graph->nodes_.end()) { - compute_graph->nodes_.erase(iter); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -/// Add two edges to the new node, respectively connecting the SRC and DST -/// associated with the original edge -/// A ---> B transfered to A ---> N ---> B -graphStatus InsertTransNode(ComputeGraph &compute_graph, const InDataAnchorPtr &in_data_anchor, - const std::vector &vec_op_desc) { - GE_CHECK_NOTNULL(in_data_anchor); - for (const auto &op_desc : vec_op_desc) { - GE_CHECK_NOTNULL(op_desc); - - auto ret = op_desc->AddInputDesc(GeTensorDesc()); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED, "Add input desc failed"); - ret = op_desc->AddOutputDesc(GeTensorDesc()); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED, "Add input desc failed"); - auto node_to_insert = compute_graph.AddNode(op_desc); - - GE_CHECK_NOTNULL(node_to_insert); - GE_CHECK_NOTNULL(in_data_anchor->GetPeerOutAnchor()); - - auto src = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); - if (!src) { - GELOGE(GRAPH_FAILED, "src nullptr error."); - return GRAPH_FAILED; - } - - auto src_out_index = in_data_anchor->GetPeerOutAnchor()->GetIdx(); - - auto dst = in_data_anchor->GetOwnerNode(); - if (!dst) { - GELOGE(GRAPH_FAILED, "dst nullptr error."); - return GRAPH_FAILED; - } - - auto dst_in_index = in_data_anchor->GetIdx(); - - auto in_data_anchor_src_format = AnchorUtils::GetFormat(in_data_anchor->GetPeerOutAnchor()); - auto in_data_anchor_dst_format = AnchorUtils::GetFormat(in_data_anchor); - - GE_CHECK_NOTNULL(src->GetOutDataAnchor(src_out_index)); - GE_CHECK_NOTNULL(dst->GetInDataAnchor(dst_in_index)); - - ret = GraphUtils::RemoveEdge(src->GetOutDataAnchor(src_out_index), dst->GetInDataAnchor(dst_in_index)); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Remove edge failed"); - return GRAPH_FAILED; - } - - GE_CHECK_NOTNULL(node_to_insert->GetInDataAnchor(0)); - GE_CHECK_NOTNULL(node_to_insert->GetOutDataAnchor(0)); - - ret = GraphUtils::AddEdge(src->GetOutDataAnchor(src_out_index), node_to_insert->GetInDataAnchor(0)); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add edge failed"); - return ret; - } - ret = GraphUtils::AddEdge(node_to_insert->GetOutDataAnchor(0), dst->GetInDataAnchor(dst_in_index)); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add edge failed"); - return ret; - } - - if (op_desc->HasAttr("input_format")) { - int64_t input_format = 0; - int64_t output_format = 0; - if (!AttrUtils::GetInt(op_desc, "input_format", input_format)) { - GELOGW("get attr input_format failed"); - continue; - } - if (!AttrUtils::GetInt(op_desc, "output_format", output_format)) { - GELOGW("get attr output_format failed"); - continue; - } - - GE_CHECK_NOTNULL(node_to_insert->GetInDataAnchor(0)->GetPeerOutAnchor()); - GE_CHK_BOOL_RET_STATUS(node_to_insert->GetOutDataAnchor(0)->GetPeerInDataAnchors().empty(), GRAPH_FAILED, - "Vistor is empty"); - GE_CHECK_NOTNULL(node_to_insert->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)); - - auto status = - AnchorUtils::SetFormat(node_to_insert->GetInDataAnchor(0)->GetPeerOutAnchor(), in_data_anchor_src_format); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set format failed,format is %d", static_cast(in_data_anchor_src_format)); - return status; - } - status = AnchorUtils::SetFormat(node_to_insert->GetInDataAnchor(0), static_cast(input_format)); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set format failed,format is %ld", input_format); - return status; - } - status = AnchorUtils::SetFormat(node_to_insert->GetOutDataAnchor(0), static_cast(output_format)); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set format failed,format is %ld", output_format); - return status; - } - status = AnchorUtils::SetFormat(node_to_insert->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0), - in_data_anchor_dst_format); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set format failed,format is %d", static_cast(in_data_anchor_dst_format)); - return status; - } - } - std::vector original_nodes; - GraphUtils::RecordOriginalNames(original_nodes, node_to_insert); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTransNode( - ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, const std::vector &vec_op_desc) { - GE_CHECK_NOTNULL(compute_graph); - GE_CHECK_NOTNULL(in_data_anchor); - graphStatus ret = - ge::InsertTransNode(*compute_graph, in_data_anchor, vec_op_desc) == GRAPH_SUCCESS ? GRAPH_SUCCESS : GRAPH_FAILED; - return ret; -} - -/// -/// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst -/// @param [in] src -/// @param [in] dsts -/// @param [in] insert_node -/// @param [in] input_index -/// @param [in] output_index -/// @return graphStatus -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector &dsts, - const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { - GE_CHECK_NOTNULL(src); - GE_CHECK_NOTNULL(insert_node); - - NodePtr src_node = src->GetOwnerNode(); - if (src_node->GetOwnerComputeGraph() != insert_node->GetOwnerComputeGraph()) { - GELOGE(GRAPH_FAILED, "src:%s and insert_node:%s not exist in the same graph.", src_node->GetName().c_str(), - insert_node->GetName().c_str()); - return GRAPH_FAILED; - } - - if (AddEdge(src, insert_node->GetInDataAnchor(input_index)) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "AddEdge %s->%s failed.", src_node->GetName().c_str(), insert_node->GetName().c_str()); - return GRAPH_FAILED; - } - - OutControlAnchorPtr src_out_ctrl_anchor = src_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(src_out_ctrl_anchor); - - bool ctrl_edge_flag = true; - std::string type = NodeUtils::GetNodeType(src->GetOwnerNode()); - if ((type == SWITCH) || (type == REFSWITCH) || (type == SWITCHN)) { - ctrl_edge_flag = false; - } - - for (auto &dst : dsts) { - GE_CHECK_NOTNULL(dst); - NodePtr dst_node = dst->GetOwnerNode(); - GELOGI("Insert node %s between %s->%s.", insert_node->GetName().c_str(), src_node->GetName().c_str(), - dst_node->GetName().c_str()); - if (src_node->GetOwnerComputeGraph() != dst_node->GetOwnerComputeGraph()) { - GELOGE(GRAPH_FAILED, "src:%s and dst:%s not exist in the same graph.", src_node->GetName().c_str(), - dst_node->GetName().c_str()); - return GRAPH_FAILED; - } - - (void)RemoveEdge(src, dst); - if (AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), - dst_node->GetName().c_str(), insert_node->GetName().c_str(), dst_node->GetName().c_str()); - return GRAPH_FAILED; - } - - if (!ctrl_edge_flag) { - continue; - } - for (const InControlAnchorPtr &peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { - if ((RemoveEdge(src_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS) || - (AddEdge(insert_node->GetOutControlAnchor(), peer_in_ctrl_anchor) != GRAPH_SUCCESS)) { - GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), - peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str(), insert_node->GetName().c_str(), - peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED; - } - } - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveJustNode(ComputeGraph &compute_graph, - const NodePtr &node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should be not null."); - return GRAPH_FAILED; - } - auto iter = find(compute_graph.nodes_.begin(), compute_graph.nodes_.end(), node); - if (iter != compute_graph.nodes_.end()) { - compute_graph.nodes_.erase(iter); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveJustNode(ComputeGraphPtr compute_graph, - const NodePtr &node) { - GE_CHECK_NOTNULL(compute_graph); - GE_CHECK_NOTNULL(node); - graphStatus ret = (RemoveJustNode(*compute_graph, node) == GRAPH_SUCCESS ? GRAPH_SUCCESS : GRAPH_FAILED); - return ret; -} - -void GraphUtils::RecordOriginalNames(std::vector original_nodes, const ge::NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, return, "node is null."); - std::vector original_names; - for (const auto &node_tmp : original_nodes) { - std::vector names_tmp; - ge::OpDescPtr opdesc_tmp = node_tmp->GetOpDesc(); - if (opdesc_tmp == nullptr) { - GELOGE(GRAPH_FAILED, "Node %s get opdesc is nullptr", node_tmp->GetName().c_str()); - continue; - } - auto ret = ge::AttrUtils::GetListStr(opdesc_tmp, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, names_tmp); - if (!ret) { - GELOGW("Get list str failed"); - continue; - } - if (names_tmp.size() != 0) { - original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); - } else { - original_names.push_back(opdesc_tmp->GetName()); - } - } - GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names), - return, "Set original_op_names fail."); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::RecordOriginalNames(std::vector names_tmp, - const ge::NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, return, "node is null."); - std::vector original_names; - if (names_tmp.size() != 0) { - original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); - } else { - std::string tmp; - original_names.push_back(tmp); - } - GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names), - return, "Set original_op_names fail."); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::MatchDumpStr(const std::string &suffix) { - char *dump_level = std::getenv(kDumpGraphLevel); - int64_t dump_graph_level = - (dump_level != nullptr) ? std::strtol(dump_level, nullptr, kBaseOfIntegerValue) : kDumpLevel2; - - if (dump_graph_level == kDumpLevel1) { - return false; - } - - if (dump_graph_level == kDumpLevel2 && - ((suffix.find(kDumpStrPartition) != std::string::npos) || - (suffix.find(kDumpStrOptimizeSubgraph) != std::string::npos) || - (suffix.find(kDumpStrAicpu) != std::string::npos) || (suffix.find(kDumpStrSubgraphFunc) != std::string::npos))) { - return true; - } - - if (dump_graph_level == kDumpLevel3 && suffix.compare(kDumpStrBuild) != 0) { - return true; - } - - return false; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(const ge::ComputeGraphPtr &graph, - const std::string &suffix, - bool is_always_dump, - const std::string &user_graph_name) { -#ifdef FMK_SUPPORT_DUMP - char *dump_ge_graph = std::getenv(kDumpGeGraph); - GE_IF_BOOL_EXEC(dump_ge_graph == nullptr && !is_always_dump, return;); - - // dump the graph according to different graph level - if (GraphUtils::MatchDumpStr(suffix)) { - return; - } - - // file name - static std::atomic_long atomic_file_index(0); - auto file_index = atomic_file_index.fetch_add(1); - GELOGD("Start to dump om txt: %ld", file_index); - - thread_local long max_dump_file_num = 0; - if (max_dump_file_num == 0) { - string opt = "0"; - (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); - max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); - } - if (max_dump_file_num != 0 && file_index > max_dump_file_num) { - GELOGW("dump graph file cnt > maxDumpFileNum, maxDumpFileCnt=%ld.", max_dump_file_num); - return; - } - - std::stringstream stream_file_name; - stream_file_name << "ge_proto_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; - stream_file_name << "_" << suffix << ".txt"; - std::string proto_file = user_graph_name.empty() ? stream_file_name.str() : user_graph_name; - - // Create buffer - ge::Model model("", ""); - model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast(graph))); - Buffer buffer; - const int64_t kDumpLevel = - (dump_ge_graph != nullptr) ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) : ge::OnnxUtils::NO_DUMP; - model.Save(buffer, kDumpLevel != ge::OnnxUtils::DUMP_ALL); - - // Write file - ge::proto::ModelDef ge_proto; - if (buffer.GetData() != nullptr) { - std::string str(reinterpret_cast(buffer.GetData()), buffer.GetSize()); - if (!ge_proto.ParseFromString(str)) { - GELOGE(GRAPH_FAILED, "parse from string failed."); - return; - } - char real_path[PATH_MAX] = {0x00}; - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(proto_file.c_str()) >= PATH_MAX, return, "file path is too longer!"); - GE_IF_BOOL_EXEC(realpath(proto_file.c_str(), real_path) == nullptr, - GELOGI("file %s does not exist, it will be created.", proto_file.c_str())); - - GraphUtils::WriteProtoToTextFile(ge_proto, real_path); - } -#else - GELOGW("need to define FMK_SUPPORT_DUMP for dump graph."); -#endif -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(const char *file, - ge::ComputeGraph &compute_graph) { - ge::proto::ModelDef model_def; - // Get ModelDef object from file generated by DumpGEGraph() - if (!ReadProtoFromTextFile(file, &model_def)) { - GELOGE(GRAPH_FAILED, "Get ModelDef failed from file"); - return false; - } - ge::Model model; - // Get Model object from ModelDef by deserialize ModelDef - if (model.Load(model_def) == GRAPH_SUCCESS) { - GE_CHK_BOOL_EXEC(GraphUtils::GetComputeGraph(model.GetGraph()) != nullptr, return false, - "Get computer graph is nullptr"); - compute_graph = *(GraphUtils::GetComputeGraph(model.GetGraph())); - return true; - } else { - GELOGE(GRAPH_FAILED, "Get Model failed from ModelDef"); - return false; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(const char *file, - ge::ComputeGraphPtr &compute_graph) { - ge::proto::ModelDef model_def; - // Get ModelDef object from file generated by DumpGEGraph() - if (!ReadProtoFromTextFile(file, &model_def)) { - GELOGE(GRAPH_FAILED, "Get ModelDef failed from file"); - return false; - } - ge::Model model; - // Get Model object from ModelDef by deserialize ModelDef - if (model.Load(model_def) == GRAPH_SUCCESS) { - GE_CHK_BOOL_EXEC(GraphUtils::GetComputeGraph(model.GetGraph()) != nullptr, return false, - "Get computer graph is nullptr"); - compute_graph = GraphUtils::GetComputeGraph(model.GetGraph()); - for (const auto &node : compute_graph->GetDirectNode()) { - GELOGI("Node %s set owner graph", node->GetName().c_str()); - GE_CHECK_NOTNULL(node); - if (node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Node %s set owner graph failed", node->GetName().c_str()); - return false; - } - } - return true; - } else { - GELOGE(GRAPH_FAILED, "Get Model failed from ModelDef"); - return false; - } -} - -// Printing protocol messages in text format is useful for debugging and human editing of messages. -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToTextFile( - const google::protobuf::Message &proto, const char *real_path) { -#ifdef FMK_SUPPORT_DUMP - const int FILE_AUTHORITY = 0600; - int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, FILE_AUTHORITY); - if (fd < 0) { - GELOGE(GRAPH_FAILED, "fail to open the file: %s, %s", real_path, strerror(errno)); - return; - } - google::protobuf::io::FileOutputStream *output = new (std::nothrow) FileOutputStream(fd); - if (output == nullptr) { - GELOGE(GRAPH_FAILED, "Output is nullptr"); - if (close(fd) != 0) { - GELOGE(GRAPH_FAILED, "Close fileoutputstream failed"); - } - return; - } - bool ret = google::protobuf::TextFormat::Print(proto, output); - if (!ret) { - GELOGE(GRAPH_FAILED, "Fail to write the file: %s", real_path); - delete output; - output = nullptr; - GE_CHK_BOOL_EXEC(close(fd) == 0, return, "Close fileoutputstream failed"); - return; - } - delete output; - output = nullptr; - GE_CHK_BOOL_EXEC(close(fd) == 0, return, "Close fileoutputstream failed"); - - FILE *file = fopen(real_path, "rb"); - if (file == nullptr) { - return; - } - if (fseek(file, 0L, SEEK_END) == 0) { - long fileSize = ftell(file); - thread_local long max_dump_file_size = 0; - if (max_dump_file_size == 0) { - string opt = "0"; - // Can not check return value - (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt); - max_dump_file_size = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); - } - if (max_dump_file_size != 0 && fileSize != -1 && fileSize > max_dump_file_size) { - GELOGW("dump graph file size > maxDumpFileSize, maxDumpFileSize=%ld.", max_dump_file_size); - GE_IF_BOOL_EXEC(std::remove(real_path) != 0, GELOGW("remove %s failed", real_path)); - GE_CHK_BOOL_EXEC(fclose(file) == 0, return, "Fclose %s failed", real_path); - return; - } - } - GE_CHK_BOOL_EXEC(fclose(file) == 0, return, "Fclose fileoutputstream failed"); -#else - GELOGW("need to define FMK_SUPPORT_DUMP for dump graph."); -#endif -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::ReadProtoFromTextFile( - const char *file, google::protobuf::Message *proto) { - if (file == nullptr || proto == nullptr) { - GELOGE(GRAPH_FAILED, "incorrect parameter. file path or message is invalid"); - return false; - } - std::ifstream fs(file, std::ifstream::in); - if (!fs.is_open()) { - GELOGE(GRAPH_FAILED, "proto file '%s' open fail.", file); - return false; - } - google::protobuf::io::IstreamInputStream input(&fs); - bool ret = google::protobuf::TextFormat::Parse(&input, proto); - if (!ret) { - GELOGE(GRAPH_FAILED, "parse proto from text ret fail, please check your text file '%s'.", file); - } - fs.close(); - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, - const std::string &suffix) { -#ifdef FMK_SUPPORT_DUMP - char *dump_ge_graph = std::getenv(kDumpGeGraph); - int64_t dump_ge_graph_level = - (dump_ge_graph != nullptr) ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) : OnnxUtils::NO_DUMP; - if ((dump_ge_graph_level == OnnxUtils::NO_DUMP) || (dump_ge_graph_level >= OnnxUtils::DUMP_LEVEL_END)) { - GELOGD("Skip DumpGEGraphToOnnx with dump_ge_graph_level %ld.", dump_ge_graph_level); - return; - } - - // dump the graph according to different graph level - if (GraphUtils::MatchDumpStr(suffix)) { - return; - } - - // 1.Get ge::onnx::ModelProto from ge::Model - ge::Model model("GE", ""); - std::shared_ptr compute_graph_ptr = ComGraphMakeShared(compute_graph); - model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast(compute_graph_ptr))); - onnx::ModelProto model_proto; - if (!OnnxUtils::ConvertGeModelToModelProto(model, model_proto)) { - GELOGE(GRAPH_FAILED, "DumpGEGraphToOnnx failed."); - return; - } - - // 2.Set file name - static std::atomic_long atomic_file_index(0); - auto file_index = atomic_file_index.fetch_add(1); - GELOGD("Start to dump ge onnx file: %ld", file_index); - - thread_local long max_dump_file_num = 0; - if (max_dump_file_num == 0) { - string opt = "0"; - (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); - max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); - } - if (max_dump_file_num != 0 && file_index > max_dump_file_num) { - GELOGW("dump graph file cnt > maxDumpFileNum, maxDumpFileNum=%ld.", max_dump_file_num); - return; - } - - std::stringstream stream_file_name; - stream_file_name << "ge_onnx_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; - stream_file_name << "_graph_" << compute_graph.GetGraphID(); - stream_file_name << "_" << suffix << ".pbtxt"; - std::string proto_file = stream_file_name.str(); - if ((proto_file.length()) >= NAME_MAX) { - GELOGE(GRAPH_FAILED, "File name is too longer!"); - return; - } - std::unique_ptr real_path(new (std::nothrow) char[PATH_MAX]{0}); - if (real_path == nullptr) { - GELOGE(GRAPH_FAILED, "New real_path failed."); - return; - } - /// Returning nullptr means 3 case as follows: - /// a.path is PATH_MAX chars or more - /// b.the file does not exist - /// c.the path has no permissions - /// Distinguish between last the two cases in the function WriteProtoToTextFile call open() - if (realpath(proto_file.c_str(), real_path.get()) == nullptr) { - // For case a - if (errno == ENAMETOOLONG) { - GELOGE(GRAPH_FAILED, "Call realpath failed: path is PATH_MAX chars or more."); - return; - } - } - - // 3. Serialize to file in current path - GraphUtils::WriteProtoToTextFile(model_proto, real_path.get()); -#else - GELOGW("need to define FMK_SUPPORT_DUMP for dump graph."); -#endif -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraphFromOnnx(const char *file, - ge::ComputeGraph &compute_graph) { - if (file == nullptr) { - GELOGE(GRAPH_FAILED, "incorrect parameter. file path is invalid"); - return false; - } - onnx::ModelProto model_proto; - // 1. Get ModelDef object from file generated by DumpGEGraphToOnnx() - if (!ReadProtoFromTextFile(file, &model_proto)) { - GELOGE(GRAPH_FAILED, "Get ModelDef from file failed"); - return false; - } - // 2.Convert onnx::ModelProto To ge::Model - ge::Model model; - if (!OnnxUtils::ConvertModelProtoToGeModel(model_proto, model)) { - GELOGE(GRAPH_FAILED, "Convert ModelDef to Model failed"); - return false; - } - auto compute_graph_ptr = GraphUtils::GetComputeGraph(model.GetGraph()); - if (compute_graph_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "Get compute graph from Model failed"); - return false; - } - compute_graph = *(compute_graph_ptr); - return true; -} - -namespace { -using InNodesToOut = std::unordered_map>; - -inline std::string GetNodeNameByAnchor(const Anchor *anchor) { - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Anchor is nullptr"); - return "Null"; - } - auto node = anchor->GetOwnerNode(); - return node == nullptr ? "Null" : node->GetName(); -} - -graphStatus ReplaceOutDataAnchor(const OutDataAnchorPtr &new_anchor, const OutDataAnchorPtr &old_anchor, - InNodesToOut *in_nodes_to_out = nullptr) { - if (new_anchor == nullptr || old_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "new_anchor or old_anchor is nullptr"); - return GRAPH_PARAM_INVALID; - } - auto new_node = new_anchor->GetOwnerNode(); - for (const auto &peer_in_anchor : old_anchor->GetPeerInDataAnchors()) { - auto ret = peer_in_anchor->Unlink(old_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to unlink old anchor link from %s(%d) to %s(%d)", - GetNodeNameByAnchor(old_anchor.get()).c_str(), old_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx()); - return GRAPH_FAILED; - } - ret = peer_in_anchor->LinkFrom(new_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to relink new anchors from %s(%d) to %s(%d)", - GetNodeNameByAnchor(new_anchor.get()).c_str(), new_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx()); - return GRAPH_FAILED; - } - - if (in_nodes_to_out != nullptr) { - (*in_nodes_to_out)[new_node].insert(peer_in_anchor->GetOwnerNode()); - } - } - return GRAPH_SUCCESS; -} - -graphStatus RelinkDataIO(const NodePtr &node, const std::vector &io_map, InNodesToOut &in_nodes_to_out) { - GE_CHECK_NOTNULL(node); - auto in_data_anchors = node->GetAllInDataAnchors(); - auto out_data_anchors = node->GetAllOutDataAnchors(); - if (out_data_anchors.size() < io_map.size()) { - GELOGE(GRAPH_FAILED, "The io_map specified for node %s type %s is larger %zu than the actual size %zu", - node->GetName().c_str(), node->GetType().c_str(), io_map.size(), out_data_anchors.size()); - return GRAPH_PARAM_INVALID; - } - - for (size_t i = 0; i < out_data_anchors.size(); ++i) { - auto out_data_anchor = out_data_anchors.at(i); - if (out_data_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Failed to relink for node %s type %s, the out data anchor at index %zu is null", - node->GetName().c_str(), node->GetType().c_str(), i); - return GRAPH_FAILED; - } - - int in_index = -1; - if (i < io_map.size()) { - in_index = io_map.at(i); - } - if (in_index < 0) { - out_data_anchor->UnlinkAll(); - continue; - } - - if (in_index >= static_cast(in_data_anchors.size())) { - GELOGE(GRAPH_PARAM_INVALID, "Failed to relink for node %s type %s, invalid index %d specified for input(%zu)", - node->GetName().c_str(), node->GetType().c_str(), in_index, in_data_anchors.size()); - return GRAPH_PARAM_INVALID; - } - auto in_anchor = in_data_anchors.at(in_index); - if (in_anchor == nullptr) { - GELOGW("Invalid in data anchors(null) found at node %s type %s index %d, ignore it.", node->GetName().c_str(), - node->GetType().c_str(), in_index); - continue; - } - auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - continue; - } - if (peer_out_anchor->Unlink(in_anchor) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, - "Failed relink node %s type %s, failed to unlink the data link" - " from %s(%d) to it at input-index %d", - node->GetName().c_str(), node->GetType().c_str(), GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), - peer_out_anchor->GetIdx(), in_index); - return GRAPH_FAILED; - } - auto ret = ReplaceOutDataAnchor(peer_out_anchor, out_data_anchor, &in_nodes_to_out); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to relink node %s type %s for relinking data anchors", node->GetName().c_str(), - node->GetType().c_str()); - return GRAPH_FAILED; - } - } - - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - in_anchor->UnlinkAll(); - } - return GRAPH_SUCCESS; -} - -InNodesToOut GetFullConnectIONodes(const NodePtr &node) { - InNodesToOut in_nodes_to_out; - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Node is nullptr"); - return in_nodes_to_out; - } - auto in_nodes_list = node->GetInNodes(); - auto out_nodes_list = node->GetOutNodes(); - auto out_nodes = std::unordered_set(out_nodes_list.begin(), out_nodes_list.end()); - - for (const auto &in_node : in_nodes_list) { - in_nodes_to_out.insert(std::make_pair(in_node, out_nodes)); - } - return in_nodes_to_out; -} - -graphStatus RelinkControlNodeIfNeed(const NodePtr &node, InNodesToOut &in_nodes_to_out, - InNodesToOut &connected_data_in_to_out) { - GE_CHECK_NOTNULL(node); - for (const auto &in_node_to_out : in_nodes_to_out) { - auto &in_node = in_node_to_out.first; - GE_CHECK_NOTNULL(in_node); - auto &connected_data_out = connected_data_in_to_out[in_node]; - for (const auto &out_node : in_node_to_out.second) { - GE_CHECK_NOTNULL(out_node); - if (connected_data_out.count(out_node) == 0) { - GE_CHECK_NOTNULL(in_node->GetOutControlAnchor()); - if (in_node->GetOutControlAnchor()->IsLinkedWith(out_node->GetInControlAnchor())) { - continue; - } - auto ret = GraphUtils::AddEdge(in_node->GetOutControlAnchor(), out_node->GetInControlAnchor()); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to add control edge from %s to %s when isolating node %s type %s", - in_node->GetName().c_str(), out_node->GetName().c_str(), node->GetName().c_str(), - node->GetType().c_str()); - return GRAPH_FAILED; - } - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus ReplaceOutDataAnchors(const Node::Vistor &new_outs, - const Node::Vistor &old_outs, const std::vector &outputs_map) { - auto new_out_size = new_outs.size(); - if (new_out_size < outputs_map.size()) { - GELOGE(GRAPH_PARAM_INVALID, - "Failed to replace out data anchors, the actual size %zu is less than the mapping size %zu", new_out_size, - outputs_map.size()); - return GRAPH_PARAM_INVALID; - } - for (size_t i = 0; i < new_out_size; ++i) { - auto &new_out_anchor = new_outs.at(i); - if (new_out_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Failed to replace out data anchors, the out data anchor on new node is null, index %zu", i); - return GRAPH_FAILED; - } - if (i >= outputs_map.size()) { - continue; - } - auto old_index = outputs_map.at(i); - if (old_index < 0) { - continue; - } - - const OutDataAnchorPtr &old_out_anchor = old_outs.at(old_index); - if (old_out_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Failed to replace out data anchors, the out data anchor on old node is null, index %d", - old_index); - return GRAPH_FAILED; - } - auto ret = ReplaceOutDataAnchor(new_out_anchor, old_out_anchor); - if (ret != GRAPH_SUCCESS) { - return ret; - } - } - - return GRAPH_SUCCESS; -} - -graphStatus ReplaceInDataAnchors(const Node::Vistor &new_ins, - const Node::Vistor &old_ins, const std::vector &inputs_map) { - auto new_in_size = new_ins.size(); - if (new_in_size < inputs_map.size()) { - GELOGE(GRAPH_FAILED, "Failed to replace in data anchors, the actual size %zu is less than the mapping size %zu", - new_in_size, inputs_map.size()); - return GRAPH_PARAM_INVALID; - } - - for (size_t i = 0; i < new_in_size; ++i) { - auto &new_in_anchor = new_ins.at(i); - if (new_in_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Failed to replace in data anchors, the out data anchor on new node is null, index %zu", i); - return GRAPH_FAILED; - } - if (i >= inputs_map.size()) { - continue; - } - auto old_index = inputs_map.at(i); - if (old_index < 0) { - continue; - } - const InDataAnchorPtr &old_in_anchor = old_ins.at(old_index); - if (old_in_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Failed to replace in data anchors, the out data anchor on old node is null, index %d", - old_index); - return GRAPH_FAILED; - } - - auto peer_out_anchor = old_in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - GELOGW("Peer out anchor is nullptr"); - continue; - } - auto ret = peer_out_anchor->Unlink(old_in_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to unlink old anchors, unlink from %s(%d) to %s(%d)", - GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), peer_out_anchor->GetIdx(), - GetNodeNameByAnchor(old_in_anchor.get()).c_str(), old_in_anchor->GetIdx()); - return GRAPH_FAILED; - } - ret = peer_out_anchor->LinkTo(new_in_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to link new anchors, link from %s(%d) to %s(%d)", - GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), peer_out_anchor->GetIdx(), - GetNodeNameByAnchor(old_in_anchor.get()).c_str(), old_in_anchor->GetIdx()); - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} - -graphStatus ReplaceControlAnchors(const NodePtr &new_node, const NodePtr &old_node) { - GE_CHECK_NOTNULL(new_node); - GE_CHECK_NOTNULL(new_node->GetInControlAnchor()); - GE_CHECK_NOTNULL(old_node); - GE_CHECK_NOTNULL(old_node->GetInControlAnchor()); - auto peer_out_anchors = old_node->GetInControlAnchor()->GetPeerAnchors(); - auto new_in_control_anchor = new_node->GetInControlAnchor(); - auto exists_out_anchors = new_in_control_anchor->GetPeerAnchors(); - auto exists_out_anchors_set = std::set(exists_out_anchors.begin(), exists_out_anchors.end()); - for (const auto &peer_out_anchor : peer_out_anchors) { - if (peer_out_anchor != nullptr) { - if (exists_out_anchors_set.count(peer_out_anchor) > 0) { - continue; - } - auto ret = GraphUtils::AddEdge(peer_out_anchor, new_in_control_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add edge failed"); - return GRAPH_FAILED; - } - } else { - GELOGW("peer outanchor is nullptr"); - continue; - } - } - auto old_out_control_anchor = old_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(old_out_control_anchor); - auto peer_in_anchors = old_out_control_anchor->GetPeerAnchors(); - auto new_out_control_anchor = new_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(new_out_control_anchor); - auto exists_in_anchors = new_out_control_anchor->GetPeerAnchors(); - auto exists_in_anchors_set = std::set(exists_in_anchors.begin(), exists_in_anchors.end()); - for (const auto &peer_in_anchor : peer_in_anchors) { - if (peer_in_anchor != nullptr) { - if (exists_in_anchors_set.count(peer_in_anchor) > 0) { - continue; - } - auto ret = GraphUtils::AddEdge(new_out_control_anchor, peer_in_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add edge failed"); - return GRAPH_FAILED; - } - } else { - GELOGW("Peer inanchor is nullptr"); - continue; - } - } - - return GRAPH_SUCCESS; -} -} // namespace - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::IsolateNode(const NodePtr &node, - const std::vector &io_map) { - if (node == nullptr) { - GELOGE(GRAPH_PARAM_INVALID, "Failed to isolate node(null)"); - return GRAPH_PARAM_INVALID; - } - - /// We must get full connections info before re-link data io, because the data - /// edges may be unlinked when relink data io - auto in_nodes_to_out = GetFullConnectIONodes(node); - - InNodesToOut data_in_to_out; - auto ret = RelinkDataIO(node, io_map, data_in_to_out); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to isolate node %s type %s when relink data IO", node->GetName().c_str(), - node->GetType().c_str()); - return ret; - } - - ret = RelinkControlNodeIfNeed(node, in_nodes_to_out, data_in_to_out); - if (ret != GRAPH_SUCCESS) { - return ret; - } - NodeUtils::UnlinkAll(*node); - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::IsolateNode(const NodePtr &node, const std::initializer_list &io_map) { - return IsolateNode(node, std::vector(io_map)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::IsolateNodeOneIO(const NodePtr &node) { - if (node == nullptr) { - GELOGE(GRAPH_PARAM_INVALID, "incorrect parameter. node is invalid"); - return GRAPH_PARAM_INVALID; - } - if (node->GetAllInDataAnchorsSize() != 1) { - return GRAPH_PARAM_INVALID; - } - if (node->GetAllOutDataAnchorsSize() != 1) { - return GRAPH_PARAM_INVALID; - } - return IsolateNode(node, {0}); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, const std::vector &inputs_map, - const std::vector &outputs_map) { - if ((new_node == nullptr) || (old_node == nullptr)) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - auto ret = ReplaceNodeDataAnchors(new_node, old_node, inputs_map, outputs_map); - if (ret != GRAPH_SUCCESS) { - // The error log was printed in `ReplaceNodeDataAnchors` - return GRAPH_FAILED; - } - ret = ReplaceControlAnchors(new_node, old_node); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, - "Failed to replace control anchors when replace node from old node %s type %s to new node %s type %s", - old_node->GetName().c_str(), old_node->GetType().c_str(), new_node->GetName().c_str(), - new_node->GetType().c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::ReplaceNodeAnchors( - const NodePtr &new_node, const NodePtr &old_node, const std::initializer_list inputs_map, - const std::initializer_list outputs_map) { - return ReplaceNodeAnchors(new_node, old_node, std::vector(inputs_map), std::vector(outputs_map)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, - std::initializer_list inputs_map, std::initializer_list outputs_map) { - return ReplaceNodeDataAnchors(new_node, old_node, std::vector(inputs_map), std::vector(outputs_map)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, const std::vector &inputs_map, - const std::vector &outputs_map) { - if (new_node == nullptr || old_node == nullptr) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - - auto ret = ReplaceOutDataAnchors(new_node->GetAllOutDataAnchors(), old_node->GetAllOutDataAnchors(), outputs_map); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, - "Failed to replace out data anchors when replace node from old node %s type %s to new node %s type %s", - old_node->GetName().c_str(), old_node->GetType().c_str(), new_node->GetName().c_str(), - new_node->GetType().c_str()); - return GRAPH_FAILED; - } - ret = ReplaceInDataAnchors(new_node->GetAllInDataAnchors(), old_node->GetAllInDataAnchors(), inputs_map); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, - "Failed to replace in data anchors when replace node from old node %s type %s to new node %s type %s", - old_node->GetName().c_str(), old_node->GetType().c_str(), new_node->GetName().c_str(), - new_node->GetType().c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInCtrlEdges(const NodePtr &src_node, - NodePtr &dst_node) { - if ((src_node == nullptr) || (dst_node == nullptr)) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - auto src_ctrl_in_nodes = src_node->GetInControlNodes(); - if (src_ctrl_in_nodes.empty()) { - return GRAPH_SUCCESS; - } - - std::unordered_set exist_in_ctrl_nodes_set; - auto exist_in_ctrl_nodes = dst_node->GetInControlNodes(); - if (!exist_in_ctrl_nodes.empty()) { - exist_in_ctrl_nodes_set.insert(exist_in_ctrl_nodes.begin(), exist_in_ctrl_nodes.end()); - } - - auto dst_ctrl = dst_node->GetInControlAnchor(); - for (const auto &in_node : src_ctrl_in_nodes) { - if (exist_in_ctrl_nodes_set.count(in_node) > 0) { - continue; - } - auto ret = GraphUtils::AddEdge(in_node->GetOutControlAnchor(), dst_ctrl); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to add control edge from %s to %s when copy control dependencies from %s to %s", - in_node->GetName().c_str(), dst_node->GetName().c_str(), src_node->GetName().c_str(), - dst_node->GetName().c_str()); - return ret; - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveInCtrlEdges(const NodePtr &src_node, - NodePtr &dst_node) { - if (src_node == nullptr || dst_node == nullptr) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_FAILED; - } - auto ret = CopyInCtrlEdges(src_node, dst_node); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Copy in ctrl edges failed"); - return ret; - } - GE_CHECK_NOTNULL(src_node->GetInControlAnchor()); - src_node->GetInControlAnchor()->UnlinkAll(); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyOutCtrlEdges(const NodePtr &src_node, - NodePtr &dst_node) { - if (src_node == nullptr || dst_node == nullptr) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_FAILED; - } - auto out_ctrl_nodes = src_node->GetOutControlNodes(); - if (out_ctrl_nodes.empty()) { - return GRAPH_SUCCESS; - } - - std::unordered_set exists_out_ctrl_nodes_set; - for (const auto &node : dst_node->GetOutControlNodes()) { - exists_out_ctrl_nodes_set.insert(node.get()); - } - - auto dst_out_ctrl = dst_node->GetOutControlAnchor(); - for (const auto &node : out_ctrl_nodes) { - if (exists_out_ctrl_nodes_set.count(node.get()) > 0) { - continue; - } - auto ret = GraphUtils::AddEdge(dst_out_ctrl, node->GetInControlAnchor()); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to add control edge from %s to %s when copy control dependencies from %s to %s", - dst_node->GetName().c_str(), node->GetName().c_str(), src_node->GetName().c_str(), - dst_node->GetName().c_str()); - return ret; - } - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveOutCtrlEdges(NodePtr &src_node, - NodePtr &dst_node) { - if (src_node == nullptr || dst_node == nullptr) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_FAILED; - } - auto ret = CopyOutCtrlEdges(src_node, dst_node); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Copyout ctrl edges failed"); - return ret; - } - GE_CHECK_NOTNULL(src_node->GetOutControlAnchor()); - src_node->GetOutControlAnchor()->UnlinkAll(); - return GRAPH_SUCCESS; -} - -/// -/// Copy all in-data edges from `src_node` to `dst_node`. -/// @param src_node -/// @param dst_node -/// @return -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInDataEdges(const NodePtr &src_node, - NodePtr &dst_node) { - if ((src_node == nullptr) || (dst_node == nullptr)) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - auto src_data_in_nodes = src_node->GetInDataNodes(); - if (src_data_in_nodes.empty()) { - return GRAPH_SUCCESS; - } - for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) { - auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); - auto ret = - GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx())); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s", - in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(), - src_node->GetName().c_str(), dst_node->GetName().c_str()); - return ret; - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AppendInputNode(const ComputeGraphPtr &graph, - const NodePtr &node) { - if (graph->AddInputNode(node) == nullptr) { - GELOGE(GRAPH_FAILED, "Copyout ctrl edges failed"); - return GRAPH_FAILED; - } - graph->SetInputSize(graph->GetInputSize() + 1); - graph->inputs_order_.emplace_back(node->GetName()); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::FindRootGraph(ComputeGraphPtr graph) { - ComputeGraphPtr result = nullptr; - while (graph != nullptr) { - result = std::move(graph); - graph = result->GetParentGraph(); - } - return result; -} - -/// -/// Make a copy of ComputeGraph. -/// @param graph: original graph. -/// @param prefix: node name prefix of new graph. -/// @param output_nodes: output nodes of new graph. -/// @return ComputeGraphPtr -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr -GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std::string &prefix, std::vector &input_nodes, - std::vector &output_nodes) { - GE_CHK_BOOL_EXEC(graph != nullptr, return nullptr, "Original graph is null"); - ComputeGraphPtr new_graph = ComGraphMakeShared(graph->GetName()); - GE_CHK_BOOL_EXEC(new_graph != nullptr, return nullptr, "Create new graph failed"); - - std::unordered_map all_new_nodes; - for (const auto &n : graph->GetDirectNode()) { - OpDescPtr op_desc = AttrUtils::CopyOpDesc(n->GetOpDesc()); - GE_CHK_BOOL_EXEC(op_desc != nullptr, return nullptr, "Create new node failed"); - - if (CopyTensorAttrs(op_desc, n) != GRAPH_SUCCESS) { - return nullptr; - } - - op_desc->SetName(prefix + n->GetName()); - NodePtr node = new_graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(node != nullptr, return nullptr, "Add node[%s] to graph failed", op_desc->GetName().c_str()); - all_new_nodes[node->GetName()] = node; - - if (node->GetType() == DATA) { - input_nodes.emplace_back(node); - } else if (node->GetType() == NETOUTPUT) { - output_nodes.emplace_back(node); - } - } - - for (const auto &n : graph->GetDirectNode()) { - if (RelinkGraphEdges(n, prefix, all_new_nodes) != GRAPH_SUCCESS) { - return nullptr; - } - } - - std::string session_graph_id; - if (AttrUtils::GetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) { - bool ret = AttrUtils::SetStr(*new_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); - if (!ret) { - GELOGE(GRAPH_FAILED, "Set attr ATTR_NAME_SESSION_GRAPH_ID failed."); - return nullptr; - } - } - return new_graph; -} - -/// -/// Copy tensor attribute to new node. -/// @param [in] dst_node: cloned node. -/// @param [in] src_node: original node. -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node) { - if (dst_desc == nullptr) { - GELOGE(GRAPH_FAILED, "Input param dst node not valid"); - return GRAPH_FAILED; - } - if (src_node == nullptr || src_node->GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "Input param src node not valid"); - return GRAPH_FAILED; - } - - const auto &src_desc = src_node->GetOpDesc(); - dst_desc->CopyAttrsFrom(*src_desc); - - for (uint32_t i = 0; i < src_node->GetAllInDataAnchorsSize(); ++i) { - auto input_desc = dst_desc->MutableInputDesc(i); - if (input_desc == nullptr) { - continue; - } - input_desc->CopyAttrsFrom(src_desc->GetInputDesc(i)); - } - - for (uint32_t i = 0; i < src_node->GetAllOutDataAnchorsSize(); ++i) { - auto output_desc = dst_desc->MutableOutputDesc(i); - if (output_desc == nullptr) { - GELOGE(GRAPH_FAILED, "Param dst node not valid"); - return GRAPH_FAILED; - } - output_desc->CopyAttrsFrom(src_desc->GetOutputDesc(i)); - } - - return GRAPH_SUCCESS; -} - -/// -/// Relink all edges for cloned ComputeGraph. -/// @param [in] node: original node. -/// @param [in] prefix: node name prefix of new node. -/// @param [in] all_nodes: all nodes in new graph. -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &prefix, - const std::unordered_map &all_nodes) { - if (node == nullptr || node->GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "Input node not valid"); - return GRAPH_FAILED; - } - - auto it = all_nodes.find(prefix + node->GetName()); - if (it == all_nodes.end()) { - GELOGE(GRAPH_FAILED, "node[%s] not found", node->GetName().c_str()); - return GRAPH_FAILED; - } - const auto &new_node = it->second; - - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - GE_CHK_BOOL_EXEC(in_anchor != nullptr, return GRAPH_FAILED, "In data anchor is null"); - const auto &out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - GELOGW("Peer out anchor is null: %s", node->GetName().c_str()); - continue; - } - GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); - - it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName()); - if (it == all_nodes.end()) { - GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED; - } - const auto &new_out_node = it->second; - - auto rslt = - GraphUtils::AddEdge(new_out_node->GetOutAnchor(out_anchor->GetIdx()), new_node->GetInAnchor(in_anchor->GetIdx())); - GE_CHK_BOOL_EXEC(rslt == GRAPH_SUCCESS, return GRAPH_FAILED, "link failed[%s to %s]", - new_out_node->GetName().c_str(), new_node->GetName().c_str()); - } - - if (node->GetInControlAnchor() != nullptr) { - for (const auto &out_anchor : node->GetInControlAnchor()->GetPeerAnchors()) { - GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "Peer out anchor is null: %s", node->GetName().c_str()); - GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); - - it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName()); - if (it == all_nodes.end()) { - GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED; - } - const auto &new_out_node = it->second; - - auto rslt = GraphUtils::AddEdge(new_out_node->GetOutAnchor(out_anchor->GetIdx()), new_node->GetInControlAnchor()); - GE_CHK_BOOL_EXEC(rslt == GRAPH_SUCCESS, return GRAPH_FAILED, "link failed[%s to %s]", - new_out_node->GetName().c_str(), new_node->GetName().c_str()); - } - } - - return GRAPH_SUCCESS; -} - -/// -/// Get reference-mapping of all data_anchors in graph -/// @param [in] graph -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol) { - GE_CHECK_NOTNULL(graph); - for (const auto &node : graph->GetAllNodes()) { - // in_data_anchor - if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { - GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); - return GRAPH_FAILED; - } - - // out_data_anchor - if (HandleOutAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { - GE_LOGE("Find ref_mapping for out_data_anchors of node %s failed.", node->GetName().c_str()); - return GRAPH_FAILED; - } - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr GraphUtils::FindNodeFromAllNodes(ComputeGraphPtr &graph, - const std::string &name) { - auto root_graph = FindRootGraph(graph); - if (root_graph == nullptr) { - GE_LOGE("Failed find node %s, null root graph", name.c_str()); - return nullptr; - } - - for (const auto &node : root_graph->GetAllNodes()) { - if (node == nullptr) { - continue; - } - if (node->GetName() == name) { - return node; - } - } - - return nullptr; -} - -/// -/// Get reference-mapping for in_data_anchors of node -/// @param [in] node -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - - if (NodeUtils::IsSubgraphOutput(node)) { - return HandleSubgraphOutput(node, symbol_to_anchors, anchor_to_symbol); - } - - if (NodeUtils::IsSubgraphInput(node)) { - return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); - } - - const std::string &type = node->GetType(); - if ((type == MERGE) || (type == STREAMMERGE)) { - return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); - } - - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - const std::string &symbol = cur_node_info.ToString(); - GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); - symbol_to_anchors[symbol] = {cur_node_info}; - anchor_to_symbol[symbol] = symbol; - } else { - NodeIndexIO exist_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); - if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { - GE_LOGE("Update symbol mapping failed."); - return GRAPH_FAILED; - } - } - } - - return GRAPH_SUCCESS; -} - -/// -/// Get reference-mapping for out_data_anchors of node -/// @param [in] node -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - NodeIndexIO cur_node_info(node, out_data_anchor->GetIdx(), kOut); - if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { - continue; - } - - int32_t reuse_in_index = -1; - if (IsRefFromInput(out_data_anchor, reuse_in_index)) { - NodeIndexIO exist_node_info(node, reuse_in_index, kIn); - if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { - GE_LOGE("Update symbol mapping failed."); - return GRAPH_FAILED; - } - } else { - const std::string &symbol = cur_node_info.ToString(); - GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); - symbol_to_anchors.emplace(std::make_pair(symbol, std::list{cur_node_info})); - anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); - } - } - - return GRAPH_SUCCESS; -} - -/// -/// Handle input of subgraph -/// @param [in] node -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::HandleSubgraphInput(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - - // Data in subgraph - uint32_t index = 0; - if (!ge::AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index)) { - GE_LOGE("Get attr ATTR_NAME_PARENT_NODE_INDEX failed, node:%s.", node->GetName().c_str()); - return GRAPH_FAILED; - } - NodePtr parent_node = node->GetOwnerComputeGraph()->GetParentNode(); - GE_CHECK_NOTNULL(parent_node); - InDataAnchorPtr parent_in_anchor = parent_node->GetInDataAnchor(index); - GE_CHECK_NOTNULL(parent_in_anchor); - OutDataAnchorPtr peer_out_anchor = parent_in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor != nullptr) { - // Data has and only has one input - NodeIndexIO cur_node_info(node, 0, kIn); - NodeIndexIO exist_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); - if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { - GE_LOGE("Update symbol mapping failed."); - return GRAPH_FAILED; - } - } - - return GRAPH_SUCCESS; -} - -/// -/// Handle input of Merge op -/// @param [in] node -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - std::vector exist_node_infos; - std::vector cur_node_infos; - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - std::string next_name; - if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next_name) && !next_name.empty()) { - ComputeGraphPtr graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - ge::NodePtr next_node = graph->FindNode(next_name); - GE_CHECK_NOTNULL(next_node); - // NextIteration has and only has one output - peer_out_anchor = next_node->GetOutDataAnchor(0); - GE_CHECK_NOTNULL(peer_out_anchor); - cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); - cur_node_infos.emplace_back(NodeIndexIO(next_node, peer_out_anchor->GetIdx(), kOut)); - } - } else { - cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); - exist_node_infos.emplace_back(NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut)); - } - } - - size_t anchor_nums = 0; - NodeIndexIO max_node_index_io(nullptr, 0, kOut); - for (const auto &temp_node_info : exist_node_infos) { - auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); - if (iter1 != anchor_to_symbol.end()) { - const std::string &temp_symbol = iter1->second; - auto iter2 = symbol_to_anchors.find(temp_symbol); - if (iter2 != symbol_to_anchors.end()) { - if (iter2->second.size() > anchor_nums) { - max_node_index_io = temp_node_info; - anchor_nums = iter2->second.size(); - } - } - } - } - - std::string symbol; - for (const auto &temp_node_info : exist_node_infos) { - if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != - GRAPH_SUCCESS) || - symbol.empty()) { - GE_LOGE("Union symbol map anchor1:%s & anchor2:%s.", max_node_index_io.ToString().c_str(), - temp_node_info.ToString().c_str()); - return GRAPH_FAILED; - } - } - - auto iter = symbol_to_anchors.find(symbol); - if (iter != symbol_to_anchors.end()) { - for (const auto &temp_node_info : cur_node_infos) { - GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); - iter->second.emplace_back(temp_node_info); - anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); - } - } - - return GRAPH_SUCCESS; -} - -/// -/// Handle output of subgraph -/// @param [in] node -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - ComputeGraphPtr owner_graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(owner_graph); - NodePtr parent_node = owner_graph->GetParentNode(); - GE_CHECK_NOTNULL(parent_node); - - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_anchor); - - GeTensorDesc in_tensor = op_desc->GetInputDesc(in_data_anchor->GetIdx()); - uint32_t index = 0; - if (!ge::AttrUtils::GetInt(in_tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { - continue; - } - GE_CHECK_NOTNULL(parent_node->GetOutDataAnchor(index)); - // Union symbol of peer_out_anchor & parent_out_anchor - NodeIndexIO peer_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); - NodeIndexIO parent_node_info(parent_node, index, kOut); - std::string symbol; - if ((UnionSymbolMapping(peer_node_info, parent_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != - GRAPH_SUCCESS) || - symbol.empty()) { - GE_LOGE("Union symbol map anchor1:%s, anchor2:%s.", peer_node_info.ToString().c_str(), - parent_node_info.ToString().c_str()); - return GRAPH_FAILED; - } - - NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); - GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); - symbol_to_anchors[symbol].emplace_back(cur_node_info); - anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); - } - - return GRAPH_SUCCESS; -} - -/// -/// Union ref-mapping -/// @param [in] exist_node_info1 -/// @param [in] exist_node_info2 -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @param [out] symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol, std::string &symbol) { - const std::string &symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; - const std::string &symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; - if (symbol1 == symbol2) { - symbol = symbol1; - GELOGI("no need to union."); - return GRAPH_SUCCESS; - } - - auto iter1 = symbol_to_anchors.find(symbol1); - auto iter2 = symbol_to_anchors.find(symbol2); - if ((iter1 == symbol_to_anchors.end()) || (iter2 == symbol_to_anchors.end())) { - GE_LOGE("symbol %s or %s not exist.", symbol1.c_str(), symbol2.c_str()); - return GRAPH_FAILED; - } - - auto &max_iter = (iter1->second.size() > iter2->second.size() ? iter1 : iter2); - auto &min_iter = (iter1->second.size() > iter2->second.size() ? iter2 : iter1); - symbol = (iter1->second.size() > iter2->second.size() ? symbol1 : symbol2); - std::string min_symbol = (iter1->second.size() > iter2->second.size() ? symbol2 : symbol1); - for (auto &node_index_io : min_iter->second) { - GELOGD("Update anchor %s, symbol %s.", node_index_io.ToString().c_str(), symbol.c_str()); - max_iter->second.emplace_back(node_index_io); - auto iter = anchor_to_symbol.find(node_index_io.ToString()); - if (iter == anchor_to_symbol.end()) { - GE_LOGE("anchor %s not exist.", node_index_io.ToString().c_str()); - return GRAPH_FAILED; - } - if (iter->second != min_symbol) { - GELOGW("not expected symbol of anchor %s, expect %s but %s exactly.", iter->first.c_str(), min_symbol.c_str(), - iter->second.c_str()); - } - iter->second = symbol; - } - - GELOGI("Union symbol %s and %s succ.", symbol.c_str(), min_symbol.c_str()); - symbol_to_anchors.erase(min_iter); - return GRAPH_SUCCESS; -} - -/// -/// Update symbol mapping with a new reference pair -/// @param [in] cur_node_info -/// @param [in] exist_node_info -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol) { - auto iter1 = anchor_to_symbol.find(exist_node_info.ToString()); - if (iter1 == anchor_to_symbol.end()) { - GE_LOGE("data_anchor %s is not visible before data_anchor %s, maybe TopoSorting is missing.", - exist_node_info.ToString().c_str(), cur_node_info.ToString().c_str()); - return GRAPH_FAILED; - } - - const std::string &symbol = iter1->second; - auto iter2 = symbol_to_anchors.find(symbol); - if (iter2 == symbol_to_anchors.end()) { - GE_LOGE("symbol %s not found.", symbol.c_str()); - return GRAPH_FAILED; - } - GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); - iter2->second.emplace_back(cur_node_info); - anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); - - return GRAPH_SUCCESS; -} - -/// -/// Check if out_data_anchor is reference of input -/// @param [in] out_data_anchor -/// @param [out] reuse_in_index -/// @return bool -/// -bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index) { - if (out_data_anchor == nullptr) { - GELOGW("out_data_anchor is NULL."); - return false; - } - int32_t output_index = out_data_anchor->GetIdx(); - - // pass-through op - NodePtr node = out_data_anchor->GetOwnerNode(); - const std::string &type = node->GetType(); - const std::set pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE}; - if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) { - reuse_in_index = output_index; - GELOGI("Pass-Through node name[%s] index[%u].", node->GetName().c_str(), reuse_in_index); - return true; - } - - // Merge op 0th output - if ((type == MERGE) && (output_index == 0)) { - reuse_in_index = 0; - GELOGI("Merge name[%s] output_index[0].", node->GetName().c_str()); - return true; - } - - // ref op - OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - GELOGW("op_desc is NULL."); - return false; - } - bool is_ref = false; - (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_REFERENCE, is_ref); - if (is_ref) { - const string &output_name = op_desc->GetOutputNameByIndex(output_index); - for (const auto &input_name : op_desc->GetAllInputNames()) { - if (!input_name.empty() && (output_name == input_name)) { - reuse_in_index = op_desc->GetInputIndexByName(input_name); - GELOGI("Reference name[%s] output[%s][%d] ref to input[%s][%d].", op_desc->GetName().c_str(), - output_name.c_str(), output_index, input_name.c_str(), reuse_in_index); - return true; - } - } - } - - // reuse input - auto output_op_desc = op_desc->GetOutputDescPtr(output_index); - bool reuse_input = false; - if (output_op_desc != nullptr) { - if ((TensorUtils::GetReuseInput(*output_op_desc, reuse_input) == GRAPH_SUCCESS) && reuse_input) { - uint32_t reuse_input_index = 0; - if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { - reuse_in_index = static_cast(reuse_input_index); - GELOGI("ReuseInput name[%s] output[%d] reuse input[%d].", op_desc->GetName().c_str(), output_index, - reuse_in_index); - return true; - } - } - } - - return false; -} - -/// -/// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs -/// of the graph have UNKNOWN_SHAPE operators or not. -/// Note: This function will only look 'down' from the graph, not 'up'. For example, the following -/// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE -/// ROOT graph: A -----> B -----> C -/// K subgraph U -/// | -/// V -/// SUB graph: D --> E --> F -/// K K K -/// @param [in] graph -/// @return bool -/// -bool GraphUtils::IsUnknownShapeGraph(const ComputeGraphPtr &graph) { - if (graph == nullptr) { - GELOGW("Input graph is nullptr."); - return false; - } - for (const auto &node : graph->GetDirectNode()) { - bool is_unknown = false; - auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); - if (ret != GRAPH_SUCCESS) { - GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(), - node->GetType().c_str()); - continue; - } - if (is_unknown) { - GELOGD("Node %s, type %s is unknown shape in graph %s.", node->GetName().c_str(), node->GetType().c_str(), - graph->GetName().c_str()); - return true; - } - } - GELOGD("Graph %s does not have unknown shape node.", graph->GetName().c_str()); - return false; -} - -/// -/// @brief Add node to graph -/// @param [in] op_desc -/// @return ComputeGraphBuilder -/// -ComputeGraphBuilder &ComputeGraphBuilder::AddNode(const OpDescPtr &op_desc) { - nodes_.emplace_back(op_desc); - return *this; -} - -/// -/// @brief Add data-link among nodes in graph -/// @param [in] src_name -/// @param [in] out_anchor_ind -/// @param [in] dst_name -/// @param [in] in_anchor_ind -/// @return ComputeGraphBuilder -/// -ComputeGraphBuilder &ComputeGraphBuilder::AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, - const std::string &dst_name, uint32_t in_anchor_ind) { - data_links_.emplace_back( - std::make_pair(std::make_pair(src_name, out_anchor_ind), std::make_pair(dst_name, in_anchor_ind))); - return *this; -} - -/// -/// @brief Add ctrl-link among nodes in graph -/// @param [in] src_name -/// @param [in] dst_name -/// @return ComputeGraphBuilder -/// -ComputeGraphBuilder &ComputeGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { - ctrl_links_.emplace_back(std::make_pair(src_name, dst_name)); - return *this; -} - -/// -/// @brief Build nodes -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -void ComputeGraphBuilder::BuildNodes(graphStatus &error_code, std::string &error_msg) { - if (owner_graph_ == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "graph is NULL."; - return; - } - - std::string node_name; - for (auto &op_desc : nodes_) { - if (op_desc == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "op_desc is NULL."; - return; - } - - node_name = op_desc->GetName(); - NodePtr node = owner_graph_->AddNode(op_desc); - if (node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "Add node " + node_name + " failed."; - return; - } - - GELOGD("Add node name:%s, type:%s.", node_name.c_str(), op_desc->GetType().c_str()); - node_names_[node_name] = node; - } - - GELOGD("BuildNodes succ."); -} - -/// -/// @brief Build data-links -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -void ComputeGraphBuilder::BuildDataLinks(graphStatus &error_code, std::string &error_msg) { - for (auto &pair : data_links_) { - std::string src_name = pair.first.first; - uint32_t out_ind = pair.first.second; - std::string dst_name = pair.second.first; - uint32_t in_ind = pair.second.second; - std::string log_msg = "Add data-edge "; - log_msg.append(src_name) - .append(":") - .append(std::to_string(out_ind)) - .append("->") - .append(dst_name) - .append(":") - .append(std::to_string(in_ind)); - - auto src_iter = node_names_.find(src_name); - auto dst_iter = node_names_.find(dst_name); - if ((src_iter == node_names_.end()) || (dst_iter == node_names_.end())) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: node not exist in graph."; - return; - } - - NodePtr src_node = node_names_[src_name]; - NodePtr dst_node = node_names_[dst_name]; - if ((src_node == nullptr) || (dst_node == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: node is NULL."; - return; - } - - if (GraphUtils::AddEdge(src_node->GetOutDataAnchor(out_ind), dst_node->GetInDataAnchor(in_ind)) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed."; - return; - } - - GELOGD("%s succ.", log_msg.c_str()); - } - - GELOGD("BuildDataLinks succ."); -} - -/// -/// @brief Build ctrl-links -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -void ComputeGraphBuilder::BuildCtrlLinks(graphStatus &error_code, std::string &error_msg) { - for (auto &pair : ctrl_links_) { - std::string src_name = pair.first; - std::string dst_name = pair.second; - std::string log_msg = "Add ctrl-edge "; - log_msg.append(src_name).append("->").append(dst_name); - - auto src_iter = node_names_.find(src_name); - auto dst_iter = node_names_.find(dst_name); - if ((src_iter == node_names_.end()) || (dst_iter == node_names_.end())) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: node not exist in graph."; - return; - } - - NodePtr src_node = node_names_[src_name]; - NodePtr dst_node = node_names_[dst_name]; - if ((src_node == nullptr) || (dst_node == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: node is NULL."; - return; - } - - if (GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed."; - return; - } - - GELOGD("%s succ.", log_msg.c_str()); - } - - GELOGD("BuildCtrlLinks succ."); -} - -/// @brief Get node with name -/// @param [in] name -/// @return NodePtr -/// -NodePtr ComputeGraphBuilder::GetNode(const std::string &name) { - auto iter = node_names_.find(name); - if (iter == node_names_.end()) { - GE_LOGE("node %s not exist.", name.c_str()); - return nullptr; - } - return iter->second; -} - -/// @brief Get all nodes -/// @return std::vector -/// -std::vector ComputeGraphBuilder::GetAllNodes() { - std::vector nodes; - for (const auto &iter : node_names_) { - nodes.emplace_back(iter.second); - } - return nodes; -} - -/// -/// @brief Add node to graph -/// @param [in] op_desc -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::AddNode(const OpDescPtr &op_desc) { - ComputeGraphBuilder::AddNode(op_desc); - return *this; -} - -/// -/// @brief Add data-link among nodes in graph -/// @param [in] src_name -/// @param [in] out_anchor_ind -/// @param [in] dst_name -/// @param [in] in_anchor_ind -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, - const std::string &dst_name, uint32_t in_anchor_ind) { - ComputeGraphBuilder::AddDataLink(src_name, out_anchor_ind, dst_name, in_anchor_ind); - return *this; -} - -/// -/// @brief Add ctrl-link among nodes in graph -/// @param [in] src_name -/// @param [in] dst_name -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { - ComputeGraphBuilder::AddControlLink(src_name, dst_name); - return *this; -} - -/// -/// @brief Set index_th input anchor for graph -/// @param [in] index -/// @param [in] node_names -/// @param [in] anchor_inds -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::SetInput(uint32_t index, const std::vector &node_names, - const std::vector &anchor_inds) { - graph_inputs_[index] = std::make_pair(node_names, anchor_inds); - return *this; -} - -/// -/// @brief Set index_th input of graph as useless -/// @param [in] index -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::SetUselessInput(uint32_t index) { - graph_inputs_[index] = std::make_pair(std::vector(), std::vector()); - return *this; -} - -/// -/// @brief Add output anchor for graph -/// @param [in] owner_node_name -/// @param [in] anchor_ind -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::AddOutput(const std::string &owner_node_name, uint32_t anchor_ind) { - graph_outputs_.emplace_back(std::make_pair(owner_node_name, anchor_ind)); - return *this; -} - -/// -/// @brief Add target for graph -/// @param [in] target_name -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::AddTarget(const std::string &target_name) { - graph_targets_.emplace_back(target_name); - return *this; -} - -/// -/// @brief Set parent-node of graph -/// @param [in] parent_node -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::SetParentNode(const NodePtr &parent_node) { - parent_node_ = parent_node; - return *this; -} - -/// -/// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node -/// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::SetInputMapping(const std::map &input_mapping) { - for (auto &item : input_mapping) { - input_mapping_[item.first] = item.second; - } - return *this; -} - -/// -/// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind -/// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::SetOutputMapping(const std::map &output_mapping) { - for (auto &item : output_mapping) { - output_mapping_[item.first] = item.second; - } - return *this; -} - -/// -/// @brief Build graph -/// @param [out] error_code -/// @param [out] error_msg -/// @return ComputeGraphPtr -/// -ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { - owner_graph_ = shared_ptr(new (std::nothrow) ComputeGraph(name_)); - if ((owner_graph_ == nullptr) || (parent_node_ == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = "graph / parent_node is NULL."; - return nullptr; - } - - owner_graph_->SetParentNode(parent_node_); - owner_graph_->SetParentGraph(parent_node_->GetOwnerComputeGraph()); - - BuildNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildDataLinks(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildCtrlLinks(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - AddDataNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - AddRetValNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildGraphTargets(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - // ATTR_NAME_SESSION_GRAPH_ID - std::string graph_id; - if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { - error_code = GRAPH_FAILED; - error_msg = "Get attr session_graph_id failed."; - return nullptr; - } - if (!AttrUtils::SetStr(owner_graph_, ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { - error_code = GRAPH_FAILED; - error_msg = "Set attr session_graph_id failed."; - return nullptr; - } - - // refresh node name - for (const NodePtr &node : owner_graph_->GetDirectNode()) { - if ((node->GetOpDesc() == nullptr) || (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2)) { - continue; - } - node->GetOpDesc()->SetName(owner_graph_->GetName() + "/" + node->GetName()); - } - - return owner_graph_; -} - -/// -/// @brief Add data nodes -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -void CompleteGraphBuilder::AddDataNodes(graphStatus &error_code, std::string &error_msg) { - for (auto &input : graph_inputs_) { - NodePtr data_node = AddDataNode(input.first, error_code, error_msg); - if (data_node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: add node Data:" + std::to_string(input.first) + +" failed."; - return; - } - - if (owner_graph_->AddInputNode(data_node) == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: add input node Data:" + std::to_string(input.first) + +" failed."; - return; - } - - // useless input - std::vector input_names = input.second.first; - std::vector anchor_indes = input.second.second; - if (input_names.size() != anchor_indes.size()) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: num of input_names and indexs not equal."; - return; - } - if (input_names.empty()) { - continue; - } - - size_t input_num = input_names.size(); - for (size_t i = 0; i < input_num; i++) { - std::string input_name = input_names[i]; - uint32_t ind = anchor_indes[i]; - auto iter = node_names_.find(input_name); - if (iter == node_names_.end()) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: node " + input_name + " not exist in graph."; - return; - } - - NodePtr in_node = node_names_[input_name]; - if (in_node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: node " + input_name + " is NULL."; - return; - } - - if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), in_node->GetInDataAnchor(ind)) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: add data-edge Data:" + std::to_string(input.first) + ":0->" + input_name + - ":" + std::to_string(ind) + " failed."; - return; - } - } - - GELOGD("AddDataNodes : Add %u input succ.", input.first); - } - - GELOGD("AddDataNodes succ."); -} - -/// -/// @brief Add data node -/// @param [in] index -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -NodePtr CompleteGraphBuilder::AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg) { - std::string data_name = "Data_" + std::to_string(index); - OpDescBuilder op_desc_builder(data_name, "Data"); - OpDescPtr op_desc = op_desc_builder.AddInput("x").AddOutput("y").Build(); - if (op_desc == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNode failed: create op_desc " + data_name + " failed."; - return nullptr; - } - - auto index_iter = input_mapping_.find(index); - if (index_iter != input_mapping_.end()) { - if (!ge::AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, index_iter->second)) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNode failed: set attr ATTR_NAME_PARENT_NODE_INDEX for " + data_name + " failed."; - return nullptr; - } - } - - NodePtr data_node = owner_graph_->AddNode(op_desc); - if (data_node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNode failed: add node " + data_name + " failed."; - return nullptr; - } - node_names_[data_name] = data_node; - - return data_node; -} - -/// -/// @brief Add RetVal nodes -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string &error_msg) { - size_t output_num = graph_outputs_.size(); - for (size_t i = 0; i < output_num; i++) { - int32_t index = graph_outputs_[i].second; - auto out_iter = node_names_.find(graph_outputs_[i].first); - if (out_iter == node_names_.end()) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode failed: node " + graph_outputs_[i].first + " not exist in graph."; - return; - } - NodePtr node = out_iter->second; - if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode failed: node is NULL."; - return; - } - - std::string name = node->GetName() + "_RetVal_" + std::to_string(index); - OpDescPtr ret_val_desc = shared_ptr(new (std::nothrow) OpDesc(name, FRAMEWORKOP)); - if (ret_val_desc == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: op_desc is NULL."; - return; - } - ge::GeTensorDesc tensor = node->GetOpDesc()->GetOutputDesc(index); - if ((ret_val_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) || - (ret_val_desc->AddOutputDesc(tensor) != GRAPH_SUCCESS)) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: add input_desc / output_desc failed."; - return; - } - - if (!(ge::AttrUtils::SetStr(ret_val_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_RetVal") && - ge::AttrUtils::SetInt(ret_val_desc, RETVAL_ATTR_NAME_INDEX, i))) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: set FRAMEWORK_ORIGINAL_TYPE / RETVAL_ATTR_NAME_INDEX failed."; - return; - } - auto iter = output_mapping_.find(i); - if (iter != output_mapping_.end()) { - if (!ge::AttrUtils::SetInt(ret_val_desc, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: set attr PARENT_NODE_INDEX failed."; - return; - } - } - - NodePtr ret_val_node = owner_graph_->AddNode(ret_val_desc); - if (ret_val_node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: add node failed."; - return; - } - - if (GraphUtils::AddEdge(node->GetOutDataAnchor(index), ret_val_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: add data-edge " + node->GetName() + ":" + std::to_string(index) + - "->" + ret_val_node->GetName() + ":0 failed."; - return; - } - } - - GELOGD("AddRetValNodes succ."); -} - -/// -/// @brief Build target-nodes for graph -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -void CompleteGraphBuilder::BuildGraphTargets(graphStatus &error_code, std::string &error_msg) { - std::vector target_nodes; - for (const std::string &target_name : graph_targets_) { - auto target_iter = node_names_.find(target_name); - if ((target_iter == node_names_.end()) || (target_iter->second == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = "BuildGraphTargets failed: target_node " + target_name + " not exist in graph."; - return; - } - target_nodes.emplace_back(target_iter->second); - } - owner_graph_->SetGraphTargetNodesInfo(target_nodes); - return; -} - -/// -/// @brief Add node to graph -/// @param [in] op_desc -/// @return PartialGraphBuilder -/// -PartialGraphBuilder &PartialGraphBuilder::AddNode(const OpDescPtr &op_desc) { - ComputeGraphBuilder::AddNode(op_desc); - return *this; -} - -/// -/// @brief Add data-link among nodes in graph -/// @param [in] src_name -/// @param [in] out_anchor_ind -/// @param [in] dst_name -/// @param [in] in_anchor_ind -/// @return PartialGraphBuilder -/// -PartialGraphBuilder &PartialGraphBuilder::AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, - const std::string &dst_name, uint32_t in_anchor_ind) { - ComputeGraphBuilder::AddDataLink(src_name, out_anchor_ind, dst_name, in_anchor_ind); - return *this; -} - -/// -/// @brief Add ctrl-link among nodes in graph -/// @param [in] src_name -/// @param [in] dst_name -/// @return PartialGraphBuilder -/// -PartialGraphBuilder &PartialGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { - ComputeGraphBuilder::AddControlLink(src_name, dst_name); - return *this; -} - -/// -/// @brief Set owner graph -/// @param [in] graph -/// @return PartialGraphBuilder -/// -PartialGraphBuilder &PartialGraphBuilder::SetOwnerGraph(const ComputeGraphPtr &graph) { - owner_graph_ = graph; - return *this; -} - -/// -/// @brief Add exist node -/// @param [in] node -/// @return PartialGraphBuilder -/// -PartialGraphBuilder &PartialGraphBuilder::AddExistNode(const NodePtr &node) { - exist_nodes_.emplace_back(node); - return *this; -} - -/// -/// @brief Build partial graph -/// @param [out] error_code -/// @param [out] error_msg -/// @return ComputeGraphPtr -/// -ComputeGraphPtr PartialGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { - if (owner_graph_ == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "graph is NULL."; - return nullptr; - } - - BuildNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildExistNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildDataLinks(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildCtrlLinks(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - return owner_graph_; -} - -/// -/// @brief Build exist nodes -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -void PartialGraphBuilder::BuildExistNodes(graphStatus &error_code, std::string &error_msg) { - std::string node_name; - for (auto &node : exist_nodes_) { - if (node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "Build exist nodes failed: node is NULL."; - return; - } - - node_name = node->GetName(); - if (node->GetOwnerComputeGraph() != owner_graph_) { - error_code = GRAPH_FAILED; - error_msg = "Build exist nodes failed: node " + node_name + " not belongs to this graph."; - return; - } - - GELOGD("Add exist_node name:%s.", node_name.c_str()); - node_names_[node_name] = node; - } - - GELOGD("Build exist nodes succ."); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector &node_vec) { - std::vector stack_input; - std::map map_in_edge_num; - graphStatus ret = compute_graph->SortNodes(stack_input, map_in_edge_num); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Sort nodes failed."); - return GRAPH_FAILED; - } - const size_t non_user_input_index = stack_input.size() - compute_graph->inputs_order_.size() - 1; - std::sort(stack_input.begin(), stack_input.begin() + non_user_input_index, - [](const NodePtr &a, const NodePtr &b) -> bool { return (a->GetName() > b->GetName()); }); - - std::queue stack; - NodePtr cur_node = nullptr; - std::map name_node_map; - vector nodes_name; - while (!stack_input.empty() || !stack.empty()) { - if (!stack.empty()) { - cur_node = stack.front(); - stack.pop(); - } else { - cur_node = stack_input.back(); - stack_input.pop_back(); - } - node_vec.emplace_back(cur_node); - compute_graph->CollectBreadthOutNode(cur_node, map_in_edge_num, name_node_map); - for (const auto &iter : name_node_map) { - nodes_name.emplace_back(iter.first); - } - std::sort(nodes_name.begin(), nodes_name.end()); - for (const auto &iter : nodes_name) { - stack.push(name_node_map[iter]); - } - name_node_map.clear(); - nodes_name.clear(); - } - // If they are not equal, there is a closed loop - if (node_vec.size() != compute_graph->nodes_.size()) { - std::set itered_nodes_set; - for (auto &node : node_vec) { - itered_nodes_set.insert(node.get()); - } - GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", - compute_graph->nodes_.size(), node_vec.size()); - for (auto &node : compute_graph->nodes_) { - if (itered_nodes_set.count(node.get()) == 0) { - GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); - } - } - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -} // namespace ge diff --git a/src/common/graph/utils/mem_utils.h b/src/common/graph/utils/mem_utils.h deleted file mode 100644 index 24bbc86c..00000000 --- a/src/common/graph/utils/mem_utils.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_UTILS_MEM_UTILS_H_ -#define COMMON_GRAPH_UTILS_MEM_UTILS_H_ - -#include -#include - -namespace ge { -template -static inline std::shared_ptr<_Tp> MakeShared(_Args &&... __args) { - typedef typename std::remove_const<_Tp>::type _Tp_nc; - std::shared_ptr<_Tp> ret(new (std::nothrow) _Tp_nc(std::forward<_Args>(__args)...)); - return ret; -} -} // namespace ge - -#endif // COMMON_GRAPH_UTILS_MEM_UTILS_H_ diff --git a/src/common/graph/utils/node_utils.cc b/src/common/graph/utils/node_utils.cc deleted file mode 100644 index 684e37ac..00000000 --- a/src/common/graph/utils/node_utils.cc +++ /dev/null @@ -1,1005 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/utils/node_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/anchor.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/types.h" -#include "external/graph/operator.h" -#include "graph/ge_context.h" -#include "graph/runtime_inference_context.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/tensor_adapter.h" -#include "graph/utils/type_utils.h" - -namespace ge { -std::map> NodeUtils::map_send_info_{}; -std::map> NodeUtils::map_recv_info_{}; - -const std::set kConstOpTypes = {"Const", "Constant"}; - -const std::set kIfOpTypes = {"If", "_If", "StatelessIf"}; -const std::set kWhileOpTypes = {"While", "_While", "StatelessWhile"}; -const std::set kCaseOpTypes = {"Case"}; -const std::set kForOpTypes = {"For"}; - -bool OpShapeIsUnknown(const OpDescPtr &desc) { - for (const auto &ptr : desc->GetAllInputsDescPtr()) { - auto ge_shape = ptr->GetShape(); - for (const auto &dim : ge_shape.GetDims()) { - if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { - return true; - } - } - } - for (const auto &ptr : desc->GetAllOutputsDescPtr()) { - auto ge_shape = ptr->GetShape(); - for (const auto &dim : ge_shape.GetDims()) { - if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { - return true; - } - } - } - return false; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node, - const uint32_t &event_id) { - GE_CHECK_NOTNULL(node); - map_send_info_[node].push_back(event_id); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node, - const uint32_t &event_id) { - GE_CHECK_NOTNULL(node); - map_recv_info_[node].push_back(event_id); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector &vec_send) { - GE_CHECK_NOTNULL(node); - auto find = map_send_info_.find(node); - if (find == map_send_info_.end()) { - return GRAPH_FAILED; - } else { - vec_send = find->second; - return GRAPH_SUCCESS; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector &vec_recv) { - GE_CHECK_NOTNULL(node); - auto find = map_recv_info_.find(node); - if (find == map_recv_info_.end()) { - return GRAPH_FAILED; - } else { - vec_recv = find->second; - return GRAPH_SUCCESS; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() { - map_send_info_.clear(); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() { - map_recv_info_.clear(); - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) { - GE_CHECK_NOTNULL(src); - NodePtr cur_ptr; - if (depth < 1) { - return GRAPH_FAILED; - } - for (int i = 0; i < depth; i++) { - if (src->GetOutDataNodes().size() != 1) { - return GRAPH_FAILED; - } - cur_ptr = src->GetOutDataNodes().at(0); - GE_CHECK_NOTNULL(cur_ptr); - } - dst = cur_ptr; - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data, - InControlAnchorPtr &in_control) { - GE_CHECK_NOTNULL(node_ptr); - for (const auto &p : node_ptr->GetAllOutDataAnchors()) { - GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr"); - for (const auto &p_in : p->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr"); - out_data = p; - in_control = p_in; - return GRAPH_SUCCESS; - } - } - return GRAPH_FAILED; -} - -graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) { - GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED, - "node or in_data_anchor is nullptr"); - - bool find_flag = false; - uint32_t index = 0; - vector::iterator it = node_ptr->in_data_anchors_.end(); - for (const auto &tmp : node_ptr->in_data_anchors_) { - if (tmp == in_data_anchor) { - find_flag = true; - auto iter = node_ptr->in_data_anchors_.begin() + index; - if (iter != node_ptr->in_data_anchors_.end()) { - it = node_ptr->in_data_anchors_.erase(iter); - } - break; - } - index++; - } - for (; it != node_ptr->in_data_anchors_.end(); ++it) { - (*it)->SetIdx(index); - index++; - } - - if (!find_flag) { - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) { - GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr"); - GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed"); - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::SetAllAnchorStatus(Node &node) { - node.anchor_status_updated_ = true; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) { - GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr"); - return IsAnchorStatusSet(*node_ptr); -} - -bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; } - -graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) { - if ((origin_node == nullptr) || (new_node == nullptr)) { - return GRAPH_FAILED; - } - auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors(); - auto new_out_data_anchors = new_node->GetAllOutDataAnchors(); - if (origin_out_data_anchors.size() != new_out_data_anchors.size()) { - return GRAPH_FAILED; - } - - for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) { - for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue, - "unlink peer_anchor failed"); - GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, - "linkto peer_anchor failed"); - } - - for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue, - "unlink peer_anchor failed"); - GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, - "linkto peer_anchor failed"); - } - } - - auto origin_out_control_anchor = origin_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(origin_out_control_anchor); - auto new_out_control_anchor = new_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(new_out_control_anchor); - for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, - "linkto peer_anchor failed"); - } - for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, - "linkto peer_anchor failed"); - } - origin_out_control_anchor->UnlinkAll(); - - return GRAPH_SUCCESS; -} - -bool NodeUtils::IsConst(const Node &node) { - auto src_node_type = node.GetType(); - bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP)); - return is_const; -} - -void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) { - if (node_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "node is null"); - return; - } - UpdateIsInputConst(*node_ptr); -} - -/// -/// update is_input_const -/// @param node -/// @return void -/// -void NodeUtils::UpdateIsInputConst(Node &node) { - std::vector is_input_const; - size_t anchor_num = node.GetAllInDataAnchors().size(); - for (size_t i = 0; i < anchor_num; i++) { - auto in_anchor = node.GetInDataAnchor(static_cast(i)); - if (in_anchor == nullptr) { - is_input_const.push_back(false); - continue; - } - auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - is_input_const.push_back(false); - continue; - } - auto src_node = peer_out_anchor->GetOwnerNode(); - if (src_node == nullptr) { - is_input_const.push_back(false); - continue; - } - if (IsConst(*(src_node))) { - is_input_const.push_back(true); - } else { - is_input_const.push_back(false); - } - } - if (node.GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr"); - return; - } - node.GetOpDesc()->SetIsInputConst(is_input_const); -} - -void NodeUtils::UnlinkAll(const Node &node) { - for (const auto &anchor : node.GetAllOutAnchors()) { - anchor->UnlinkAll(); - } - for (const auto &anchor : node.GetAllInAnchors()) { - anchor->UnlinkAll(); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) { - if (node_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "Nodeptr is nullptr"); - return GRAPH_FAILED; - } - auto op_desc = node_ptr->GetOpDesc(); - if (op_desc == nullptr) { - return GRAPH_FAILED; - } - bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag(); - if (is_unknown_graph) { - return GRAPH_SUCCESS; - } - for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { - auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); - auto out_dims = output_tensor->GetShape().GetDims(); - auto out_dtype = output_tensor->GetDataType(); - ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetShape().GetDims().size())); - output_tensor->SetOriginShape(output_tensor->GetShape()); - output_tensor->SetOriginDataType(output_tensor->GetDataType()); - - GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", - node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), - TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), - TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); - - for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { - if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); - continue; - } - auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx()); - if (peer_input_desc == nullptr) { - GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); - continue; - } - // check shape and dtype continuity. do not stop process - auto peer_input_dims = peer_input_desc->GetShape().GetDims(); - auto peer_input_dtype = peer_input_desc->GetDataType(); - if (out_dtype != peer_input_dtype) { - GELOGW( - "current node [%s] [%d]\'th out_dtype is [%s].peer input node [%s] [%d]\'th " - "input_dtype is [%s].The two dtype should be same! Please check graph and fix it", - node_ptr->GetName().c_str(), out_anchor->GetIdx(), TypeUtils::DataTypeToSerialString(out_dtype).c_str(), - peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), - TypeUtils::DataTypeToSerialString(peer_input_dtype).c_str()); - } else if ((!peer_input_dims.empty()) && (out_dims != peer_input_dims)) { - string out_shape_str, peer_in_shape_str; - out_shape_str += "["; - for (int64_t dim : out_dims) { - out_shape_str += std::to_string(dim) + " "; - } - out_shape_str += "]"; - peer_in_shape_str += "["; - for (int64_t dim : peer_input_dims) { - peer_in_shape_str += std::to_string(dim) + " "; - } - peer_in_shape_str += "]"; - - GELOGW( - "current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th " - "input_shape is [%s].The two shape should be same! Please check graph and fix it", - node_ptr->GetName().c_str(), out_anchor->GetIdx(), out_shape_str.c_str(), - peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), peer_in_shape_str.c_str()); - } - GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", - peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(), - output_tensor->GetDataType(), output_tensor->GetOriginDataType()); - peer_input_desc->SetOriginShape(output_tensor->GetOriginShape()); - peer_input_desc->SetShape(output_tensor->GetShape()); - peer_input_desc->SetDataType(output_tensor->GetDataType()); - peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType()); - std::vector> shape_range; - (void)output_tensor->GetShapeRange(shape_range); - peer_input_desc->SetShapeRange(shape_range); - ge::TensorUtils::SetRealDimCnt(*peer_input_desc, - static_cast(output_tensor->GetShape().GetDims().size())); - GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", - peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), - peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node, - uint32_t num) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Input node is null"); - return GRAPH_FAILED; - } - - GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); - const auto &op_desc = node->GetOpDesc(); - for (size_t i = op_desc->GetInputsSize(); i < num; ++i) { - if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add input desc failed"); - return GRAPH_FAILED; - } - - auto anchor = ComGraphMakeShared(node, i); - if (anchor == nullptr) { - GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed."); - return GRAPH_FAILED; - } - node->in_data_anchors_.push_back(anchor); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node, - uint32_t num) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Input node is null"); - return GRAPH_FAILED; - } - - const auto &op_desc = node->GetOpDesc(); - while (op_desc->GetInputsSize() > num) { - if (!OpDescUtils::ClearInputDesc(op_desc, num)) { - return GRAPH_FAILED; - } - } - - auto input_names = op_desc->GetAllInputName(); - (void)op_desc->UpdateInputName(input_names); - auto is_input_const = op_desc->GetIsInputConst(); - is_input_const.resize(num); - op_desc->SetIsInputConst(is_input_const); - - while (node->in_data_anchors_.size() > num) { - node->in_data_anchors_.pop_back(); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendOutputAnchor(const NodePtr &node, - uint32_t num) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Input node is null"); - return GRAPH_FAILED; - } - - GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); - const OpDescPtr &op_desc = node->GetOpDesc(); - for (size_t i = op_desc->GetOutputsSize(); i < num; ++i) { - if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add output desc failed"); - return GRAPH_FAILED; - } - - auto anchor = ComGraphMakeShared(node, i); - if (anchor == nullptr) { - GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed."); - return GRAPH_FAILED; - } - node->out_data_anchors_.push_back(anchor); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveOutputAnchor(const NodePtr &node, - uint32_t num) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Input node is null"); - return GRAPH_FAILED; - } - - const auto &op_desc = node->GetOpDesc(); - auto output_names = op_desc->GetAllOutputName(); - while (op_desc->GetOutputsSize() > num) { - if (!OpDescUtils::ClearOutputDesc(op_desc, num)) { - return GRAPH_FAILED; - } - } - (void)op_desc->UpdateOutputName(output_names); - - while (node->out_data_anchors_.size() > num) { - node->out_data_anchors_.pop_back(); - } - - return GRAPH_SUCCESS; -} - -bool NodeUtils::IsInNodesEmpty(const Node &node) { - for (const auto &in_anchor : node.in_data_anchors_) { - if (in_anchor != nullptr) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor != nullptr) { - if (out_anchor->GetOwnerNode() != nullptr) { - return false; - } - } - } - } - - if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) { - auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors(); - for (const auto &out_control_anchor : peer_out_control_anchors) { - if (out_control_anchor != nullptr) { - if (out_control_anchor->GetOwnerNode() != nullptr) { - return false; - } - } - } - } - - return true; -} -GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) { - auto desc = node.GetOpDesc(); - if (desc == nullptr) { - return GeTensorDesc(); - } - return desc->GetOutputDesc(index); -} -GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) { - auto desc = node.GetOpDesc(); - if (desc == nullptr) { - return GeTensorDesc(); - } - return desc->GetInputDesc(index); -} -graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) { - auto desc = node.GetOpDesc(); - if (desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - auto output_desc = desc->MutableOutputDesc(index); - if (output_desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - output_desc->SetShape(shape); - return GRAPH_SUCCESS; -} -graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) { - auto desc = node.GetOpDesc(); - if (desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - auto input_desc = desc->MutableInputDesc(index); - if (input_desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - input_desc->SetShape(shape); - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { - auto desc = node.GetOpDesc(); - GE_CHECK_NOTNULL(desc); - // check self - is_unknow = OpShapeIsUnknown(desc); - if (is_unknow) { - return GRAPH_SUCCESS; - } - auto sub_graph_names = desc->GetSubgraphInstanceNames(); - if (sub_graph_names.empty()) { - return GRAPH_SUCCESS; - } else { - auto owner_graph = node.GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(owner_graph); - auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); - if (root_graph == nullptr) { - GE_LOGE("Node %s gets null root graph", node.GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - for (auto &sub_graph_name : sub_graph_names) { - auto sub_graph = root_graph->GetSubgraph(sub_graph_name); - GE_CHECK_NOTNULL(sub_graph); - for (const auto &node_ptr : sub_graph->GetDirectNode()) { - auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow); - if (status != GRAPH_SUCCESS) { - GE_LOGE("get node unknown shape status failed!"); - return status; - } - if (is_unknow) { - return GRAPH_SUCCESS; - } - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor) { - GE_CHECK_NOTNULL(node_ptr); - return NodeUtils::GetInputConstData(*node_ptr, dst_name, ge_tensor); -} - -graphStatus NodeUtils::GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor) { - // For inner compute graph - auto op_desc = node.GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - auto index = op_desc->GetInputIndexByName(dst_name); - auto in_data_anchor = node.GetInDataAnchor(index); - GE_CHECK_NOTNULL(in_data_anchor); - auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(out_data_anchor); - auto peer_node = out_data_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(peer_node); - auto peer_op_desc = peer_node->GetOpDesc(); - GE_CHECK_NOTNULL(peer_op_desc); - auto peer_op_type = peer_op_desc->GetType(); - if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) { - if (!AttrUtils::MutableTensor(peer_node->GetOpDesc(), ATTR_NAME_WEIGHTS, ge_tensor)) { - GELOGW("get attr name %s failed.", ATTR_NAME_WEIGHTS.c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; - } else if (peer_op_type == DATA) { - auto parent_node = NodeUtils::GetParentInput(peer_node); - while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { - parent_node = NodeUtils::GetParentInput(parent_node); - } - if ((parent_node != nullptr) && ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { - if (!AttrUtils::MutableTensor(parent_node->GetOpDesc(), ATTR_NAME_WEIGHTS, ge_tensor)) { - GELOGW("get attr name %s failed.", ATTR_NAME_WEIGHTS.c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; - } - } - // Try get from runtime inference context - auto session_id = std::to_string(GetContext().SessionId()); - RuntimeInferenceContext *runtime_infer_ctx = nullptr; - if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { - GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); - auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), ge_tensor); - if (ret == GRAPH_SUCCESS) { - return GRAPH_SUCCESS; - } - } - GELOGW("node[%s]'s input[%s]'s peer node is not const", node.GetName().c_str(), dst_name.c_str()); - return GRAPH_FAILED; -} - -std::string NodeUtils::GetNodeType(const Node &node) { - if (node.GetType() != FRAMEWORKOP) { - return node.GetType(); - } - - std::string type; - (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); - return type; -} - -std::string NodeUtils::GetNodeType(const NodePtr &node) { return node == nullptr ? "" : GetNodeType(*node); } - -ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { - auto op_desc = node.GetOpDesc(); - if (op_desc == nullptr) { - return nullptr; - } - auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); - if (root_graph == nullptr) { - return nullptr; - } - return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); -} - -graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) { - if (subgraph == nullptr) { - GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index); - return GRAPH_PARAM_INVALID; - } - auto op_desc = node.GetOpDesc(); - if (op_desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); - if (root_graph == nullptr) { - GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName()); - if (ret != GRAPH_SUCCESS) { - GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index); - return ret; - } - subgraph->SetParentNode(node.shared_from_this()); - subgraph->SetParentGraph(node.GetOwnerComputeGraph()); - return root_graph->AddSubgraph(subgraph); -} - -/// -/// Check if node is input of subgraph -/// @param [in] node -/// @return bool -/// -bool NodeUtils::IsSubgraphInput(const NodePtr &node) { - if ((node == nullptr) || (node->GetOpDesc() == nullptr) || - (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) { - return false; - } - - auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); - if (parent_op_desc == nullptr) { - return false; - } - - // dynamic shape unknown graph false - // dynamic shape known graph with functional subgraph maybe true - if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { - if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) { - return false; - } else { - if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) { - return false; - } - } - } - - return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); -} - -/// -/// Check if node is output of subgraph -/// @param [in] node -/// @return bool -/// -bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { - if ((node == nullptr) || (node->GetOpDesc() == nullptr) || - (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) { - return false; - } - - auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); - if (parent_op_desc == nullptr) { - return false; - } - - if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { - if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) { - return false; - } else { - if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) { - return false; - } - } - } - - for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) { - if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) { - return true; - } - } - - return false; -} - -/// -/// @brief Get subgraph original input node. -/// @param [in] node -/// @return Node -/// -NodePtr NodeUtils::GetParentInput(const Node &node) { - uint32_t parent_index = 0; - if (!AttrUtils::GetInt(node.GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - return nullptr; - } - - // Subgraph Data Node, check for constant input. - const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); - GE_CHECK_NOTNULL_EXEC(graph, return nullptr); - - const NodePtr &parent_node = graph->GetParentNode(); - GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr); - - const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index); - GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr); - - const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr); - - return peer_out_anchor->GetOwnerNode(); -} - -NodePtr NodeUtils::GetParentInput(const NodePtr &node) { return node == nullptr ? node : GetParentInput(*node); } - -/// -/// @brief Get is dynamic shape graph from node. -/// @param [in] node -/// @return bool -/// -bool NodeUtils::IsDynamicShape(const Node &node) { - const auto graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); - if (graph == nullptr) { - return false; - } - - bool is_dynamic_shape = false; - (void)AttrUtils::GetBool(graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); - return is_dynamic_shape; -} - -bool NodeUtils::IsDynamicShape(const NodePtr &node) { return node == nullptr ? false : IsDynamicShape(*node); } - -/// -/// @brief Check is varying_input for while node -/// @param [in] node: Data node for subgraph -/// @return bool -/// -bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) { - if (node == nullptr) { - return false; - } - if (node->GetType() != DATA) { - return false; // not input_node for subgraph - } - - const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode(); - if (parent_node == nullptr) { - return false; // root graph - } - - if (kWhileOpTypes.count(parent_node->GetType()) == 0) { - return false; // not input_node for while subgraph - } - - uint32_t index_i = 0; - if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) { - GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str()); - return false; - } - bool varying_flag = true; - for (const auto &item : node->GetOutDataNodesAndAnchors()) { - if (item.first->GetType() != NETOUTPUT) { - continue; - } - OpDescPtr op_desc = item.first->GetOpDesc(); - uint32_t index_o = 0; - if ((op_desc == nullptr) || - !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) { - continue; // input for while-cond subgraph - } - if (index_i != index_o) { - continue; // varying input for while-body subgraph - } - varying_flag = false; - break; - } - return varying_flag; -} - -/// -/// @brief Get subgraph input is constant. -/// @param [in] node -/// @param [out] string -/// @return bool -/// -bool NodeUtils::GetConstOpType(const NodePtr &node, std::string &type) { - if (node == nullptr) { - return false; - } - - if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) { - type = node->GetType(); - return true; - } - - if (node->GetType() != DATA) { - return false; // not subgraph input node - } - - const auto &parent = GetParentInput(node); - return GetConstOpType(parent, type); -} - -/// -/// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. -/// @param [in] node -/// @return return GRAPH_SUCCESS if remove successfully, other for failed. -/// -Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { - GE_CHECK_NOTNULL(node); - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - auto subgraph_names = op_desc->GetSubgraphInstanceNames(); - if (subgraph_names.empty()) { - return GRAPH_SUCCESS; - } else { - auto owner_graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(owner_graph); - auto root_graph = GraphUtils::FindRootGraph(owner_graph); - GE_CHECK_NOTNULL(root_graph); - - std::unordered_set subgraph_to_remove; - for (auto &subgraph_name : subgraph_names) { - std::deque queue; - queue.push_back(subgraph_name); - subgraph_to_remove.insert(subgraph_name); - op_desc->RemoveSubgraphInstanceName(subgraph_name); - while (!queue.empty()) { - auto graph_name = queue.front(); - queue.pop_front(); - - auto subgraph = root_graph->GetSubgraph(graph_name); - GE_CHECK_NOTNULL(subgraph); - for (const auto &sub_node : subgraph->GetDirectNode()) { - auto sub_op_desc = sub_node->GetOpDesc(); - GE_CHECK_NOTNULL(sub_op_desc); - auto sub_names = sub_op_desc->GetSubgraphInstanceNames(); - // Subgraph and all nodes in it will be removed later, - // no need to remove 'SubgraphInstanceName' in op desc here. - for (auto &name : sub_names) { - if (subgraph_to_remove.insert(name).second) { - queue.push_back(name); - } - } - } - } - } - // Remove subgraph from root_graph - for (const auto &name : subgraph_to_remove) { - GELOGI("Remove subgraph:%s.", name.c_str()); - root_graph->RemoveSubgraph(name); - } - } - - return GRAPH_SUCCESS; -} -/// -/// @brief Get subgraph input data node by index. -/// @param [in] node -/// @return Node -/// -vector NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) { - vector in_data_node_vec; - auto op_desc = node.GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec); - auto subgraph_names = op_desc->GetSubgraphInstanceNames(); - if (subgraph_names.empty()) { - GELOGW("Node %s is single node without sub graph.", node.GetName().c_str()); - return in_data_node_vec; - } - auto compute_graph = node.GetOwnerComputeGraph(); - for (const std::string &instance_name : subgraph_names) { - auto subgraph = compute_graph->GetSubgraph(instance_name); - for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { - int parent_index = -1; - if (NodeUtils::IsSubgraphInput(node_in_subgraph)) { - (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index); - if (parent_index == index) { - in_data_node_vec.emplace_back(node_in_subgraph); - } - } - } - } - return in_data_node_vec; -} -/// -/// @brief Get subgraph input data node by index. -/// @param [in] node -/// @return Node -/// -vector NodeUtils::GetSubgraphOutputNodes(const Node &node) { - vector out_data_node_vec; - auto op_desc = node.GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec); - auto subgraph_names = op_desc->GetSubgraphInstanceNames(); - if (subgraph_names.empty()) { - GELOGI("Node %s is single node without sub graph.", node.GetName().c_str()); - return out_data_node_vec; - } - auto compute_graph = node.GetOwnerComputeGraph(); - for (const std::string &instance_name : subgraph_names) { - auto subgraph = compute_graph->GetSubgraph(instance_name); - for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { - if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) { - out_data_node_vec.emplace_back(node_in_subgraph); - } - } - } - return out_data_node_vec; -} - -NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, const int index) { - if (node.GetInDataAnchor(index) == nullptr) { - return nullptr; - } - if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) { - return nullptr; - } - return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode(); -} - -vector> NodeUtils::GetOutDataNodesWithAnchorByIndex(const Node &node, const int index) { - vector> out_data_nodes; - auto out_data_anchor = node.GetOutDataAnchor(index); - if (out_data_anchor == nullptr) { - return out_data_nodes; - } - - for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - if (peer_in_anchor == nullptr) { - continue; - } - if (peer_in_anchor->GetOwnerNode() == nullptr) { - continue; - } - out_data_nodes.emplace_back(std::make_pair(peer_in_anchor, peer_in_anchor->GetOwnerNode())); - } - return out_data_nodes; -} - -ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) { return oprt.GetNode(); } -} // namespace ge diff --git a/src/common/graph/utils/op_desc_utils.cc b/src/common/graph/utils/op_desc_utils.cc deleted file mode 100644 index 17c80b2c..00000000 --- a/src/common/graph/utils/op_desc_utils.cc +++ /dev/null @@ -1,825 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/op_desc_utils.h" -#include -#include "debug/ge_attr_define.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/anchor.h" -#include "graph/compute_graph.h" -#include "graph/ge_attr_value.h" -#include "utils/graph_utils.h" -#include "utils/node_utils.h" - -using std::vector; - -/*lint -e512 -e737 -e752*/ -namespace ge { -const char OP_DESC_QUANT_PARAMS[] = "quantize_factor"; -static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1; - -bool OpDescUtils::ClearInputDesc(const NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); - GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr"); - vector index_list; - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - if (in_anchor->GetPeerOutAnchor() == nullptr) { - index_list.push_back(in_anchor->GetIdx()); - } - } - std::sort(index_list.begin(), index_list.end()); - // Node's in anchor index need shrink - for (size_t i = 0; i < index_list.size(); ++i) { - auto iter = node->GetOpDesc()->inputs_desc_.begin() + index_list[i]; - if (iter < node->GetOpDesc()->inputs_desc_.end()) { - (void)node->GetOpDesc()->inputs_desc_.erase(iter); - } else { - GELOGW("inputs_desc_ iterator out of range."); - } - } - - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearInputDesc(OpDescPtr op_desc, - const uint32_t index) { - GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr"); - GE_CHK_BOOL_EXEC(index < op_desc->inputs_desc_.size(), return false, "index %u is invalid.", index); - - auto iter = op_desc->inputs_desc_.begin() + index; - if (iter < op_desc->inputs_desc_.end()) { - (void)op_desc->inputs_desc_.erase(iter); - } else { - GELOGW("inputs_desc_ iterator out of range."); - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::HasQuantizeFactorParams(const OpDescPtr &op_desc) { - GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return false, "op_desc is nullptr"); - return op_desc->HasAttr(OP_DESC_QUANT_PARAMS); -} - -bool OpDescUtils::ClearOutputDesc(const NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); - GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr"); - vector index_list; - for (const auto &out_anchor : node->GetAllOutDataAnchors()) { - if (out_anchor->GetPeerInDataAnchors().empty()) { - index_list.push_back(out_anchor->GetIdx()); - } - } - std::sort(index_list.begin(), index_list.end()); - // Node's out anchor index need shrink - for (size_t i = 0; i < index_list.size(); ++i) { - auto iter = node->GetOpDesc()->outputs_desc_.begin() + index_list[i]; - if (iter < node->GetOpDesc()->outputs_desc_.end()) { - (void)node->GetOpDesc()->outputs_desc_.erase(iter); - } else { - GELOGW("outputs_desc_ iterator out of range."); - } - } - - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearOutputDesc(const OpDescPtr &op_desc, - uint32_t index) { - GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr"); - GE_CHK_BOOL_EXEC(index < op_desc->outputs_desc_.size(), return false, "index %u is invalid.", index); - - auto iter = op_desc->outputs_desc_.begin() + index; - if (iter < op_desc->outputs_desc_.end()) { - (void)op_desc->outputs_desc_.erase(iter); - } else { - GELOGW("outputs_desc_ iterator out of range."); - } - return true; -} - -bool OpDescUtils::HasQuantizeFactorParams(const OpDesc &op_desc) { return op_desc.HasAttr(OP_DESC_QUANT_PARAMS); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant) { - GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); - GeAttrValue attr_value; - GE_CHK_BOOL_EXEC_INFO(op_desc->GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED, - "GetQuantizeFactorParams failed"); - return attr_value.GetValue(quant); -} - -graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant) { - GeAttrValue attr_value; - GE_CHK_BOOL_EXEC_INFO(op_desc.GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED, - "GetQuantizeFactorParams failed"); - return attr_value.GetValue(quant); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) { - GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); - return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom(quant)); // lint !e732 -} - -graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) { - return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom(quant)); // lint !e732 -} - -GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) { - GeTensorPtr weight = nullptr; - if (!AttrUtils::MutableTensor(&op_desc, ATTR_NAME_WEIGHTS, weight)) { - GELOGW("MutableTensor error"); - } - - return weight; -} - -GE_FUNC_HOST_VISIBILITY GeTensorPtr OpDescUtils::MutableWeights(OpDescPtr op_desc) { - if (op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "op_desc is null"); - return nullptr; - } - return MutableWeights(*op_desc); -} - -graphStatus OpDescUtils::SetWeights(OpDesc &op_desc, const GeTensorPtr weight) { - if (weight == nullptr) { - GELOGE(GRAPH_FAILED, "weight is null"); - return GRAPH_FAILED; - } - return AttrUtils::SetTensor(&op_desc, ATTR_NAME_WEIGHTS, weight) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus OpDescUtils::SetWeights(OpDescPtr op_desc, const GeTensorPtr weight) { - GE_CHECK_NOTNULL(op_desc); - GE_CHECK_NOTNULL(weight); - return SetWeights(*op_desc, weight); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetWeights(const ge::Node &node) { - auto weights = MutableWeights(node); - vector ret(weights.size()); - std::copy(weights.begin(), weights.end(), ret.begin()); - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetWeights( - const ge::ConstNodePtr &node) { - if (node == nullptr) { - return vector(); - } - return GetWeights(*node); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetConstInputNode( - const ge::Node &node) { - vector ret; - auto in_anchors = node.GetAllInDataAnchors(); - for (const auto &in_anchor : in_anchors) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - // normally out_anchor could be null, this is ok - GELOGD("node %s' peer_out_anchor is null", node.GetName().c_str()); - continue; - } - auto in_node = out_anchor->GetOwnerNode(); - while (true) { - if (in_node == nullptr) { - break; - } - if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { - ret.push_back(in_node); - break; - } else if (in_node->GetType() == DATA) { - if (NodeUtils::IsWhileVaryingInput(in_node)) { - break; - } - in_node = NodeUtils::GetParentInput(in_node); - } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) { - bool is_constant = false; - (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant); - if (!is_constant) { - break; - } - // Enter node has and only has one input - if (in_node->GetInDataNodes().size() != 1) { - GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(), - in_node->GetInDataNodes().size()); - break; - } - in_node = in_node->GetInDataNodes().at(0); - } else { - break; - } - } - } - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetInputData( - const vector &input_nodes) { - vector ret; - - for (const auto &input_node : input_nodes) { - auto temp_weight = MutableWeights(input_node->GetOpDesc()); - if (temp_weight == nullptr) { - GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str()); - return vector(); - } - ret.push_back(temp_weight); - } - - return ret; -} -size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) { - if (NodeUtils::IsAnchorStatusSet(node)) { - size_t input_num = 0; - for (const auto &anchor : node.GetAllInDataAnchors()) { - if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { - input_num++; - continue; - } - } - return input_num; // lint !e712 - } else { - GE_IF_BOOL_EXEC( - node.GetInDataNodes().size() < GetConstInputs(node).size(), - GELOGE(GRAPH_FAILED, "%zu is smaller than %zu", node.GetInDataNodes().size(), GetConstInputs(node).size()); - return 0); - return node.GetInDataNodes().size() - GetConstInputs(node).size(); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDescUtils::GetNonConstInputsSize(const ge::ConstNodePtr node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Node is nullptr"); - return 0; - } - return GetNonConstInputsSize(*node); -} - -GeTensorDesc OpDescUtils::GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const) { - GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GeTensorDesc(), "node.GetOpDesc() is nullptr!"); - size_t i = 0; - if (NodeUtils::IsAnchorStatusSet(node)) { - for (const auto &anchor : node.GetAllInDataAnchors()) { - if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { - if (index_non_const == i) { - return node.GetOpDesc()->GetInputDesc(static_cast(anchor->GetIdx())); - } - ++i; - } - } - } else { - for (const auto &anchor : node.GetAllInDataAnchors()) { - auto peer_anchor = anchor->GetPeerOutAnchor(); - if (peer_anchor == nullptr) { - continue; - } - auto owner_node = peer_anchor->GetOwnerNode(); - if (owner_node == nullptr) { - continue; - } - if (owner_node->GetType() == CONSTANT) { - continue; - } - if (index_non_const == i) { - return node.GetOpDesc()->GetInputDesc(anchor->GetIdx()); - } - ++i; - } - } - return GeTensorDesc(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc -OpDescUtils::GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const) { - CHECK_FALSE_EXEC(node != nullptr, return GeTensorDesc()); - return GetNonConstInputTensorDesc(*node, index_non_const); -} - -bool OpDescUtils::GetNonConstInputIndex(const ge::Node &node, const size_t index_non_const, size_t &index) { - bool ret = false; - size_t i = 0; - if (NodeUtils::IsAnchorStatusSet(node)) { - for (const auto &anchor : node.GetAllInDataAnchors()) { - if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { - if (index_non_const == i) { - index = static_cast(anchor->GetIdx()); - ret = true; - } - ++i; - } - } - } else { - for (const auto &anchor : node.GetAllInDataAnchors()) { - auto peer_anchor = anchor->GetPeerOutAnchor(); - if (peer_anchor == nullptr) { - continue; - } - auto owner_node = peer_anchor->GetOwnerNode(); - if (owner_node == nullptr) { - continue; - } - if (owner_node->GetType() == CONSTANT) { - continue; - } - if (index_non_const == i) { - index = static_cast(anchor->GetIdx()); - ret = true; - } - ++i; - } - } - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::GetNonConstInputIndex(const ge::ConstNodePtr &node, - size_t index_non_const, - size_t &index) { - CHECK_FALSE_EXEC(node != nullptr, return false); - return GetNonConstInputIndex(*node, index_non_const, index); -} - -bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) { - bool ret = false; - if (index < node.GetAllInDataAnchors().size()) { - if (NodeUtils::IsAnchorStatusSet(node)) { - ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast(index))) == ANCHOR_DATA); // lint !e712 - } else { - for (const auto &anchor : node.GetAllInDataAnchors()) { - if (anchor->GetIdx() != static_cast(index)) { - continue; - } - auto peer_anchor = anchor->GetPeerOutAnchor(); - if (peer_anchor == nullptr) { - break; - } - auto owner_node = peer_anchor->GetOwnerNode(); - if (owner_node == nullptr) { - break; - } - ret = (owner_node->GetType() != CONSTANT); - } - } - } - - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::IsNonConstInput(const ge::ConstNodePtr &node, - size_t index) { - CHECK_FALSE_EXEC(node != nullptr, return false); - return IsNonConstInput(*node, index); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetConstInputs( - const ge::ConstNodePtr &node) { - if (node == nullptr) { - return vector(); - } - return GetConstInputs(*node); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetNonConstTensorDesc( - const ge::ConstNodePtr &node) { - if (node == nullptr || node->GetOpDesc() == nullptr) { - return vector(); - } - vector ret; - if (NodeUtils::IsAnchorStatusSet(*node)) { - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) { - ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); - } - } - } else { - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr || out_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { - continue; - } - if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) { - ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); - } - } - } - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetConstInputs(const ge::Node &node) { - vector ret; - auto in_anchors = node.GetAllInDataAnchors(); - for (const auto &in_anchor : in_anchors) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) continue; - - auto in_node = out_anchor->GetOwnerNode(); - if (in_node->GetType() == CONSTANT) { - ret.push_back(in_node); - } else if (in_node->GetType() == SWITCH && node.GetType() == MATMUL) { - // const --> switch --> matmul - auto switch_input = GetConstInputs(*in_node); - if (switch_input.size() > 0) { - ret.insert(ret.end(), switch_input.begin(), switch_input.end()); - } - } else if (in_node->GetType() == DATA) { - auto parent = NodeUtils::GetParentInput(in_node); - if ((parent != nullptr) && (parent->GetType() == CONSTANT)) { - ret.push_back(parent); - } - } - } - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::MutableWeights(const ge::Node &node) { - vector ret; - auto op_desc = node.GetOpDesc(); - GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!"); - // Place holder operator, try to get the weight from parent node - // when parent node is const operator - if (node.GetType() == PLACEHOLDER) { - std::string parent_op; - (void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op); - // This if judgment is necessary because the current subgraph optimization is multithreaded - // and the parent node of the PLD operation should be a stable type, such as const - if (parent_op == CONSTANT || parent_op == CONSTANTOP) { - NodePtr parent_node = nullptr; - parent_node = op_desc->TryGetExtAttr("parentNode", parent_node); - if (parent_node != nullptr) { - op_desc = parent_node->GetOpDesc(); - GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str()); - } - } - } - // Const operator, take the weight directly - if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) { - auto weight = MutableWeights(op_desc); - if (weight == nullptr) { - GELOGI("const op has no weight, op name:%s", node.GetName().c_str()); - return ret; - } - ret.push_back(weight); - return ret; - } - - if (node.GetType() == DATA) { - auto parent = NodeUtils::GetParentInput(node); - if ((parent != nullptr) && NodeUtils::IsConst(*parent)) { - auto weight = MutableWeights(parent->GetOpDesc()); - if (weight == nullptr) { - GELOGI("const op has no weight, op name:%s", parent->GetName().c_str()); - return ret; - } - ret.push_back(weight); - } - return ret; - } - - // Other operators, get weights from connected constop - auto input_nodes = GetConstInputs(node); - for (const auto &input_node : input_nodes) { - auto temp_weight = MutableWeights(input_node->GetOpDesc()); - if (temp_weight == nullptr) { - GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str()); - return vector(); - } - ret.push_back(temp_weight); - } - - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::MutableWeights(const ge::NodePtr node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Node is nullptr"); - return vector(); - } - return MutableWeights(*node); -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::SetWeights(ge::Node &node, const vector &weights) { - GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GRAPH_PARAM_INVALID, "node.GetOpDesc is nullptr!"); - if (node.GetOpDesc()->GetType() == CONSTANT) { - if (weights.size() == CONST_OP_NORMAL_WEIGHT_SIZE) { - return SetWeights(node.GetOpDesc(), weights[0]); - } - GELOGI("const op weight size %zu should be 1", weights.size()); - return GRAPH_PARAM_INVALID; - } - - auto input_nodes = GetConstInputs(node); - if (weights.size() < input_nodes.size()) { - GELOGE(GRAPH_FAILED, "weights count can't be less than const input count"); - return GRAPH_PARAM_INVALID; - } - - ge::GeAttrValue::NAMED_ATTRS named_attrs; - (void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights); - vector copy_weights; - (void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights); - - for (size_t i = 0; i < input_nodes.size(); ++i) { - if (input_nodes[i]->GetOpDesc() != nullptr) { - SetWeights(input_nodes[i]->GetOpDesc(), copy_weights[i]); - } - } - - // If set more weights than constop, need to add constop - for (size_t i = input_nodes.size(); i < copy_weights.size(); ++i) { - // Use org weight before SetWeights Overwrite - auto const_opdesc = CreateConstOp(copy_weights[i]); - GE_CHECK_NOTNULL(const_opdesc); - - auto owner_graph = node.GetOwnerComputeGraph(); - if (owner_graph == nullptr) { - GELOGE(GRAPH_FAILED, "node's graph is empty, name: %s", node.GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - auto const_node = owner_graph->AddNodeFront(const_opdesc); - GE_CHK_BOOL_EXEC(node.AddLinkFrom(const_node) == GRAPH_SUCCESS, return GRAPH_FAILED, "graph add link failedï¼"); - std::vector original_nodes; - ge::GraphUtils::RecordOriginalNames(original_nodes, const_node); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::SetWeights(ge::Node &node, const map &weights_map) { - GE_CHECK_NOTNULL(node.GetOpDesc()); - // 1. node is const - if (node.GetOpDesc()->GetType() == CONSTANT) { - if (weights_map.size() == CONST_OP_NORMAL_WEIGHT_SIZE) { - return SetWeights(node.GetOpDesc(), weights_map.begin()->second); - } - GELOGE(GRAPH_PARAM_INVALID, "const op %s weight size %zu should be 1", node.GetName().c_str(), weights_map.size()); - return GRAPH_PARAM_INVALID; - } - // 2. node is not const - for (const auto &pair : weights_map) { - auto in_data_anchor = node.GetInDataAnchor(pair.first); - if (in_data_anchor != nullptr && in_data_anchor->GetPeerOutAnchor() != nullptr) { - // a. update const input node - auto out_anchor = in_data_anchor->GetPeerOutAnchor(); - auto peer_node = out_anchor->GetOwnerNode(); - if (peer_node == nullptr) { - GELOGE(GRAPH_PARAM_INVALID, "op %s [%d]'s input node is null", node.GetName().c_str(), pair.first); - return GRAPH_PARAM_INVALID; - } - if (peer_node->GetType() != CONSTANT) { - GELOGE(GRAPH_PARAM_INVALID, " op %s [%d]'s input node should be const, but is %s type:%s ", - node.GetName().c_str(), pair.first, peer_node->GetName().c_str(), peer_node->GetType().c_str()); - } - SetWeights(peer_node->GetOpDesc(), pair.second); - } else { - // b. create new const input node - auto const_opdesc = CreateConstOp(pair.second); - GE_CHECK_NOTNULL(const_opdesc); - auto owner_graph = node.GetOwnerComputeGraph(); - if (owner_graph == nullptr) { - GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", node.GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - auto const_node = owner_graph->AddNodeFront(const_opdesc); - if (node.AddLinkFrom(static_cast(pair.first), const_node) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "op %s add const to input index[%d] failed", node.GetName().c_str(), pair.first); - return GRAPH_FAILED; - } - } - } - NodeUtils::UpdateIsInputConst(node); - return GRAPH_SUCCESS; -} - -OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr) { - GE_CHK_BOOL_EXEC(tensor_ptr != nullptr, return nullptr, "tensor_ptr is nullptr!"); - shared_ptr const_opdesc = ComGraphMakeShared(); - if (const_opdesc == nullptr) { - GELOGE(GRAPH_FAILED, "failed to make_shared "); - return nullptr; - } - - CHECK_FALSE_EXEC(SetWeights(const_opdesc, tensor_ptr) == ge::GRAPH_SUCCESS, return nullptr); - - const_opdesc->SetType(CONSTANT); - - thread_local int64_t const_count = 0; - const_opdesc->SetName("dynamic_const_" + std::to_string(GetTid()) + "_" + std::to_string(const_count)); - GELOGI("add const op: %s", const_opdesc->GetName().c_str()); - ++const_count; - - (void)const_opdesc->AddOutputDesc(tensor_ptr->GetTensorDesc()); - - GELOGI("after add const op: %s", const_opdesc->GetName().c_str()); - - return const_opdesc; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr) { - GE_CHECK_NOTNULL(in_anchor); - GE_CHECK_NOTNULL(tensor_ptr); - auto const_opdesc = CreateConstOp(tensor_ptr); - GE_CHECK_NOTNULL(const_opdesc); - auto in_node = in_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(in_node); - auto owner_graph = in_node->GetOwnerComputeGraph(); - if (owner_graph == nullptr) { - GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", in_node->GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - auto const_node = in_node->GetOwnerComputeGraph()->AddNodeFront(const_opdesc); - GE_CHECK_NOTNULL(const_node); - if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), in_anchor) != GRAPH_SUCCESS) { - GELOGE(GRAPH_PARAM_INVALID, "Addedge const to node failed."); - return GRAPH_PARAM_INVALID; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::SetWeights(ge::NodePtr node, const vector &weights) { - GE_CHECK_NOTNULL(node); - return SetWeights(*node, weights); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWeights(const ge::NodePtr node) { - GE_CHECK_NOTNULL(node); - auto const_ops = GetConstInputs(node); - auto graph = node->GetOwnerComputeGraph(); - if (graph == nullptr) { - GELOGE(GRAPH_FAILED, "Graph is nullptr"); - return GRAPH_PARAM_INVALID; - } - for (const auto &const_op : const_ops) { - GE_CHK_STATUS_RET(GraphUtils::IsolateNode(const_op, {}), "Isolate removed node: %s, type: %s failed", - const_op->GetName().c_str(), const_op->GetType().c_str()); - GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, const_op), - "Remove node: %s, type: %s without relink failed", const_op->GetName().c_str(), - const_op->GetType().c_str()); - } - return GRAPH_SUCCESS; -} - -/// -/// @brief Add input -/// @param [in] name -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) { - inputs_.emplace_back(std::make_pair(name, GeTensorDesc())); - return *this; -} - -/// -/// @brief Add input -/// @param [in] name -/// @param [in] tensor -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name, - const GeTensorDesc &tensor) { - inputs_.emplace_back(std::make_pair(name, tensor)); - return *this; -} - -/// -/// @brief Add dynamic input -/// @param [in] name -/// @param [in] num -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, - uint32_t num) { - for (uint32_t i = 0; i < num; i++) { - inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); - } - return *this; -} - -/// -/// @brief Add dynamic input -/// @param [in] name -/// @param [in] num -/// @param [in] tensor -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput( - const std::string &name, uint32_t num, const GeTensorDesc &tensor) { - for (uint32_t i = 0; i < num; i++) { - inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); - } - return *this; -} - -/// -/// @brief Add output -/// @param [in] name -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) { - outputs_.emplace_back(std::make_pair(name, GeTensorDesc())); - return *this; -} - -/// -/// @brief Add output -/// @param [in] name -/// @param [in] tensor -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name, - const GeTensorDesc &tensor) { - outputs_.emplace_back(std::make_pair(name, tensor)); - return *this; -} - -/// -/// @brief Add dynamic output -/// @param [in] name -/// @param [in] num -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, - uint32_t num) { - for (uint32_t i = 0; i < num; i++) { - outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); - } - return *this; -} - -/// -/// @brief Add dynamic output -/// @param [in] name -/// @param [in] num -/// @param [in] tensor -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput( - const std::string &name, uint32_t num, const GeTensorDesc &tensor) { - for (uint32_t i = 0; i < num; i++) { - outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); - } - return *this; -} - -/// -/// @brief Build op_desc -/// @return OpDescPtr -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() { - OpDescPtr op_desc = shared_ptr(new (std::nothrow) OpDesc(name_, type_)); - if (op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "OpDesc is nullptr"); - return nullptr; - } - - for (auto &input : inputs_) { - if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add input_desc failed."); - return nullptr; - } - } - - for (auto &output : outputs_) { - if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add output_desc failed."); - return nullptr; - } - } - - return op_desc; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgraphInstanceName( - const std::string &subgraph_name, const std::string &subgraph_instance_name, OpDescPtr &op_desc) { - const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); - auto iter = subgraph_names_to_index.find(subgraph_name); - if (iter == subgraph_names_to_index.end()) { - GELOGE(GRAPH_PARAM_INVALID, - "Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists", - subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), - subgraph_name.c_str()); - return GRAPH_PARAM_INVALID; - } - - return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name); -} -} // namespace ge -/*lint +e512 +e737 +e752*/ diff --git a/src/common/graph/utils/string_utils.h b/src/common/graph/utils/string_utils.h deleted file mode 100644 index a9700469..00000000 --- a/src/common/graph/utils/string_utils.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_UTILS_STRING_UTILS_H_ -#define COMMON_GRAPH_UTILS_STRING_UTILS_H_ - -#include -#include -#include -#include -#include -#include "securec.h" - -namespace ge { -class StringUtils { - public: - static std::string &Ltrim(std::string &s) { - (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); - return s; - } - - static std::string &Rtrim(std::string &s) { - (void)s.erase(std::find_if(s.rbegin(), s.rend(), [](int c) { return !std::isspace(c); }).base(), s.end()); - return s; - } - - /// @ingroup domi_common - /// @brief trim space - static std::string &Trim(std::string &s) { return Ltrim(Rtrim(s)); } - - // split string - static std::vector Split(const std::string &str, char delim) { - std::vector elems; - - if (str.empty()) { - elems.emplace_back(""); - return elems; - } - - std::stringstream ss(str); - std::string item; - - while (getline(ss, item, delim)) { - elems.push_back(item); - } - auto str_size = str.size(); - if (str_size > 0 && str[str_size - 1] == delim) { - elems.emplace_back(""); - } - - return elems; - } -}; -} // namespace ge -#endif // COMMON_GRAPH_UTILS_STRING_UTILS_H_ diff --git a/src/common/graph/utils/tensor_utils.cc b/src/common/graph/utils/tensor_utils.cc deleted file mode 100644 index 26ac8cc8..00000000 --- a/src/common/graph/utils/tensor_utils.cc +++ /dev/null @@ -1,401 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/utils/tensor_utils.h" -#include - -#include "debug/ge_log.h" -#include "framework/common/debug/ge_log.h" -#include "common/util/error_manager/error_manager.h" -#include "graph/ge_tensor.h" -#include "graph/types.h" -#include "graph/utils/type_utils.h" - -namespace ge { -namespace { -// When nc1hwc0 dim size = 5, calc element count directly. -const uint32_t kNc1hwc0CalcByDimsSize = 5; - -// Unknown shape element num -const int64_t kElementCntUnknownShape = -1; - -// Unknown shape mem size -const int64_t kMemSizeUnknownShape = -1; - -// Nchw and nhwc dim size must be 4 -const uint32_t kDimSize4d = 4; - -// C1HWNCoC0 dim size must be 6 -const uint32_t kDimSizeC1hwncoc0 = 6; - -// Cube size is 16 -const uint32_t kTheCubeSize = 16; - -// Default c0 size equals cube size. -const uint32_t kC0SizeDefault = kTheCubeSize; - -// Size equals int8 cube size is 32 -const uint32_t kC0SizeInt8 = 32; - -// NCHW dim N index -const int32_t kNchwDimIdxN = 0; -// NCHW dim C index -const int32_t kNchwDimIdxC = 1; -// NCHW dim H index -const int32_t kNchwDimIdxH = 2; -// NCHW dim W index -const int32_t kNchwDimIdxW = 3; - -const int kDataMemAlignSize = 32; -const int kNum2 = 2; -} // namespace - -/// -/// Check if a * b overflow. -/// @param a multiplier -/// @param b Multiplicand -/// @return true: overflow -/// false: not overflow -/// -static bool CheckMultiplyOverflowInt64(const int64_t &a, const int64_t &b) { - if (a > 0) { - if (b > 0) { - if (a > (INT64_MAX / b)) { - return true; - } - } else { - if (b < (INT64_MIN / a)) { - return true; - } - } - } else { - if (b > 0) { - if (a < (INT64_MIN / b)) { - return true; - } - } else { - if ((a != 0) && (b < (INT64_MAX / a))) { - return true; - } - } - } - return false; -} - -/// -/// Calculate element num by dims directly. -/// @param dims dim info -/// @param element_cnt element count -/// @return GRAPH_SUCCESS:success -/// other:failed -/// -static graphStatus CalcElementCntByDims(const std::vector &dims, int64_t &element_cnt) { - element_cnt = 1; - for (int64_t dim : dims) { - if (CheckMultiplyOverflowInt64(element_cnt, dim)) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19013", {"function", "var1", "var2"}, - {"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)}); - GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim); - return GRAPH_FAILED; - } - element_cnt *= dim; - } - return GRAPH_SUCCESS; -} - -/// -/// Calculate fixed dims element num. -/// @param dims dim info -/// @param fixed_dim_size fixed dim size -/// @param element_cnt element count -/// @return GRAPH_SUCCESS:success -/// other:failed -/// -static graphStatus CalcElementCntOfFixedDims(const std::vector &dims, Format format, uint32_t fixed_dim_size, - int64_t &element_cnt) { - if (dims.size() != fixed_dim_size) { - GELOGW("Format %d(%s) need dim size=%u but %zu, calc as ND.", format, - TypeUtils::FormatToSerialString(format).c_str(), fixed_dim_size, dims.size()); - } - return CalcElementCntByDims(dims, element_cnt); -} - -/// -/// Get dim c0 size by type -/// @param data_type data type -/// @return c0 size -/// -static uint32_t GetDimC0(DataType &data_type) { - bool is_int8_size = (data_type == DT_INT8) || (data_type == DT_UINT8) || (data_type == DT_DUAL_SUB_UINT8) || - (data_type == DT_DUAL_SUB_INT8) || (data_type == DT_BOOL) || (data_type == DT_QINT8); - return is_int8_size ? kC0SizeInt8 : kC0SizeDefault; -} - -/// -/// Calculate nc1hwc0 element num. -/// @param dims dim info -/// @param data_type data type -/// @param element_cnt element count -/// @return GRAPH_SUCCESS:success -/// other:failed -/// -static graphStatus CalcElementCntOfNc1hwc0(const std::vector &dims, DataType data_type, int64_t &element_cnt) { - // When nc1hwc0 dims size = 5, no need split dim c - if (dims.size() == kNc1hwc0CalcByDimsSize) { - return CalcElementCntByDims(dims, element_cnt); - } else if (dims.size() != kDimSize4d) { - GELOGE(GRAPH_FAILED, "CalcElementCntOfNc1hwc0 failed as dims.size=%zu is not %u or %u.", dims.size(), kDimSize4d, - kNc1hwc0CalcByDimsSize); - return GRAPH_FAILED; - } - - auto c0 = static_cast(GetDimC0(data_type)); - // Nc1hwc0 dims is according to nchw, dim c index is 1. - auto c1 = static_cast(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0)); - // Store dims is split c to c1 and c0. - std::vector store_dims = {dims[kNchwDimIdxN], c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0}; - return CalcElementCntByDims(store_dims, element_cnt); -} - -/// -/// Calculate FractalZ element num. -/// @param dims dim info -/// @param data_type data type -/// @param element_cnt element count -/// @return GRAPH_SUCCESS:success -/// other:failed -/// -static graphStatus CalcElementCntOfFractalZ(const std::vector &dims, DataType data_type, - int64_t &element_cnt) { - static char *parser_priority = std::getenv("PARSER_PRIORITY"); - if (parser_priority != nullptr && string(parser_priority) == "cce") { - if (dims.size() != kDimSize4d) { - GELOGE(GRAPH_FAILED, "CalcElementCntOfFractalZ failed as dims.size=%zu is not %u.", dims.size(), kDimSize4d); - return GRAPH_FAILED; - } - auto c0 = static_cast(GetDimC0(data_type)); - // FractalZ dims is according to nchw, dim c index is 1. - auto c1 = static_cast(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0)); - - // Spread NC1HWC0 as a two dimension array, n as column dimension, - // C1HWC0 as row dimension - std::vector r_count_vec = {c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0}; - - int64_t r_count = 1; - graphStatus graph_status = CalcElementCntByDims(r_count_vec, r_count); - if (graph_status != GRAPH_SUCCESS) { - GELOGE(graph_status, "Calc [%ld, %ld, %ld, %ld] element count failed.", c1, dims[kNchwDimIdxH], - dims[kNchwDimIdxW], c0); - return graph_status; - } - - // Cube count in n - auto nc_cnt = static_cast(std::ceil(dims[kNchwDimIdxN] * 1.0 / kTheCubeSize)); - - // Cube count in vertical direction(C1HWC0) - int64_t vc_cnt = r_count / c0; - // Element count in each cube - int64_t cube_elem_cnt = c0 * kTheCubeSize; - - if (CheckMultiplyOverflowInt64(nc_cnt, vc_cnt)) { - GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", nc_cnt, vc_cnt); - return GRAPH_FAILED; - } - // Read data times needed by cube - int64_t c_cnt = nc_cnt * vc_cnt; - - if (CheckMultiplyOverflowInt64(c_cnt, cube_elem_cnt)) { - GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", c_cnt, cube_elem_cnt); - return GRAPH_FAILED; - } - // Element count after fractal arrangement - element_cnt = c_cnt * cube_elem_cnt; - return GRAPH_SUCCESS; - } else { - return CalcElementCntByDims(dims, element_cnt); - } -} - -/// -/// Calculate tensor element num. -/// @param dims dim info -/// @param format tensor format -/// @param data_type data type -/// @param element_cnt element count -/// @return GRAPH_SUCCESS:success -/// other:failed -/// -static graphStatus CalcTensorElementCnt(const std::vector &dims, Format format, DataType data_type, - int64_t &element_cnt) { - const string format_str = TypeUtils::FormatToSerialString(format); - // Check dims - for (size_t i = 0; i < dims.size(); ++i) { - int64_t dim = dims[i]; - if (dim < 0) { - GELOGI("It's unknown shape, as dims[%zu]=%ld negative, format=%d(%s).", i, dim, format, format_str.c_str()); - element_cnt = kElementCntUnknownShape; - return GRAPH_SUCCESS; - } else if (dim == 0) { - GELOGI("No need calc element count, as dims[%zu]=%ld, format=%d(%s).", i, dim, format, format_str.c_str()); - element_cnt = 0; - return GRAPH_SUCCESS; - } - } - - graphStatus graph_status; - switch (format) { - case FORMAT_ND: - case FORMAT_MD: - graph_status = CalcElementCntByDims(dims, element_cnt); - break; - case FORMAT_NCHW: - case FORMAT_HWCN: - case FORMAT_NHWC: - case FORMAT_CHWN: - graph_status = CalcElementCntOfFixedDims(dims, format, kDimSize4d, element_cnt); - break; - case FORMAT_C1HWNCoC0: - graph_status = CalcElementCntOfFixedDims(dims, format, kDimSizeC1hwncoc0, element_cnt); - break; - case FORMAT_NC1HWC0: - graph_status = CalcElementCntOfNc1hwc0(dims, data_type, element_cnt); - break; - case FORMAT_FRACTAL_Z: - graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt); - break; - case FORMAT_FRACTAL_NZ: - case FORMAT_FRACTAL_ZZ: - case FORMAT_NDHWC: - case FORMAT_NCDHW: - case FORMAT_DHWCN: - case FORMAT_DHWNC: - case FORMAT_FRACTAL_Z_3D: - case FORMAT_FRACTAL_Z_3D_TRANSPOSE: - case FORMAT_NDC1HWC0: - case FORMAT_FRACTAL_Z_C04: - case FORMAT_FRACTAL_ZN_LSTM: - case FORMAT_NC1HWC0_C04: - graph_status = CalcElementCntByDims(dims, element_cnt); - break; - default: - GELOGE(GRAPH_FAILED, "unsupported format, format=%d(%s).", format, format_str.c_str()); - graph_status = GRAPH_FAILED; - break; - } - - const string type_str = TypeUtils::DataTypeToSerialString(data_type); - if (graph_status == GRAPH_SUCCESS) { - GELOGD( - "CalcTensorElementCnt end, format=%d(%s)," - " data_type=%d(%s), element_cnt=%ld.", - format, format_str.c_str(), data_type, type_str.c_str(), element_cnt); - } else { - GELOGE(GRAPH_FAILED, "CalcTensorElementCnt failed, format=%d(%s), data_type=%d(%s).", format, format_str.c_str(), - data_type, type_str.c_str()); - } - return graph_status; -} - -/// -/// Calculate tensor mem size. -/// @param shape tensor shape -/// @param format tensor format -/// @param data_type tensor data type -/// @param mem_size -1 means unknown shape,other means mem size -/// @return GRAPH_SUCCESS:success, other:failed -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTensorMemSize(const GeShape &shape, - Format format, - DataType data_type, - int64_t &mem_size) { - const string format_str = TypeUtils::FormatToSerialString(format); - const string type_str = TypeUtils::DataTypeToSerialString(data_type); - uint32_t type_size = 0; - bool result = TypeUtils::GetDataTypeLength(data_type, type_size); - if (!result) { - GELOGE(GRAPH_FAILED, "GetDataTypeLength failed, data_type=%d(%s).", data_type, type_str.c_str()); - return GRAPH_FAILED; - } - - std::vector dims = shape.GetDims(); - int64_t element_cnt = 0; - graphStatus status = CalcTensorElementCnt(dims, format, data_type, element_cnt); - if (status != GRAPH_SUCCESS) { - GELOGE(status, "CalcTensorElementCnt failed, status=%u format=%d(%s) data_type=%d(%s).", status, format, - format_str.c_str(), data_type, type_str.c_str()); - return status; - } - // Support unknown shape - if (element_cnt < 0) { - mem_size = kMemSizeUnknownShape; - GELOGD( - "element_cnt is unknown. " - "format=%d(%s), data_type=%d(%s), mem_size=%ld", - format, format_str.c_str(), data_type, type_str.c_str(), mem_size); - return GRAPH_SUCCESS; - } - auto type_size_int64 = static_cast(type_size); - if (CheckMultiplyOverflowInt64(element_cnt, type_size_int64)) { - GELOGE(GRAPH_FAILED, "CalcTensorMemSize overflow, when multiplying %ld and %ld, format=%d(%s), data_type=%d(%s).", - element_cnt, type_size_int64, format, format_str.c_str(), data_type, type_str.c_str()); - return GRAPH_FAILED; - } - mem_size = element_cnt * type_size_int64; - - GELOGD( - "CalcTensorMemSize end, " - "format=%d(%s), data_type=%d(%s), mem_size=%ld", - format, format_str.c_str(), data_type, type_str.c_str(), mem_size); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { - graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp); - if (graph_status != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - // 64-byte alignment, if size is 0, align to 32 bytes - if (size_temp > (INT64_MAX - kNum2 * kDataMemAlignSize)) { - GELOGW("The updated mem size %ld is bigger than INT64_MAX", size_temp); - } else { - size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize; - } - return GRAPH_SUCCESS; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { - GeShape output_shape = desc_temp.GetShape(); - Format format = desc_temp.GetFormat(); - DataType data_type = desc_temp.GetDataType(); - int64_t output_mem_size = 0; - graphStatus graph_status = CalcTensorMemSize(output_shape, format, data_type, output_mem_size); - if (graph_status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "CalcTensorMemSize failed!"); - return GRAPH_FAILED; - } - - if (output_mem_size < 0) { - GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %ld]", - output_mem_size, INT64_MAX); - return GRAPH_FAILED; - } - - size_temp = output_mem_size; - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/src/common/graph/utils/tuning_utils.cc b/src/common/graph/utils/tuning_utils.cc deleted file mode 100644 index 0f07a197..00000000 --- a/src/common/graph/utils/tuning_utils.cc +++ /dev/null @@ -1,684 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/tuning_utils.h" -#include "../debug/ge_util.h" -#include "../debug/ge_op_types.h" - -namespace ge { -const std::string peer_node_name_attr = "_peerNodeName"; -const std::string parent_node_name_attr = "_parentNodeName"; -const std::string alias_name_attr = "_aliasName"; -const std::string parent_node_attr = "parentNode"; -const std::string parent_node_anchor_index_attr = "_parentNodeAnchorIndex"; -const std::string tuning_subgraph_prefix = "/aicore_subgraph_"; -const std::string non_tuning_subgraph_prefix = "/subgraph_"; -const std::set kPartitionOpTypes = {PLACEHOLDER, END}; -const std::set kExeTypes = {DATA, NETOUTPUT}; -NodeNametoNodeNameMap TuningUtils::data_2_netoutput_; -NodetoNodeNameMap TuningUtils::data_node_2_netoutput_; -NodetoNodeMap TuningUtils::data_node_2_netoutput_node_; -NodeSet TuningUtils::netoutput_nodes_; -NodeSet TuningUtils::merged_graph_nodes_; -SubgraphCreateOutNode TuningUtils::create_output_; -std::mutex TuningUtils::mutex_; - -std::string TuningUtils::PrintCheckLog() { - std::stringstream ss; - ss << "d2n:{"; - for (const auto &pair : data_2_netoutput_) { - ss << "data:" << pair.first << "-" - << "netoutput:" << pair.second; - ss << " | "; - } - ss << "}"; - ss << "netoutputs:{"; - for (const auto &node : netoutput_nodes_) { - ss << "netoutput:" << node->GetName(); - ss << " | "; - } - ss << "}"; - return ss.str(); -} - -std::string TuningUtils::GetNodeNameByAnchor(const Anchor *anchor) { - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Anchor is nullptr"); - return "Null"; - } - auto node = anchor->GetOwnerNode(); - return node == nullptr ? "Null" : node->GetName(); -} - -// part 1 -graphStatus TuningUtils::ConvertGraphToFile(std::vector tuning_subgraphs, - std::vector non_tuning_subgraphs, bool exe_flag, - const std::string &path, const std::string &user_path) { - int64_t i = 0; - int64_t j = 0; - std::lock_guard lock(mutex_); - for (auto &subgraph : tuning_subgraphs) { - create_output_.emplace(subgraph, nullptr); - auto help_info = HelpInfo{i, exe_flag, true, path, user_path}; - if (MakeExeGraph(subgraph, help_info) != SUCCESS) { - GELOGE(GRAPH_FAILED, "TUU:subgraph %zu generate exe graph failed", i); - return GRAPH_FAILED; - } - i++; - } - - for (auto &subgraph : non_tuning_subgraphs) { - create_output_.emplace(subgraph, nullptr); - auto help_info = HelpInfo{j, true, false, path, user_path}; - if (MakeExeGraph(subgraph, help_info) != SUCCESS) { - GELOGE(GRAPH_FAILED, "TUU:non tuning_subgraph %zu generate exe graph failed", j); - return GRAPH_FAILED; - } - j++; - } - create_output_.clear(); - return SUCCESS; -} - -// +---------------+ -// | pld pld | -// | \ / | -// | relu relu | -// | \ / | -// | add | -// | | | -// | end | -// +---------------+ -// | -// | -// V -// +---------------+ -// | data data | -// | \ / | -// | relu relu | -// | \ / | -// | add | -// | | | -// | netoutput | -// +---------------+ -graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info) { - GE_CHECK_NOTNULL(exe_graph); - // if not make exe, just dump and return - if (!help_info.exe_flag) { - DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path); - GELOGI("TUU:just return, dump original sub_graph[%s]index[%d]", exe_graph->GetName().c_str(), help_info.index); - return SUCCESS; - } - // modify sub graph - for (NodePtr &node : exe_graph->GetDirectNode()) { - // 1.handle pld - if (node->GetType() == PLACEHOLDER) { - if (HandlePld(node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), - exe_graph->GetName().c_str()); - return FAILED; - } - } - // 2.handle end - if (node->GetType() == END) { - if (HandleEnd(node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), - exe_graph->GetName().c_str()); - return FAILED; - } - } - } - graphStatus ret = exe_graph->TopologicalSorting(); - if (ret != SUCCESS) { - GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", exe_graph->GetName().c_str(), ret); - return ret; - } - // dump subgraphs which modified by us - if (help_info.user_path.empty()) { - DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path); - } else { - GraphUtils::DumpGEGraph(exe_graph, "", true, help_info.user_path); - } - return SUCCESS; -} - -void TuningUtils::DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path) { - if (!path.empty()) { - if (is_tuning_graph) { - GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt"); - } else { - GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt"); - } - } else { - path = "./"; - if (is_tuning_graph) { - GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt"); - } else { - GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt"); - } - } -} - -graphStatus TuningUtils::CreateDataNode(NodePtr &node, NodePtr &data_node) { - auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - auto data_op_desc = ComGraphMakeShared(node->GetName(), DATA); - GE_CHECK_NOTNULL(data_op_desc); - auto pld_op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(pld_op_desc); - auto output_desc = pld_op_desc->GetOutputDesc(0); // only one output for pld and data - // data inputdesc & outputdesc set as same - if (data_op_desc->AddInputDesc(output_desc) != SUCCESS) { - GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str()); - return FAILED; - } - if (data_op_desc->AddOutputDesc(output_desc) != SUCCESS) { - GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str()); - return FAILED; - } - data_node = graph->AddNode(data_op_desc); - GE_CHECK_NOTNULL(data_node); - if (data_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed"); - return FAILED; - } - return SUCCESS; -} - -graphStatus TuningUtils::AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node) { - auto op_desc = data_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - auto pld_desc = pld->GetOpDesc(); - GE_CHECK_NOTNULL(pld_desc); - // inherit - // a. set `end's input node type` as attr - std::string parent_op_type; - if (!AttrUtils::GetStr(pld_desc, "parentOpType", parent_op_type)) { - GELOGE(FAILED, "TUU:pld %s get parentOpType failed", pld_desc->GetName().c_str()); - return FAILED; - } - (void)AttrUtils::SetStr(op_desc, "parentOpType", parent_op_type); - // b. set `end's input node name` as attr - std::string parent_op_name; - if (!AttrUtils::GetStr(pld_desc, parent_node_name_attr, parent_op_name)) { - GELOGE(FAILED, "TUU:pld %s get _parentNodeName failed", pld_desc->GetName().c_str()); - return FAILED; - } - (void)AttrUtils::SetStr(op_desc, parent_node_name_attr, parent_op_name); - // c. set `end's input node's out anchor index` as attr - int parent_node_anchor_index; - if (!AttrUtils::GetInt(pld_desc, "anchorIndex", parent_node_anchor_index)) { - GELOGE(FAILED, "TUU:pld %s get anchorIndex failed", pld_desc->GetName().c_str()); - return FAILED; - } - (void)AttrUtils::SetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index); - GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(), - data_node->GetName().c_str(), data_node->GetType().c_str()); - // d. set `end node name` as attr - std::string peer_end_name; - if (!AttrUtils::GetStr(pld_desc, peer_node_name_attr, peer_end_name)) { - GELOGE(FAILED, "TUU:pld %s get _peerNodeName failed", pld_desc->GetName().c_str()); - return FAILED; - } - (void)AttrUtils::SetStr(op_desc, peer_node_name_attr, peer_end_name); - GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(), - data_node->GetName().c_str(), data_node->GetType().c_str()); - return SUCCESS; -} - -graphStatus TuningUtils::ChangePld2Data(NodePtr &node, NodePtr &data_node) { - auto type_pld = node->GetType(); - auto type_data = data_node->GetType(); - if (type_pld != PLACEHOLDER || type_data != DATA) { - GELOGE(FAILED, "TUU:Failed to change node %s from type %s to type %s", node->GetName().c_str(), type_pld.c_str(), - type_data.c_str()); - return FAILED; - } - auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - std::vector output_map(node->GetAllOutDataAnchorsSize()); - for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) { - output_map[i] = static_cast(i); - } - - auto ret = GraphUtils::ReplaceNodeAnchors(data_node, node, {}, output_map); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:Failed to replace node %s by node %s error node %u", node->GetName().c_str(), - data_node->GetName().c_str(), ret); - return FAILED; - } - - NodeUtils::UnlinkAll(*node); - - ret = GraphUtils::RemoveNodeWithoutRelink(graph, node); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str()); - return FAILED; - } - - GELOGD("TUU:Remove node %s(%s) by the ChangePld2Data process, replace it with node %s(%s)", node->GetName().c_str(), - node->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str()); - return ret; -} - -graphStatus TuningUtils::HandlePld(NodePtr &node) { - GE_CHECK_NOTNULL(node); - auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - NodePtr data_node = nullptr; - - // 1. create data node - if (CreateDataNode(node, data_node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - // 2. add necessary info to data_node for recovery whole graph - if (AddAttrToDataNodeForMergeGraph(node, data_node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - // 3. replace pld node by data node created before - if (ChangePld2Data(node, data_node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - GELOGD("TUU:pld[%s] handle success", node->GetName().c_str()); - return SUCCESS; -} - -graphStatus TuningUtils::CreateNetOutput(NodePtr &node, NodePtr &out_node) { - GE_CHECK_NOTNULL(node); - auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - auto search = create_output_.find(graph); - if (search == create_output_.end()) { - GELOGE(FAILED, "TUU:node %s's owner sub graph %s not exist in create_output map", node->GetName().c_str(), - graph->GetName().c_str()); - return FAILED; - } - if (search->second != nullptr) { - out_node = search->second; - GELOGD("TUU:sub graph %s has created output node, just return", graph->GetName().c_str()); - return SUCCESS; - } - auto out_op_desc = ComGraphMakeShared(node->GetName(), NETOUTPUT); - GE_CHECK_NOTNULL(out_op_desc); - out_node = graph->AddNode(out_op_desc); - GE_CHECK_NOTNULL(out_node); - if (out_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed"); - return FAILED; - } - create_output_[graph] = out_node; - return SUCCESS; -} - -graphStatus TuningUtils::AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node) { - GE_CHECK_NOTNULL(end); - GE_CHECK_NOTNULL(out_node); - auto op_desc = out_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - std::vector alias_names = {}; - (void)AttrUtils::GetListStr(op_desc, alias_name_attr, alias_names); - alias_names.push_back(end->GetName()); - (void)AttrUtils::SetListStr(op_desc, alias_name_attr, alias_names); - return SUCCESS; -} - -graphStatus TuningUtils::LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) { - GE_CHECK_NOTNULL(end_node); - GE_CHECK_NOTNULL(out_node); - // get end in node is control node or normal node - AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr) - ? Anchor::DynamicAnchorCast(end_node->GetInControlAnchor()) - : Anchor::DynamicAnchorCast(end_node->GetInDataAnchor(0)); - auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1 - if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:remove end input edge from from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", - GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), - GetNodeNameByAnchor(end_in_anchor.get()).c_str(), end_in_anchor->GetIdx(), end_node->GetName().c_str(), - end_node->GetOwnerComputeGraph()->GetName().c_str()); - return FAILED; - } - // add edge between `end in node` and `out_node` - if (src_anchor->IsTypeOf()) { - std::shared_ptr anchor = - ComGraphMakeShared(out_node, out_node->GetAllInDataAnchors().size()); - GE_CHECK_NOTNULL(anchor); - out_node->in_data_anchors_.push_back(anchor); - if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", - GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), - GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(), - end_node->GetOwnerComputeGraph()->GetName().c_str()); - return FAILED; - } - auto end_op_desc = end_node->GetOpDesc(); - GE_CHECK_NOTNULL(end_op_desc); - auto out_node_op_desc = out_node->GetOpDesc(); - GE_CHECK_NOTNULL(out_node_op_desc); - // end node always has one input - if (out_node_op_desc->AddInputDesc(end_op_desc->GetInputDesc(0)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:node %s add input desc failed.", out_node_op_desc->GetName().c_str()); - return FAILED; - } - } else if (src_anchor->IsTypeOf()) { - auto anchor = out_node->GetInControlAnchor(); - if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", - GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), - GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(), - end_node->GetOwnerComputeGraph()->GetName().c_str()); - return FAILED; - } - } else { - GELOGE(FAILED, "TUU: node_name:%s, graph_name:%s handled failed", end_node->GetName().c_str(), - end_node->GetOwnerComputeGraph()->GetName().c_str()); - return FAILED; - } - - return SUCCESS; -} - -graphStatus TuningUtils::ChangeEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) { - GE_CHECK_NOTNULL(end_node); - GE_CHECK_NOTNULL(out_node); - auto type_end = end_node->GetType(); - auto type_out = out_node->GetType(); - if (type_end != END || type_out != NETOUTPUT) { - GELOGE(FAILED, "TUU:Failed to change end_node %s from type %s to type %s", end_node->GetName().c_str(), - type_end.c_str(), type_out.c_str()); - return FAILED; - } - // link all `end nodes's in node` to this out_node - if (LinkEnd2NetOutput(end_node, out_node) != SUCCESS) { - GELOGE(FAILED, "TUU:end_node [%s] LinkEnd2NetOutput failed.", end_node->GetName().c_str()); - return FAILED; - } - // remove `end node` - NodeUtils::UnlinkAll(*end_node); - auto graph = end_node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - if (GraphUtils::RemoveNodeWithoutRelink(graph, end_node) != SUCCESS) { - GELOGE(FAILED, "TUU:end node [%s] RemoveNodeWithoutRelink failed.", end_node->GetName().c_str()); - return FAILED; - } - return SUCCESS; -} - -graphStatus TuningUtils::HandleEnd(NodePtr &node) { - GE_CHECK_NOTNULL(node); - auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - NodePtr out_node = nullptr; - - // 1. create net_output node , add only one NetOutput node to one subgraph - if (CreateNetOutput(node, out_node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - // 2. add necessary info to out_node for recovery whole graph - if (AddAttrToNetOutputForMergeGraph(node, out_node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - // 3. replace all end nodes by one output node created before - if (ChangeEnd2NetOutput(node, out_node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - GELOGD("TUU:end[%s] handle success", node->GetName().c_str()); - return SUCCESS; -} - -// part 2 -graphStatus TuningUtils::ConvertFileToGraph(const map &options, ge::Graph &graph) { - // 1. get all subgraph object - std::vector graphs; - // options format like {index:"subgraph_path"} - for (const auto &pair : options) { - ComputeGraphPtr compute_graph = ComGraphMakeShared(std::to_string(pair.first)); - if (!ge::GraphUtils::LoadGEGraph(pair.second.c_str(), *compute_graph)) { - GELOGE(FAILED, "TUU:load graph from file failed"); - } - graphs.push_back(compute_graph); - } - // 2. merge graph - ComputeGraphPtr merged_graph = ComGraphMakeShared("whole_graph_after_tune"); - GE_CHECK_NOTNULL(merged_graph); - if (MergeAllSubGraph(graphs, merged_graph) != SUCCESS) { - GELOGE(FAILED, "TUU:MergeGraph failed"); - return FAILED; - } - // 3. set parent graph - for (const auto &node : merged_graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node); - if (node->SetOwnerComputeGraph(merged_graph) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:node %s set owner graph failed", node->GetName().c_str()); - return FAILED; - } - } - graph = GraphUtils::CreateGraphFromComputeGraph(merged_graph); - return SUCCESS; -} - -// +----------------------------------+ -// | const const | -// | \ / | -// | netoutput(end,end) | -// +----------------------------------+ -// + -// +----------------------------------+ -// | data(pld) data(pld) | -// | \ / | -// | relu relu | -// | \ / | -// | \ / | -// | add | -// | | | -// | netoutput(end) | -// +----------------------------------+ -// + -// +----------------------------------+ -// | data(pld) | -// | / | -// | netoutput | -// +----------------------------------+ -// | -// | -// V -// +----------------------------------+ -// | const const | -// | \ / | -// | relu relu | -// | \ / | -// | \ / | -// | add | -// | | | -// | netoutput | -// +----------------------------------+ -graphStatus TuningUtils::MergeAllSubGraph(std::vector &subgraphs, - ComputeGraphPtr &output_merged_compute_graph) { - GE_CHECK_NOTNULL(output_merged_compute_graph); - // 1. handle all subgraphs - for (auto &subgraph : subgraphs) { - Status ret_status = MergeSubGraph(subgraph); - if (ret_status != SUCCESS) { - GELOGE(ret_status, "TUU:subgraph %s merge failed", subgraph->GetName().c_str()); - return ret_status; - } - } - - for (const auto &node : merged_graph_nodes_) { - (void)output_merged_compute_graph->AddNode(node); - GELOGD("TUU:graph %s add node %s success", output_merged_compute_graph->GetName().c_str(), node->GetName().c_str()); - } - - // 2. remove data and output node added by us - if (RemoveDataNetoutputEdge(output_merged_compute_graph) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to merge graph %s", output_merged_compute_graph->GetName().c_str()); - return FAILED; - } - graphStatus ret = output_merged_compute_graph->TopologicalSorting(); - if (ret != SUCCESS) { - GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", output_merged_compute_graph->GetName().c_str(), ret); - return ret; - } - GELOGD("TUU:Print-%s", PrintCheckLog().c_str()); - GELOGI("TUU:output_merged_compute_graph %s success", output_merged_compute_graph->GetName().c_str()); - return SUCCESS; -} - -graphStatus TuningUtils::MergeSubGraph(ComputeGraphPtr &subgraph) { - for (auto &node : subgraph->GetDirectNode()) { - if (kPartitionOpTypes.count(node->GetType()) > 0) { - GELOGE(FAILED, "TUU:subgraph passed in should not contain nodes of end or pld type"); - return FAILED; - } - // handle data converted from pld node - if (node->GetType() == DATA) { - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - std::string peer_out_name; - bool has_valid_str = (AttrUtils::GetStr(op_desc, peer_node_name_attr, peer_out_name)) && (!peer_out_name.empty()); - if (has_valid_str) { - std::lock_guard lock(mutex_); - data_2_netoutput_.emplace(op_desc->GetName(), peer_out_name); - data_node_2_netoutput_.emplace(node, peer_out_name); - continue; - } - } - // handle netoutput converted from end node - if (node->GetType() == NETOUTPUT) { - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - std::vector out_alias_name; - bool has_valid_str = - (AttrUtils::GetListStr(op_desc, alias_name_attr, out_alias_name)) && (!out_alias_name.empty()); - if (has_valid_str) { - std::lock_guard lock(mutex_); - netoutput_nodes_.insert(node); - } - } - { - std::lock_guard lock(mutex_); - merged_graph_nodes_.emplace(node); - } - GELOGD("TUU:subgraph %s add node %s success", subgraph->GetName().c_str(), node->GetName().c_str()); - } - GELOGI("TUU:merge subgraph %s success", subgraph->GetName().c_str()); - return SUCCESS; -} - -graphStatus TuningUtils::RemoveDataNetoutputEdge(ComputeGraphPtr &graph) { - GE_CHECK_NOTNULL(graph); - // 1. traverse - for (auto &pair : data_node_2_netoutput_) { - auto data_node = pair.first; - GE_CHECK_NOTNULL(data_node); - auto netoutput_name = pair.second; - auto netoutput_node = graph->FindNode(netoutput_name); - GE_CHECK_NOTNULL(netoutput_node); - data_node_2_netoutput_node_.emplace(data_node, netoutput_node); - // 2. get `data out anchor` and `net output in anchor` and `net output in node's out anchor` - AnchorPtr data_out_anchor = (data_node->GetOutDataAnchor(0)->GetFirstPeerAnchor() == nullptr) - ? Anchor::DynamicAnchorCast(data_node->GetOutControlAnchor()) - : Anchor::DynamicAnchorCast(data_node->GetOutDataAnchor(0)); - AnchorPtr net_output_in_anchor = nullptr; - AnchorPtr src_out_anchor = nullptr; - if (GetInAndOutAnchorPair(data_node, netoutput_node, net_output_in_anchor, src_out_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:get out node:%s 's in anchor related with data node:%s failed", - netoutput_node->GetName().c_str(), data_node->GetName().c_str()); - return FAILED; - } - // 3. relink - if (GraphUtils::RemoveEdge(src_out_anchor, net_output_in_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", - GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(), - GetNodeNameByAnchor(net_output_in_anchor.get()).c_str(), net_output_in_anchor->GetIdx(), - data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - GE_CHECK_NOTNULL(data_out_anchor); - for (const auto &peer_in_anchor : data_out_anchor->GetPeerAnchors()) { - if (GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", - GetNodeNameByAnchor(data_out_anchor.get()).c_str(), data_out_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), - data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - if (GraphUtils::AddEdge(src_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", - GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), - data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - } - } - // 4. remove out nodes added by us - for (auto &node : netoutput_nodes_) { - NodeUtils::UnlinkAll(*node); - if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str()); - return FAILED; - } - GELOGD("TUU:Remove node %s by the RemoveDataNetoutputEdge process success", node->GetName().c_str()); - } - return SUCCESS; -} - -graphStatus TuningUtils::GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor, - AnchorPtr &src_out_anchor) { - // 1. get `data parent node name`, i.e. `netoutput input node name` - std::string netoutput_input_name; - auto op_desc = data_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (!AttrUtils::GetStr(op_desc, parent_node_name_attr, netoutput_input_name)) { - GELOGE(FAILED, "TUU:Failed to get parent node attr from node %s", op_desc->GetName().c_str()); - return FAILED; - } - // 2. find index - int parent_node_anchor_index; - if (!AttrUtils::GetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index)) { - GELOGE(FAILED, "TUU:Failed to get parent node anchor index attr from node %s", op_desc->GetName().c_str()); - return FAILED; - } - // 3.find in data or ctrl anchor by 1&2 step - for (auto &in_anchor : out_node->GetAllInAnchors()) { - GE_CHECK_NOTNULL(in_anchor); - for (auto &src_anchor : in_anchor->GetPeerAnchors()) { // get all peer anchors for ctrl - GE_CHECK_NOTNULL(src_anchor); - auto src_node = src_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(src_node); - if (src_node->GetName() == netoutput_input_name && src_anchor->GetIdx() == parent_node_anchor_index) { - dest_in_anchor = in_anchor; - src_out_anchor = src_anchor; - GELOGD("TUU:get out node:%s 's in anchor(%d) src_node:%s 's out anchor(%d) related with data node:%s", - out_node->GetName().c_str(), dest_in_anchor->GetIdx(), netoutput_input_name.c_str(), - parent_node_anchor_index, data_node->GetName().c_str()); - break; - } - } - } - GE_CHECK_NOTNULL(dest_in_anchor); - GE_CHECK_NOTNULL(src_out_anchor); - return SUCCESS; -} - -} // namespace ge \ No newline at end of file diff --git a/src/common/graph/utils/type_utils.cc b/src/common/graph/utils/type_utils.cc deleted file mode 100644 index 2efc530e..00000000 --- a/src/common/graph/utils/type_utils.cc +++ /dev/null @@ -1,448 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/utils/type_utils.h" -#include "debug/ge_util.h" - -using domi::domiTensorFormat_t; - -namespace ge { -static const std::map kFormatToStringMap = { - {FORMAT_NCHW, "NCHW"}, - {FORMAT_NHWC, "NHWC"}, - {FORMAT_ND, "ND"}, - {FORMAT_NC1HWC0, "NC1HWC0"}, - {FORMAT_FRACTAL_Z, "FRACTAL_Z"}, - {FORMAT_NC1C0HWPAD, "NC1C0HWPAD"}, - {FORMAT_NHWC1C0, "NHWC1C0"}, - {FORMAT_FSR_NCHW, "FSR_NCHW"}, - {FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"}, - {FORMAT_C1HWNC0, "C1HWNC0"}, - {FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, - {FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, - {FORMAT_NC1HWC0_C04, "NC1HWC0_C04"}, - {FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, - {FORMAT_CHWN, "CHWN"}, - {FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, - {FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"}, - {FORMAT_BN_WEIGHT, "BN_WEIGHT"}, - {FORMAT_FILTER_HWCK, "FILTER_HWCK"}, - {FORMAT_HWCN, "HWCN"}, - {FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, - {FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, - {FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, - {FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, - {FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, - {FORMAT_MD, "MD"}, - {FORMAT_NDHWC, "NDHWC"}, - {FORMAT_NCDHW, "NCDHW"}, - {FORMAT_DHWCN, "DHWCN"}, - {FORMAT_DHWNC, "DHWNC"}, - {FORMAT_NDC1HWC0, "NDC1HWC0"}, - {FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, - {FORMAT_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"}, - {FORMAT_C1HWNCoC0, "C1HWNCoC0"}, - {FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, - {FORMAT_CN, "CN"}, - {FORMAT_NC, "NC"}, - {FORMAT_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"}, - {FORMAT_FRACTAL_Z_G, "FRACTAL_Z_G"}, - {FORMAT_RESERVED, "FORMAT_RESERVED"}, - {FORMAT_ALL, "ALL"}}; - -static const std::map kDomiFormatToGeFormat = { - {domi::DOMI_TENSOR_NCHW, FORMAT_NCHW}, - {domi::DOMI_TENSOR_NHWC, FORMAT_NHWC}, - {domi::DOMI_TENSOR_ND, FORMAT_ND}, - {domi::DOMI_TENSOR_NC1HWC0, FORMAT_NC1HWC0}, - {domi::DOMI_TENSOR_FRACTAL_Z, FORMAT_FRACTAL_Z}, - {domi::DOMI_TENSOR_NC1C0HWPAD, FORMAT_NC1C0HWPAD}, - {domi::DOMI_TENSOR_NHWC1C0, FORMAT_NHWC1C0}, - {domi::DOMI_TENSOR_FSR_NCHW, FORMAT_FSR_NCHW}, - {domi::DOMI_TENSOR_FRACTAL_DECONV, FORMAT_FRACTAL_DECONV}, - {domi::DOMI_TENSOR_BN_WEIGHT, FORMAT_BN_WEIGHT}, - {domi::DOMI_TENSOR_CHWN, FORMAT_CHWN}, - {domi::DOMI_TENSOR_FILTER_HWCK, FORMAT_FILTER_HWCK}, - {domi::DOMI_TENSOR_NDHWC, FORMAT_NDHWC}, - {domi::DOMI_TENSOR_NCDHW, FORMAT_NCDHW}, - {domi::DOMI_TENSOR_DHWCN, FORMAT_DHWCN}, - {domi::DOMI_TENSOR_DHWNC, FORMAT_DHWNC}, - {domi::DOMI_TENSOR_RESERVED, FORMAT_RESERVED}}; - -static const std::unordered_set kInternalFormat = {"NC1HWC0", - "FRACTAL_Z", - "NC1C0HWPAD", - "NHWC1C0", - "FRACTAL_DECONV", - "C1HWNC0", - "FRACTAL_DECONV_TRANSPOSE", - "FRACTAL_DECONV_SP_STRIDE_TRANS", - "NC1HWC0_C04", - "FRACTAL_Z_C04", - "FRACTAL_DECONV_SP_STRIDE8_TRANS", - "NC1KHKWHWC0", - "C1HWNCoC0", - "FRACTAL_ZZ", - "FRACTAL_NZ", - "NDC1HWC0", - "FORMAT_FRACTAL_Z_3D", - "FORMAT_FRACTAL_Z_3D_TRANSPOSE", - "FORMAT_FRACTAL_ZN_LSTM", - "FORMAT_FRACTAL_Z_G"}; - -static const std::map kDataFormatMap = { - {"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}}; - -static const std::map kStringToFormatMap = { - {"NCHW", FORMAT_NCHW}, - {"NHWC", FORMAT_NHWC}, - {"ND", FORMAT_ND}, - {"NC1HWC0", FORMAT_NC1HWC0}, - {"FRACTAL_Z", FORMAT_FRACTAL_Z}, - {"NC1C0HWPAD", FORMAT_NC1C0HWPAD}, - {"NHWC1C0", FORMAT_NHWC1C0}, - {"FSR_NCHW", FORMAT_FSR_NCHW}, - {"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV}, - {"C1HWNC0", FORMAT_C1HWNC0}, - {"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE}, - {"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, - {"NC1HWC0_C04", FORMAT_NC1HWC0_C04}, - {"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04}, - {"CHWN", FORMAT_CHWN}, - {"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, - {"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0}, - {"BN_WEIGHT", FORMAT_BN_WEIGHT}, - {"FILTER_HWCK", FORMAT_FILTER_HWCK}, - {"HWCN", FORMAT_HWCN}, - {"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, - {"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS}, - {"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE}, - {"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, - {"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS}, - {"MD", FORMAT_MD}, - {"C1HWNCoC0", FORMAT_C1HWNCoC0}, - {"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, - {"NDHWC", FORMAT_NDHWC}, - {"NCDHW", FORMAT_NCDHW}, - {"DHWCN", FORMAT_DHWCN}, - {"DHWNC", FORMAT_DHWNC}, - {"NDC1HWC0", FORMAT_NDC1HWC0}, - {"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D}, - {"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, - {"CN", FORMAT_CN}, - {"NC", FORMAT_NC}, - {"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, - {"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G}, - {"FORMAT_RESERVED", FORMAT_RESERVED}, - {"ALL", FORMAT_ALL}, - {"NULL", FORMAT_NULL}}; - -static const std::map kDataTypeToStringMap = { - {DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. - {DT_FLOAT, "DT_FLOAT"}, // float type - {DT_FLOAT16, "DT_FLOAT16"}, // fp16 type - {DT_INT8, "DT_INT8"}, // int8 type - {DT_INT16, "DT_INT16"}, // int16 type - {DT_UINT16, "DT_UINT16"}, // uint16 type - {DT_UINT8, "DT_UINT8"}, // uint8 type - {DT_INT32, "DT_INT32"}, // uint32 type - {DT_INT64, "DT_INT64"}, // int64 type - {DT_UINT32, "DT_UINT32"}, // unsigned int32 - {DT_UINT64, "DT_UINT64"}, // unsigned int64 - {DT_BOOL, "DT_BOOL"}, // bool type - {DT_DOUBLE, "DT_DOUBLE"}, // double type - {DT_DUAL, "DT_DUAL"}, // dual output type - {DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type - {DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type - {DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type - {DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type - {DT_QINT8, "DT_QINT8"}, // qint8 type - {DT_QINT16, "DT_QINT16"}, // qint16 type - {DT_QINT32, "DT_QINT32"}, // qint32 type - {DT_QUINT8, "DT_QUINT8"}, // quint8 type - {DT_QUINT16, "DT_QUINT16"}, // quint16 type - {DT_RESOURCE, "DT_RESOURCE"}, // resource type - {DT_STRING_REF, "DT_STRING_REF"}, // string ref type - {DT_STRING, "DT_STRING"}, // string type -}; - -static const std::map kStringTodataTypeMap = { - {"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set. - {"DT_FLOAT", DT_FLOAT}, // float type - { - "DT_FLOAT16", - DT_FLOAT16, - }, // fp16 type - {"DT_INT8", DT_INT8}, // int8 type - {"DT_INT16", DT_INT16}, // int16 type - {"DT_UINT16", DT_UINT16}, // uint16 type - {"DT_UINT8", DT_UINT8}, // uint8 type - {"DT_INT32", DT_INT32}, // uint32 type - {"DT_INT64", DT_INT64}, // int64 type - {"DT_UINT32", DT_UINT32}, // unsigned int32 - {"DT_UINT64", DT_UINT64}, // unsigned int64 - {"DT_BOOL", DT_BOOL}, // bool type - {"DT_DOUBLE", DT_DOUBLE}, // double type - {"DT_DUAL", DT_DUAL}, // dual output type - {"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type - {"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type - {"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type - {"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type - {"DT_QINT8", DT_QINT8}, // qint8 type - {"DT_QINT16", DT_QINT16}, // qint16 type - {"DT_QINT32", DT_QINT32}, // qint32 type - {"DT_QUINT8", DT_QUINT8}, // quint8 type - {"DT_QUINT16", DT_QUINT16}, // quint16 type - {"DT_RESOURCE", DT_RESOURCE}, // resource type - {"DT_STRING_REF", DT_STRING_REF}, // string ref type - {"DT_STRING", DT_STRING}, // string type -}; - -static const std::map kDataTypeToLength = { - {DT_BOOL, sizeof(bool)}, - {DT_INT64, sizeof(int64_t)}, - {DT_UINT64, sizeof(int64_t)}, - {DT_FLOAT, sizeof(float)}, - {DT_INT32, sizeof(int32_t)}, - {DT_UINT32, sizeof(int32_t)}, - {DT_INT8, sizeof(char)}, - {DT_UINT8, sizeof(char)}, - {DT_INT16, sizeof(int16_t)}, - {DT_UINT16, sizeof(int16_t)}, - {DT_FLOAT16, sizeof(int16_t)}, - {DT_DOUBLE, sizeof(double)}, - {DT_DUAL, sizeof(float) + sizeof(int8_t)}, - {DT_DUAL_SUB_INT8, sizeof(int8_t)}, - {DT_DUAL_SUB_UINT8, sizeof(uint8_t)}, - {DT_COMPLEX64, sizeof(int64_t)}, - {DT_COMPLEX128, sizeof(int64_t) * 2}, - {DT_QINT8, sizeof(int8_t)}, - {DT_QINT16, sizeof(int16_t)}, - {DT_QINT32, sizeof(int32_t)}, - {DT_QUINT8, sizeof(uint8_t)}, - {DT_QUINT16, sizeof(uint16_t)}, - {DT_STRING_REF, sizeof(uint64_t) * 2}, - {DT_STRING, sizeof(uint64_t)}, - {DT_RESOURCE, sizeof(uint64_t)}, -}; - -static const std::map kFmkTypeToString = { - {domi::CAFFE, "caffe"}, {domi::MINDSPORE, "mindspore"}, {domi::TENSORFLOW, "tensorflow"}, - {domi::ANDROID_NN, "android_nn"}, {domi::ONNX, "onnx"}, {domi::FRAMEWORK_RESERVED, "framework_reserved"}, -}; - -static const std::map kImplyTypeToString = { - {domi::ImplyType::BUILDIN, "buildin"}, {domi::ImplyType::TVM, "tvm"}, {domi::ImplyType::CUSTOM, "custom"}, - {domi::ImplyType::AI_CPU, "ai_cpu"}, {domi::ImplyType::CCE, "cce"}, {domi::ImplyType::GELOCAL, "gelocal"}, - {domi::ImplyType::HCCL, "hccl"}, {domi::ImplyType::INVALID, "invalid"}}; - -std::string TypeUtils::ImplyTypeToSerialString(domi::ImplyType imply_type) { - auto it = kImplyTypeToString.find(imply_type); - if (it != kImplyTypeToString.end()) { - return it->second; - } else { - GELOGE(GRAPH_FAILED, "ImplyTypeToSerialString: imply_type not support %u", imply_type); - return "UNDEFINED"; - } -} - -bool TypeUtils::IsDataTypeValid(DataType dt) { - uint32_t num = static_cast(dt); - GE_CHK_BOOL_EXEC((num <= DT_UNDEFINED), return false, "The DataType is invalid"); - return true; -} - -std::string TypeUtils::DataTypeToSerialString(DataType data_type) { - auto it = kDataTypeToStringMap.find(data_type); - if (it != kDataTypeToStringMap.end()) { - return it->second; - } else { - GELOGE(GRAPH_FAILED, "DataTypeToSerialString: datatype not support %u", data_type); - return "UNDEFINED"; - } -} - -DataType TypeUtils::SerialStringToDataType(const std::string &str) { - auto it = kStringTodataTypeMap.find(str); - if (it != kStringTodataTypeMap.end()) { - return it->second; - } else { - GELOGE(GRAPH_FAILED, "SerialStringToDataType: datatype not support %s", str.c_str()); - return DT_UNDEFINED; - } -} - -bool TypeUtils::IsFormatValid(Format format) { - uint32_t num = static_cast(format); - GE_CHK_BOOL_EXEC((num <= FORMAT_RESERVED), return false, "The Format is invalid"); - return true; -} - -bool TypeUtils::IsInternalFormat(Format format) { - std::string serial_format = FormatToSerialString(format); - auto iter = kInternalFormat.find(serial_format); - bool result = (iter == kInternalFormat.end()) ? false : true; - return result; -} - -std::string TypeUtils::FormatToSerialString(Format format) { - auto it = kFormatToStringMap.find(format); - if (it != kFormatToStringMap.end()) { - return it->second; - } else { - GELOGE(GRAPH_FAILED, "Format not support %u", format); - return "RESERVED"; - } -} -Format TypeUtils::SerialStringToFormat(const std::string &str) { - auto it = kStringToFormatMap.find(str); - if (it != kStringToFormatMap.end()) { - return it->second; - } else { - GELOGE(GRAPH_FAILED, "Format not support %s", str.c_str()); - return FORMAT_RESERVED; - } -} - -Format TypeUtils::DataFormatToFormat(const std::string &str) { - auto it = kDataFormatMap.find(str); - if (it != kDataFormatMap.end()) { - return it->second; - } else { - GELOGE(GRAPH_FAILED, "Format not support %s", str.c_str()); - return FORMAT_RESERVED; - } -} - -Format TypeUtils::DomiFormatToFormat(domi::domiTensorFormat_t domi_format) { - auto it = kDomiFormatToGeFormat.find(domi_format); - if (it != kDomiFormatToGeFormat.end()) { - return it->second; - } - GELOGE(GRAPH_FAILED, "do not find domi Format %d from map", domi_format); - return FORMAT_RESERVED; -} - -std::string TypeUtils::FmkTypeToSerialString(domi::FrameworkType fmk_type) { - auto it = kFmkTypeToString.find(fmk_type); - if (it != kFmkTypeToString.end()) { - return it->second; - } else { - GELOGW("Framework type not support %d.", fmk_type); - return ""; - } -} - -static inline void CopyDataFromBuffer(vector &data, const Buffer &buffer) { - data.clear(); - if (buffer.GetData() != nullptr && buffer.GetSize() != 0) { - data.assign(buffer.GetData(), buffer.GetData() + buffer.GetSize()); - } -} - -graphStatus Usr2DefQuantizeFactor(const UsrQuantizeFactor &usr, QuantizeFactor &def) { - def.scale_mode = uint32_t(usr.scale_mode); - def.set_scale_value(usr.scale_value.data(), usr.scale_value.size()); - def.scale_offset = usr.scale_offset; - def.set_offset_data_value(usr.offset_data_value.data(), usr.offset_data_value.size()); - def.offset_data_offset = usr.offset_data_offset; - def.set_offset_weight_value(usr.offset_weight_value.data(), usr.offset_weight_value.size()); - def.offset_weight_offset = usr.offset_weight_offset; - def.set_offset_pad_value(usr.offset_pad_value.data(), usr.offset_pad_value.size()); - def.offset_pad_offset = usr.offset_pad_offset; - return GRAPH_SUCCESS; -} -graphStatus Def2UsrQuantizeFactor(const QuantizeFactor &def, UsrQuantizeFactor &usr) { - usr.scale_mode = UsrQuantizeScaleMode(def.scale_mode); - CopyDataFromBuffer(usr.scale_value, def.scale_value); - usr.scale_offset = def.scale_offset; - CopyDataFromBuffer(usr.offset_data_value, def.offset_data_value); - usr.offset_data_offset = def.offset_data_offset; - CopyDataFromBuffer(usr.offset_weight_value, def.offset_weight_value); - usr.offset_weight_offset = def.offset_weight_offset; - CopyDataFromBuffer(usr.offset_pad_value, def.offset_pad_value); - usr.offset_pad_offset = def.offset_pad_offset; - return GRAPH_SUCCESS; -} -graphStatus Usr2DefUsrQuantizeCalcFactor(const UsrQuantizeCalcFactor &usr, QuantizeCalcFactor &def) { - def.set_offsetw(usr.offsetw.data(), usr.offsetw.size()); - def.offsetw_offset = usr.offsetw_offset; - def.set_offsetd(usr.offsetd.data(), usr.offsetd.size()); - def.offsetd_offset = usr.offsetd_offset; - def.set_scalereq(usr.scalereq.data(), usr.scalereq.size()); - def.scaledreq_offset = usr.scaledreq_offset; - def.set_offsetdnext(usr.offsetdnext.data(), usr.offsetdnext.size()); - def.offsetdnext_offset = usr.offsetdnext_offset; - return GRAPH_SUCCESS; -} -graphStatus Def2UsrQuantizeCalcFactor(const QuantizeCalcFactor &def, UsrQuantizeCalcFactor &usr) { - CopyDataFromBuffer(usr.offsetw, def.offsetw); - usr.offsetw_offset = def.offsetw_offset; - CopyDataFromBuffer(usr.offsetd, def.offsetd); - usr.offsetd_offset = def.offsetd_offset; - CopyDataFromBuffer(usr.scalereq, def.scalereq); - usr.scaledreq_offset = def.scaledreq_offset; - CopyDataFromBuffer(usr.offsetdnext, def.offsetdnext); - usr.offsetdnext_offset = def.offsetdnext_offset; - return GRAPH_SUCCESS; -} -graphStatus TypeUtils::Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def) { - def.quantize_algo = uint32_t(usr.quantize_algo); - def.scale_type = uint32_t(usr.scale_type); - GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.quantize_param, def.quantize_param), - "Usr2DefQuantizeFactor quantize_param failed"); - GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.dequantize_param, def.dequantize_param), - "Usr2DefQuantizeFactor dequantize_param failed"); - GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.requantize_param, def.requantize_param), - "Usr2DefQuantizeFactor requantize_param failed"); - GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefUsrQuantizeCalcFactor(usr.quantizecalc_param, def.quantizecalc_param), - "Usr2DefQuantizeFactor quantizecalc_param failed"); - return GRAPH_SUCCESS; -} -graphStatus TypeUtils::Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr) { - usr.quantize_algo = UsrQuantizeAlgorithm(def.quantize_algo); - usr.scale_type = UsrQuantizeScaleType(def.scale_type); - GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.quantize_param, usr.quantize_param), - "Def2UsrQuantizeFactor quantize_param failed"); - GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.dequantize_param, usr.dequantize_param), - "Def2UsrQuantizeFactor dequantize_param failed"); - GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.requantize_param, usr.requantize_param), - "Def2UsrQuantizeFactor requantize_param failed"); - GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeCalcFactor(def.quantizecalc_param, usr.quantizecalc_param), - "Def2UsrQuantizeCalcFactor quantizecalc_param failed"); - return GRAPH_SUCCESS; -} -bool TypeUtils::GetDataTypeLength(ge::DataType data_type, uint32_t &length) { - auto it = kDataTypeToLength.find(data_type); - if (it != kDataTypeToLength.end()) { - length = it->second; - return true; - } else { - GELOGE(GRAPH_FAILED, "data_type not support %d", data_type); - return false; - } -} -bool TypeUtils::CheckUint64MulOverflow(uint64_t a, uint32_t b) { - // Not overflow - if (a == 0) { - return false; - } - if ((ULLONG_MAX / a) >= b) { - return false; - } - return true; -} -} // namespace ge diff --git a/src/ge/CMakeLists.txt b/src/ge/CMakeLists.txt deleted file mode 100755 index 69159c16..00000000 --- a/src/ge/CMakeLists.txt +++ /dev/null @@ -1,382 +0,0 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -# libge_compiler.so & libge_runner.so -# will later be integrated into libgraph_runner.so, works for both training and inference -# compiling proto files generates some warnings, use no-unused-variable to suppress them -set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") -file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "../proto/fusion_model.proto" - "../proto/optimizer_priority.proto" - ) -file(GLOB PROTO_CLIENT_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "../proto/ge_api.proto" - ) -file(GLOB PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "../proto/om.proto" - "../proto/task.proto" - "../proto/insert_op.proto" - "../proto/ge_ir.proto" - "../proto/fwk_adapter.proto" - "../proto/op_mapping_info.proto" - "../proto/dump_task.proto" - ) -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) -protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) -# include directories -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${GE_SOURCE_DIR}) -include_directories(${GE_SOURCE_DIR}/src) -include_directories(${GE_SOURCE_DIR}/src/ge/analyzer) -include_directories(${GE_SOURCE_DIR}/inc) -include_directories(${GE_SOURCE_DIR}/inc/common/util) -include_directories(${GE_SOURCE_DIR}/inc/external) -include_directories(${GE_SOURCE_DIR}/inc/external/graph) -include_directories(${GE_SOURCE_DIR}/inc/framework) -include_directories(${GE_SOURCE_DIR}/inc/framework/common) -include_directories(${GE_SOURCE_DIR}/inc/graph) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/toolchain) -include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_BINARY_DIR}/proto/ge) - -######### libge_runner.so ############# -# need to remove dependencies on pb files later -file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "analyzer/analyzer.cc" - "client/ge_prof.cc" - "client/ge_api.cc" - "common/dump/dump_manager.cc" - "common/dump/dump_properties.cc" - "common/dump/dump_op.cc" - "common/formats/format_transfers/*.cc" - "common/formats/formats.cc" - "common/formats/utils/formats_trans_utils.cc" - "common/fp16_t.cc" - "common/ge/op_tiling_manager.cc" - "common/ge/plugin_manager.cc" - "common/helper/model_cache_helper.cc" - "common/profiling/profiling_manager.cc" - "engine_manager/dnnengine_manager.cc" - "executor/ge_executor.cc" - "ge_local_engine/engine/host_cpu_engine.cc" - "generator/ge_generator.cc" - "generator/generator_api.cc" - "graph/build/*.cc" - "graph/common/*.cc" - "graph/execute/graph_execute.cc" - "graph/label/*.cc" - "graph/load/graph_loader.cc" - "graph/load/new_model_manager/*.cc" - "graph/load/new_model_manager/task_info/end_graph_task_info.cc" - "graph/load/new_model_manager/task_info/event_record_task_info.cc" - "graph/load/new_model_manager/task_info/event_wait_task_info.cc" - "graph/load/new_model_manager/task_info/fusion_start_task_info.cc" - "graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" - "graph/load/new_model_manager/task_info/hccl_task_info.cc" - "graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" - "graph/load/new_model_manager/task_info/kernel_task_info.cc" - "graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" - "graph/load/new_model_manager/task_info/label_set_task_info.cc" - "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" - "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" - "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" - "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" - "graph/load/new_model_manager/task_info/stream_active_task_info.cc" - "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" - "graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" - "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" - "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" - "graph/load/new_model_manager/task_info/task_info.cc" - "graph/manager/graph_context.cc" - "graph/manager/graph_manager.cc" - "graph/manager/graph_manager_utils.cc" - "graph/manager/graph_mem_allocator.cc" - "graph/manager/graph_caching_allocator.cc" - "graph/manager/graph_var_manager.cc" - "graph/manager/model_manager/event_manager.cc" - "graph/manager/rdma_pool_allocator.cc" - "graph/manager/trans_var_data_utils.cc" - "graph/manager/util/debug.cc" - "graph/manager/util/hcom_util.cc" - "graph/manager/util/rt_context_util.cc" - "graph/manager/util/variable_accelerate_ctrl.cc" - "graph/manager/util/debug.cc" - "graph/manager/util/hcom_util.cc" - "graph/manager/util/rt_context_util.cc" - "graph/manager/util/variable_accelerate_ctrl.cc" - "graph/optimize/graph_optimize.cc" - "graph/optimize/mem_rw_conflict_optimize.cc" - "graph/optimize/optimizer/allreduce_fusion_pass.cc" - "graph/optimize/summary_optimize.cc" - "graph/partition/dynamic_shape_partition.cc" - "graph/partition/engine_place.cc" - "graph/partition/graph_partition.cc" - "graph/passes/*.cc" - "graph/preprocess/graph_preprocess.cc" - "graph/preprocess/insert_op/ge_aipp_op.cc" - "graph/preprocess/insert_op/util_insert_aipp_op.cc" - "graph/preprocess/multi_batch_copy_graph.cc" - "graph/preprocess/multi_batch_options.cc" - "host_kernels/add_kernel.cc" - "host_kernels/broadcast_args_kernel.cc" - "host_kernels/broadcast_gradient_args_kernel.cc" - "host_kernels/cast_kernel.cc" - "host_kernels/concat_offset_kernel.cc" - "host_kernels/concat_v2_kernel.cc" - "host_kernels/dynamic_stitch_kernel.cc" - "host_kernels/empty_kernel.cc" - "host_kernels/expanddims_kernel.cc" - "host_kernels/fill_kernel.cc" - "host_kernels/floordiv_kernel.cc" - "host_kernels/floormod_kernel.cc" - "host_kernels/gather_v2_kernel.cc" - "host_kernels/greater_kernel.cc" - "host_kernels/identity_kernel.cc" - "host_kernels/kernel_utils.cc" - "host_kernels/maximum_kernel.cc" - "host_kernels/mul_kernel.cc" - "host_kernels/pack_kernel.cc" - "host_kernels/permute_kernel.cc" - "host_kernels/range_kernel.cc" - "host_kernels/rank_kernel.cc" - "host_kernels/reduce_prod_kernel.cc" - "host_kernels/reshape_kernel.cc" - "host_kernels/rsqrt_kernel.cc" - "host_kernels/shape_kernel.cc" - "host_kernels/shape_n_kernel.cc" - "host_kernels/size_kernel.cc" - "host_kernels/slice_d_kernel.cc" - "host_kernels/slice_kernel.cc" - "host_kernels/squeeze_kernel.cc" - "host_kernels/ssd_prior_box_kernel.cc" - "host_kernels/strided_slice_kernel.cc" - "host_kernels/sub_kernel.cc" - "host_kernels/transdata_kernel.cc" - "host_kernels/transpose_kernel.cc" - "host_kernels/unpack_kernel.cc" - "host_kernels/unsqueeze_kernel.cc" - "hybrid/common/npu_memory_allocator.cc" - "hybrid/common/tensor_value.cc" - "hybrid/executor/*.cc" - "hybrid/executor/worker/*.cc" - "hybrid/hybrid_davinci_model.cc" - "hybrid/model/*.cc" - "hybrid/node_executor/aicore/*.cc" - "hybrid/node_executor/aicpu/aicpu_ext_info.cc" - "hybrid/node_executor/aicpu/aicpu_node_executor.cc" - "hybrid/node_executor/compiledsubgraph/known_node_executor.cc" - "hybrid/node_executor/controlop/control_op_executor.cc" - "hybrid/node_executor/ge_local/ge_local_node_executor.cc" - "hybrid/node_executor/hccl/hccl_node_executor.cc" - "hybrid/node_executor/hostcpu/ge_local_node_executor.cc" - "hybrid/node_executor/host_cpu/host_cpu_node_executor.cc" - "hybrid/node_executor/host_cpu/kernel_factory.cc" - "hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc" - "hybrid/node_executor/host_cpu/kernel/variable_kernel.cc" - "hybrid/node_executor/host_cpu/kernel/assign_kernel.cc" - "hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc" - "hybrid/node_executor/node_executor.cc" - "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" - "hybrid/node_executor/rts/rts_node_executor.cc" - "hybrid/node_executor/task_context.cc" - "init/gelib.cc" - "model/ge_model.cc" - "model/ge_root_model.cc" - "omm/csa_interact.cc" - "opskernel_manager/ops_kernel_manager.cc" - "session/inner_session.cc" - "session/session_manager.cc" - "single_op/*.cc" - "single_op/task/*.cc" - ) - - -######### libge_runner.so ############# -add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS} ${PROTO_HEADER_HDRS}) -target_compile_definitions(ge_runner PRIVATE - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - DAVINCI_SUPPORT_PROFILING - REUSE_MEMORY=1 - DAVINCI_CLOUD) -target_link_libraries(ge_runner - graph - ge_common - ge_memory - #${PROTOBUF_LIBRARY} - protobuf - ${register} - ${c_sec} - ${slog} - ${mmpa} - ${hccl} - ${msprof} - ${runtime} - ${resouce} - ${ascend_hal} - ${adump_server} - ${msprofiler} - rt - dl) - -######### libge_compiler.so ############# -# need to remove dependencies on pb files later -file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "analyzer/analyzer.cc" - "common/dump/dump_properties.cc" - "common/dump/dump_manager.cc" - "common/dump/dump_op.cc" - "common/dump/dump_server.cc" - "common/formats/format_transfers/*.cc" - "common/formats/formats.cc" - "common/formats/utils/formats_trans_utils.cc" - "common/fp16_t.cc" - "common/ge/op_tiling_manager.cc" - "common/ge/plugin_manager.cc" - "common/helper/model_cache_helper.cc" - "common/profiling/profiling_manager.cc" - "engine_manager/dnnengine_manager.cc" - "ge_local_engine/engine/host_cpu_engine.cc" - "generator/ge_generator.cc" - "generator/generator_api.cc" - "graph/build/*.cc" - "graph/common/*.cc" - "graph/execute/graph_execute.cc" - "graph/label/*.cc" - "graph/load/graph_loader.cc" - "graph/load/new_model_manager/*.cc" - "graph/load/new_model_manager/task_info/end_graph_task_info.cc" - "graph/load/new_model_manager/task_info/event_record_task_info.cc" - "graph/load/new_model_manager/task_info/event_wait_task_info.cc" - "graph/load/new_model_manager/task_info/fusion_start_task_info.cc" - "graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" - "graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" - "graph/load/new_model_manager/task_info/kernel_task_info.cc" - "graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" - "graph/load/new_model_manager/task_info/label_set_task_info.cc" - "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" - "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" - "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" - "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" - "graph/load/new_model_manager/task_info/stream_active_task_info.cc" - "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" - "graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" - "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" - "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" - "graph/load/new_model_manager/task_info/task_info.cc" - "graph/manager/graph_caching_allocator.cc" - "graph/manager/graph_context.cc" - "graph/manager/graph_manager.cc" - "graph/manager/graph_manager_utils.cc" - "graph/manager/graph_mem_allocator.cc" - "graph/manager/trans_var_data_utils.cc" - "graph/manager/graph_var_manager.cc" - "graph/manager/model_manager/event_manager.cc" - "graph/manager/rdma_pool_allocator.cc" - "graph/manager/util/debug.cc" - "graph/manager/util/rt_context_util.cc" - "graph/manager/util/variable_accelerate_ctrl.cc" - "graph/optimize/graph_optimize.cc" - "graph/optimize/mem_rw_conflict_optimize.cc" - "graph/optimize/summary_optimize.cc" - "graph/partition/dynamic_shape_partition.cc" - "graph/partition/engine_place.cc" - "graph/partition/graph_partition.cc" - "graph/passes/*.cc" - "graph/preprocess/graph_preprocess.cc" - "graph/preprocess/insert_op/ge_aipp_op.cc" - "graph/preprocess/insert_op/util_insert_aipp_op.cc" - "graph/preprocess/multi_batch_copy_graph.cc" - "graph/preprocess/multi_batch_options.cc" - "host_kernels/add_kernel.cc" - "host_kernels/broadcast_args_kernel.cc" - "host_kernels/broadcast_gradient_args_kernel.cc" - "host_kernels/cast_kernel.cc" - "host_kernels/concat_offset_kernel.cc" - "host_kernels/concat_v2_kernel.cc" - "host_kernels/dynamic_stitch_kernel.cc" - "host_kernels/empty_kernel.cc" - "host_kernels/expanddims_kernel.cc" - "host_kernels/fill_kernel.cc" - "host_kernels/floordiv_kernel.cc" - "host_kernels/floormod_kernel.cc" - "host_kernels/gather_v2_kernel.cc" - "host_kernels/greater_kernel.cc" - "host_kernels/identity_kernel.cc" - "host_kernels/kernel_utils.cc" - "host_kernels/maximum_kernel.cc" - "host_kernels/mul_kernel.cc" - "host_kernels/pack_kernel.cc" - "host_kernels/permute_kernel.cc" - "host_kernels/range_kernel.cc" - "host_kernels/rank_kernel.cc" - "host_kernels/reduce_prod_kernel.cc" - "host_kernels/reshape_kernel.cc" - "host_kernels/rsqrt_kernel.cc" - "host_kernels/shape_kernel.cc" - "host_kernels/shape_n_kernel.cc" - "host_kernels/size_kernel.cc" - "host_kernels/slice_d_kernel.cc" - "host_kernels/slice_kernel.cc" - "host_kernels/squeeze_kernel.cc" - "host_kernels/ssd_prior_box_kernel.cc" - "host_kernels/strided_slice_kernel.cc" - "host_kernels/sub_kernel.cc" - "host_kernels/transdata_kernel.cc" - "host_kernels/transpose_kernel.cc" - "host_kernels/unpack_kernel.cc" - "host_kernels/unsqueeze_kernel.cc" - "hybrid/hybrid_davinci_model_stub.cc" - "hybrid/node_executor/aicpu/aicpu_ext_info.cc" - "init/gelib.cc" - "ir_build/atc_ir_common.cc" - "ir_build/ge_ir_build.cc" - "model/ge_model.cc" - "model/ge_root_model.cc" - "omm/csa_interact.cc" - "opskernel_manager/ops_kernel_manager.cc" - "session/inner_session.cc" - "session/session_manager.cc" - "single_op/*.cc" - "single_op/task/*.cc" - ) - -add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) -target_compile_definitions(ge_compiler PRIVATE - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - REUSE_MEMORY=1 - FMK_HOST_INFER - FMK_SUPPORT_DUMP - COMPILE_OMG_PACKAGE - REUSE_MEMORY=1) -target_link_libraries(ge_compiler - graph - ge_common - ge_memory - #${PROTOBUF_LIBRARY} - protobuf - ${register} - ${c_sec} - ${slog} - ${mmpa} - ${msprof} - ${runtime} - ${resouce} - ${error_manager} - rt - dl) diff --git a/src/ge/client/CMakeLists.txt b/src/ge/client/CMakeLists.txt deleted file mode 100755 index 962df4af..00000000 --- a/src/ge/client/CMakeLists.txt +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -# libge_client.so -# add all proto files, generate corresponding .h and .cc files -set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") -file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "../../proto/ge_api.proto" - ) - -file(GLOB PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "../../proto/ge_ir.proto" - "../../proto/task.proto" - "../../proto/om.proto" - "../../proto/insert_op.proto" - ) - -file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "ge_api.cc" - "ge_prof.cc" - ) - -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) - -# include directories -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${GE_SOURCE_DIR}/src/ge) -include_directories(${GE_SOURCE_DIR}/src) -include_directories(${GE_SOURCE_DIR}/inc) -include_directories(${GE_SOURCE_DIR}/inc/external) -include_directories(${GE_SOURCE_DIR}/inc/external/graph) -include_directories(${GE_SOURCE_DIR}/inc/common) -include_directories(${GE_SOURCE_DIR}/inc/framework) -include_directories(${GE_SOURCE_DIR}/inc/graph) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) -include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_BINARY_DIR}/proto/ge) - -############ libge_client.so ################ -add_library(ge_client SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) -target_compile_definitions(ge_client PRIVATE - Werror - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - REUSE_MEMORY=1 - PLATFORM_CLOUD) -target_link_libraries(ge_client - graph - ge_compiler - ge_common - #${PROTOBUF_LIBRARY} - protobuf - ${register} - ${c_sec} - ${slog} - ${mmpa} - ${runtime} - ${msprof} - ${msprofiler} - ${ascend_hal} - rt - dl) diff --git a/src/ge/common/CMakeLists.txt b/src/ge/common/CMakeLists.txt deleted file mode 100755 index ccb42214..00000000 --- a/src/ge/common/CMakeLists.txt +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -# libge_common.so -file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "../../proto/om.proto" - "../../proto/ge_ir.proto" - "../../proto/task.proto" - "../../proto/insert_op.proto" - ) - -file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "../model/ge_model.cc" - "auth/file_saver.cc" - "context/ctx.cc" - "cust_aicpu_kernel_store.cc" - "debug/memory_dumper.cc" - "dump/dump_properties.cc" - "fmk_error_codes.cc" - "formats/format_transfers/datatype_transfer.cc" - "formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" - "formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" - "formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" - "formats/format_transfers/format_transfer_fractal_nz.cc" - "formats/format_transfers/format_transfer_fractal_z.cc" - "formats/format_transfers/format_transfer_fractal_zz.cc" - "formats/format_transfers/format_transfer_fracz_hwcn.cc" - "formats/format_transfers/format_transfer_fracz_nchw.cc" - "formats/format_transfers/format_transfer_fracz_nhwc.cc" - "formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" - "formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" - "formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" - "formats/format_transfers/format_transfer_nchw_fz_c04.cc" - "formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" - "formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" - "formats/format_transfers/format_transfer_transpose.cc" - "formats/formats.cc" - "formats/utils/formats_trans_utils.cc" - "fp16_t.cc" - "ge/datatype_util.cc" - "ge/tbe_plugin_manager.cc" - "ge_format_util.cc" - "helper/model_helper.cc" - "helper/om_file_helper.cc" - "kernel_store.cc" - "math/fp16_math.cc" - "model_parser/base.cc" - "model_saver.cc" - "op/attr_value_util.cc" - "op/ge_op_utils.cc" - "properties_manager.cc" - "tbe_kernel_store.cc" - "thread_pool.cc" - "types.cc" - "util.cc" - ) - -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) - -# include directories -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${CMAKE_CURRENT_LIST_DIR}/op) -include_directories(${GE_SOURCE_DIR}/src/ge) -include_directories(${GE_SOURCE_DIR}/inc) -include_directories(${GE_SOURCE_DIR}/inc/common/util) -include_directories(${GE_SOURCE_DIR}/inc/external) -include_directories(${GE_SOURCE_DIR}/inc/external/graph) -include_directories(${GE_SOURCE_DIR}/inc/framework) -include_directories(${GE_SOURCE_DIR}/inc/graph) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) -include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_BINARY_DIR}/proto/ge) - -############ libge_common.so ################ -add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) -target_compile_definitions(ge_common PUBLIC - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - HOST_VISIBILITY - OS_CENTOS) -target_link_libraries(ge_common - graph - #${PROTOBUF_LIBRARY} - protobuf - ${register} - ${c_sec} - ${slog} - ${mmpa} - ${resource} - ${error_manager} - rt - dl) diff --git a/src/ge/executor/CMakeLists.txt b/src/ge/executor/CMakeLists.txt deleted file mode 100755 index a6962f31..00000000 --- a/src/ge/executor/CMakeLists.txt +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -# libge_executor.so -# add all proto files, generate corresponding .h and .cc files -# add src files -file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "../../proto/task.proto" - "../../proto/om.proto" - "../../proto/insert_op.proto" - "../../proto/op_mapping_info.proto" - "../../proto/ge_ir.proto" - "../../proto/dump_task.proto" - ) - -file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "ge_executor.cc" - "../common/dump/dump_properties.cc" - "../common/dump/dump_manager.cc" - "../common/dump/dump_op.cc" - "../common/ge/op_tiling_manager.cc" - "../common/ge/plugin_manager.cc" - "../common/profiling/profiling_manager.cc" - "../graph/execute/graph_execute.cc" - "../graph/load/graph_loader.cc" - "../graph/load/new_model_manager/aipp_utils.cc" - "../graph/load/new_model_manager/cpu_queue_schedule.cc" - "../graph/load/new_model_manager/data_dumper.cc" - "../graph/load/new_model_manager/data_inputer.cc" - "../graph/load/new_model_manager/davinci_model.cc" - "../graph/load/new_model_manager/davinci_model_parser.cc" - "../graph/load/new_model_manager/model_manager.cc" - "../graph/load/new_model_manager/model_utils.cc" - "../graph/load/new_model_manager/task_info/end_graph_task_info.cc" - "../graph/load/new_model_manager/task_info/event_record_task_info.cc" - "../graph/load/new_model_manager/task_info/event_wait_task_info.cc" - "../graph/load/new_model_manager/task_info/fusion_start_task_info.cc" - "../graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" - "../graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" - "../graph/load/new_model_manager/task_info/kernel_task_info.cc" - "../graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" - "../graph/load/new_model_manager/task_info/label_set_task_info.cc" - "../graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" - "../graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" - "../graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" - "../graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" - "../graph/load/new_model_manager/task_info/stream_active_task_info.cc" - "../graph/load/new_model_manager/task_info/stream_switch_task_info.cc" - "../graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" - "../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" - "../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" - "../graph/load/new_model_manager/task_info/task_info.cc" - "../graph/load/new_model_manager/tbe_handle_store.cc" - "../graph/load/new_model_manager/zero_copy_offset.cc" - "../graph/load/new_model_manager/zero_copy_task.cc" - "../graph/manager/graph_caching_allocator.cc" - "../graph/manager/graph_manager_utils.cc" - "../graph/manager/graph_mem_allocator.cc" - "../graph/manager/graph_var_manager.cc" - "../graph/manager/rdma_pool_allocator.cc" - "../graph/manager/trans_var_data_utils.cc" - "../graph/manager/util/debug.cc" - "../hybrid/hybrid_davinci_model_stub.cc" - "../hybrid/node_executor/aicpu/aicpu_ext_info.cc" - "../model/ge_model.cc" - "../model/ge_root_model.cc" - "../omm/csa_interact.cc" - "../single_op/single_op.cc" - "../single_op/single_op_manager.cc" - "../single_op/single_op_model.cc" - "../single_op/stream_resource.cc" - "../single_op/task/aicpu_task_builder.cc" - "../single_op/task/build_task_utils.cc" - "../single_op/task/op_task.cc" - "../single_op/task/tbe_task_builder.cc" - ) - -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) - -# include directories -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${GE_SOURCE_DIR}/src/ge) -include_directories(${GE_SOURCE_DIR}/inc/external) -include_directories(${GE_SOURCE_DIR}/inc/external/graph) -include_directories(${GE_SOURCE_DIR}/inc/framework) -include_directories(${GE_SOURCE_DIR}/inc) -include_directories(${GE_SOURCE_DIR}/inc/graph) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) -include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_BINARY_DIR}/proto/ge) - -######## libge_executor.so ######## -add_library(ge_executor SHARED ${SRC_LIST} ${PROTO_HDRS}) -target_compile_definitions(ge_executor PRIVATE - Werror - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - DAVINCI_SUPPORT_PROFILING - FMK_HOST_INFER) -target_link_libraries(ge_executor - ge_common - graph - ${PROTOBUF_LIBRARY} - protobuf - ${register} - ${c_sec} - ${runtime} - ${slog} - ${mmpa} - ${msprof} - ${error_manager} - ${ascend_hal} - rt - dl) - diff --git a/src/ge/ge_local_engine/CMakeLists.txt b/src/ge/ge_local_engine/CMakeLists.txt deleted file mode 100755 index 2d1d30c6..00000000 --- a/src/ge/ge_local_engine/CMakeLists.txt +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -# libge_local_engine.so -# add all proto files, generate corresponding .h and .cc files -file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "../../proto/task.proto" - ) - -file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "engine/ge_local_engine.cc" - "ops_kernel_store/*.cc" - "ops_kernel_store/op/*.cc" - ) - -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) - -# include directories -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${GE_SOURCE_DIR}/src/ge) -include_directories(${GE_SOURCE_DIR}/inc) -include_directories(${GE_SOURCE_DIR}/inc/external) -include_directories(${GE_SOURCE_DIR}/inc/external/graph) -include_directories(${GE_SOURCE_DIR}/inc/framework) -include_directories(${GE_SOURCE_DIR}/inc/graph) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) -include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_BINARY_DIR}/proto/ge) - -######### libge_local_engine.so ############# -add_library(ge_local_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) -target_compile_definitions(ge_local_engine PRIVATE Werror COMPILE_OMG_PACKAGE) -target_link_libraries(ge_local_engine - graph - #${PROTOBUF_LIBRARY} - protobuf - ${register} - ${c_sec} - ${slog} - ${runtime}) diff --git a/src/ge/ge_runtime/CMakeLists.txt b/src/ge/ge_runtime/CMakeLists.txt deleted file mode 100755 index d316f738..00000000 --- a/src/ge/ge_runtime/CMakeLists.txt +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -# libge_runtime.so -# include directories -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${GE_SOURCE_DIR}/src/ge) -include_directories(${GE_SOURCE_DIR}/src) -include_directories(${GE_SOURCE_DIR}/inc) -include_directories(${GE_SOURCE_DIR}/inc/graph) -include_directories(${GE_SOURCE_DIR}/inc/external) -include_directories(${GE_SOURCE_DIR}/inc/framework) -include_directories(${GE_SOURCE_DIR}/inc/framework/common) -include_directories(${GE_SOURCE_DIR}/inc/framework/ge_runtime) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) -include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_BINARY_DIR}/proto/ge) - -######### libge_runtime.so ############# -file(GLOB_RECURSE GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "model_runner.cc" - "runtime_model.cc" - "output.cc" - "task/*.cc" - ) - -add_library(ge_runtime SHARED ${GE_SRC_LIST}) -target_compile_definitions(ge_runtime PUBLIC - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - Werror) -target_link_libraries(ge_runtime - graph - protobuf - ${slog} - ${runtime} - ${c_sec} - rt - dl - ) diff --git a/src/ge/ge_runtime/proto/task.pb.h b/src/ge/ge_runtime/proto/task.pb.h deleted file mode 100644 index 490289ac..00000000 --- a/src/ge/ge_runtime/proto/task.pb.h +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Generated by the protocol buffer compiler. DO NOT EDIT! -// source: task.proto - -#ifndef STUB_TASK_PROTO_H -#define STUB_TASK_PROTO_H - -namespace domi { -class TaskDef; -} - -#endif // STUB_TASK_PROTO_H diff --git a/src/ge/graph/build/memory/CMakeLists.txt b/src/ge/graph/build/memory/CMakeLists.txt deleted file mode 100644 index ea87b906..00000000 --- a/src/ge/graph/build/memory/CMakeLists.txt +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -# libge_memosy.a -file(GLOB_RECURSE SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "memory_assigner.cc" - "graph_mem_assigner.cc" - "binary_block_mem_assigner.cc" - "block_mem_assigner.cc" - "hybrid_mem_assigner.cc" - "max_block_mem_assigner.cc" - "var_mem_assign_util.cc" - ) - -# include directories -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${GE_SOURCE_DIR}/src) -include_directories(${GE_SOURCE_DIR}/src/ge) -include_directories(${GE_SOURCE_DIR}/inc) -include_directories(${GE_SOURCE_DIR}/inc/external) -include_directories(${GE_SOURCE_DIR}/inc/external/graph) -include_directories(${GE_SOURCE_DIR}/inc/framework) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) -include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_BINARY_DIR}/proto/ge) - -######### libge_memory.a ############# -add_library(ge_memory STATIC ${SRC_LIST}) -target_compile_definitions(ge_memory PRIVATE - Werror - DAVINCI_CLOUD) -target_link_libraries(ge_memory - graph - ge_common - ${PROTOBUF_LIBRARY} - ${c_sec} - ${slog} - rt - dl) diff --git a/src/ge/host_cpu_engine/proto/task.proto b/src/ge/host_cpu_engine/proto/task.proto deleted file mode 120000 index 36ae4847..00000000 --- a/src/ge/host_cpu_engine/proto/task.proto +++ /dev/null @@ -1 +0,0 @@ -../../proto/task.proto \ No newline at end of file diff --git a/src/ge/plugin/engine/CMakeLists.txt b/src/ge/plugin/engine/CMakeLists.txt deleted file mode 100644 index a3f14ee2..00000000 --- a/src/ge/plugin/engine/CMakeLists.txt +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -# libengine.so -file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "*.cc" - ) - -# include directories -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${GE_SOURCE_DIR}) -include_directories(${GE_SOURCE_DIR}/src) -include_directories(${GE_SOURCE_DIR}/src/ge) -include_directories(${GE_SOURCE_DIR}/inc) -include_directories(${GE_SOURCE_DIR}/inc/framework) -include_directories(${GE_SOURCE_DIR}/inc/framework/common) -include_directories(${GE_SOURCE_DIR}/inc/external) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) -include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_BINARY_DIR}/proto/ge) -include_directories(${GE_SOURCE_DIR}/build) - -######### libengine.so ############# -add_library(engine SHARED ${SRC_LIST}) -target_compile_definitions(engine PRIVATE - REUSE_MEMORY=1 - PLATFORM_CLOUD - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - Werror) -target_link_libraries(engine - ${slog} - rt - dl) diff --git a/src/proto/onnx.proto b/src/proto/onnx.proto deleted file mode 100644 index 093fcf99..00000000 --- a/src/proto/onnx.proto +++ /dev/null @@ -1,569 +0,0 @@ -// -// WARNING: This file is automatically generated! Please edit onnx.in.proto. -// - - -// Copyright (c) ONNX Project Contributors. -// Licensed under the MIT license. - -syntax = "proto2"; - -package onnx; - -// Overview -// -// ONNX is an open specification that is comprised of the following components: -// -// 1) A definition of an extensible computation graph model. -// 2) Definitions of standard data types. -// 3) Definitions of built-in operators. -// -// This document describes the syntax of models and their computation graphs, -// as well as the standard data types. Together, they are referred to as the ONNX -// Intermediate Representation, or 'IR' for short. -// -// The normative semantic specification of the ONNX IR is found in docs/IR.md. -// Definitions of the built-in neural network operators may be found in docs/Operators.md. - -// Notes -// -// Release -// -// We are still in the very early stage of defining ONNX. The current -// version of ONNX is a starting point. While we are actively working -// towards a complete spec, we would like to get the community involved -// by sharing our working version of ONNX. -// -// Protobuf compatibility -// -// To simplify framework compatibility, ONNX is defined using the subset of protobuf -// that is compatible with both protobuf v2 and v3. This means that we do not use any -// protobuf features that are only available in one of the two versions. -// -// Here are the most notable contortions we have to carry out to work around -// these limitations: -// -// - No 'map' (added protobuf 3.0). We instead represent mappings as lists -// of key-value pairs, where order does not matter and duplicates -// are not allowed. - - -// Versioning -// -// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md -// -// To be compatible with both proto2 and proto3, we will use a version number -// that is not defined by the default value but an explicit enum number. -enum Version { - // proto3 requires the first enum value to be zero. - // We add this just to appease the compiler. - _START_VERSION = 0; - // The version field is always serialized and we will use it to store the - // version that the graph is generated from. This helps us set up version - // control. - // For the IR, we are using simple numbers starting with with 0x00000001, - // which was the version we published on Oct 10, 2017. - IR_VERSION_2017_10_10 = 0x0000000000000001; - - // IR_VERSION 2 published on Oct 30, 2017 - // - Added type discriminator to AttributeProto to support proto3 users - IR_VERSION_2017_10_30 = 0x0000000000000002; - - // IR VERSION 3 published on Nov 3, 2017 - // - For operator versioning: - // - Added new message OperatorSetIdProto - // - Added opset_import in ModelProto - // - For vendor extensions, added domain in NodeProto - IR_VERSION_2017_11_3 = 0x0000000000000003; - - // IR VERSION 4 published on Jan 22, 2019 - // - Relax constraint that initializers should be a subset of graph inputs - // - Add type BFLOAT16 - IR_VERSION_2019_1_22 = 0x0000000000000004; - - // IR VERSION 5 published on March 18, 2019 - // - Add message TensorAnnotation. - // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. - IR_VERSION_2019_3_18 = 0x0000000000000005; - - // IR VERSION 6 published on Sep 19, 2019 - // - Add support for sparse tensor constants stored in model. - // - Add message SparseTensorProto - // - Add sparse initializers - IR_VERSION = 0x0000000000000006; -} - -// Attributes -// -// A named attribute containing either singular float, integer, string, graph, -// and tensor values, or repeated float, integer, string, graph, and tensor values. -// An AttributeProto MUST contain the name field, and *only one* of the -// following content fields, effectively enforcing a C/C++ union equivalent. -message AttributeProto { - - // Note: this enum is structurally identical to the OpSchema::AttrType - // enum defined in schema.h. If you rev one, you likely need to rev the other. - enum AttributeType { - UNDEFINED = 0; - FLOAT = 1; - INT = 2; - STRING = 3; - TENSOR = 4; - GRAPH = 5; - SPARSE_TENSOR = 11; - - FLOATS = 6; - INTS = 7; - STRINGS = 8; - TENSORS = 9; - GRAPHS = 10; - SPARSE_TENSORS = 12; - } - - // The name field MUST be present for this version of the IR. - optional string name = 1; // namespace Attribute - - // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. - // In this case, this AttributeProto does not contain data, and it's a reference of attribute - // in parent scope. - // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. - optional string ref_attr_name = 21; - - // A human-readable documentation for this attribute. Markdown is allowed. - optional string doc_string = 13; - - // The type field MUST be present for this version of the IR. - // For 0.0.1 versions of the IR, this field was not defined, and - // implementations needed to use has_field hueristics to determine - // which value field was in use. For IR_VERSION 0.0.2 or later, this - // field MUST be set and match the f|i|s|t|... field in use. This - // change was made to accomodate proto3 implementations. - optional AttributeType type = 20; // discriminator that indicates which field below is in use - - // Exactly ONE of the following fields must be present for this version of the IR - optional float f = 2; // float - optional int64 i = 3; // int - optional bytes s = 4; // UTF-8 string - optional TensorProto t = 5; // tensor value - optional GraphProto g = 6; // graph - optional SparseTensorProto sparse_tensor = 22; // sparse tensor value - // Do not use field below, it's deprecated. - // optional ValueProto v = 12; // value - subsumes everything but graph - - repeated float floats = 7; // list of floats - repeated int64 ints = 8; // list of ints - repeated bytes strings = 9; // list of UTF-8 strings - repeated TensorProto tensors = 10; // list of tensors - repeated GraphProto graphs = 11; // list of graph - repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors -} - -// Defines information on value, including the name, the type, and -// the shape of the value. -message ValueInfoProto { - // This field MUST be present in this version of the IR. - optional string name = 1; // namespace Value - // This field MUST be present in this version of the IR for - // inputs and outputs of the top-level graph. - optional TypeProto type = 2; - // A human-readable documentation for this value. Markdown is allowed. - optional string doc_string = 3; -} - -// Nodes -// -// Computation graphs are made up of a DAG of nodes, which represent what is -// commonly called a "layer" or "pipeline stage" in machine learning frameworks. -// -// For example, it can be a node of type "Conv" that takes in an image, a filter -// tensor and a bias tensor, and produces the convolved output. -message NodeProto { - repeated string input = 1; // namespace Value - repeated string output = 2; // namespace Value - - // An optional identifier for this node in a graph. - // This field MAY be absent in ths version of the IR. - optional string name = 3; // namespace Node - - // The symbolic identifier of the Operator to execute. - optional string op_type = 4; // namespace Operator - // The domain of the OperatorSet that specifies the operator named by op_type. - optional string domain = 7; // namespace Domain - - // Additional named attributes. - repeated AttributeProto attribute = 5; - - // A human-readable documentation for this node. Markdown is allowed. - optional string doc_string = 6; -} - -// Models -// -// ModelProto is a top-level file/container format for bundling a ML model and -// associating its computation graph with metadata. -// -// The semantics of the model are described by the associated GraphProto. -message ModelProto { - // The version of the IR this model targets. See Version enum above. - // This field MUST be present. - optional int64 ir_version = 1; - - // The OperatorSets this model relies on. - // All ModelProtos MUST have at least one entry that - // specifies which version of the ONNX OperatorSet is - // being imported. - // - // All nodes in the ModelProto's graph will bind against the operator - // with the same-domain/same-op_type operator with the HIGHEST version - // in the referenced operator sets. - repeated OperatorSetIdProto opset_import = 8; - - // The name of the framework or tool used to generate this model. - // This field SHOULD be present to indicate which implementation/tool/framework - // emitted the model. - optional string producer_name = 2; - - // The version of the framework or tool used to generate this model. - // This field SHOULD be present to indicate which implementation/tool/framework - // emitted the model. - optional string producer_version = 3; - - // Domain name of the model. - // We use reverse domain names as name space indicators. For example: - // `com.facebook.fair` or `com.microsoft.cognitiveservices` - // - // Together with `model_version` and GraphProto.name, this forms the unique identity of - // the graph. - optional string domain = 4; - - // The version of the graph encoded. See Version enum below. - optional int64 model_version = 5; - - // A human-readable documentation for this model. Markdown is allowed. - optional string doc_string = 6; - - // The parameterized graph that is evaluated to execute the model. - optional GraphProto graph = 7; - - // Named metadata values; keys should be distinct. - repeated StringStringEntryProto metadata_props = 14; -}; - -// StringStringEntryProto follows the pattern for cross-proto-version maps. -// See https://developers.google.com/protocol-buffers/docs/proto3#maps -message StringStringEntryProto { - optional string key = 1; - optional string value= 2; -}; - -message TensorAnnotation { - optional string tensor_name = 1; - // pairs to annotate tensor specified by above. - // The keys used in the mapping below must be pre-defined in ONNX spec. - // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as - // quantization parameter keys. - repeated StringStringEntryProto quant_parameter_tensor_names = 2; -} - - - -// Graphs -// -// A graph defines the computational logic of a model and is comprised of a parameterized -// list of nodes that form a directed acyclic graph based on their inputs and outputs. -// This is the equivalent of the "network" or "graph" in many deep learning -// frameworks. -message GraphProto { - // The nodes in the graph, sorted topologically. - repeated NodeProto node = 1; - - // The name of the graph. - optional string name = 2; // namespace Graph - - // A list of named tensor values, used to specify constant inputs of the graph. - // Each TensorProto entry must have a distinct name (within the list) that - // MAY also appear in the input list. - repeated TensorProto initializer = 5; - - // Initializers (see above) stored in sparse format. - repeated SparseTensorProto sparse_initializer = 15; - - // A human-readable documentation for this graph. Markdown is allowed. - optional string doc_string = 10; - - // The inputs and outputs of the graph. - repeated ValueInfoProto input = 11; - repeated ValueInfoProto output = 12; - - // Information for the values in the graph. The ValueInfoProto.name's - // must be distinct. It is optional for a value to appear in value_info list. - repeated ValueInfoProto value_info = 13; - - // This field carries information to indicate the mapping among a tensor and its - // quantization parameter tensors. For example: - // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, - // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. - repeated TensorAnnotation quantization_annotation = 14; - - // DO NOT USE the following fields, they were deprecated from earlier versions. - // repeated string input = 3; - // repeated string output = 4; - // optional int64 ir_version = 6; - // optional int64 producer_version = 7; - // optional string producer_tag = 8; - // optional string domain = 9; -} - -// Tensors -// -// A serialized tensor value. -message TensorProto { - enum DataType { - UNDEFINED = 0; - // Basic types. - FLOAT = 1; // float - UINT8 = 2; // uint8_t - INT8 = 3; // int8_t - UINT16 = 4; // uint16_t - INT16 = 5; // int16_t - INT32 = 6; // int32_t - INT64 = 7; // int64_t - STRING = 8; // string - BOOL = 9; // bool - - // IEEE754 half-precision floating-point format (16 bits wide). - // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. - FLOAT16 = 10; - - DOUBLE = 11; - UINT32 = 12; - UINT64 = 13; - COMPLEX64 = 14; // complex with float32 real and imaginary components - COMPLEX128 = 15; // complex with float64 real and imaginary components - - // Non-IEEE floating-point format based on IEEE754 single-precision - // floating-point number truncated to 16 bits. - // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. - BFLOAT16 = 16; - - // Future extensions go here. - } - - // The shape of the tensor. - repeated int64 dims = 1; - - // The data type of the tensor. - // This field MUST have a valid TensorProto.DataType value - optional int32 data_type = 2; - - // For very large tensors, we may want to store them in chunks, in which - // case the following fields will specify the segment that is stored in - // the current TensorProto. - message Segment { - optional int64 begin = 1; - optional int64 end = 2; - } - optional Segment segment = 3; - - // Tensor content must be organized in row-major order. - // - // Depending on the data_type field, exactly one of the fields below with - // name ending in _data is used to store the elements of the tensor. - - // For float and complex64 values - // Complex64 tensors are encoded as a single array of floats, - // with the real components appearing in odd numbered positions, - // and the corresponding imaginary component apparing in the - // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] - // is encoded as [1.0, 2.0 ,3.0 ,4.0] - // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. - repeated float float_data = 4 [packed = true]; - - // For int32, uint8, int8, uint16, int16, bool, and float16 values - // float16 values must be bit-wise converted to an uint16_t prior - // to writing to the buffer. - // When this field is present, the data_type field MUST be - // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 - repeated int32 int32_data = 5 [packed = true]; - - // For strings. - // Each element of string_data is a UTF-8 encoded Unicode - // string. No trailing null, no leading BOM. The protobuf "string" - // scalar type is not used to match ML community conventions. - // When this field is present, the data_type field MUST be STRING - repeated bytes string_data = 6; - - // For int64. - // When this field is present, the data_type field MUST be INT64 - repeated int64 int64_data = 7 [packed = true]; - - // Optionally, a name for the tensor. - optional string name = 8; // namespace Value - - // A human-readable documentation for this tensor. Markdown is allowed. - optional string doc_string = 12; - - // Serializations can either use one of the fields above, or use this - // raw bytes field. The only exception is the string case, where one is - // required to store the content in the repeated bytes string_data field. - // - // When this raw_data field is used to store tensor value, elements MUST - // be stored in as fixed-width, little-endian order. - // Floating-point data types MUST be stored in IEEE 754 format. - // Complex64 elements must be written as two consecutive FLOAT values, real component first. - // Complex128 elements must be written as two consecutive DOUBLE values, real component first. - // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). - // - // Note: the advantage of specific field rather than the raw_data field is - // that in some cases (e.g. int data), protobuf does a better packing via - // variable length storage, and may lead to smaller binary footprint. - // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED - optional bytes raw_data = 9; - - // Data can be stored inside the protobuf file using type-specific fields or raw_data. - // Alternatively, raw bytes data can be stored in an external file, using the external_data field. - // external_data stores key-value pairs describing data location. Recognized keys are: - // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX - // protobuf model was stored - // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. - // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. - // - "length" (optional) - number of bytes containing data. Integer stored as string. - // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. - repeated StringStringEntryProto external_data = 13; - - // Location of the data for this tensor. MUST be one of: - // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. - // - EXTERNAL - data stored in an external location as described by external_data field. - enum DataLocation { - DEFAULT = 0; - EXTERNAL = 1; - } - - // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. - optional DataLocation data_location = 14; - - // For double - // Complex128 tensors are encoded as a single array of doubles, - // with the real components appearing in odd numbered positions, - // and the corresponding imaginary component apparing in the - // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] - // is encoded as [1.0, 2.0 ,3.0 ,4.0] - // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 - repeated double double_data = 10 [packed = true]; - - // For uint64 and uint32 values - // When this field is present, the data_type field MUST be - // UINT32 or UINT64 - repeated uint64 uint64_data = 11 [packed = true]; -} - -// A serialized sparse-tensor value -message SparseTensorProto { - // The sequence of non-default values are encoded as a tensor of shape [NNZ]. - // The default-value is zero for numeric tensors, and empty-string for string tensors. - optional TensorProto values = 1; - - // The indices of the non-default values, which may be stored in one of two formats. - // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value - // corresponding to the j-th index of the i-th value (in the values tensor). - // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value - // must be the linearized-index of the i-th value (in the values tensor). - // The linearized-index can be converted into an index tuple (k_1,...,k_rank) - // using the shape provided below. - // The indices must appear in ascending order without duplication. - // In the first format, the ordering is lexicographic-ordering: - // e.g., index-value [1,4] must appear before [2,1] - optional TensorProto indices = 2; - - // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] - repeated int64 dims = 3; -} - -// Defines a tensor shape. A dimension can be either an integer value -// or a symbolic variable. A symbolic variable represents an unknown -// dimension. -message TensorShapeProto { - message Dimension { - oneof value { - int64 dim_value = 1; - string dim_param = 2; // namespace Shape - }; - // Standard denotation can optionally be used to denote tensor - // dimensions with standard semantic descriptions to ensure - // that operations are applied to the correct axis of a tensor. - // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition - // for pre-defined dimension denotations. - optional string denotation = 3; - }; - repeated Dimension dim = 1; -} - -// Types -// -// The standard ONNX data types. -message TypeProto { - - message Tensor { - // This field MUST NOT have the value of UNDEFINED - // This field MUST have a valid TensorProto.DataType value - // This field MUST be present for this version of the IR. - optional int32 elem_type = 1; - optional TensorShapeProto shape = 2; - } - - // repeated T - message Sequence { - // The type and optional shape of each element of the sequence. - // This field MUST be present for this version of the IR. - optional TypeProto elem_type = 1; - }; - - // map - message Map { - // This field MUST have a valid TensorProto.DataType value - // This field MUST be present for this version of the IR. - // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING - optional int32 key_type = 1; - // This field MUST be present for this version of the IR. - optional TypeProto value_type = 2; - }; - - - oneof value { - // The type of a tensor. - Tensor tensor_type = 1; - - // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values - // as input and output to graphs and nodes. These types are needed to naturally - // support classical ML operators. DNN operators SHOULD restrict their input - // and output types to tensors. - - // The type of a sequence. - Sequence sequence_type = 4; - - // The type of a map. - Map map_type = 5; - - } - - // An optional denotation can be used to denote the whole - // type with a standard semantic description as to what is - // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition - // for pre-defined type denotations. - optional string denotation = 6; -} - -// Operator Sets -// -// OperatorSets are uniquely identified by a (domain, opset_version) pair. -message OperatorSetIdProto { - // The domain of the operator set being identified. - // The empty string ("") or absence of this field implies the operator - // set that is defined as part of the ONNX specification. - // This field MUST be present in this version of the IR when referring to any other operator set. - optional string domain = 1; - - // The version of the operator set being identified. - // This field MUST be present in this version of the IR. - optional int64 version = 2; -} \ No newline at end of file diff --git a/third_party/patch/securec/0001-add-securec-cmake-script.patch b/third_party/patch/securec/0001-add-securec-cmake-script.patch new file mode 100644 index 00000000..0fcf50c4 --- /dev/null +++ b/third_party/patch/securec/0001-add-securec-cmake-script.patch @@ -0,0 +1,105 @@ +From 455c9812d70646fe725896d597d6c953bf5a09ac Mon Sep 17 00:00:00 2001 +From: taoxiangdong +Date: Wed, 14 Oct 2020 22:14:01 +0800 +Subject: [PATCH] add securec cmake script + +--- + CMakeLists.txt | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++ + 1 file changed, 86 insertions(+) + create mode 100755 CMakeLists.txt + +diff --git a/CMakeLists.txt b/CMakeLists.txt +new file mode 100755 +index 0000000..9b91fb2 +--- /dev/null ++++ b/CMakeLists.txt +@@ -0,0 +1,86 @@ ++cmake_minimum_required(VERSION 3.14) ++project(Securec) ++file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} ++ "src/vsprintf_s.c" ++ "src/wmemmove_s.c" ++ "src/strncat_s.c" ++ "src/vsnprintf_s.c" ++ "src/fwscanf_s.c" ++ "src/scanf_s.c" ++ "src/strcat_s.c" ++ "src/sscanf_s.c" ++ "src/secureprintoutput_w.c" ++ "src/wmemcpy_s.c" ++ "src/wcsncat_s.c" ++ "src/secureprintoutput_a.c" ++ "src/secureinput_w.c" ++ "src/memcpy_s.c" ++ "src/fscanf_s.c" ++ "src/vswscanf_s.c" ++ "src/secureinput_a.c" ++ "src/sprintf_s.c" ++ "src/memmove_s.c" ++ "src/swscanf_s.c" ++ "src/snprintf_s.c" ++ "src/vscanf_s.c" ++ "src/vswprintf_s.c" ++ "src/wcscpy_s.c" ++ "src/vfwscanf_s.c" ++ "src/memset_s.c" ++ "src/wscanf_s.c" ++ "src/vwscanf_s.c" ++ "src/strtok_s.c" ++ "src/wcsncpy_s.c" ++ "src/vfscanf_s.c" ++ "src/vsscanf_s.c" ++ "src/wcstok_s.c" ++ "src/securecutil.c" ++ "src/gets_s.c" ++ "src/swprintf_s.c" ++ "src/strcpy_s.c" ++ "src/wcscat_s.c" ++ "src/strncpy_s.c" ++ ) ++ ++include_directories(./include) ++include_directories(./src) ++add_library(shared_c_sec SHARED ${SRC_LIST}) ++ ++target_compile_options(shared_c_sec PRIVATE ++ -I/usr/local/include ++ -Werror ++ -Wall ++ -O1 ++) ++target_compile_definitions(shared_c_sec PRIVATE ++ NDEBUG ++ SECUREC_SUPPORT_STRTOLD=1 ++ ) ++ ++add_library(static_c_sec STATIC ${SRC_LIST}) ++ ++target_compile_options(static_c_sec PRIVATE ++ -I/usr/local/include ++ -Werror ++ -Wall ++ -O1 ++) ++ ++target_compile_definitions(static_c_sec PRIVATE ++ NDEBUG ++ SECUREC_SUPPORT_STRTOLD=1 ++ ) ++ ++set_target_properties(static_c_sec ++ PROPERTIES ++ OUTPUT_NAME c_sec ++) ++set_target_properties(shared_c_sec ++ PROPERTIES ++ OUTPUT_NAME c_sec ++) ++install(TARGETS shared_c_sec static_c_sec OPTIONAL ++ DESTINATION lib) ++install(FILES "./include/securec.h" ++ "./include/securectype.h" ++ DESTINATION include) +-- +2.17.1 + From cd9185189b10fe2ecf300dcae60f754a076671ef Mon Sep 17 00:00:00 2001 From: w00562650 Date: Tue, 24 Nov 2020 14:57:53 +0800 Subject: [PATCH 02/13] update json --- ge/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index cd4d0c92..62b8a69b 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -385,7 +385,6 @@ target_link_libraries(ge_runner error_manager ascend_hal_stub -Wl,--as-needed - json -lrt -ldl ) From 0cfdcdb0966a26ef19d81831eebb46402ff0d8b8 Mon Sep 17 00:00:00 2001 From: w00562650 Date: Tue, 24 Nov 2020 15:49:38 +0800 Subject: [PATCH 03/13] update submodule --- .gitmodules | 2 -- metadef | 1 + parser | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) create mode 160000 metadef create mode 160000 parser diff --git a/.gitmodules b/.gitmodules index a2b1f260..039be1d4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,8 +1,6 @@ [submodule "metadef"] path = metadef url = https://gitee.com/ascend/metadef.git - branch = master [submodule "parser"] path = parser url = https://gitee.com/ascend/parser.git - branch = master diff --git a/metadef b/metadef new file mode 160000 index 00000000..31a89522 --- /dev/null +++ b/metadef @@ -0,0 +1 @@ +Subproject commit 31a89522398f697410087724885fc7f74d9e7117 diff --git a/parser b/parser new file mode 160000 index 00000000..76a862b1 --- /dev/null +++ b/parser @@ -0,0 +1 @@ +Subproject commit 76a862b1bced4c0c2ca675f0a619ba06ada973b0 From 61f5835f2ecb90789c88362f575f872eea128e0c Mon Sep 17 00:00:00 2001 From: w00562650 Date: Tue, 24 Nov 2020 16:16:22 +0800 Subject: [PATCH 04/13] =?UTF-8?q?Update=EF=BC=9AThe=20directory=20must=20b?= =?UTF-8?q?e=20the=20same=20as=20the=20dev=20branch.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ge/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 62b8a69b..4847daf9 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -384,6 +384,7 @@ target_link_libraries(ge_runner resource error_manager ascend_hal_stub + json -Wl,--as-needed -lrt -ldl From 5292a506a76d4a449d15bdd7255611d2283e02d8 Mon Sep 17 00:00:00 2001 From: wqtshg Date: Tue, 24 Nov 2020 19:53:35 +0800 Subject: [PATCH 05/13] update datatype --- ge/graph/manager/util/hcom_util.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ge/graph/manager/util/hcom_util.h b/ge/graph/manager/util/hcom_util.h index 064058f8..6ae1c019 100644 --- a/ge/graph/manager/util/hcom_util.h +++ b/ge/graph/manager/util/hcom_util.h @@ -39,6 +39,8 @@ static std::map kConstOpHcclDataType = { {ge::DT_FLOAT16, HCCL_DATA_TYPE_FP16}, {ge::DT_INT8, HCCL_DATA_TYPE_INT8}, {ge::DT_INT32, HCCL_DATA_TYPE_INT32}, + {ge::DT_INT64, HCCL_DATA_TYPE_INT64}, + {ge::DT_UINT64, HCCL_DATA_TYPE_UINT64}, }; static std::map kConstOpHcclDataTypeSize = { @@ -46,6 +48,8 @@ static std::map kConstOpHcclDataTypeSize = { {HCCL_DATA_TYPE_FP16, sizeof(float) / 2}, {HCCL_DATA_TYPE_INT8, sizeof(int8_t)}, {HCCL_DATA_TYPE_INT32, sizeof(int32_t)}, + {HCCL_DATA_TYPE_INT64, sizeof(int64_t)}, + {HCCL_DATA_TYPE_UINT64, sizeof(uint64_t)}, }; static std::map kHorovodRedOpToHcclRedOp = { From 73e3484b2c9758f73c071cb772c92164e7132582 Mon Sep 17 00:00:00 2001 From: wqtshg Date: Tue, 24 Nov 2020 20:35:40 +0800 Subject: [PATCH 06/13] update build.sh --- build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.sh b/build.sh index b693ba74..73bb7aa2 100644 --- a/build.sh +++ b/build.sh @@ -169,7 +169,7 @@ build_graphengine() elif [ "x${PLATFORM}" = "xall" ] then # build all the target - TARGET="" + TARGET="ge_runner ge_local_engine host_cpu_engine ge_compiler atc_ge_local_engine atc_host_cpu_engine atc opensrc_ascendcl ${TARGET}" fi make ${VERBOSE} ${TARGET} -j${THREAD_NUM} && make install From f5e161b87635ee688a286c1c43521686525e32c1 Mon Sep 17 00:00:00 2001 From: wqtshg Date: Wed, 25 Nov 2020 10:40:57 +0800 Subject: [PATCH 07/13] update build.sh --- build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.sh b/build.sh index 73bb7aa2..b8fd9e9a 100644 --- a/build.sh +++ b/build.sh @@ -276,7 +276,7 @@ generate_package() fi for lib in "${PLUGIN_OPSKERNEL[@]}"; do - find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth ${MAX_DEPTH} -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH}/${OPSKERNEL_PATH} \; + find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH}/${OPSKERNEL_PATH} \; find ${OUTPUT_PATH}/${GRAPHENGINE_LIB_PATH} -maxdepth ${MAX_DEPTH} -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH}/${OPSKERNEL_PATH} \; done From b409e6a442f2ab848dc6c16a3bb2870a44612bc1 Mon Sep 17 00:00:00 2001 From: wqtshg Date: Wed, 25 Nov 2020 11:38:03 +0800 Subject: [PATCH 08/13] update master directory --- CMakeLists.txt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9a9a7a9d..7c2fef72 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,13 +103,13 @@ if (ENABLE_OPEN_SRC) elseif(PLATFORM STREQUAL "all") find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) - find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) - find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) + find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) + find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) - find_module(resource libresource.so ${ASCEND_ATC_DIR}) - find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) + find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) + find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) - find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) + find_module(msprofiler libmsprofiler.a ${ASCEND_RUNTIME_DIR}) find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) else() message(FATAL_ERROR "PLATFORM param is invalid, should be train or inference, build terminated") From 3ad7da206e6439a1a544d27db2fd111f62bcce91 Mon Sep 17 00:00:00 2001 From: wqtshg Date: Wed, 25 Nov 2020 14:55:55 +0800 Subject: [PATCH 09/13] update master directory --- CMakeLists.txt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c2fef72..39903194 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,13 +103,13 @@ if (ENABLE_OPEN_SRC) elseif(PLATFORM STREQUAL "all") find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) - find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) - find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) + find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) + find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) - find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) - find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) + find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) + find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) - find_module(msprofiler libmsprofiler.a ${ASCEND_RUNTIME_DIR}) + find_module(msprofiler libmsprofiler.a ${ASCEND_RUNTIME_DIR}) find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) else() message(FATAL_ERROR "PLATFORM param is invalid, should be train or inference, build terminated") From 06e33a7c0453e0b7b3ff3306bc4c868d781d4cf7 Mon Sep 17 00:00:00 2001 From: wqtshg Date: Wed, 25 Nov 2020 15:07:05 +0800 Subject: [PATCH 10/13] update master directory --- CMakeLists.txt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 39903194..9a9a7a9d 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,13 +103,13 @@ if (ENABLE_OPEN_SRC) elseif(PLATFORM STREQUAL "all") find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) - find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) - find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) + find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) + find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) - find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) - find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) + find_module(resource libresource.so ${ASCEND_ATC_DIR}) + find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) - find_module(msprofiler libmsprofiler.a ${ASCEND_RUNTIME_DIR}) + find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) else() message(FATAL_ERROR "PLATFORM param is invalid, should be train or inference, build terminated") From d16cf136c27a907abeea2993e7925cd01123f983 Mon Sep 17 00:00:00 2001 From: wqtshg Date: Wed, 25 Nov 2020 15:29:51 +0800 Subject: [PATCH 11/13] update master directory --- CMakeLists.txt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9a9a7a9d..c511c3a6 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,13 +103,13 @@ if (ENABLE_OPEN_SRC) elseif(PLATFORM STREQUAL "all") find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) - find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) - find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) + find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) - find_module(resource libresource.so ${ASCEND_ATC_DIR}) - find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) + find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) + find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) + find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) - find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) + find_module(msprofiler libmsprofiler.a ${ASCEND_RUNTIME_DIR}) find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) else() message(FATAL_ERROR "PLATFORM param is invalid, should be train or inference, build terminated") From 66e9b44b89b1978182d7f3bfb20299100a5c0dab Mon Sep 17 00:00:00 2001 From: wqtshg Date: Wed, 25 Nov 2020 15:36:45 +0800 Subject: [PATCH 12/13] update master directory --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c511c3a6..3a9f9477 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -105,7 +105,7 @@ if (ENABLE_OPEN_SRC) find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) - find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) + find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) From 66f563a97d5e9d4d9e671398b1e228965eb562e6 Mon Sep 17 00:00:00 2001 From: wqtshg Date: Wed, 25 Nov 2020 16:08:25 +0800 Subject: [PATCH 13/13] update master directory --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3a9f9477..2a69b12c 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -105,8 +105,8 @@ if (ENABLE_OPEN_SRC) find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) - find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) - find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) + find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) + find_module(resource libresource.so ${ASCEND_ATC_DIR}) find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) find_module(msprofiler libmsprofiler.a ${ASCEND_RUNTIME_DIR})