Browse Source

!22 Update GraphEngine to synchronize with latest Ascend driver software suite 4 May 2020

Merge pull request !22 from yanghaoran/master
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
63cb729373
100 changed files with 3829 additions and 1166 deletions
  1. +1
    -2
      inc/common/blocking_queue.h
  2. +8
    -8
      inc/common/dynamic_aipp.h
  3. +4
    -4
      inc/common/npu_error_define.h
  4. +1
    -2
      inc/common/opskernel/ge_task_info.h
  5. +8
    -8
      inc/common/opskernel/ops_kernel_info_store.h
  6. +1
    -0
      inc/common/opskernel/ops_kernel_info_types.h
  7. +4
    -5
      inc/common/optimizer/graph_optimizer.h
  8. +1
    -3
      inc/common/optimizer/graph_optimizer_types.h
  9. +51
    -2
      inc/external/ge/ge_api_types.h
  10. +75
    -0
      inc/external/ge/ge_ir_build.h
  11. +1
    -1
      inc/external/graph/attr_value.h
  12. +1
    -1
      inc/external/graph/graph.h
  13. +3
    -3
      inc/external/graph/inference_context.h
  14. +3
    -3
      inc/external/graph/operator.h
  15. +2
    -2
      inc/external/graph/operator_factory.h
  16. +23
    -23
      inc/external/graph/operator_reg.h
  17. +2
    -2
      inc/external/graph/tensor.h
  18. +3
    -1
      inc/external/graph/types.h
  19. +35
    -0
      inc/external/register/register.h
  20. +2
    -1
      inc/external/register/register_error_codes.h
  21. +4
    -0
      inc/external/register/register_types.h
  22. +40
    -27
      inc/framework/common/debug/ge_log.h
  23. +56
    -99
      inc/framework/common/debug/log.h
  24. +4
    -4
      inc/framework/common/ge_inner_error_codes.h
  25. +27
    -7
      inc/framework/common/ge_types.h
  26. +23
    -19
      inc/framework/common/helper/model_helper.h
  27. +6
    -3
      inc/framework/common/helper/om_file_helper.h
  28. +1
    -1
      inc/framework/common/l2_cache_optimize.h
  29. +4
    -0
      inc/framework/common/op/attr_define.h
  30. +3
    -2
      inc/framework/common/op/attr_value_util.h
  31. +21
    -6
      inc/framework/common/op/ge_op_utils.h
  32. +2
    -2
      inc/framework/common/op/op_parser_util.h
  33. +1
    -1
      inc/framework/common/scope_guard.h
  34. +20
    -3
      inc/framework/common/types.h
  35. +68
    -96
      inc/framework/common/util.h
  36. +2
    -2
      inc/framework/dlog/log.h
  37. +1
    -1
      inc/framework/engine/dnnengine.h
  38. +60
    -10
      inc/framework/executor/ge_executor.h
  39. +5
    -1
      inc/framework/generator/ge_generator.h
  40. +0
    -1
      inc/framework/generator/generator_api.h
  41. +2
    -2
      inc/framework/memory/memory_assigner.h
  42. +17
    -16
      inc/framework/omg/omg_inner_types.h
  43. +0
    -8
      inc/framework/omg/version.h
  44. +16
    -10
      inc/graph/anchor.h
  45. +6
    -7
      inc/graph/attr_value_serializable.h
  46. +2
    -3
      inc/graph/buffer.h
  47. +60
    -17
      inc/graph/compute_graph.h
  48. +276
    -46
      inc/graph/debug/ge_attr_define.h
  49. +3
    -6
      inc/graph/def_types.h
  50. +1
    -2
      inc/graph/detail/attributes_holder.h
  51. +4
    -6
      inc/graph/detail/model_serialize_imp.h
  52. +2
    -4
      inc/graph/ge_attr_value.h
  53. +1
    -2
      inc/graph/ge_context.h
  54. +0
    -2
      inc/graph/ge_local_context.h
  55. +16
    -6
      inc/graph/ge_tensor.h
  56. +1
    -2
      inc/graph/model.h
  57. +1
    -2
      inc/graph/model_serialize.h
  58. +7
    -7
      inc/graph/node.h
  59. +23
    -4
      inc/graph/op_desc.h
  60. +0
    -2
      inc/graph/operator_factory_impl.h
  61. +4
    -2
      inc/graph/shape_refiner.h
  62. +3
    -3
      inc/graph/usr_types.h
  63. +3
    -2
      inc/graph/utils/attr_utils.h
  64. +321
    -4
      inc/graph/utils/graph_utils.h
  65. +5
    -0
      inc/graph/utils/node_utils.h
  66. +89
    -38
      inc/graph/utils/op_desc_utils.h
  67. +7
    -8
      inc/graph/utils/tensor_utils.h
  68. +1
    -0
      src/common/graph/CMakeLists.txt
  69. +2
    -0
      src/common/graph/anchor.cc
  70. +1
    -2
      src/common/graph/buffer.cc
  71. +200
    -23
      src/common/graph/compute_graph.cc
  72. +35
    -69
      src/common/graph/debug/ge_log.h
  73. +0
    -1
      src/common/graph/debug/ge_util.h
  74. +0
    -2
      src/common/graph/debug/graph_debug.cc
  75. +0
    -2
      src/common/graph/debug/graph_debug.h
  76. +0
    -2
      src/common/graph/detail/attributes_holder.cc
  77. +18
    -17
      src/common/graph/format_refiner.cc
  78. +248
    -33
      src/common/graph/ge_attr_define.cc
  79. +42
    -7
      src/common/graph/ge_attr_value.cc
  80. +31
    -3
      src/common/graph/ge_tensor.cc
  81. +13
    -4
      src/common/graph/model.cc
  82. +23
    -14
      src/common/graph/model_serialize.cc
  83. +4
    -3
      src/common/graph/node.cc
  84. +158
    -35
      src/common/graph/op_desc.cc
  85. +1
    -2
      src/common/graph/op_imp.cc
  86. +31
    -22
      src/common/graph/operator.cc
  87. +2
    -4
      src/common/graph/opsproto/opsproto_manager.cc
  88. +1
    -1
      src/common/graph/option/ge_context.cc
  89. +144
    -6
      src/common/graph/shape_refiner.cc
  90. +5
    -3
      src/common/graph/tensor.cc
  91. +163
    -59
      src/common/graph/utils/ge_ir_utils.cc
  92. +14
    -0
      src/common/graph/utils/ge_ir_utils.h
  93. +879
    -65
      src/common/graph/utils/graph_utils.cc
  94. +44
    -0
      src/common/graph/utils/node_utils.cc
  95. +79
    -4
      src/common/graph/utils/op_desc_utils.cc
  96. +16
    -9
      src/common/graph/utils/tensor_utils.cc
  97. +161
    -156
      src/common/graph/utils/type_utils.cc
  98. +40
    -33
      src/ge/CMakeLists.txt
  99. +1
    -1
      src/ge/client/CMakeLists.txt
  100. +25
    -17
      src/ge/client/ge_api.cc

+ 1
- 2
inc/common/blocking_queue.h View File

@@ -18,7 +18,6 @@
#define INC_COMMON_BLOCKING_QUEUE_H_

#include <stdint.h>

#include <condition_variable>
#include <list>
#include <mutex>
@@ -87,7 +86,7 @@ class BlockingQueue {
is_stoped_ = false;
}

// if the queue stop , the function to release the unprocessed items will be call
// if the queue is stoped ,need call this function to release the unprocessed items
std::list<T> GetRemainItems() {
std::unique_lock<std::mutex> lock(mutex_);



+ 8
- 8
inc/common/dynamic_aipp.h View File

@@ -19,10 +19,10 @@

#include <stdint.h>

///
/// @ingroup dnn
/// @brief struct define of dynamic aipp batch parameter.
///
/**
* @ingroup dnn
* @brief struct define of dynamic aipp batch parameter.
*/
typedef struct tagAippDynamicBatchPara {
int8_t cropSwitch; // crop switch
int8_t scfSwitch; // resize switch
@@ -66,10 +66,10 @@ typedef struct tagAippDynamicBatchPara {
int8_t reserve1[16]; // 32B assign, for ub copy
} kAippDynamicBatchPara;

///
/// @ingroup dnn
/// @brief struct definition of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte
///
/**
* @ingroup dnn
* @brief struct define of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte
*/
typedef struct tagAippDynamicPara {
uint8_t inputFormat; // input format:YUV420SP_U8/XRGB8888_U8/RGB888_U8
int8_t cscSwitch; // csc switch


+ 4
- 4
inc/common/npu_error_define.h View File

@@ -61,19 +61,19 @@ typedef enum tagHiAiNpuModuleId {
HIAI_DP = 23,
} HiAiNpuModuleId;

// bit 31-bit30 to be hiai local
/* bit 31-bit30 to be hiai local */
#define HIAI_NPULOCAL_MASK 0xC0000000
#define SHIFT_LOCAL_MASK 30
#define HIAI_NPULOCAL_VAL_MASK 0x3
// bit 29 -bit28 to be hiai aicpu code type
/* bit 29 -bit28 to be hiai aicpu code type */
#define HIAI_CODE_TYPE_MASK 0x30000000
#define SHIFT_CODE_MASK 28
#define HIAI_CODE_TYPE_VAL_MASK 0x3
// bit 27 -bit25 to be hiai error level
/* bit 27 -bit25 to be hiai error level */
#define HIAI_ERROR_LEVEL_MASK 0x0E000000
#define SHIFT_ERROR_LVL_MASK 25
#define HIAI_ERROR_LEVEL_VAL_MASK 0x7
// bit 24 -bit17 to be hiai mod
/* bit 24 -bit17 to be hiai mod */
#define HIAI_MODE_ID_MASK 0x01FE0000
#define SHIFT_MODE_MASK 17
#define HIAI_MODE_ID_VAL_MASK 0xFF


+ 1
- 2
inc/common/opskernel/ge_task_info.h View File

@@ -19,13 +19,12 @@

#include <runtime/rt.h>
#include <stdint.h>

#include <string>
#include <vector>

using std::string;
namespace ge {
// DAVINCI_TRAIN/DAVINCI_CLOUD is not needed when GETaskKernelHcclInfo needed
// when need to eliminate GETaskKernelHcclInfo, so not need DAVINCI_TRAIN/DAVINCI_CLOUD
struct GETaskKernelHcclInfo {
string hccl_type;
void *inputDataAddr;


+ 8
- 8
inc/common/opskernel/ops_kernel_info_store.h View File

@@ -21,7 +21,6 @@
#include <map>
#include <string>
#include <vector>

#include "./ge_task_info.h"
#include "./ops_kernel_info_types.h"
#include "cce/aicpu_engine_struct.h"
@@ -29,7 +28,6 @@
#include "common/ge_inner_error_codes.h"
#include "graph/node.h"
#include "proto/task.pb.h"

using std::map;
using std::string;
using std::to_string;
@@ -47,7 +45,7 @@ class OpsKernelInfoStore {
// initialize opsKernelInfoStore
virtual Status Initialize(const map<string, string> &options) = 0;

// finalize opsKernelInfoStore
// close opsKernelInfoStore
virtual Status Finalize() = 0;

virtual Status CreateSession(const std::map<std::string, std::string> &session_options) { return SUCCESS; }
@@ -57,18 +55,20 @@ class OpsKernelInfoStore {
// get all opsKernelInfo
virtual void GetAllOpsKernelInfo(map<string, OpInfo> &infos) const = 0;

// check whether opsKernelInfoStore is supported based on the operator attribute
// whether the opsKernelInfoStore is supported based on the operator attribute
virtual bool CheckSupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason) const = 0;

virtual bool CheckAccuracySupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason,
bool realQuery = false) const {
return CheckSupported(opDescPtr, un_supported_reason);
}
// opsFlag opsFlag[0] indicates constant folding is supported or not
virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag){};

// requirement of memory allocation
// memory allocation requirement
virtual Status CalcOpRunningParam(Node &node) = 0;

// generate task for op
// generate task for op
virtual Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) = 0;

// only call fe engine interface to compile single op
@@ -77,10 +77,10 @@ class OpsKernelInfoStore {
// load task for op
virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; }

// only to call aicpu interface for generating task struct
// only call aicpu interface to generate task struct
virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; }

// only to call aicpu interface for generating task struct
// only call aicpu interface to generate task struct
virtual Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; }
};
} // namespace ge


+ 1
- 0
inc/common/opskernel/ops_kernel_info_types.h View File

@@ -37,6 +37,7 @@ struct RunContext {
ge::Buffer weightsBuffer;
std::vector<rtStream_t> graphStreamList; // all streams of graph, order by ge stream id(0,1,...)
std::vector<rtEvent_t> graphEventList; // all events of graph, order by ge event id(0,1,...)
std::vector<rtLabel_t> graphLabelList; // all labels of graph, order by ge label id(0,1,...)
};

struct Task {


+ 4
- 5
inc/common/optimizer/graph_optimizer.h View File

@@ -19,7 +19,6 @@

#include <map>
#include <string>

#include "./graph_optimizer_types.h"
#include "common/ge_inner_error_codes.h"
#include "common/opskernel/ops_kernel_info_types.h"
@@ -39,19 +38,19 @@ class GraphOptimizer {
// close graphOptimizer
virtual Status Finalize() = 0;

// optimize original graph for FE quant optimization
// optimize original graph for FE quant optimize
virtual Status OptimizeGraphPrepare(ComputeGraph &graph) { return SUCCESS; }

// optimize original graph used in the graph preparation stage
// optimize original graph, using in graph preparation stage
virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0;

// optimize fused graph
virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0;

// optimize the whole graph which will be used after graph merged
// optimize whole graph, using after graph merged stage
virtual Status OptimizeWholeGraph(ComputeGraph &graph) = 0;

// get attributes of graph optimizer
// get attribute of graph optimizer
virtual Status GetAttributes(GraphOptimizerAttribute &attrs) const = 0;

// optimize streamed Graph


+ 1
- 3
inc/common/optimizer/graph_optimizer_types.h View File

@@ -19,8 +19,6 @@

#include <stdint.h>
#include <string>

using std::string;
namespace ge {
enum OPTIMIZER_SCOPE {
UNIT = 0,
@@ -28,7 +26,7 @@ enum OPTIMIZER_SCOPE {
};

struct GraphOptimizerAttribute {
string engineName;
std::string engineName;
OPTIMIZER_SCOPE scope;
};
} // namespace ge


+ 51
- 2
inc/external/ge/ge_api_types.h View File

@@ -20,6 +20,7 @@
#include <cstdint>
#include <string>
#include <vector>
#include <set>

namespace ge {
// Option key: graph run mode
@@ -38,9 +39,11 @@ const char *const GE_AICPU_FLAG = "ge.aicpuFlag";
const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath";
const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump";
const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath";
const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep";
// Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0
const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag";
const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic";
const char *const OPTION_EXEC_DISABLE_REUSED_MEMORY = "ge.exec.disableReuseMemory";

// Option key: memory init
const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize";
@@ -141,19 +144,43 @@ const std::string STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum";
// congigure outputDatatype to setting net output type
const std::string OUTPUT_DATATYPE = "ge.outputDatatype";

// congigure opSelectImplmode to setting op select implmode
const std::string kOpSelectImplmode = "ge.opSelectImplmode";

// configure whether to enable hcom parallel by session constructor options param,
// its value should be "0" or "1", default value is "0"
const std::string HCOM_PARALLEL = "ge.hcomParallel";

// configure whether to use dynamic batch size
const char *const kDynamicBatchSize = "ge.dynamicBatchSize";

// configure whether to use dynamic image size
const char *const kDynamicImageSize = "ge.dynamicImageSize";

// Configure auto tune mode, this option only take effect while AUTO_TUNE_FLAG is Y,
// example: GA|RL, support configure multiple, split by |
const std::string AUTO_TUNE_MODE = "ge.autoTuneMode";

// Configure soc version , example: "Ascend310"
const std::string SOC_VERSION = "ge.socVersion";

// Configure core type "VectorEngine", default value is "AIcoreEngine"
const std::string CORE_TYPE = "ge.engineType";

// Configure soc version , example: "Ascend310"
const std::string SOC_VERSION = "ge.socVersion";
// Configure AICORE NUM
const std::string AICORE_NUM = "ge.aicoreNum";

// Configure L1FUSION
const std::string L1_FUSION = "ge.l1Fusion";

// Configure Small Channel flag
const std::string ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel";

// Configure Compress Weight flag
const std::string ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight";

// Configure fusion switch file path
const std::string FUSION_SWITCH_FILE = "ge.fusionSwitchFile";

// Save original model
const std::string SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel";
@@ -194,6 +221,28 @@ struct TensorInfo {
DataDesc data; // tensor data
ShapeDesc shapeInfo; // tensor shape
};
// for ir build
namespace ir_option {
static const char *const INPUT_FORMAT = "input_format";
static const char *const INPUT_SHAPE = "input_shape";
static const char *const OP_NAME_MAP = "op_name_map";
static const char *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize;
static const char *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize;
static const char *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str();
static const char *const PRECISION_MODE = ge::PRECISION_MODE.c_str();
static const char *const EXEC_DISABLE_REUSED_MEMORY = ge::OPTION_EXEC_DISABLE_REUSED_MEMORY;
static const char *const HEAD_STREAM = ge::HEAD_STREAM.c_str();
static const char *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str();
static const char *const CORE_TYPE = ge::CORE_TYPE.c_str();
static const char *const SOC_VERSION = ge::SOC_VERSION.c_str();
// for interface: aclgrphBuildModel
const std::set<std::string> ir_builder_suppported_options = {
INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, DYNAMIC_BATCH_SIZE,
DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY,
AUTO_TUNE_MODE};
// for interface: aclgrphBuildInitialize
const std::set<std::string> global_options = {HEAD_STREAM, CORE_TYPE, SOC_VERSION};
} // namespace ir_option
} // namespace ge

#endif // INC_EXTERNAL_GE_GE_API_TYPES_H_

+ 75
- 0
inc/external/ge/ge_ir_build.h View File

@@ -0,0 +1,75 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GE_IR_BUILD_H_
#define INC_EXTERNAL_GE_IR_BUILD_H_

#include <string>
#include <map>
#include <memory>
#include "graph/graph.h"
#include "graph/ge_error_codes.h"

namespace ge {

struct ModelBufferData {
std::shared_ptr<uint8_t> data = nullptr;
uint64_t length;
};

/**
* @ingroup AscendCL
* @brief build model.Notice the model is stored in buffer
*
* @param global_options[IN] global init params for build
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options);

/**
* @ingroup AscendCL
* @brief build model.Notice the model is stored in buffer
*
*/
void aclgrphBuildFinalize();

/**
* @ingroup AscendCL
* @brief build model.Notice the model is stored in buffer
*
* @param graph[IN] the graph ready to build
* @param options[IN] options used for build
* @param model[OUT] builded model
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options,
ModelBufferData &model);

/**
* @ingroup AscendCL
* @brief save model buffer to file
*
* @param output_file[IN] the file path to be saved
* @param model[IN] model buffer data
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model);

}; // namespace ge
#endif

+ 1
- 1
inc/external/graph/attr_value.h View File

@@ -22,7 +22,7 @@
#include <string>
#include <vector>

#include "external/graph/ge_error_codes.h"
#include "./ge_error_codes.h"

using std::make_shared;
using std::map;


+ 1
- 1
inc/external/graph/graph.h View File

@@ -22,7 +22,7 @@
#include <utility>
#include <vector>

#include "external/graph/operator.h"
#include "./operator.h"

namespace ge {
class GraphImpl;


+ 3
- 3
inc/external/graph/inference_context.h View File

@@ -21,8 +21,8 @@
#include <string>
#include <vector>

#include "external/graph/tensor.h"
#include "external/graph/types.h"
#include "./tensor.h"
#include "./types.h"

namespace ge {
class InferenceContext;
@@ -69,7 +69,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext {
static std::unique_ptr<InferenceContext> Create();

private:
InferenceContext(std::unique_ptr<InferenceContextImpl> &impl);
explicit InferenceContext(std::unique_ptr<InferenceContextImpl> &impl);
std::shared_ptr<InferenceContextImpl> inference_context_impl_;
};
} // namespace ge


+ 3
- 3
inc/external/graph/operator.h View File

@@ -23,9 +23,9 @@
#include <string>
#include <vector>

#include "external/graph/ge_error_codes.h"
#include "external/graph/inference_context.h"
#include "external/graph/tensor.h"
#include "./ge_error_codes.h"
#include "./inference_context.h"
#include "./tensor.h"

#ifndef USER_GE_LOGI
#define USER_GE_LOGI(...)


+ 2
- 2
inc/external/graph/operator_factory.h View File

@@ -22,8 +22,8 @@
#include <string>
#include <vector>

#include "external/graph//operator.h"
#include "external/graph/ge_error_codes.h"
#include "./operator.h"
#include "./ge_error_codes.h"

namespace ge {
using OpCreator = std::function<Operator(const std::string &)>;


+ 23
- 23
inc/external/graph/operator_reg.h View File

@@ -22,10 +22,10 @@
#include <string>
#include <vector>

#include "external/graph/operator.h"
#include "external/graph/operator_factory.h"
#include "external/graph/tensor.h"
#include "external/graph/types.h"
#include "./operator.h"
#include "./operator_factory.h"
#include "./tensor.h"
#include "./types.h"

namespace ge {
using std::function;
@@ -60,7 +60,7 @@ class OpReg {
\
private: \
void __##x() { \
OpReg()
OpReg()

#define ATTR(x, Type, ...) \
N(); \
@@ -86,7 +86,7 @@ class OpReg {
void __attr_##x() { \
Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \
string attr_name(#x); \
(void)OpReg()
(void)OpReg()

#define REQUIRED_ATTR(x, Type) \
N(); \
@@ -112,7 +112,7 @@ class OpReg {
void __required_attr_##x() { \
Operator::RequiredAttrRegister(#x); \
string attr_name(#x); \
(void)OpReg()
(void)OpReg()

#define INPUT(x, t) \
N(); \
@@ -137,7 +137,7 @@ class OpReg {
private: \
void __input_##x() { \
Operator::InputRegister(#x); \
(void)OpReg()
(void)OpReg()

#define OPTIONAL_INPUT(x, t) \
N(); \
@@ -162,7 +162,7 @@ class OpReg {
private: \
void __optional_input_##x() { \
Operator::OptionalInputRegister(#x); \
(void)OpReg()
(void)OpReg()

#define OUTPUT(x, t) \
N(); \
@@ -179,7 +179,7 @@ class OpReg {
private: \
void __out_##x() { \
Operator::OutputRegister(#x); \
(void)OpReg()
(void)OpReg()

#define DYNAMIC_INPUT(x, t) \
N(); \
@@ -206,7 +206,7 @@ class OpReg {
\
private: \
void __dy_input_##x() { \
(void)OpReg()
(void)OpReg()

#define DYNAMIC_OUTPUT(x, t) \
N(); \
@@ -227,18 +227,18 @@ class OpReg {
\
private: \
void __dy_output_##x() { \
(void)OpReg()
(void)OpReg()

#define PASTE(g_register, y) g_register##y
#define __OP_END_IMPL__(x, y) \
N(); \
} \
static_assert( \
std::is_same<x, _THIS_TYPE>::value, \
"The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \
} \
; \
static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \
#define __OP_END_IMPL__(x, y) \
N(); \
} \
static_assert( \
std::is_same<x, _THIS_TYPE>::value, \
"The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \
} \
; \
static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \
}
#define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__)

@@ -286,7 +286,7 @@ class OpReg {
// Common shape inferencer

#define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \
[](Operator op)->graphStatus { \
[](Operator op) -> graphStatus { \
auto x_shape = op.GetInputDesc(in_name).GetShape().GetDims(); \
auto x_type = op.GetInputDesc(in_name).GetDataType(); \
TensorDesc op_output_desc = op.GetOutputDesc(out_name); \
@@ -300,7 +300,7 @@ graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape,
const function<void(const vector<int64_t> &y_shape)> &set_out_shape);

#define BROADCAST_INFER(in1_name, in2_name, out_name) \
[](Operator op)->graphStatus { \
[](Operator op) -> graphStatus { \
return BroadCastInfer([&]() { return op.GetInputDesc(in1_name).GetShape().GetDims(); }, \
[&]() { return op.GetInputDesc(in2_name).GetShape().GetDims(); }, \
[&](const vector<int64_t> &y_shape) { \


+ 2
- 2
inc/external/graph/tensor.h View File

@@ -22,8 +22,8 @@
#include <string>
#include <vector>

#include "external/graph/ge_error_codes.h"
#include "external/graph/types.h"
#include "./ge_error_codes.h"
#include "./types.h"

namespace ge {
class ShapeImpl;


+ 3
- 1
inc/external/graph/types.h View File

@@ -133,11 +133,13 @@ enum Format {
FORMAT_FRACTAL_ZZ,
FORMAT_FRACTAL_NZ,
FORMAT_NCDHW,
FORMAT_DHWCK, // 3D filter input tensor format
FORMAT_DHWCN, // 3D filter input tensor format
FORMAT_NDC1HWC0,
FORMAT_FRACTAL_Z_3D,
FORMAT_CN,
FORMAT_NC,
FORMAT_DHWNC,
FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format
FORMAT_RESERVED,
FORMAT_ALL
};


+ 35
- 0
inc/external/register/register.h View File

@@ -47,6 +47,12 @@ class Tensor;
class TBEPluginManager;
} // namespace ge

namespace google {
namespace protobuf {
class Message;
}
} // namespace google

namespace domi {
Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op);
Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op,
@@ -56,6 +62,8 @@ using google::protobuf::Message;
class OpRegistrationDataImpl;

using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>;
using FusionParseParamFunc =
std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>;

class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData {
public:
@@ -71,15 +79,20 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData {

OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn);

OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn);

OpRegistrationData &ImplyType(const domi::ImplyType &imply_type);

OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue);

OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type);

domi::ImplyType GetImplyType() const;
std::string GetOmOptype() const;
std::set<std::string> GetOriginOpTypeSet() const;
domi::FrameworkType GetFrameworkType() const;
ParseParamFunc GetParseParamFn() const;
FusionParseParamFunc GetFusionParseParamFn() const;

private:
std::shared_ptr<OpRegistrationDataImpl> impl_;
@@ -103,5 +116,27 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver {
namespace ge {
using OpRegistrationData = domi::OpRegistrationData;
using OpReceiver = domi::OpReceiver;

class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp {
public:
HostCpuOp() = default;
virtual ~HostCpuOp() = default;

virtual graphStatus Compute(Operator &op, const std::map<std::string, const Tensor> &inputs,
std::map<std::string, Tensor> &outputs) = 0;
};

class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar {
public:
HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)());
};

#define REGISTER_HOST_CPU_OP_BUILDER(name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op)

#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op)

#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \
static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr __attribute__((unused)) = \
::ge::HostCpuOpRegistrar(name, []() -> ::ge::HostCpuOp * { return new (std::nothrow) op(); })
} // namespace ge
#endif // INC_EXTERNAL_REGISTER_REGISTER_H_

+ 2
- 1
inc/external/register/register_error_codes.h View File

@@ -22,7 +22,7 @@

#define DECLARE_ERRORNO(sysid, modid, name, value) \
const domi::Status name = \
((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value));
((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value));

#define DECLARE_ERRORNO_COMMON(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_COMMON, name, value)

@@ -33,6 +33,7 @@ using Status = uint32_t;
DECLARE_ERRORNO(0, 0, SUCCESS, 0);
DECLARE_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFFFFFF);
DECLARE_ERRORNO_COMMON(PARAM_INVALID, 1); // 50331649
DECLARE_ERRORNO(SYSID_FWK, 1, SCOPE_NOT_CHANGED, 201);
} // namespace domi

#endif // INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_

+ 4
- 0
inc/external/register/register_types.h View File

@@ -48,6 +48,10 @@ typedef enum tagDomiTensorFormat {
DOMI_TENSOR_BN_WEIGHT,
DOMI_TENSOR_CHWN, // Android NN Depth CONV
DOMI_TENSOR_FILTER_HWCK, // filter input tensor format
DOMI_TENSOR_NDHWC,
DOMI_TENSOR_NCDHW,
DOMI_TENSOR_DHWCN, // 3D filter input tensor format
DOMI_TENSOR_DHWNC,
DOMI_TENSOR_RESERVED
} domiTensorFormat_t;
} // namespace domi


+ 40
- 27
inc/framework/common/debug/ge_log.h View File

@@ -18,11 +18,13 @@
#define INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_

#include <cstdint>
#include <unistd.h>
#include <sys/syscall.h>

#include "framework/common/ge_inner_error_codes.h"
#include "toolchain/slog.h"

#define GE_MODULE_NAME GE
#define GE_MODULE_NAME static_cast<int>(GE)

// trace status of log
enum TraceStatus { TRACE_INIT = 0, TRACE_RUNNING, TRACE_WAITING, TRACE_STOP };
@@ -35,15 +37,20 @@ enum TraceStatus { TRACE_INIT = 0, TRACE_RUNNING, TRACE_WAITING, TRACE_STOP };
#define GELOGO(...) GE_LOG_OPLOG(GE_MODULE_NAME, __VA_ARGS__)
#define GELOGT(VALUE, ...) GE_LOG_TRACE(GE_MODULE_NAME, VALUE, __VA_ARGS__)

inline bool IsLogEnable(int module_name, int log_level) noexcept {
int32_t enable_event = 0;
int32_t dlog_level = dlog_getlevel(module_name, &enable_event);
if (dlog_level <= log_level) {
inline bool IsLogEnable(int module_name, int log_level) {
int32_t enable = CheckLogLevel(module_name, log_level);
// 1:enable, 0:disable
if (enable == 1) {
return true;
}
return false;
}

inline pid_t GetTid() {
thread_local static pid_t tid = syscall(__NR_gettid);
return tid;
}

#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap()

#define GE_TIMESTAMP_END(stage, stage_name) \
@@ -68,29 +75,35 @@ inline bool IsLogEnable(int module_name, int log_level) noexcept {
GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \
call_num_of##stage)

#define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \
dlog_error(static_cast<int>(MOD_NAME), "%s: ErrorNo: %d(%s) " fmt, __FUNCTION__, ERROR_CODE, \
#define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \
dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GetTid(), __FUNCTION__, ERROR_CODE, \
((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__)
#define GE_LOG_WARN(MOD_NAME, fmt, ...) \
if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_WARN)) \
dlog_warn(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__)
#define GE_LOG_INFO(MOD_NAME, fmt, ...) \
if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_INFO)) \
dlog_info(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__)
#define GE_LOG_DEBUG(MOD_NAME, fmt, ...) \
if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_DEBUG)) \
dlog_debug(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__)
#define GE_LOG_EVENT(MOD_NAME, fmt, ...) dlog_event(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__)
#define GE_LOG_WARN(MOD_NAME, fmt, ...) \
if (IsLogEnable(MOD_NAME, DLOG_WARN)) dlog_warn(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__)
#define GE_LOG_INFO(MOD_NAME, fmt, ...) \
if (IsLogEnable(MOD_NAME, DLOG_INFO)) dlog_info(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__)
#define GE_LOG_DEBUG(MOD_NAME, fmt, ...) \
if (IsLogEnable(MOD_NAME, DLOG_DEBUG)) dlog_debug(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__)
#define GE_LOG_EVENT(MOD_NAME, fmt, ...) dlog_event(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__)
#define GE_LOG_OPLOG(MOD_NAME, fmt, ...) \
Dlog(static_cast<int>(MOD_NAME), DLOG_OPLOG, "%s:" fmt, __FUNCTION__, ##__VA_ARGS__)
#define GE_LOG_TRACE(MOD_NAME, value, fmt, ...) \
do { \
TraceStatus stat = value; \
const char *const TraceStatStr[] = {"INIT", "RUNNING", "WAITING", "STOP"}; \
int idx = static_cast<int>(stat); \
char *k = const_cast<char *>("status"); \
char *v = const_cast<char *>(TraceStatStr[idx]); \
KeyValue kv = {k, v}; \
DlogWithKV(static_cast<int>(MOD_NAME), DLOG_TRACE, &kv, 1, "%s:" fmt, __FUNCTION__, ##__VA_ARGS__); \
Dlog(MOD_NAME, DLOG_OPLOG, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__)

#define GE_LOG_TRACE(MOD_NAME, value, fmt, ...) \
do { \
TraceStatus stat = value; \
const char *const TraceStatStr[] = {"INIT", "RUNNING", "WAITING", "STOP"}; \
int idx = static_cast<int>(stat); \
char *k = const_cast<char *>("status"); \
char *v = const_cast<char *>(TraceStatStr[idx]); \
KeyValue kv = {k, v}; \
DlogWithKV(static_cast<int>(MOD_NAME), DLOG_TRACE, &kv, 1, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__); \
} while (0)

// print memory when it is greater than 1KB.
#define GE_PRINT_DYNAMIC_MEMORY(FUNC, PURPOSE, SIZE) \
do { \
if ((SIZE) > 1024) { \
GELOGI("MallocMemory, func=%s, size=%zu, purpose=%s", (#FUNC), static_cast<size_t>(SIZE), (PURPOSE)); \
} \
} while (0);
#endif // INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_

+ 56
- 99
inc/framework/common/debug/log.h View File

@@ -29,7 +29,18 @@
using cce::CC_STATUS_SUCCESS;
using cce::ccStatus_t;

#define GE_LOGE(...) DAV_LOGE("GE", __VA_ARGS__)
#if !defined(__ANDROID__) && !defined(ANDROID)
#define DOMI_LOGE(...) DAV_LOGE("DOMI", __VA_ARGS__)
#else
#include <android/log.h>
#if defined(BUILD_VERSION_PERF)
#define DOMI_LOGE(fmt, ...)
#else
// The Android system has strict log control. Do not modify the log.
#define DOMI_LOGE(fmt, ...) \
__android_log_print(ANDROID_LOG_ERROR, "NPU_FMK", "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__)
#endif
#endif

// ge marco
#define GE_LOGI_IF(condition, ...) \
@@ -44,7 +55,7 @@ using cce::ccStatus_t;

#define GE_LOGE_IF(condition, ...) \
if ((condition)) { \
GE_LOGE(__VA_ARGS__); \
DOMI_LOGE(__VA_ARGS__); \
}

// If expr is not SUCCESS, print the log and return the same value
@@ -52,7 +63,7 @@ using cce::ccStatus_t;
do { \
const ge::Status _status = (expr); \
if (_status != ge::SUCCESS) { \
GE_LOGE(__VA_ARGS__); \
DOMI_LOGE(__VA_ARGS__); \
return _status; \
} \
} while (0);
@@ -62,7 +73,7 @@ using cce::ccStatus_t;
do { \
const ge::Status _status = (expr); \
if (_status != ge::SUCCESS) { \
GE_LOGE(__VA_ARGS__); \
DOMI_LOGE(__VA_ARGS__); \
} \
} while (0);

@@ -75,6 +86,15 @@ using cce::ccStatus_t;
} \
} while (0);

// If expr is not GRAPH_SUCCESS, print the log and return FAILED
#define GE_CHK_GRAPH_STATUS_RET(expr, ...) \
do { \
if ((expr) != ge::GRAPH_SUCCESS) { \
DOMI_LOGE(__VA_ARGS__); \
return FAILED; \
} \
} while (0);

// If expr is not SUCCESS, print the log and execute a custom statement
#define GE_CHK_STATUS_EXEC(expr, exec_expr, ...) \
do { \
@@ -91,25 +111,11 @@ using cce::ccStatus_t;
(void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \
(void)msg.append( \
ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \
GE_LOGE("%s", msg.c_str()); \
DOMI_LOGE("%s", msg.c_str()); \
return _status; \
} \
} while (0);

// If expr is not true, print the Info log and return the specified status
#define GE_CHK_BOOL_RET_STATUS_LOGI(expr, _status, ...) \
do { \
bool b = (expr); \
if (!b) { \
std::string msg; \
(void)msg.append(StringUtils::FormatString(__VA_ARGS__)); \
(void)msg.append( \
StringUtils::FormatString(" Check result false, status: 0x%X %s", _status, GET_ERRORNO_STR(_status).c_str())); \
GELOGI("%s", msg.c_str()); \
return _status; \
} \
} while (0);

// If expr is not true, print the log and return the specified status
#define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \
do { \
@@ -124,7 +130,7 @@ using cce::ccStatus_t;
{ \
bool b = (expr); \
if (!b) { \
GE_LOGE(__VA_ARGS__); \
DOMI_LOGE(__VA_ARGS__); \
exec_expr; \
} \
};
@@ -163,7 +169,7 @@ using cce::ccStatus_t;
{ \
bool b = (expr); \
if (b) { \
GE_LOGE(__VA_ARGS__); \
DOMI_LOGE(__VA_ARGS__); \
exec_expr; \
} \
};
@@ -182,7 +188,7 @@ using cce::ccStatus_t;
{ \
bool b = (expr); \
if (b) { \
GE_LOGE(__VA_ARGS__); \
DOMI_LOGE(__VA_ARGS__); \
exec_expr; \
return; \
} \
@@ -193,7 +199,7 @@ using cce::ccStatus_t;
{ \
bool b = (expr); \
if (b) { \
GE_LOGE(__VA_ARGS__); \
DOMI_LOGE(__VA_ARGS__); \
exec_expr; \
return _status; \
} \
@@ -210,62 +216,42 @@ using cce::ccStatus_t;

// -----------------runtime related macro definitions-------------------------------
// If expr is not RT_ERROR_NONE, print the log
#define GE_CHK_RT(expr) \
do { \
rtError_t _rt_ret = (expr); \
if (_rt_ret != RT_ERROR_NONE) { \
GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \
} \
#define GE_CHK_RT(expr) \
do { \
rtError_t _rt_ret = (expr); \
if (_rt_ret != RT_ERROR_NONE) { \
DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \
} \
} while (0);

// If expr is not RT_ERROR_NONE, print the log and execute the exec_expr expression
#define GE_CHK_RT_EXEC(expr, exec_expr) \
{ \
rtError_t _rt_ret = (expr); \
if (_rt_ret != RT_ERROR_NONE) { \
GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \
exec_expr; \
} \
#define GE_CHK_RT_EXEC(expr, exec_expr) \
{ \
rtError_t _rt_ret = (expr); \
if (_rt_ret != RT_ERROR_NONE) { \
DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \
exec_expr; \
} \
}

// If expr is not RT_ERROR_NONE, print the log and return
#define GE_CHK_RT_RET(expr) \
do { \
rtError_t _rt_ret = (expr); \
if (_rt_ret != RT_ERROR_NONE) { \
GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \
return ge::RT_FAILED; \
} \
#define GE_CHK_RT_RET(expr) \
do { \
rtError_t _rt_ret = (expr); \
if (_rt_ret != RT_ERROR_NONE) { \
DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \
return ge::RT_FAILED; \
} \
} while (0);

// ------------------------cce related macro definitions----------------------------
// If expr is not CC_STATUS_SUCCESS, print the log
#define GE_CHK_CCE(expr) \
do { \
ccStatus_t _cc_ret = (expr); \
if (_cc_ret != CC_STATUS_SUCCESS) { \
GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \
} \
} while (0);

// If expr is not CC_STATUS_SUCCESS, print the log and execute the exec_expr expression
#define GE_CHK_CCE_EXEC(expr, exec_expr) \
do { \
ccStatus_t _cc_ret = (expr); \
if (_cc_ret != CC_STATUS_SUCCESS) { \
GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \
exec_expr; \
} \
} while (0);

// If expr is not CC_STATUS_SUCCESS, print the log and return
#define GE_CHK_CCE_RET(expr) \
do { \
ccStatus_t _cc_ret = (expr); \
if (_cc_ret != CC_STATUS_SUCCESS) { \
GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \
return ge::CCE_FAILED; \
} \
#define GE_CHK_CCE(expr) \
do { \
ccStatus_t _cc_ret = (expr); \
if (_cc_ret != CC_STATUS_SUCCESS) { \
DOMI_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \
} \
} while (0);

// If expr is true, execute exec_expr without printing logs
@@ -281,37 +267,8 @@ using cce::ccStatus_t;
try { \
exec_expr0; \
} catch (const std::bad_alloc &) { \
GE_LOGE("Make shared failed"); \
DOMI_LOGE("Make shared failed"); \
exec_expr1; \
}

#define GE_CHECK_INT32_MUL_OVERFLOW(a, b, ...) \
do { \
if ((a) > 0) { \
if ((b) > 0) { \
if ((a) > (INT32_MAX / (b))) { \
GE_LOGE(__VA_ARGS__); \
return ge::FAILED; \
} \
} else { \
if ((b) < (INT32_MIN / (a))) { \
GE_LOGE(__VA_ARGS__); \
return ge::FAILED; \
} \
} \
} else { \
if ((b) > 0) { \
if ((a) < (INT32_MAX / (b))) { \
GE_LOGE(__VA_ARGS__); \
return ge::FAILED; \
} \
} else { \
if (((a) != 0) && ((b) < (INT32_MAX / (a)))) { \
GE_LOGE(__VA_ARGS__); \
return ge::FAILED; \
} \
} \
} \
} while (0);

#endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_

+ 4
- 4
inc/framework/common/ge_inner_error_codes.h View File

@@ -152,7 +152,6 @@ GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_INVALID, 11, "Get computeGraph by g
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_TINY_CAL_CHECK_FAILED, 15, "Check tiny calibration failed."); // 1343242255
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256
GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257
GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258
@@ -204,15 +203,16 @@ GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_GET_GRAPH_REBUILD_FAILED, 60,
GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_SET_GRAPH_FINISH_REBUILD_GRAPH_FAILED, 61,
"Failed set graph finish rebuild in node searcher."); // 1343242301
GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_OP_PASS_FAILED, 62, "Failed to run variable pass."); // 1343242302
// Optimize errocode
GE_ERRORNO_GRAPH(TO_BE_DELETED, 200, "The node of the graph to be deleted.");
GE_ERRORNO_GRAPH(NOT_CHANGED, 201, "NThe node of the graph not changed.");

// Engine_manager module error code definition
GE_ERRORNO_ENGINE(GE_ENG_INIT_FAILED, 0, "Failed to initialize engine."); // 1343246336
GE_ERRORNO_ENGINE(GE_ENG_FINALIZE_FAILED, 1, "Engine finalize failed."); // 1343246337
GE_ERRORNO_ENGINE(GE_ENG_MEMTYPE_ERROR, 2, "Memory type HBM is necessary when engine is in device"); // 1343246338

// Optimize errocode
GE_ERRORNO_GRAPH(TO_BE_DELETED, 63, "The node of the graph to be deleted."); // 1343242303
GE_ERRORNO_GRAPH(NOT_CHANGED, 64, "The node of the graph no changed."); // 1343242304

// Ops module error code definition
GE_ERRORNO_OPS(GE_OPS_KERNEL_STORE_INIT_FAILED, 0, "Failed to initialize OpsKernelInfoStore."); // 1343250432
GE_ERRORNO_OPS(GE_OPS_GRAPH_OPTIMIZER_INIT_FAILED, 1, "Failed to initialize GraphOptimizer."); // 1343250433


+ 27
- 7
inc/framework/common/ge_types.h View File

@@ -24,8 +24,7 @@

#include "common/fmk_error_codes.h"
#include "ge/ge_api_error_codes.h"

using std::string;
#include "external/graph/types.h"

namespace ge {
enum RuntimeType { HOST = 0, DEVICE = 1 };
@@ -56,7 +55,7 @@ struct DataBuffer {

///
/// @ingroup domi_ome
/// @brief External inputdata
/// @brief External input data
///
struct InputData {
uint32_t index; // Index of input data
@@ -65,13 +64,14 @@ struct InputData {
uint32_t model_id; // Model ID required for data processing
uint64_t request_id = 0; // Request ID
std::vector<DataBuffer> blobs; // Actual input data, currently only supports one input
bool is_dynamic_batch = false; // Whether is dynamic batch size scene, default:false
std::string batch_label; // Gear used for current inference in dynamic batch scene
};

// The definition of output result structure
/// Output result structure definition
struct OutputData {
uint32_t index; // Index of input data
uint32_t model_id; // The model ID corresponding to the processing result

/// Output data cache, arranged in sequence of output operators.
/// If the operator has multiple outputs,
/// the data buffer order of the operator is the same as that defined in the
@@ -142,11 +142,31 @@ struct Options {
bool deployMode;
bool isAICPUMode;
bool enable_atomic;
string podName;
std::string podName;
int64_t rankId;
string rankTableFile;
std::string rankTableFile;
int32_t ge_hccl_flag = 0;
int32_t physical_device_id;
};

// Profiling info of task
struct TaskDescInfo {
std::string op_name;
uint32_t block_dim;
uint32_t task_id;
uint32_t stream_id;
};

// Profiling info of graph
struct ComputeGraphDescInfo {
std::string op_name;
std::string op_type;
std::vector<Format> input_format;
std::vector<std::vector<int64_t>> input_shape;
std::vector<DataType> input_data_type;
std::vector<Format> output_format;
std::vector<std::vector<int64_t>> output_shape;
std::vector<DataType> output_data_type;
};
} // namespace ge
#endif // INC_FRAMEWORK_COMMON_GE_TYPES_H_

+ 23
- 19
inc/framework/common/helper/model_helper.h View File

@@ -19,7 +19,6 @@

#include <memory>
#include <string>
#include <memory>

#include "common/fmk_types.h"
#include "common/helper/om_file_helper.h"
@@ -33,36 +32,41 @@ class ModelHelper {
ModelHelper() = default;
~ModelHelper();

Status SaveToOmModel(const GeModelPtr &ge_model, const SaveParam &save_param, const std::string &output_file);
Status SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::string &output_file);
Status LoadModel(const ge::ModelData &model_data);
Status SaveToOmModel(const GeModelPtr& ge_model, const SaveParam& save_param, const std::string& output_file,
ge::ModelBufferData& model);
Status SaveOriginalGraphToOmModel(const ge::Graph& graph, const std::string& output_file);
Status LoadModel(const ge::ModelData& model_data);
Status GetModelBufferData(ge::ModelBufferData& model);

ModelFileHeader *GetFileHeader() { return file_header_; }
ModelFileHeader* GetFileHeader() { return file_header_; }

GeModelPtr GetGeModel();
void SetSaveMode(bool val) { is_offline_ = val; }
bool GetSaveMode(void) const { return is_offline_; }

static Status TransModelToGeModel(const ModelPtr &model, GeModelPtr &ge_model);
static Status TransGeModelToModel(const GeModelPtr &geModelPtr, ModelPtr &modelPtr);
static Status TransModelToGeModel(const ModelPtr& model, GeModelPtr& ge_model);
static Status TransGeModelToModel(const GeModelPtr& geModelPtr, ModelPtr& modelPtr);

private:
bool is_assign_model_ = false;
ModelFileHeader *file_header_ = nullptr;
bool is_offline_ = true;
ModelFileHeader* file_header_ = nullptr;
// Encrypted model need delete temp model and unencrypted model need not delete model
uint8_t *model_addr_tmp_ = nullptr;
uint8_t* model_addr_tmp_ = nullptr;
uint32_t model_len_tmp_ = 0;
GeModelPtr model_;

ModelHelper(const ModelHelper &);
ModelHelper &operator=(const ModelHelper &);
Status GenerateGeModel(OmFileLoadHelper &om_load_helper);
Status LoadModelData(OmFileLoadHelper &om_load_helper);
void SetModelToGeModel(ge::Model &model);
Status LoadWeights(OmFileLoadHelper &om_load_helper);
Status LoadTask(OmFileLoadHelper &om_load_helper);
Status LoadTBEKernelStore(OmFileLoadHelper &om_load_helper);
ModelHelper(const ModelHelper&);
ModelHelper& operator=(const ModelHelper&);
Status GenerateGeModel(OmFileLoadHelper& om_load_helper);
Status LoadModelData(OmFileLoadHelper& om_load_helper);
void SetModelToGeModel(ge::Model& model);
Status LoadWeights(OmFileLoadHelper& om_load_helper);
Status LoadTask(OmFileLoadHelper& om_load_helper);
Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper);
Status ReleaseLocalModelData() noexcept;
Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, ModelPartitionType type,
const uint8_t *data, size_t size);
Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, ModelPartitionType type,
const uint8_t* data, size_t size);
};
} // namespace ge
#endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_

+ 6
- 3
inc/framework/common/helper/om_file_helper.h View File

@@ -20,10 +20,12 @@
#include <string>
#include <vector>

#include "external/ge/ge_ir_build.h"
#include "framework/common/fmk_types.h"
#include "framework/common/ge_types.h"
#include "framework/common/types.h"
#include "framework/common/ge_types.h"

using ProcParam = struct PROC_PARAM;
using std::string;
using std::vector;

@@ -80,9 +82,10 @@ class OmFileSaveHelper {

const std::vector<ModelPartition> &GetModelPartitions() const;

Status SaveModel(const SaveParam &save_param, const char *target_file);
Status SaveModel(const SaveParam &save_param, const char *target_file, ge::ModelBufferData &model,
bool is_offline = true);

Status SaveModelToFile(const char *output_file);
Status SaveModelToFile(const char *output_file, ge::ModelBufferData &model, bool is_offline = true);

ModelFileHeader model_header_;
OmFileContext context_;


+ 1
- 1
inc/framework/common/l2_cache_optimize.h View File

@@ -120,4 +120,4 @@ class L2CacheOptimize {
};
} // namespace ge

#endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_
#endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_

+ 4
- 0
inc/framework/common/op/attr_define.h View File

@@ -649,6 +649,8 @@ extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_M

extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_EVENT_NUM;

extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_LABEL_NUM;

extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_MEMORY_SIZE;

extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_WEIGHT_SIZE;
@@ -801,6 +803,8 @@ extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_N

extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_FORMAT;
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_DATATYPE;
// For constant folding
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NO_NEED_CONSTANT_FOLDING;
} // namespace domi

#endif // INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_

+ 3
- 2
inc/framework/common/op/attr_value_util.h View File

@@ -17,11 +17,12 @@
#ifndef INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_
#define INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_

#include <google/protobuf/map.h>
#include <unordered_map>
#include <string>
#include <google/protobuf/map.h>
#include "graph/debug/ge_attr_define.h"

#include "common/types.h"
#include "graph/debug/ge_attr_define.h"
#include "proto/om.pb.h"

using domi::AttrDef;


+ 21
- 6
inc/framework/common/op/ge_op_utils.h View File

@@ -18,7 +18,6 @@
#define INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_

#include <cce/dnn.h>

#include <memory>
#include <vector>

@@ -56,6 +55,15 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_TR
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_DATA_INPUT;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_PRED_INPUT;

// FunctionOp
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t IF_COND_INPUT;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_START_INPUT;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_LIMIT_INPUT;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DELTA_INPUT;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DATA_INPUT;

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int NORMAL_TENSOR_SIZE;

class OpUtils {
public:
///
@@ -164,15 +172,23 @@ class OpUtils {
///
static Status ConvertAippParams(const GeAttrValue::NamedAttrs &aipp_attr, domi::AippOpParams *aipp_params);
static Status TransferDim(const std::vector<int64_t> &dim, std::vector<int64_t> &dim_vector);
static void SliceData(std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output, int64_t begin,
int64_t out_dim, int64_t stride);
template <typename T>
static void SliceData(const std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output,
int64_t begin, int64_t out_dim, int64_t stride);
template <typename T>
static Status SetDataByDataType(size_t out_size, const std::vector<char *> &chunk_input,
const std::vector<char *> &chunk_output, GeTensor *output);
template <typename T>
static Status SetOutputSliceDataByDataType(void *data, int64_t data_size, const std::vector<int64_t> &input_dims,
const std::vector<int64_t> &begin, const std::vector<int64_t> &output_dims,
ge::GeTensor *output, const std::vector<int64_t> &stride);
static Status SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector<int64_t> &input_dims,
std::vector<int64_t> &begin, std::vector<int64_t> &output_dims, ge::GeTensor *output,
std::vector<int64_t> &stride);

///
/// @ingroup domi_omg
/// @brief Convert the convolution weight data from [h, w, c, k] to [k, c, h, w]
/// @brief Convert the convolutional weight data from [h, w, c, k] to [k, c, h, w]
/// @param [in] input Weight data in HWCK format
/// @param [in] H value of H dimension
/// @param [in] W value of W dimension
@@ -183,7 +199,7 @@ class OpUtils {
static void TransDataHWCK2KCHW(const void *input, int64_t H, int64_t W, int64_t C, int64_t K, void **output);
///
/// @ingroup domi_omg
/// @brief Converts the convolution weight data from [k, c, h, w] to [h, w, c, k].
/// @brief Converts the convolutional weight data from [k, c, h, w] to [h, w, c, k].
/// @param [in] input Weight data in HWCK format
/// @param [in] K value of K dimension
/// @param [in] C value of C dimension
@@ -222,7 +238,6 @@ using CceTensorDescriptorPtr = std::shared_ptr<CceTensorDescriptor>;
class CceTensorDescriptor {
public:
explicit CceTensorDescriptor(ccTensorDescriptor_t cc_tensor);

CceTensorDescriptor(const CceTensorDescriptor &) = delete;
CceTensorDescriptor &operator=(const CceTensorDescriptor &) = delete;



+ 2
- 2
inc/framework/common/op/op_parser_util.h View File

@@ -22,7 +22,7 @@
#include <math.h>
#include <stdint.h>

namespace domi {
namespace ge {
// general
const float DEFAULT_ALPHA_VALUE = 1.0;
const float DEFAULT_BETA_VALUE = 0.0;
@@ -421,5 +421,5 @@ const uint32_t MULTI_SHAPE_INPUT_NUM = 2;

// Shufflechannel
const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1;
} // namespace domi
} // namespace ge
#endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_

+ 1
- 1
inc/framework/common/scope_guard.h View File

@@ -25,7 +25,7 @@
/// MAKE_GUARD([&] { Release Resource 1 })
/// Acquire Resource 2
// MAKE_GUARD([&] { Release Resource 2 })
#define GE_MAKE_GUARD(var, callback) ge::ScopeGuard make_guard_##var(callback)
#define GE_MAKE_GUARD(var, callback) ScopeGuard make_guard_##var(callback)
#define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss()

namespace ge {


+ 20
- 3
inc/framework/common/types.h View File

@@ -156,6 +156,7 @@ REGISTER_OPTYPE_DECLARE(GATHER, "Gather");
REGISTER_OPTYPE_DECLARE(REALDIV, "RealDiv");
REGISTER_OPTYPE_DECLARE(PACK, "Pack");
REGISTER_OPTYPE_DECLARE(SLICE, "Slice");
REGISTER_OPTYPE_DECLARE(SLICED, "SliceD");
REGISTER_OPTYPE_DECLARE(FLOORDIV, "FloorDiv");
REGISTER_OPTYPE_DECLARE(SQUEEZE, "Squeeze");
REGISTER_OPTYPE_DECLARE(STRIDEDSLICE, "StridedSlice");
@@ -188,6 +189,19 @@ REGISTER_OPTYPE_DECLARE(REFNEXTITERATION, "RefNextIteration");
REGISTER_OPTYPE_DECLARE(EXIT, "Exit");
REGISTER_OPTYPE_DECLARE(REFEXIT, "RefExit");
REGISTER_OPTYPE_DECLARE(CONTROLTRIGGER, "ControlTrigger");
REGISTER_OPTYPE_DECLARE(SYMBOLICGRADIENT, "SymbolicGradient");
REGISTER_OPTYPE_DECLARE(REMOTECALL, "RemoteCall");
REGISTER_OPTYPE_DECLARE(_IF, "_If");
REGISTER_OPTYPE_DECLARE(STATELESSIF, "StatelessIf");
REGISTER_OPTYPE_DECLARE(IF, "If");
REGISTER_OPTYPE_DECLARE(CASE, "Case");
REGISTER_OPTYPE_DECLARE(_WHILE, "_While");
REGISTER_OPTYPE_DECLARE(WHILE, "While");
REGISTER_OPTYPE_DECLARE(STATELESSWHILE, "StatelessWhile");
REGISTER_OPTYPE_DECLARE(FOR, "For");
REGISTER_OPTYPE_DECLARE(PARTITIONEDCALL, "PartitionedCall");
REGISTER_OPTYPE_DECLARE(STATEFULPARTITIONEDCALL, "StatefulPartitionedCall");
REGISTER_OPTYPE_DECLARE(FAKEPARAM, "FakeParam");
REGISTER_OPTYPE_DECLARE(TRANSPOSE, "Transpose");
REGISTER_OPTYPE_DECLARE(TRANSPOSED, "TransposeD");
REGISTER_OPTYPE_DECLARE(CAST, "Cast");
@@ -424,6 +438,12 @@ REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge");
REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph");
REGISTER_OPTYPE_DECLARE(SEND, "Send");
REGISTER_OPTYPE_DECLARE(RECV, "Recv");

REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet");
REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto");
REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch");
REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex");

REGISTER_OPTYPE_DECLARE(ATOMICADDRCLEAN, "AtomicAddrClean");

REGISTER_OPTYPE_DECLARE(ABS_GRAD, "AbsGrad");
@@ -1032,14 +1052,11 @@ struct BasicInfo {
uint32_t workspace_size; // workspace
uint32_t total_size; // total memory size
};

#pragma pack() // Cancels single-byte alignment
} // namespace ge

namespace domi {

/// @brief Data structure definition related to task sinking
/// Build model
enum BuildMode {
GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled)
GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled)


+ 68
- 96
inc/framework/common/util.h View File

@@ -30,6 +30,14 @@
#include "framework/common/ge_inner_error_codes.h"
#include "mmpa/mmpa_api.h"

#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \
do { \
if (size <= 0) { \
DOMI_LOGE(param[#size] is not a positive number); \
return PARAM_INVALID; \
} \
} while (0)

#define CHECK_FALSE_EXEC(expr, exec_expr, ...) \
{ \
bool b = (expr); \
@@ -50,21 +58,6 @@
if (var) GE_CHK_RT(rtStreamDestroy(var)); \
});

#define GE_MAKE_GUARD_RTEVENT(var) \
GE_MAKE_GUARD(var, [&] { \
if (var) GE_CHK_RT(rtEventDestroy(var)); \
});

#define GE_MAKE_GUARD_TENSOR(var) \
GE_MAKE_GUARD(var, [&] { \
if (var) GE_CHK_CCE(ccDestroyTensorDescriptor(&var)); \
});

#define GE_MAKE_GUARD_FILTER_DESC(var) \
GE_MAKE_GUARD(var, [&] { \
if (var) GE_CHK_CCE(ccDestroyFilterDescriptor(&var)); \
});

// For propagating errors when calling a function.
#define GE_RETURN_IF_ERROR(expr) \
do { \
@@ -76,7 +69,7 @@
do { \
const ::ge::Status _status = (expr); \
if (_status) { \
GE_LOGE(__VA_ARGS__); \
DOMI_LOGE(__VA_ARGS__); \
return _status; \
} \
} while (0)
@@ -85,7 +78,7 @@
#define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \
do { \
if (condition) { \
GE_LOGE(__VA_ARGS__); \
DOMI_LOGE(__VA_ARGS__); \
return ge::FAILED; \
} \
} while (0)
@@ -95,7 +88,7 @@
do { \
bool _condition = (condition); \
if (!_condition) { \
GE_LOGE(__VA_ARGS__); \
DOMI_LOGE(__VA_ARGS__); \
return ge::FAILED; \
} \
} while (0)
@@ -104,7 +97,7 @@
#define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \
do { \
if (condition) { \
GE_LOGE(__VA_ARGS__); \
DOMI_LOGE(__VA_ARGS__); \
return ge::PARAM_INVALID; \
} \
} while (0)
@@ -114,111 +107,90 @@
do { \
bool _condition = (condition); \
if (!_condition) { \
GE_LOGE(__VA_ARGS__); \
DOMI_LOGE(__VA_ARGS__); \
return ge::PARAM_INVALID; \
} \
} while (0)

// Check if the parameter is null. If yes, return PARAM_INVALID and record the error
#define GE_CHECK_NOTNULL(val) \
do { \
if (val == nullptr) { \
GE_LOGE(param[#val] must not be null.); \
return ge::PARAM_INVALID; \
} \
#define GE_CHECK_NOTNULL(val) \
do { \
if (val == nullptr) { \
DOMI_LOGE(param[#val] must not be null.); \
return ge::PARAM_INVALID; \
} \
} while (0)

// Check if the parameter is null. If yes, return PARAM_INVALID and record the error
#define GE_CHECK_NOTNULL_JUST_RETURN(val) \
do { \
if (val == nullptr) { \
GE_LOGE(param[#val] must not be null.); \
return; \
} \
// Check if the parameter is null. If yes, just return and record the error
#define GE_CHECK_NOTNULL_JUST_RETURN(val) \
do { \
if (val == nullptr) { \
DOMI_LOGE(param[#val] must not be null.); \
return; \
} \
} while (0)

// Check whether the parameter is null. If so, execute the exec_expr expression and record the error log
#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \
do { \
if (val == nullptr) { \
GE_LOGE(param[#val] must not be null.); \
exec_expr; \
} \
#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \
do { \
if (val == nullptr) { \
DOMI_LOGE(param[#val] must not be null.); \
exec_expr; \
} \
} while (0)

// Check whether the parameter is null. If yes, return directly and record the error log
#define GE_RT_VOID_CHECK_NOTNULL(val) \
do { \
if (val == nullptr) { \
GE_LOGE(param[#val] must not be null.); \
return; \
} \
#define GE_RT_VOID_CHECK_NOTNULL(val) \
do { \
if (val == nullptr) { \
DOMI_LOGE(param[#val] must not be null.); \
return; \
} \
} while (0)

// Check if the parameter is null. If yes, return false and record the error log
#define GE_RT_FALSE_CHECK_NOTNULL(val) \
do { \
if (val == nullptr) { \
GE_LOGE(param[#val] must not be null.); \
return false; \
} \
#define GE_RT_FALSE_CHECK_NOTNULL(val) \
do { \
if (val == nullptr) { \
DOMI_LOGE(param[#val] must not be null.); \
return false; \
} \
} while (0)

// Check if the parameter is out of bounds
#define GE_CHECK_SIZE(size) \
do { \
if (size == 0) { \
GE_LOGE(param[#size] is out of range); \
return ge::PARAM_INVALID; \
} \
#define GE_CHECK_SIZE(size) \
do { \
if (size == 0) { \
DOMI_LOGE(param[#size] is out of range); \
return ge::PARAM_INVALID; \
} \
} while (0)

// Macros that define the size variable
#define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \
uint32_t _var_name; \
do { \
uint32_t _expr_size = (_expr); \
uint32_t _sizeof_size = (_sizeof); \
if (_expr_size > (0xffffffff) / _sizeof_size) { \
GE_LOGE(byte size : #_var_name is out of range); \
return ge::PARAM_INVALID; \
} \
_var_name = _sizeof_size * _expr_size; \
} while (0);

// Check if the container is empty
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \
do { \
if (vector.empty()) { \
GE_LOGE(param[#vector] is empty !); \
return ge::FAILED; \
} \
} while (0)

#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \
do { \
if (size <= 0) { \
GE_LOGE(param[#size] is not a positive number); \
return ge::PARAM_INVALID; \
} \
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \
do { \
if (vector.empty()) { \
DOMI_LOGE(param[#vector] is empty !); \
return ge::FAILED; \
} \
} while (0)

// Check if the value on the left is greater than or equal to the value on the right
#define GE_CHECK_GE(lhs, rhs) \
do { \
if (lhs < rhs) { \
GE_LOGE(param[#lhs] is less than[#rhs]); \
return ge::PARAM_INVALID; \
} \
#define GE_CHECK_GE(lhs, rhs) \
do { \
if (lhs < rhs) { \
DOMI_LOGE(param[#lhs] is less than[#rhs]); \
return ge::PARAM_INVALID; \
} \
} while (0)

// Check if the value on the left is less than or equal to the value on the right
#define GE_CHECK_LE(lhs, rhs) \
do { \
if (lhs > rhs) { \
GE_LOGE(param[#lhs] is greater than[#rhs]); \
return ge::PARAM_INVALID; \
} \
#define GE_CHECK_LE(lhs, rhs) \
do { \
if (lhs > rhs) { \
DOMI_LOGE(param[#lhs] is greater than[#rhs]); \
return ge::PARAM_INVALID; \
} \
} while (0)

#define GE_DELETE_NEW_SINGLE(var) \


+ 2
- 2
inc/framework/dlog/log.h View File

@@ -52,10 +52,10 @@
#define DLOG_DECLARE(level) \
void Log_##level(const char *mod_name, const char *func, const char *file, int line, const char *format, ...)

namespace ge {
namespace domi {
DLOG_DECLARE(INFO);
DLOG_DECLARE(WARNING);
DLOG_DECLARE(ERROR);
} // namespace ge
} // namespace domi

#endif // INC_FRAMEWORK_DLOG_LOG_H_

+ 1
- 1
inc/framework/engine/dnnengine.h View File

@@ -38,7 +38,7 @@ struct DNNEngineAttribute {
std::vector<std::string> mem_type;
uint32_t compute_cost;
enum RuntimeType runtime_type; // HOST, DEVICE
// set this attribute if the inputformat of engine must be specific, otherwise set FORMAT_RESERVED
// If engine input format must be specific, set this attribute, else set FORMAT_RESERVED
Format engine_input_format;
Format engine_output_format;
};


+ 60
- 10
inc/framework/executor/ge_executor.h View File

@@ -26,6 +26,7 @@
#include "common/types.h"
#include "graph/tensor.h"
#include "runtime/base.h"
#include "common/dynamic_aipp.h"

namespace ge {
class ModelListenerAdapter;
@@ -33,12 +34,15 @@ class ModelListenerAdapter;
class SingleOp;

struct RunModelData {
uint32_t index; // Data index
uint32_t model_id; // Model id
std::vector<DataBuffer> blobs; // All input/output data buffer
uint32_t timestamp; // Data creation time
uint32_t timeout; // Processing timeout
uint64_t request_id = 0; // Request ID
uint32_t index; // Data index
uint32_t modelId;
std::vector<DataBuffer> blobs; // All input/output data buffer
uint32_t timestamp; // Data creation time
uint32_t timeout; // Processing timeout
uint64_t request_id = 0; // Request ID
uint64_t dynamic_batch_size = 0; // Dynamic batch size scene, set dynamic size, not supported by default:0
uint64_t dynamic_image_height = 0; // Dynamic image size scene, set image height, not supported by default:0
uint64_t dynamic_image_width = 0; // Dynamic image size scene, set image width, not supported by default:0
};

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor {
@@ -46,12 +50,13 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor {
GeExecutor();
~GeExecutor() = default;
ge::Status Initialize();
ge::Status Finalize();

// Load model
ge::Status LoadModelOffline(uint32_t &model_id, const std::string &path, const std::string &key, int32_t priority,
std::shared_ptr<ge::ModelListener> listener);

ge::Status UnloadModel(uint32_t model_id);
ge::Status UnloadModel(uint32_t modelId);

ge::Status RunModel(const ge::RunModelData &input_data, ge::RunModelData &output_data);

@@ -59,6 +64,52 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor {
ge::Status GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc,
std::vector<ge::TensorDesc> &output_desc);

///
/// @ingroup ge
/// @brief Set dynamic batch size
/// @param [in] model_id: model id allocate from manager
/// @param [in] dynamic_input_addr: dynamic input addr created by user
/// @param [in] length: length of dynamic input addr
/// @param [in] batch_size: batch size entered by user in dynamic multi-batch scenario
/// @return execute result
///
ge::Status SetDynamicBatchSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t batch_size);

///
/// @ingroup ge
/// @brief Set dynamic image info
/// @param [in] model_id: model id allocate from manager
/// @param [in] dynamic_input_addr: dynamic input addr created by user
/// @param [in] length: length of dynamic input addr
/// @param [in] image_height: image height entered by user in dynamic multi-resolution scenario
/// @param [in] image_width: image width entered by user in dynamic multi-resolution scenario
/// @return execute result
///
ge::Status SetDynamicImageSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t image_height,
uint64_t image_width);
///
/// @ingroup ge
/// @brief Get dynamic batch_info
/// @param [in] model_id
/// @param [out] batch_info
/// @return execute result
///
ge::Status GetDynamicBatchInfo(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info);

///
/// @ingroup ge
/// @brief Set dynamic image info
/// @param [in] model_id: model id allocate from manager
/// @param [in] dynamic_input_addr: dynamic input addr created by user
/// @param [in] length: length of dynamic input addr
/// @param [in] aippBatchPara: kAippDynamicBatchPara vector by user in dynamic aipp
/// @param [in] aippParms: kAippDynamicPara by user in dynamic aipp
/// @return execute result
///
ge::Status SetDynamicAippData(uint32_t model_id, void *dynamic_input_addr, uint64_t length,
const std::vector<kAippDynamicBatchPara> &aippBatchPara,
const kAippDynamicPara &aippParms);

ge::Status GetModelDescInfoForZeroCopy(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc,
std::vector<ge::TensorDesc> &output_desc);

@@ -147,7 +198,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor {
///
ge::Status GetMemAndWeightSize(const void *model_data, size_t model_size, size_t &mem_size, size_t &weight_size);

static ge::Status LoadSingleOp(const std::string &model_name, const ge::ModelData &model_data, void *stream,
static ge::Status LoadSingleOp(const std::string &modelName, const ge::ModelData &modelData, void *stream,
SingleOp **single_op);

static ge::Status ExecuteAsync(SingleOp *executor, const std::vector<DataBuffer> &inputs,
@@ -156,8 +207,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor {
static ge::Status ReleaseSingleOpResource(void *stream);

private:
static bool is_init_;
std::vector<std::shared_ptr<ModelListenerAdapter>> listener_adapters_;
static bool isInit_;
};

ge::Status ModelInfoParser(const ge::ModelData &model, ge::ModelInfo &model_info);


+ 5
- 1
inc/framework/generator/ge_generator.h View File

@@ -21,7 +21,7 @@
#include <memory>
#include <string>
#include <vector>
#include "ge/ge_ir_build.h"
#include "common/ge_inner_error_codes.h"
#include "graph/ge_tensor.h"
#include "graph/graph.h"
@@ -45,6 +45,8 @@ class GeGenerator {
Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix,
const std::vector<GeTensor> &inputs = std::vector<GeTensor>());

Status GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ge::ModelBufferData &model);

///
/// @ingroup ge
/// @brief: Build single OP in Model.
@@ -58,6 +60,8 @@ class GeGenerator {
const std::vector<GeTensor> &outputs, const std::string &model_file_name);

private:
Status GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs,
ge::ModelBufferData &model, bool is_offline = true);
class Impl;

std::shared_ptr<Impl> impl_;


+ 0
- 1
inc/framework/generator/generator_api.h View File

@@ -24,7 +24,6 @@ extern "C" {
#endif

typedef uint32_t Status_t;
using Status_t = uint32_t;

typedef void *OpAttr_t;
typedef void *OpTensor_t;


+ 2
- 2
inc/framework/memory/memory_assigner.h View File

@@ -23,7 +23,7 @@
#include "graph/node.h"

namespace ge {
const int64_t kMemAlignSize = 512;
const int64_t MEM_ALIGN_SIZE = 512;
class MemoryAssigner {
public:
explicit MemoryAssigner(ge::ComputeGraphPtr compute_graph) : compute_graph_(std::move(compute_graph)) {}
@@ -39,4 +39,4 @@ class MemoryAssigner {
ge::ComputeGraphPtr compute_graph_;
};
} // namespace ge
#endif // INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_
#endif // INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_

+ 17
- 16
inc/framework/omg/omg_inner_types.h View File

@@ -31,7 +31,6 @@
using domi::DOMI_TENSOR_ND;
using domi::DOMI_TENSOR_RESERVED;
using domi::domiTensorFormat_t;
using domi::FMK_TYPE_RESERVED;
using domi::FrameworkType;
using std::map;
using std::string;
@@ -44,10 +43,10 @@ namespace ge {
* @brief run model
*/
enum RunMode {
kGeOmModel = 0, // generate offline model file
kModelToJson = 1, // convert to JSON file
kOnlyPreCheck = 3, // only for pre-check
kPbtxtToJson = 5 // pbtxt to json
GEN_OM_MODEL = 0, // generate offline model file
MODEL_TO_JSON = 1, // convert to JSON file
ONLY_PRE_CHECK = 3, // only for pre-check
PBTXT_TO_JSON = 5 // pbtxt to json
};

///
@@ -56,10 +55,10 @@ enum RunMode {
///
enum HighPrecisionMode {
// the FP16 high-precision function is disabled in common mode
kHighPrecisonDefault = 0,
HIGH_PRECISION_DEFAULT = 0,

// high-precision mode, in which FP16 high-precision mode (Convolution/FullConnect/AvgPooling are involved) is enable
kHighPrecisionFP16 = 1
// high-precision mode, enabling FP16 high-precision mode (Convolution/FullConnect/AvgPooling are involved)
HIGH_PRECISION_FP16 = 1
};

///
@@ -99,21 +98,23 @@ struct OmgContext {
// preferential format used by the entire network
domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED;
domi::FrameworkType type = domi::FMK_TYPE_RESERVED;
RunMode run_mode = kOnlyPreCheck;
RunMode run_mode = ONLY_PRE_CHECK;
bool train_flag = false;
// whether to use FP16 high precision
int32_t fp16_high_precision = kHighPrecisonDefault;
int32_t fp16_high_precision = HIGH_PRECISION_DEFAULT;

std::string output_type;

// Save the name of the entire network: Some special operators are used to determine a network. Some operators in the
// network require special processing based on the specific network.
// e.g:faster-rcnn, the FirstStageProcessor module is determined as the Faster-R-CNN network based on the scope
// fusion. Then, the conv+reshape operators in the FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The
// convolution kernel rearrangement reshape operator needs to be deleted for the convolution kernel.
// network require special processing based on the specific network. e.g:faster-rcnn, the FirstStageProcessor module
// is determined as the Faster-R-CNN network based on the scope fusion. Then, the conv+reshape operators in the
// FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The convolution kernel rearrangement reshape
// operator needs to be deleted for the convolution kernel.
std::string net_name;
// whether to enable dynamic batch
bool enable_l2dynamic = false;
// Whether to use dynamic batch size or dynamic image size
bool is_dynamic_input = false;
std::string dynamic_batch_size;
std::string dynamic_image_size;
};
} // namespace ge



+ 0
- 8
inc/framework/omg/version.h View File

@@ -32,15 +32,7 @@ class PlatformVersionManager {
PlatformVersionManager() = delete;
~PlatformVersionManager() = delete;
static Status GetPlatformVersion(std::string &ver) {
#if defined PLATFORM_PHOENIX
ver = "3.51.z";
#elif defined PLATFORM_ORLANDO
ver = "3.31.z";
#elif defined PLATFORM_MINI
ver = "1.11.z";
#elif defined PLATFORM_CLOUD
ver = "1.61.z";
#endif
std::vector<std::string> version_splits = StringUtils::Split(ver, '.');
GE_IF_BOOL_EXEC(version_splits.size() < 3, GELOGW("Read platform version error!"); return FAILED;);



+ 16
- 10
inc/graph/anchor.h View File

@@ -20,13 +20,17 @@
#include <memory>
#include <string>
#include <vector>

#include "graph/ge_error_codes.h"
#include "graph/range_vistor.h"
#include "graph/types.h"

namespace ge {
enum AnchorStatus { ANCHOR_SUSPEND = 0, ANCHOR_CONST = 1, ANCHOR_DATA = 2, ANCHOR_RESERVED = 3 };
enum AnchorStatus {
ANCHOR_SUSPEND = 0, // dat null
ANCHOR_CONST = 1,
ANCHOR_DATA = 2, // Effective
ANCHOR_RESERVED = 3
};
using std::string;
using std::vector;

@@ -81,17 +85,19 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable
virtual ~Anchor() = default;

protected:
// Whether the two anchors are equal
// Whether the two anchor is equal
virtual bool Equal(AnchorPtr anchor) const = 0;
virtual bool IsTypeOf(TYPE type) const;

public:
// Get all peer anchors connected to current anchor
Vistor<AnchorPtr> GetPeerAnchors() const;
// Get the first peer anchor
// Get peer anchor size
size_t GetPeerAnchorsSize() const;
// Get first peer anchor
AnchorPtr GetFirstPeerAnchor() const;

// Get the node which is the owner of the anchor
// Get the anchor belong to which node
NodePtr GetOwnerNode() const;

// Remove all links with the anchor
@@ -100,22 +106,22 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable
// Remove link with the given anchor
graphStatus Unlink(const AnchorPtr &peer);

// Replace the peeranchor with the new peeranchor
// Replace peer with new peers
graphStatus ReplacePeer(const AnchorPtr &oldPeer, const AnchorPtr &firstPeer, const AnchorPtr &secondPeer);

// Judge if the anchor is linked with the given anchor
bool IsLinkedWith(const AnchorPtr &peer);

// Get the anchor index of the node
// Get anchor index of the node
int GetIdx() const;

// Set the anchor index of the node
// set anchor index of the node
void SetIdx(int index);

protected:
// All peer anchors connected to current anchor
vector<std::weak_ptr<Anchor>> peer_anchors_;
// The owner nodes of the anchor
// The owner node of anchor
std::weak_ptr<Node> owner_node_;
// The index of current anchor
int idx_;
@@ -167,7 +173,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchor : public DataA

virtual ~InDataAnchor() = default;

// Get source out data anchor
// Get source out data anchor
OutDataAnchorPtr GetPeerOutAnchor() const;

// Build connection from OutDataAnchor to InDataAnchor


+ 6
- 7
inc/graph/attr_value_serializable.h View File

@@ -19,10 +19,10 @@

#include <string>
#include <vector>

#include "graph/ge_attr_value.h"

namespace ge {

class GeAttrValue;
class _GeSerializable {
public:
@@ -107,7 +107,6 @@ class _GeSerializable {
static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) { return GRAPH_SUCCESS; }
};


#define _GE_FI(a) #a, a
#define _GE_MAP_FIELDS1(a1) _GE_FI(a1)
#define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2)
@@ -130,23 +129,23 @@ class _GeSerializable {
#define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \
_GE_FI(a1) \
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
_GE_FI(a11)
_GE_FI(a11)
#define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \
_GE_FI(a1) \
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
_GE_FI(a11), _GE_FI(a12)
_GE_FI(a11), _GE_FI(a12)
#define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \
_GE_FI(a1) \
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13)
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13)
#define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \
_GE_FI(a1) \
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14)
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14)
#define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \
_GE_FI(a1) \
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15)
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15)

#define _GE_PRIVATE_ARGS_GLUE(x, y) x y



+ 2
- 3
inc/graph/buffer.h View File

@@ -17,12 +17,11 @@
#ifndef INC_GRAPH_BUFFER_H_
#define INC_GRAPH_BUFFER_H_

#include <graph/types.h>
#include <memory>
#include <string>
#include <vector>

#include "detail/attributes_holder.h"
#include "graph/types.h"

namespace ge {
#ifdef HOST_VISIBILITY
@@ -72,7 +71,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer {
GeIrProtoHelper<proto::AttrDef> data_;
std::string *buffer_ = nullptr;

// Create buffer from protobuf obj
// Create from protobuf obj
Buffer(const ProtoMsgOwner &protoOnwer, proto::AttrDef *buffer);
Buffer(const ProtoMsgOwner &protoOnwer, std::string *buffer);



+ 60
- 17
inc/graph/compute_graph.h View File

@@ -17,7 +17,6 @@
#ifndef INC_GRAPH_COMPUTE_GRAPH_H_
#define INC_GRAPH_COMPUTE_GRAPH_H_

#include <deque>
#include <map>
#include <memory>
#include <string>
@@ -63,7 +62,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
using Vistor = RangeVistor<T, std::shared_ptr<ConstComputeGraph>>;

explicit ComputeGraph(const std::string &name);
virtual ~ComputeGraph();
~ComputeGraph() override;

std::string GetName() const;
void SetName(const std::string &name);
@@ -81,7 +80,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
Vistor<NodePtr> GetOutputNodes() const;

NodePtr FindNode(const std::string &name) const;
// Add node
// AddNode with NodePtr
NodePtr AddNode(NodePtr node);
NodePtr AddNode(OpDescPtr op);
NodePtr AddNodeFront(NodePtr node);
@@ -94,9 +93,40 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
graphStatus RemoveOutputNode(const NodePtr &node);
graphStatus RemoveConstInput(const NodePtr &node);

/// Add a subgraph to this graph. The subgraph must has a parent graph and parent node,
/// which means the member functions `SetParentGraph` and `SetParentNode` of the subgraph
/// must be called before add it to the root graph. and subgraph->GetParentNode()->GetOwnerGraph()
/// must equal to subgraph->GetOwnerGraph().
/// The subgraphs can only be added to a *root graph*. A root graph is a graph without any parent graph.
/// The subgraph's name SHOULD(not must) be the same as the parameter `name`
graphStatus AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph);
graphStatus AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph);

void RemoveSubgraph(const std::string &name);
void RemoveSubgraph(const std::shared_ptr<ComputeGraph> &subgraph);

std::shared_ptr<ComputeGraph> GetSubgraph(const std::string &name) const;
std::vector<std::shared_ptr<ComputeGraph>> GetAllSubgraphs() const;

// obsolete
std::shared_ptr<ComputeGraph> AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph);
// obsolete
graphStatus RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph);

///
/// @brief Update input-mapping
/// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input
/// @return graphStatus
///
graphStatus UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping);

///
/// @brief Update output-mapping
/// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output
/// @return graphStatus
///
graphStatus UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping);

graphStatus TopologicalSorting();
bool IsValid() const;
void Dump() const;
@@ -127,6 +157,11 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
}
}

shared_ptr<ComputeGraph> GetParentGraph();
void SetParentGraph(const shared_ptr<ComputeGraph> &parent);
shared_ptr<Node> GetParentNode();
void SetParentNode(const shared_ptr<Node> &parent);

const std::map<std::string, std::vector<int32_t>> &GetGraphOutNodes() const { return out_nodes_map_; }

void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; }
@@ -138,8 +173,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
uint32_t GetInputSize() const { return input_size_; }

///
/// Set iteration needed.
/// If set is true, it means this graph need run iteration some
/// Set is need train iteration.
/// If set true, it means this graph need to be run iteration some
/// times(according variant "npu_runconfig/iterations_per_loop").
/// @param need_iteration is need iteration
///
@@ -150,7 +185,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
const std::string GetOutput();

///
/// Get need_iteration.
/// Get is need train iteration.
/// @return is need iteration
///
bool GetNeedIteration() const { return need_iteration_; }
@@ -201,6 +236,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
std::deque<NodePtr> &stack);
graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num,
std::map<string, NodePtr> &breadth_node_map);
graphStatus TopologicalSortingSubgraph();
graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum);
size_t GetInEdgeSize(const NodePtr &node);
size_t GetOutEdgeSize(const NodePtr &node);
@@ -210,31 +246,38 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
bool VectorInputNodePtrIsEqual(const std::vector<NodePtr> &r_node_ptr_vector,
const std::vector<NodePtr> &l_node_ptr_vector) const;

ProtoAttrMapHelper attrs_;

friend class ModelSerializeImp;
friend class GraphDebugImp;
friend class OnnxUtils;

std::string name_;
uint32_t graph_id_ = 0;
ProtoAttrMapHelper attrs_;
std::vector<NodePtr> nodes_;
std::map<OperatorImplPtr, NodePtr> all_nodes_infos_;
std::vector<NodePtr> target_nodes_info_;

std::vector<NodePtr> input_nodes_;
std::vector<std::string> inputs_order_;
uint32_t input_size_ = 1;
std::map<std::string, std::vector<int32_t>> out_nodes_map_;
uint32_t output_size_ = 1;
std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_;

std::vector<std::shared_ptr<ComputeGraph>> sub_graph_;
std::string name_;
std::map<std::string, std::shared_ptr<ComputeGraph>> names_to_subgraph_;
std::weak_ptr<ComputeGraph> parent_graph_;
std::weak_ptr<Node> parent_node_;

// the members followed should not in the ComputeGraph class
bool is_valid_flag_;
bool is_summary_graph_ = false;
// Indicates whether it is need iteration
bool need_iteration_ = false;
std::map<std::vector<std::string>, std::vector<std::string>> params_share_map_;
std::map<std::string, std::vector<int32_t>> out_nodes_map_;
// TaskIdx -> op_name Map
std::map<uint32_t, std::string> op_name_map_;
std::vector<std::string> inputs_order_;
uint32_t output_size_ = 1;
uint32_t input_size_ = 1;
std::map<OperatorImplPtr, NodePtr> all_nodes_infos_;
std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_;
std::vector<NodePtr> target_nodes_info_;
uint64_t session_id_ = 0;
uint32_t graph_id_ = 0;
ge::Format data_format_ = ge::FORMAT_ND;
};
} // namespace ge


+ 276
- 46
inc/graph/debug/ge_attr_define.h View File

@@ -18,7 +18,6 @@
#define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_

#include <string>

#include "graph/types.h"

namespace ge {
@@ -59,6 +58,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS;
@@ -75,8 +76,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE;

// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string
// ATTR_NAME_WEIGHTS;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE;

@@ -124,6 +124,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT;
@@ -141,10 +148,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM;
@@ -166,15 +178,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER;

// _Arg
@@ -263,7 +275,29 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNOR
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION;

// Huberloss
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA;

// SSDRealDivTileMul
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA;

// SSDSumMulRealDivMean
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM;
/// ConcatFive2Four
/// ConcatFour2Five
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH;
// Scale
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS;
@@ -300,7 +334,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_AT
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS;

// Roipooling
@@ -313,6 +346,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLI

// DetectionOutput
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD;
@@ -371,6 +405,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_

// Permute
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM;

// SSD Normalize
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL;
@@ -411,9 +446,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_AT
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT;

// Log
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE;
// Pack
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM;

// Dynamic stitch
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM;
// Unpack
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM;
// Gathernd
@@ -422,8 +463,16 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND

// Argmax
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS;

// Upsample
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W;
// Relu
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE;

@@ -439,6 +488,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_AT
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_MAGIC;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_BLOCKDIM;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_METADATA;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_WORKSPACE_TYPE;

// Squeeze
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_AXIS;
@@ -461,6 +511,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_TF;

// Generate_rpn_proposal
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK;
@@ -493,6 +544,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_AT
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT;
static const std::string NOT_NET_OUTPUT = "not_net_output";

// ENTER
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME;
@@ -518,6 +570,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_B
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA;

// RetinaNet
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION;
// MatMul
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W;
@@ -566,10 +621,30 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS;

// Upsample
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE;

// PadV2
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE;

// MirrorPad
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE;
// Filler
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE;
@@ -590,36 +665,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP;
@@ -634,6 +679,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE;
@@ -642,12 +689,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME;

// Public attribute
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE;

@@ -685,11 +726,178 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE;

// Used for operators that do not generate task
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOTASK;

// Used for operators that output reuse input
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_REUSE_INPUT;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET;

// L2_normalize
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD;
// HCOM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE;

// Log time stamp
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY;
// SpaceToDepth/DepthToSpace
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE;

// SparseSoftmaxCrossEntropyWithLogits
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES;

// MaxPoolGradWithArgmax
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE;

// AvgPoolGrad
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE;

// Varible
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX;

// Assign
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE;

// ShapeN
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE;

// Space2bacth batch2space
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING;
// Depth_to_space space_to_depth
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE;
// FakeQuantWithMinMaxVars
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN;
// Mobilenet_ssd_conv_fusion
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM;

// Lsh project
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE;

// Control flow
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG;

// GatherV2 attr def
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS;

// Reshape attr def
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC;

// Axis attr def
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP;
// The node link with SparseSoftmaxCrossEntropyWithLogits
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE;
// For constant folding
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING;

// Used for mark the active label list to find stream of activated node
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE;

// Multi batch
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM;
@@ -697,7 +905,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM

// Control flow
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE;

@@ -709,6 +916,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEXT_ITERATION;

// Function Op
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX;

// Used for mark the active node is for loop, type:bool
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE;

@@ -742,6 +952,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INS
// For inserted op
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE;

// For compress weight
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMPRESS_WEIGHT;

// For data dump
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP;
@@ -752,6 +965,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE;

// used for l1 fusion and other fusion in future
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED;

// used for label switch
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST;
// Varible
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME;


+ 3
- 6
inc/graph/def_types.h View File

@@ -20,10 +20,8 @@
#include <atomic>
#include <memory>
#include <vector>

#include "graph/attr_value_serializable.h"
#include "graph/buffer.h"

namespace ge {
#define DEF_TYPE_DEC(type, name) \
inline void set_##name(const type &value) { name = value; } \
@@ -49,10 +47,9 @@ namespace ge {
inline void add_##name(type value) { name.push_back(value); } \
inline std::vector<type> *mutable_##name() { return &name; }

#define DEF_TYPE_BYTES_DEC(name) \
inline void clear_##name() { name.ClearBuffer(); } \
inline void set_##name(const void *value, size_t size) { \
name = Buffer::CopyFrom((const uint8_t *)(value), size); } \
#define DEF_TYPE_BYTES_DEC(name) \
inline void clear_##name() { name.ClearBuffer(); } \
inline void set_##name(const void *value, size_t size) { name = Buffer::CopyFrom((const uint8_t *)(value), size); } \
inline Buffer *mutable_##name() { return &name; }

struct CompressInfo {


+ 1
- 2
inc/graph/detail/attributes_holder.h View File

@@ -23,7 +23,6 @@
#include <unordered_set>
#include <utility>
#include <vector>

#include "graph/detail/any_map.h"
#include "graph/ge_error_codes.h"
#include "graph/types.h"
@@ -96,7 +95,7 @@ class GeIrProtoHelper {
}
}

// protoMsg_ is part of protoOwner_ and they have the same runtime
// protoMsg_ is part of protoOwner_, they have the same runtime
ProtoMsgOwner protoOwner_ = nullptr;
ProtoType *protoMsg_ = nullptr;
friend class GeIrProtoHelper<typename std::conditional<


+ 4
- 6
inc/graph/detail/model_serialize_imp.h View File

@@ -21,9 +21,7 @@
#include <memory>
#include <string>
#include <vector>

#include "graph/anchor.h"
#include "graph/model.h"
#include "detail/attributes_holder.h"
#include "graph/ge_tensor.h"
#include "graph/graph.h"
@@ -48,15 +46,15 @@ struct NodeNameNodeReq {

class ModelSerializeImp {
public:
bool SerializeModel(const Model &model, proto::ModelDef *modeProto);
bool SerializeModel(const Model &model, proto::ModelDef *modeProto, bool is_dump = false);

bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto);
bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto, bool is_dump = false);

bool SerializeEdge(const NodePtr &node, proto::OpDef *opDefProto);

bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto);
bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto, bool is_dump = false);

bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto);
bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto, bool is_dump = false);

bool SerializeTensor(const ConstGeTensorPtr &tensor, proto::TensorDef *tensorProto);



+ 2
- 4
inc/graph/ge_attr_value.h View File

@@ -23,7 +23,6 @@
#include <string>
#include <utility>
#include <vector>

#include "graph/buffer.h"
#include "detail/attributes_holder.h"
#include "graph/ge_error_codes.h"
@@ -139,15 +138,14 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue {

template <typename vector_type>
// To cols
using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE,
int>::type;
using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE, int>::type;

template <typename one_type>
using enable_if_one_type_valid_t = typename std::enable_if<IsAttrTypeEnable<one_type>::VALUE, int>::type;

template <typename val_type>
using enable_if_type_valid_t =
typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type;
typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type;

template <typename seriliable_type>
using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable;


+ 1
- 2
inc/graph/ge_context.h View File

@@ -18,7 +18,6 @@
#define INC_GRAPH_GE_CONTEXT_H_

#include <string>

#include "graph/ge_error_codes.h"

namespace ge {
@@ -42,4 +41,4 @@ class GEContext {
GEContext &GetContext();
} // namespace ge

#endif // INC_GRAPH_GE_CONTEXT_H_
#endif // INC_GRAPH_GE_CONTEXT_H_

+ 0
- 2
inc/graph/ge_local_context.h View File

@@ -20,7 +20,6 @@
#include <map>
#include <string>
#include <vector>

#include "graph/ge_error_codes.h"

using std::map;
@@ -42,5 +41,4 @@ class GEThreadLocalContext {

GEThreadLocalContext &GetThreadLocalContext();
} // namespace ge

#endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_

+ 16
- 6
inc/graph/ge_tensor.h View File

@@ -21,12 +21,10 @@
#include <memory>
#include <string>
#include <vector>

#include "detail/attributes_holder.h"
#include "graph/buffer.h"
#include "graph/ge_error_codes.h"
#include "graph/types.h"

namespace ge {
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape {
public:
@@ -43,6 +41,18 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape {
int64_t GetShapeSize() const;
std::string ToString() const;

///
/// @brief Check is unknown shape
/// @return bool
///
bool IsUnknownShape() const;

///
/// @brief Check is a scalar
/// @return bool
///
bool IsScalar() const;

GeShape(const GeShape &other);
GeShape(GeShape &&other);
GeShape &operator=(const GeShape &other);
@@ -51,7 +61,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape {
private:
GeIrProtoHelper<proto::ShapeDef> shape_def_;
friend class GeTensorDesc;
// Create geshape from proto obj
// Create from proto obj
GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg);

void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; }
@@ -112,7 +122,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH

void Init();

// Create getensordesc from proto obj
// Create from proto obj
GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg);
friend class GeTensor;
friend class GeAttrValueImp;
@@ -159,10 +169,10 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor {
friend class GeAttrValueImp;
friend class ModelSerializeImp;
friend class OnnxUtils;
// Create getensor from proto obj
// Create from proto obj
GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg);
GeIrProtoHelper<proto::TensorDef> tensor_def_;
// Reference from tensorDef_, cab not use it directly
// Reference from tensorDef_, do not direct use
mutable GeTensorDesc __desc_;
GeTensorDesc &DescReference() const;
};


+ 1
- 2
inc/graph/model.h View File

@@ -21,7 +21,6 @@
#include <memory>
#include <string>
#include <vector>

#include "detail/attributes_holder.h"
#include "graph/ge_attr_value.h"
#include "graph/graph.h"
@@ -62,7 +61,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder {
using AttrHolder::HasAttr;
using AttrHolder::SetAttr;

graphStatus Save(Buffer &buffer) const;
graphStatus Save(Buffer &buffer, bool is_dump = false) const;

graphStatus SaveToFile(const string &file_name) const;
// Model will be rewrite


+ 1
- 2
inc/graph/model_serialize.h View File

@@ -19,7 +19,6 @@

#include <map>
#include <string>

#include "graph/buffer.h"
#include "graph/compute_graph.h"
#include "graph/model.h"
@@ -27,7 +26,7 @@
namespace ge {
class ModelSerialize {
public:
Buffer SerializeModel(const Model &model);
Buffer SerializeModel(const Model &model, bool is_dump = false);

Model UnserializeModel(const uint8_t *data, size_t len);
Model UnserializeModel(ge::proto::ModelDef &model_def);


+ 7
- 7
inc/graph/node.h View File

@@ -113,25 +113,25 @@ class Node : public std::enable_shared_from_this<Node> {

bool IsAllInNodesSeen(std::unordered_set<Node *> &nodes_seen) const;

// All inData nodes
// All in Data nodes
Vistor<NodePtr> GetInDataNodes() const;
// All inControl nodes
// All in Control nodes
Vistor<NodePtr> GetInControlNodes() const;
// GetInAllNodes = InDataNodes + InControlNodes
Vistor<NodePtr> GetInAllNodes() const;

// All outData nodes
// All out Data nodes
Vistor<NodePtr> GetOutDataNodes() const;
uint32_t GetOutDataNodesSize() const;
// All outControl nodes
// All out Control nodes
Vistor<NodePtr> GetOutControlNodes() const;
// GetOutAllNodes = OutDataNodes + InControlNodes
Vistor<NodePtr> GetOutAllNodes() const;

// Get all indata nodes and its outanchor
// Get all in data nodes and its out-anchor
Vistor<std::pair<NodePtr, OutDataAnchorPtr>> GetInDataNodesAndAnchors() const;

// Get all outdata nodes and its inanchor
// Get all out data nodes and its in-anchor
Vistor<std::pair<NodePtr, InDataAnchorPtr>> GetOutDataNodesAndAnchors() const;

graphStatus InferShapeAndType() const;
@@ -176,7 +176,7 @@ class Node : public std::enable_shared_from_this<Node> {

void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; }

NodePtr GetOrigNode(void) { return orig_node_; }
NodePtr GetOrigNode() { return orig_node_; }

private:
bool NodeMembersAreEqual(const Node &r_node) const;


+ 23
- 4
inc/graph/op_desc.h View File

@@ -23,7 +23,6 @@
#include <string>
#include <unordered_set>
#include <vector>

#include "detail/attributes_holder.h"
#include "graph/range_vistor.h"

@@ -108,6 +107,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {

size_t GetInputsSize() const;

size_t GetAllInputsSize() const;

graphStatus AddOutputDesc(const GeTensorDesc &output_desc);

graphStatus AddOutputDesc(const string &name, const GeTensorDesc &output_desc);
@@ -122,6 +123,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {

GeTensorDescPtr MutableOutputDesc(uint32_t index) const;

uint32_t GetAllOutputsDescSize() const;

Vistor<GeTensorDesc> GetAllOutputsDesc() const;

Vistor<GeTensorDescPtr> GetAllOutputsDescPtr() const;
@@ -132,6 +135,10 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {

ConstGeTensorDescPtr GetInputDescPtr(uint32_t index) const;

ConstGeTensorDescPtr GetInputDescPtrDfault(uint32_t index) const;

ConstGeTensorDescPtr GetInputDescPtr(const string &name) const;

graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true);

graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true);
@@ -140,7 +147,11 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {

bool IsOptionalInput(uint32_t index) const;

std::map<string, uint32_t> GetAllInputName();
std::map<string, uint32_t> GetAllInputName() const;

void SetAllInputName(const std::map<string, uint32_t> &input_name_idx);

std::vector<string> GetAllOptionalInputName() const;

std::map<string, uint32_t> GetAllOutputName();

@@ -225,6 +236,14 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {

std::string GetOpEngineName() const;

graphStatus AddSubgraphName(const std::string &name);
const std::map<std::string, uint32_t> &GetSubgraphNameIndexes() const;

std::string GetSubgraphInstanceName(uint32_t index) const;
const std::vector<std::string> &GetSubgraphInstanceNames() const;
void AddSubgraphInstanceName(std::string name);
void RemoveSubgraphInstanceName(const std::string &name);

protected:
ProtoAttrMapHelper MutableAttrMap() override;
ConstProtoAttrMapHelper GetAttrMap() const override;
@@ -236,9 +255,9 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {
bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const;

GeIrProtoHelper<ge::proto::OpDef> op_def_;
std::vector<std::string> subgraph_instance_names_;
std::map<std::string, uint32_t> subgraph_names_to_index_;
vector<GeTensorDescPtr> inputs_desc_{};
map<string, uint32_t> input_name_idx_{};
std::unordered_set<string> optional_input_names_{};
vector<GeTensorDescPtr> outputs_desc_{};
map<string, uint32_t> output_name_idx_{};
std::function<graphStatus(Operator &)> infer_func_ = nullptr;


+ 0
- 2
inc/graph/operator_factory_impl.h View File

@@ -21,7 +21,6 @@
#include <memory>
#include <string>
#include <vector>

#include "graph/operator_factory.h"

namespace ge {
@@ -47,7 +46,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactoryImpl {

static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func);

private:
static shared_ptr<std::map<string, OpCreator>> operator_creators_;
static shared_ptr<std::map<string, InferShapeFunc>> operator_infershape_funcs_;
static shared_ptr<std::map<string, InferFormatFunc>> operator_inferformat_funcs_;


+ 4
- 2
inc/graph/shape_refiner.h View File

@@ -18,8 +18,8 @@
#define INC_GRAPH_SHAPE_REFINER_H_

#include <string>

#include "external/graph/inference_context.h"

#include "external/graph/ge_error_codes.h"
#include "graph/node.h"

@@ -27,8 +27,10 @@ namespace ge {
// ShapeRefiner performs shape inference for compute graphs
class ShapeRefiner {
public:
static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op);
static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph);
static graphStatus InferShapeAndType(const NodePtr &node, bool before_subgraph);
static graphStatus InferShapeAndType(const NodePtr &node);
static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op);

private:
static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase);


+ 3
- 3
inc/graph/usr_types.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GRAPH_USR_TYPES_H_
#define INC_EXTERNAL_GRAPH_USR_TYPES_H_
#ifndef INC_GRAPH_USR_TYPES_H_
#define INC_GRAPH_USR_TYPES_H_

#include <atomic>
#include <memory>
@@ -130,4 +130,4 @@ struct UsrQuantizeFactorParams {
#undef USR_TYPE_BYTES_DEC
} // namespace ge

#endif // INC_EXTERNAL_GRAPH_USR_TYPES_H_
#endif // INC_GRAPH_USR_TYPES_H_

+ 3
- 2
inc/graph/utils/attr_utils.h View File

@@ -99,8 +99,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils {
static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer);
static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer);
// Value will be moved
static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name,
vector<Buffer> &listBuffer);
static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer);
static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer);

static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector<vector<int64_t>> &value);
@@ -116,6 +115,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils {

static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc);

static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj);

class AttrHolderAdapter {
public:
AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {}


+ 321
- 4
inc/graph/utils/graph_utils.h View File

@@ -137,6 +137,18 @@ class GraphUtils {
static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor,
const std::vector<OpDescPtr> &vec_op_desc);

///
/// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst
/// @param [in] src
/// @param [in] dsts
/// @param [in] insert_node
/// @param [in] input_index
/// @param [in] output_index
/// @return graphStatus
///
static graphStatus InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0);

static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node);

static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node);
@@ -145,16 +157,12 @@ class GraphUtils {

static void RecordOriginalNames(std::vector<std::string> names_tmp, const ge::NodePtr &node);

static bool CheckIsTrainGraph(const ge::ComputeGraphPtr &compute_graph);

static bool MatchDumpStr(const std::string &suffix);

static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false);

static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph);

static bool CheckGlobalStepNode(const ge::NodePtr &node);

static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos);

static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix);
@@ -252,6 +260,315 @@ class GraphUtils {
/// @return success: GRAPH_SUCESS
///
static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node);

static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph);
};

class ComputeGraphBuilder {
public:
ComputeGraphBuilder() : owner_graph_(nullptr) {}
ComputeGraphBuilder(const ComputeGraphBuilder &) = delete;
ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete;
ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete;
ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete;
~ComputeGraphBuilder() = default;

///
/// @brief Add node to graph
/// @param [in] op_desc
/// @return ComputeGraphBuilder
///
virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc);

///
/// @brief Add data-link among nodes in graph
/// @param [in] src_name
/// @param [in] out_anchor_ind
/// @param [in] dst_name
/// @param [in] in_anchor_ind
/// @return ComputeGraphBuilder
///
virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind,
const std::string &dst_name, uint32_t in_anchor_ind);

///
/// @brief Add ctrl-link among nodes in graph
/// @param [in] src_name
/// @param [in] dst_name
/// @return ComputeGraphBuilder
///
virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name);

///
/// @brief Build graph
/// @param [out] error_code
/// @param [out] error_msg
/// @return ComputeGraphPtr
///
virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0;

/// @brief Get node with name
/// @param [in] name
/// @return NodePtr
///
NodePtr GetNode(const std::string &name);

protected:
///
/// @brief Build nodes
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void BuildNodes(graphStatus &error_code, std::string &error_msg);

///
/// @brief Build data-links
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void BuildDataLinks(graphStatus &error_code, std::string &error_msg);

///
/// @brief Build ctrl-links
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg);

ComputeGraphPtr owner_graph_;

// node_name -> node
std::map<std::string, NodePtr> node_names_;
std::vector<OpDescPtr> nodes_;

// <src_node_name, out_anchor_ind> -> <dst_node_name, in_anchor_ind>
std::vector<std::pair<std::pair<std::string, uint32_t>, std::pair<std::string, uint32_t>>> data_links_;
// src_node_name -> dst_node_name
std::vector<std::pair<std::string, std::string>> ctrl_links_;
};

class CompleteGraphBuilder : public ComputeGraphBuilder {
public:
explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {}
CompleteGraphBuilder(const CompleteGraphBuilder &) = delete;
CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete;
CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete;
CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete;
~CompleteGraphBuilder() = default;

///
/// @brief Add node to graph
/// @param [in] op_desc
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override;

///
/// @brief Add data-link among nodes in graph
/// @param [in] src_name
/// @param [in] out_anchor_ind
/// @param [in] dst_name
/// @param [in] in_anchor_ind
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name,
uint32_t in_anchor_ind) override;

///
/// @brief Add ctrl-link among nodes in graph
/// @param [in] src_name
/// @param [in] dst_name
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override;

///
/// @brief Set index_th input anchor for graph
/// @param [in] index
/// @param [in] node_names
/// @param [in] anchor_inds
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &SetInput(uint32_t index, const std::vector<std::string> &node_names,
const std::vector<uint32_t> &anchor_inds);

///
/// @brief Set index_th input of graph as useless
/// @param [in] index
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &SetUselessInput(uint32_t index);

///
/// @brief Add output anchor for graph
/// @param [in] owner_node_name
/// @param [in] anchor_ind
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind);

///
/// @brief Set parent-node of graph
/// @param [in] parent_node
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node);

///
/// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node
/// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &SetInputMapping(const std::map<uint32_t, uint32_t> &input_mapping);

///
/// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind
/// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &SetOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping);

///
/// @brief Build graph
/// @param [out] error_code
/// @param [out] error_msg
/// @return ComputeGraphPtr
///
ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override;

private:
///
/// @brief Build inputs
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void BuildInputs(graphStatus &error_code, std::string &error_msg);

///
/// @brief Add data node
/// @param [in] index
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
NodePtr AddDateNode(uint32_t index, graphStatus &error_code, std::string &error_msg);

///
/// @brief Build outputs
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void BuildOutputs(graphStatus &error_code, std::string &error_msg);

///
/// @brief Add NetOutput node
/// @param [out] error_code
/// @param [out] error_msg
/// @return NodePtr
///
NodePtr AddNetOutputNode(graphStatus &error_code, std::string &error_msg);

///
/// @brief Add input/output tensor for NetOutput node
/// @param [in] out_nodes_info
/// @param [out] net_output_desc
/// @return graphStatus
///
graphStatus BuildInOutForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info,
OpDescPtr &net_output_desc);

///
/// @brief Add edge for NetOutput node
/// @param [in] out_nodes_info
/// @param [out] net_output_node
/// @return graphStatus
///
graphStatus AddEdgeForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info,
const NodePtr &net_output_node);

std::string name_;
NodePtr parent_node_;
std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_;
std::vector<std::pair<std::string, uint32_t>> graph_outputs_;

// index_of_graph_input -> in_anchor_index_of_parent_node
std::map<uint32_t, uint32_t> input_mapping_;
// index_of_graph_output -> out_anchor_index_of_parent_node
std::map<uint32_t, uint32_t> output_mapping_;
};

class PartialGraphBuilder : public ComputeGraphBuilder {
public:
PartialGraphBuilder() = default;
PartialGraphBuilder(const PartialGraphBuilder &) = delete;
PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete;
PartialGraphBuilder(const PartialGraphBuilder &&) = delete;
PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete;
~PartialGraphBuilder() = default;

///
/// @brief Add node to graph
/// @param [in] op_desc
/// @return PartialGraphBuilder
///
PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override;

///
/// @brief Add data-link among nodes in graph
/// @param [in] src_name
/// @param [in] out_anchor_ind
/// @param [in] dst_name
/// @param [in] in_anchor_ind
/// @return PartialGraphBuilder
///
PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name,
uint32_t in_anchor_ind) override;

///
/// @brief Add ctrl-link among nodes in graph
/// @param [in] src_name
/// @param [in] dst_name
/// @return PartialGraphBuilder
///
PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override;

///
/// @brief Set owner graph
/// @param [in] graph
/// @return PartialGraphBuilder
///
PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph);

///
/// @brief Add exist node
/// @param [in] node
/// @return PartialGraphBuilder
///
PartialGraphBuilder &AddExistNode(const NodePtr &node);

///
/// @brief Build multi nodes with links
/// @param [out] error_code
/// @param [out] error_msg
/// @return ComputeGraphPtr
///
ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override;

private:
///
/// @brief Build exist nodes
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void BuildExistNodes(graphStatus &error_code, std::string &error_msg);

std::vector<NodePtr> exist_nodes_;
};
} // namespace ge



+ 5
- 0
inc/graph/utils/node_utils.h View File

@@ -56,6 +56,11 @@ class NodeUtils {
static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape);
static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape);

static std::string GetNodeType(const Node &node);

static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index);
static graphStatus AddSubgraph(Node &node, const ComputeGraphPtr &subgraph);

private:
static std::map<NodePtr, std::vector<uint32_t>> map_send_info_;
static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_;


+ 89
- 38
inc/graph/utils/op_desc_utils.h View File

@@ -20,7 +20,6 @@
#include <memory>
#include <string>
#include <vector>

#include "graph/def_types.h"
#include "graph/node.h"
#include "graph/op_desc.h"
@@ -29,7 +28,6 @@

namespace ge {
class OpDesc;

using OpDescPtr = std::shared_ptr<OpDesc>;

class OpDescUtils {
@@ -39,55 +37,108 @@ class OpDescUtils {

OpDescUtils() = default;
~OpDescUtils() = default;
static bool HasQuantizeFactorParams(const OpDescPtr &op_desc);
static bool HasQuantizeFactorParams(const OpDesc &op_desc);
static graphStatus GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant);
static graphStatus GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant);
static graphStatus SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant);
static graphStatus SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant);
static vector<ge::NodePtr> GetConstInputNode(const ge::Node &node);
static vector<ConstGeTensorPtr> GetInputData(const vector<ge::NodePtr> &input_nodes);
static vector<ConstGeTensorPtr> GetWeights(const ge::Node &node);
static vector<ConstGeTensorPtr> GetWeights(const ge::ConstNodePtr &node);
static vector<GeTensorPtr> MutableWeights(const ge::Node &node);
static bool HasQuantizeFactorParams(const OpDescPtr& op_desc);
static bool HasQuantizeFactorParams(const OpDesc& op_desc);
static graphStatus GetQuantizeFactorParams(const OpDescPtr& op_desc, QuantizeFactorParams& quant);
static graphStatus GetQuantizeFactorParams(const OpDesc& op_desc, QuantizeFactorParams& quant);
static graphStatus SetQuantizeFactorParams(const OpDescPtr& op_desc, const QuantizeFactorParams& quant);
static graphStatus SetQuantizeFactorParams(OpDesc& op_desc, const QuantizeFactorParams& quant);
static vector<ge::NodePtr> GetConstInputNode(const ge::Node& node);
static vector<ConstGeTensorPtr> GetInputData(const vector<ge::NodePtr>& input_nodes);
static vector<ConstGeTensorPtr> GetWeights(const ge::Node& node);
static vector<ConstGeTensorPtr> GetWeights(const ge::ConstNodePtr& node);
static vector<GeTensorPtr> MutableWeights(const ge::Node& node);
static vector<GeTensorPtr> MutableWeights(const ge::NodePtr node);
static graphStatus SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights);
static graphStatus SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr> &weights);
static graphStatus SetWeights(ge::Node& node, const vector<ge::GeTensorPtr>& weights);
static graphStatus SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr>& weights);
static graphStatus ClearWeights(ge::NodePtr node);

static bool ClearInputDesc(ge::OpDescPtr op_desc, uint32_t index);
static bool ClearInputDesc(const ge::NodePtr &node);
static bool ClearOutputDesc(const ge::OpDescPtr &op_desc, uint32_t index);
static bool ClearOutputDesc(const ge::NodePtr &node);
static vector<ge::NodePtr> GetConstInputs(const ge::Node &node);
static vector<ge::NodePtr> GetConstInputs(const ge::ConstNodePtr &node);
static size_t GetNonConstInputsSize(const ge::Node &node);
static bool ClearInputDesc(const ge::NodePtr& node);
static bool ClearOutputDesc(const ge::OpDescPtr& op_desc, uint32_t index);
static bool ClearOutputDesc(const ge::NodePtr& node);
static vector<ge::NodePtr> GetConstInputs(const ge::Node& node);
static vector<ge::NodePtr> GetConstInputs(const ge::ConstNodePtr& node);
static size_t GetNonConstInputsSize(const ge::Node& node);
static size_t GetNonConstInputsSize(ge::ConstNodePtr node);
// Index: Indicate the index of all non const inputs
static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const = 0);
static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const = 0);
static bool GetNonConstInputIndex(const ge::Node &node, size_t index_non_const, size_t &index);
static bool GetNonConstInputIndex(const ge::ConstNodePtr &node, size_t index_non_const, size_t &index);
// Index: Indicate the index of all inputs
static bool IsNonConstInput(const ge::Node &node, size_t index = 0);
static bool IsNonConstInput(const ge::ConstNodePtr &node, size_t index = 0);
static vector<ge::GeTensorDesc> GetNonConstTensorDesc(const ge::ConstNodePtr &node);
static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr);
// Index: Indicates the index of all non const inputs
static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node& node, size_t index_non_const = 0);
static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr& node, size_t index_non_const = 0);
static bool GetNonConstInputIndex(const ge::Node& node, size_t index_non_const, size_t& index);
static bool GetNonConstInputIndex(const ge::ConstNodePtr& node, size_t index_non_const, size_t& index);
// Index: Indicates the index of all inputs
static bool IsNonConstInput(const ge::Node& node, size_t index = 0);
static bool IsNonConstInput(const ge::ConstNodePtr& node, size_t index = 0);
static vector<ge::GeTensorDesc> GetNonConstTensorDesc(const ge::ConstNodePtr& node);
static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr& tensor_ptr);

static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc);
static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr);
static OpDescPtr GetOpDescFromOperator(const Operator &oprt);
static OpDescPtr GetOpDescFromOperator(const Operator& oprt);

static OpDescPtr CreateConstOp(const GeTensorPtr &tensor_ptr);
static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr);

private:
static GeTensorPtr MutableWeights(ge::OpDesc &op_desc);
static GeTensorPtr MutableWeights(ge::OpDesc& op_desc);
static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc);
static graphStatus SetWeights(ge::OpDesc &op_desc, const GeTensorPtr weight);
static graphStatus SetWeights(ge::OpDesc& op_desc, const GeTensorPtr weight);
static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight);
};

class OpDescBuilder {
public:
OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {}
OpDescBuilder(const OpDescBuilder&) = delete;
OpDescBuilder& operator=(const OpDescBuilder&) = delete;
OpDescBuilder(const OpDescBuilder&&) = delete;
OpDescBuilder& operator=(const OpDescBuilder&&) = delete;
~OpDescBuilder() = default;

///
/// @brief Add input
/// @param [in] name
/// @return OpDescBuilder
///
OpDescBuilder& AddInput(const std::string& name);

///
/// @brief Add dynamic input
/// @param [in] name
/// @param [in] num
/// @return OpDescBuilder
///
OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num);

///
/// @brief Add output
/// @param [in] name
/// @return OpDescBuilder
///
OpDescBuilder& AddOutput(const std::string& name);

///
/// @brief Add dynamic output
/// @param [in] name
/// @param [in] num
/// @return OpDescBuilder
///
OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num);

///
/// @brief Build op_desc
/// @return OpDescPtr
///
OpDescPtr Build();

private:
std::string name_;
std::string type_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
};
} // namespace ge

#endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_

+ 7
- 8
inc/graph/utils/tensor_utils.h View File

@@ -18,15 +18,14 @@
#define INC_GRAPH_UTILS_TENSOR_UTILS_H_

#include <vector>

#include "graph/def_types.h"
#include "graph/ge_error_codes.h"
#include "graph/ge_tensor.h"
namespace ge {
class TensorUtils {
public:
static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, uint32_t &size);
static void SetSize(GeTensorDesc &tensorDesc, uint32_t size);
static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, int64_t &size);
static void SetSize(GeTensorDesc &tensorDesc, int64_t size);
static uint32_t GetWeightSize(const ConstGeTensorPtr &tensorPtr);
static uint32_t GetWeightSize(const GeTensor &tensor);
static uint32_t GetWeightSize(const GeTensorDesc &tensorDesc);
@@ -62,16 +61,16 @@ class TensorUtils {
static void SetRC(GeTensorDesc &tensorDesc, uint32_t rc);

///
/// calculate mem size of the tensor.
/// calculate tensor mem size.
/// @param shape tensor shape
/// @param format tensor format
/// @param data_type tensor data type
/// @param mem_size -1 means unknown shape,others means mem size
/// @return GRAPH_SUCCESS:success, others:failed
/// @param mem_size -1 means unknown shape,other means mem size
/// @return GRAPH_SUCCESS:success, other:failed
///
static ge::graphStatus CalcTensorMemSize(const GeShape &shape, Format format, DataType data_type, int64_t &mem_size);
static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp);
static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp);
static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp);
static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp);
};
} // namespace ge
#endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_

+ 1
- 0
src/common/graph/CMakeLists.txt View File

@@ -58,6 +58,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external/graph)
include_directories(${GE_SOURCE_DIR}/inc/graph)
include_directories(${GE_SOURCE_DIR}/inc/common)
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc)
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops)
include_directories(${GE_SOURCE_DIR}/third_party/securec/include)
include_directories(${CMAKE_BINARY_DIR})
include_directories(${CMAKE_BINARY_DIR}/proto/ge)


+ 2
- 0
src/common/graph/anchor.cc View File

@@ -26,6 +26,8 @@ Anchor::Anchor(const NodePtr &owner_node, int idx) : owner_node_(owner_node), id

bool Anchor::IsTypeOf(TYPE type) const { return strcmp(Anchor::TypeOf<Anchor>(), type) == 0; }

size_t Anchor::GetPeerAnchorsSize() const { return peer_anchors_.size(); }

Anchor::Vistor<AnchorPtr> Anchor::GetPeerAnchors() const {
vector<AnchorPtr> ret;
for (const auto &anchor : peer_anchors_) {


+ 1
- 2
src/common/graph/buffer.cc View File

@@ -32,8 +32,7 @@ Buffer::Buffer(const Buffer &other) {
buffer_ = other.buffer_;
}

// default
Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() {
Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { // default
auto proto_msg = data_.GetProtoMsg();
if (proto_msg != nullptr) {
try {


+ 200
- 23
src/common/graph/compute_graph.cc View File

@@ -15,9 +15,7 @@
*/

#include "graph/compute_graph.h"

#include <deque>

#include "./format_refiner.h"
#include "./ge_context.h"
#include "debug/ge_attr_define.h"
@@ -41,7 +39,7 @@ const size_t OUTPUT_PARAM_SIZE = 2;
} // namespace

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name)
: nodes_(), input_nodes_(), sub_graph_(), name_(name), is_valid_flag_(false), need_iteration_(false) {
: name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) {
attrs_.InitDefault();
}
ComputeGraph::~ComputeGraph() {}
@@ -154,7 +152,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNod

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual(
const ComputeGraph &r_graph) const {
return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.sub_graph_.size()") &&
return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.subgraphs_.size()") &&
IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") &&
VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) &&
IsEqual(this->name_, r_graph.name_, "graph.name_") &&
@@ -398,6 +396,165 @@ graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr<ComputeGraph> &su
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph) {
if (subgraph == nullptr) {
GE_LOGE("Try to add a null subgraph, name %s", name.c_str());
return GRAPH_PARAM_INVALID;
}
auto parent_graph = subgraph->GetParentGraph();
if (parent_graph == nullptr) {
GE_LOGE("Try to add subgraph without parent graph, name %s", name.c_str());
return GRAPH_PARAM_INVALID;
}
auto parent_node = subgraph->GetParentNode();
if (parent_node == nullptr) {
GE_LOGE("Try to add a subgraph without parent node, name %s", name.c_str());
return GRAPH_PARAM_INVALID;
}
if (parent_node->GetOwnerComputeGraph() != parent_graph) {
GE_LOGE(
"Try to add a subgraph which parent node's parent graph is not equal to "
"the subgraph's parent graph, subgraph name %s, parent node name %s",
subgraph->GetName().c_str(), parent_graph->GetName().c_str());
return GRAPH_PARAM_INVALID;
}
if (!this->parent_graph_.expired()) {
GE_LOGE("The subgraphs can only be added to the root graph");
return GRAPH_PARAM_INVALID;
}
if (name != subgraph->GetName()) {
GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str());
}
sub_graph_.push_back(subgraph);
names_to_subgraph_[name] = subgraph;
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
ComputeGraph::AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph) {
if (subgraph == nullptr) {
return GRAPH_PARAM_INVALID;
}
return AddSubgraph(subgraph->GetName(), subgraph);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(const std::string &name) {
auto iter = names_to_subgraph_.find(name);
if (iter == names_to_subgraph_.end()) {
return;
}
for (auto vec_iter = sub_graph_.begin(); vec_iter != sub_graph_.end(); ++vec_iter) {
if (*vec_iter == iter->second) {
sub_graph_.erase(vec_iter);
break;
}
}
names_to_subgraph_.erase(iter);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(
const std::shared_ptr<ComputeGraph> &subgraph) {
if (subgraph != nullptr) {
RemoveSubgraph(subgraph->GetName());
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr<ComputeGraph> ComputeGraph::GetSubgraph(
const std::string &name) const {
auto iter = names_to_subgraph_.find(name);
return iter == names_to_subgraph_.end() ? nullptr : iter->second;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector<std::shared_ptr<ComputeGraph>>
ComputeGraph::GetAllSubgraphs() const {
return sub_graph_;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<ComputeGraph> ComputeGraph::GetParentGraph() {
return parent_graph_.lock();
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentGraph(
const shared_ptr<ComputeGraph> &parent) {
parent_graph_ = parent;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<Node> ComputeGraph::GetParentNode() {
return parent_node_.lock();
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode(const shared_ptr<Node> &parent) {
parent_node_ = parent;
}

///
/// @brief Update input-mapping
/// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input
/// @return graphStatus
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping) {
for (auto &input : input_nodes_) {
uint32_t cur_index = 0;
if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) {
continue;
}
auto iter = input_mapping.find(cur_index);
if (iter == input_mapping.end()) {
continue;
}
if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) {
GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed.");
return GRAPH_FAILED;
}
}

return GRAPH_SUCCESS;
}

///
/// @brief Update output-mapping
/// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output
/// @return graphStatus
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping) {
NodePtr net_output = FindNode(kNodeNameNetOutput);
if (net_output == nullptr) {
GE_LOGE("UpdateOutputMapping failed: node %s not exist in graph.", kNodeNameNetOutput);
return GRAPH_FAILED;
}
OpDescPtr op_desc = net_output->GetOpDesc();
if (op_desc == nullptr) {
GE_LOGE("UpdateOutputMapping failed: op_desc is NULL.");
return GRAPH_FAILED;
}

size_t num = op_desc->GetInputsSize();
for (size_t i = 0; i < num; i++) {
GeTensorDesc tensor = op_desc->GetInputDesc(i);
uint32_t cur_index = 0;
if (!ge::AttrUtils::GetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, cur_index)) {
continue;
}
auto iter = output_mapping.find(cur_index);
if (iter == output_mapping.end()) {
continue;
}
if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) {
GE_LOGE("UpdateOutputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed.");
return GRAPH_FAILED;
}
if (op_desc->UpdateInputDesc(i, tensor) != GRAPH_SUCCESS) {
GE_LOGE("UpdateOutputMapping failed: update %u input_tensor failed.", i);
return GRAPH_FAILED;
}
}

return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() {
std::vector<NodePtr> node_vec = nodes_;
for (const auto &node : GetAllNodes()) {
@@ -551,6 +708,23 @@ graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<No
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() {
auto ret = TopologicalSortingSubgraph();
if (ret != SUCCESS) {
GELOGE(ret, "Sub graph partition Failed");
return ret;
}
// partition sub graph
for (const auto &sub_graph : GetAllSubgraphs()) {
ret = sub_graph->TopologicalSortingSubgraph();
if (ret != SUCCESS) {
GELOGE(ret, "Sub graph topological sort Failed");
return ret;
}
}
return SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSortingSubgraph() {
std::vector<NodePtr> node_vec;
std::map<NodePtr, uint32_t> map_in_edge_num;
bool use_BFS = false;
@@ -598,6 +772,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog
node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null]
nodes_.push_back(node);
}

is_valid_flag_ = true;
return GRAPH_SUCCESS;
}
@@ -614,7 +789,7 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt
verify_isolated = true;
}
}
for (const auto &node : GetAllNodes()) {
for (const auto &node : GetDirectNode()) {
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue);
map_in_edge_num[node] = static_cast<uint32_t>(GetInEdgeSize(node));
if (map_in_edge_num[node] == 0) {
@@ -640,16 +815,16 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt
/// 2. Compare two indices, if not match, swap the positions of two inputs
/// *: Remind: stack is reverse-order
for (size_t i = 0; i < stack.size(); ++i) {
// [stack: should not be null]
// If not found in 'inputs_order_', skip it
auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName());
GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue);
auto inx_i = it_i - inputs_order_.begin();
for (size_t j = i + 1; j < stack.size(); ++j) {
// If not found in 'inputs_order_', skip it
auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName());
GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue);
auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName());
GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue);

// Compare index, swap them if it should be
auto inx_i = it_i - inputs_order_.begin();
auto inx_j = it_j - inputs_order_.begin();
GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j]));
}
@@ -663,7 +838,7 @@ size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) {
return in_edge_size;
}
for (const auto &anchor : node->GetAllInDataAnchors()) {
in_edge_size = in_edge_size + anchor->GetPeerAnchors().size();
in_edge_size = in_edge_size + anchor->GetPeerAnchorsSize();
// Break flow control data loop.
OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor();
if ((out_anchor != nullptr) && (out_anchor->GetOwnerNode() != nullptr)) {
@@ -680,10 +855,11 @@ size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) {
}
}
if (node->GetInControlAnchor() != nullptr) {
in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchors().size();
in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchorsSize();
}
return in_edge_size;
}

size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) {
size_t out_edge_size = 0;
if (node == nullptr) {
@@ -699,7 +875,7 @@ size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) {
}
}
if (node->GetOutControlAnchor() != nullptr) {
if (out_edge_size > (UINT32_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) {
if (out_edge_size > (UINT64_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) {
return 0;
}
out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchors().size();
@@ -724,17 +900,18 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const {
peer_in_anchor->GetOwnerNode()->GetName().c_str()));
}
}
GE_IF_BOOL_EXEC(node->GetOutControlAnchor() == nullptr, GELOGE(GRAPH_FAILED, "Out control anchor is null");
return );
for (const auto &peer_in_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) {
GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(),
peer_in_anchor->GetOwnerNode()->GetName().c_str()));
}
for (const auto &peer_in_anchor : node->GetOutControlAnchor()->GetPeerInDataAnchors()) {
GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(),
peer_in_anchor->GetOwnerNode()->GetName().c_str()));
auto out_control_anchor = node->GetOutControlAnchor();
if (out_control_anchor != nullptr) {
for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) {
GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(),
peer_in_anchor->GetOwnerNode()->GetName().c_str()));
}
for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) {
GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr,
GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(),
peer_in_anchor->GetOwnerNode()->GetName().c_str()));
}
}
}
}


+ 35
- 69
src/common/graph/debug/ge_log.h View File

@@ -18,21 +18,9 @@
#define COMMON_GRAPH_DEBUG_GE_LOG_H_

#include "graph/ge_error_codes.h"
#include "toolchain/slog.h"
#include "framework/common/debug/ge_log.h"

#define GE_MOD_ID GE

#ifdef _MSC_VER
#define FUNC_NAME __FUNCTION__
#else
#define FUNC_NAME __PRETTY_FUNCTION__
#endif

#define D_GE_LOGE(fmt, ...) \
dlog_error(static_cast<int>(GE_MOD_ID), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__)

#define GE_LOGE(...) D_GE_LOGE(__VA_ARGS__)
#define GE_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__)

#define GE_LOGI_IF(condition, ...) \
if ((condition)) { \
@@ -44,15 +32,15 @@
GELOGW(__VA_ARGS__); \
}

#define GE_LOGE_IF(condition, ...) \
if ((condition)) { \
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \
#define GE_LOGE_IF(condition, ...) \
if ((condition)) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
}

#define GE_CHK_STATUS_RET_NOLOG(expr) \
do { \
const ge::graphStatus _status = (expr); \
if (_status != ge::GRAPH_SUCCESS) { \
if (ge::SUCCESS != _status) { \
return _status; \
} \
} while (0)
@@ -61,7 +49,7 @@
do { \
bool b = (expr); \
if (!b) { \
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \
GELOGE(ge::FAILED, __VA_ARGS__); \
return _status; \
} \
} while (0)
@@ -85,7 +73,7 @@
do { \
const ge::graphStatus _status = (expr); \
if (_status) { \
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \
GELOGE(ge::FAILED, __VA_ARGS__); \
return _status; \
} \
} while (0)
@@ -95,7 +83,7 @@
{ \
bool b = (expr); \
if (b) { \
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \
GELOGE(ge::FAILED, __VA_ARGS__); \
exec_expr; \
} \
}
@@ -119,63 +107,41 @@
} while (0)

// If expr is not true, the log is printed and a custom statement is executed
#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \
{ \
bool b = (expr); \
if (!b) { \
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \
exec_expr; \
} \
}

// If expr is not true, the log is printed and a custom statement is executed
#define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \
{ \
bool b = (expr); \
if (!b) { \
GELOGI(__VA_ARGS__); \
exec_expr; \
} \
#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \
{ \
bool b = (expr); \
if (!b) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
exec_expr; \
} \
}

// If expr is not true, the log is printed and a custom statement is executed
#define GE_CHK_BOOL_EXEC_DEBUG(expr, exec_expr, ...) \
{ \
bool b = (expr); \
if (!b) { \
GELOGD(__VA_ARGS__); \
exec_expr; \
} \
#define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \
{ \
bool b = (expr); \
if (!b) { \
GELOGI(__VA_ARGS__); \
exec_expr; \
} \
}

// If expr is not GRAPH_SUCCESS, print the log and return the same value
#define GE_CHK_STATUS_RET(expr, ...) \
do { \
const ge::graphStatus _status = (expr); \
if (_status != ge::GRAPH_SUCCESS) { \
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \
return _status; \
} \
#define GE_CHK_STATUS_RET(expr, ...) \
do { \
const ge::graphStatus _status = (expr); \
if (ge::SUCCESS != _status) { \
GELOGE(ge::FAILED, __VA_ARGS__); \
return _status; \
} \
} while (0)

#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \
try { \
exec_expr0; \
} catch (...) { \
GELOGE(ge::GRAPH_FAILED, "Make shared failed"); \
exec_expr1; \
#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \
try { \
exec_expr0; \
} catch (...) { \
GELOGE(ge::FAILED, "Make shared failed"); \
exec_expr1; \
}

/// CCE related macro definition
/// If expr is not CC_STATUS_GRAPH_SUCCESS, print the log and return
#define GE_CHK_CCE_RET(expr) \
do { \
ccgraphStatus_t _cc_ret = (expr); \
if (_cc_ret != CC_STATUS_GRAPH_SUCCESS) { \
GELOGE(ge::GRAPH_FAILED, "Call cce api failed, ret: 0x%X", _cc_ret); \
return ge::GRAPH_FAILED; \
} \
} while (0)

#endif // COMMON_GRAPH_DEBUG_GE_LOG_H_


+ 0
- 1
src/common/graph/debug/ge_util.h View File

@@ -25,7 +25,6 @@
#include <string>
#include <utility>
#include <vector>

#include "framework/common/debug/ge_log.h"
#include "graph/debug/ge_log.h"
#include "graph/ge_error_codes.h"


+ 0
- 2
src/common/graph/debug/graph_debug.cc View File

@@ -15,12 +15,10 @@
*/

#include "graph/debug/graph_debug.h"

#include <algorithm>
#include <unordered_set>
#include <vector>
#include "debug/ge_util.h"

#include "framework/common/debug/ge_log.h"

#define TAB " "


+ 0
- 2
src/common/graph/debug/graph_debug.h View File

@@ -16,13 +16,11 @@

#ifndef COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_
#define COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_

#include <cstdint>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>

#include "external/graph/graph.h"
#include "./ge_error_codes.h"
#include "graph/compute_graph.h"


+ 0
- 2
src/common/graph/detail/attributes_holder.cc View File

@@ -15,9 +15,7 @@
*/

#include "detail/attributes_holder.h"

#include <map>

#include "debug/ge_log.h"
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"


+ 18
- 17
src/common/graph/format_refiner.cc View File

@@ -14,14 +14,12 @@
* limitations under the License.
*/

#include "graph/format_refiner.h"

#include "format_refiner.h"
#include <deque>
#include <iostream>
#include <set>
#include <unordered_map>
#include <unordered_set>

#include "./compute_graph.h"
#include "./ge_error_codes.h"
#include "./graph/ge_tensor.h"
@@ -57,6 +55,7 @@ graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) {
}
return GRAPH_SUCCESS;
}

graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points,
std::vector<ge::NodePtr> &data_nodes,
std::unordered_map<ge::NodePtr, bool> &node_status) {
@@ -82,10 +81,10 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std
// consider special node save process
// get all input desc format
bool node_is_all_nd = false;
for (uint32_t i = 0; i < static_cast<uint32_t>(op_desc->GetInputsSize()); i++) {
auto input_desc = op_desc->GetInputDesc(i);
auto input_size = static_cast<uint32_t>(op_desc->GetInputsSize());
for (uint32_t i = 0; i < input_size; i++) {
// Operator pre-set format but not origin format
auto input_format = input_desc.GetFormat();
auto input_format = op_desc->MutableInputDesc(i)->GetFormat();
// Pre-save data node and default infer fail
if (node_ptr->GetType() == DATA) {
data_nodes.push_back(node_ptr);
@@ -95,9 +94,9 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std
}
}
// Get all output desc format
for (uint32_t i = 0; i < static_cast<uint32_t>(op_desc->GetOutputsSize()); i++) {
GeTensorDesc output_desc = op_desc->GetOutputDesc(i);
auto output_format = output_desc.GetFormat();
auto output_size = static_cast<uint32_t>(op_desc->GetOutputsSize());
for (uint32_t i = 0; i < output_size; i++) {
auto output_format = op_desc->MutableOutputDesc(i)->GetFormat();
if (output_format != FORMAT_ND && output_format != FORMAT_RESERVED) {
node_is_all_nd = true;
}
@@ -145,7 +144,8 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge::
for (const auto &in_anchor : node->GetAllInDataAnchors()) {
GELOGD("Node is [%s] [B]", (node->GetName()).c_str());
auto in_data_anchor_idx = in_anchor->GetIdx();
auto to_be_set_format = (node->GetOpDesc()->GetInputDesc(in_data_anchor_idx)).GetOriginFormat();
auto to_be_set_format =
node->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx))->GetOriginFormat();
if (to_be_set_format == FORMAT_ND) {
GELOGD("Node [%s] [B], format is ND", (node->GetName()).c_str());
continue;
@@ -162,7 +162,7 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge::
}
// Check format whether have been set
int idx = peer_out_data_anchor->GetIdx();
auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(idx);
auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(idx));
if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) {
auto dim_num = ge_tensor_desc.GetShape().GetDimNum();
if (dim_num == 0) {
@@ -182,7 +182,7 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge::

ge_tensor_desc.SetOriginFormat(to_be_set_format);
ge_tensor_desc.SetFormat(to_be_set_format);
(void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(idx, ge_tensor_desc);
(void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(idx), ge_tensor_desc);

// Call operator infer format api (forward) to get out format
GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str());
@@ -205,7 +205,8 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g
GELOGD("Node is [%s] [F]", (node->GetName()).c_str());
GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue);
auto out_data_anchor_idx = out_data_anchor->GetIdx();
auto to_be_set_format = (node->GetOpDesc()->GetOutputDesc(out_data_anchor_idx)).GetOriginFormat();
auto to_be_set_format =
node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(out_data_anchor_idx))->GetOriginFormat();
if (to_be_set_format == FORMAT_ND) {
GELOGD("Node [%s] format is ND.[F]", (node->GetName()).c_str());
continue;
@@ -222,7 +223,7 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g
}
// Check format whether have been set
int idx = peer_in_data_anchor->GetIdx();
auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(idx);
auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(idx));
if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) {
auto dim_num = ge_tensor_desc.GetShape().GetDimNum();
if (dim_num == 0) {
@@ -285,9 +286,9 @@ void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer =
graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format,
std::unordered_map<ge::NodePtr, bool> &node_status) {
bool is_internal_format = TypeUtils::IsInternalFormat(data_format);
bool need_process = ((!is_first_infer) && (is_internal_format == false) && (data_format != FORMAT_ND));
bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND);
if (!need_process) {
GELOGI("no necessary to do DataNodeFormatProcess.IsFirstInfer: %d, data_format:%s", is_first_infer,
GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer,
TypeUtils::FormatToSerialString(data_format).c_str());
return GRAPH_SUCCESS;
}
@@ -378,9 +379,9 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph)
/// Notice: ignore 5D formats
auto data_format = graph->GetDataFormat();
status = DataNodeFormatProcess(data_nodes, data_format, node_status);

// Set infer flag to false
SetInferOrigineFormatFlag(false);

return status;
}
} // namespace ge

+ 248
- 33
src/common/graph/ge_attr_define.cc View File

@@ -42,6 +42,8 @@ const std::string ATTR_NAME_BIAS = "bias";

const std::string ATTR_NAME_BIAS_TERM = "bias_term";

const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value";

const std::string ATTR_NAME_PAD = "pad";

const std::string ATTR_NAME_PADS = "pad";
@@ -83,6 +85,7 @@ const std::string ATTR_NAME_LRN_BETA = "lrn_beta";
const std::string ATTR_NAME_AXIS = "axis";
const std::string ATTR_NAME_BROADCAST = "broadcast";

const std::string ATTR_NAME_OUTPUT = "output";
const std::string ATTR_NAME_OUTPUT_NUM = "output_num";
const std::string ATTR_NAME_TIDX = "t_idx";

@@ -103,6 +106,13 @@ const std::string ATTR_NAME_TSHAPE = "Tshape";
const std::string ATTR_NAME_NAN_OPT = "nan_opt";

const std::string ATTR_NAME_AIPP = "aipp";
const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp";

const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id";

const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist";
const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size";
const std::string ATTR_MODEL_BATCH_NUM = "batch_num";

const std::string ATTR_NAME_INPUT_FORMAT = "input_format";
const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format";
@@ -111,6 +121,7 @@ const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def";
const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def";
const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type";
const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def";
const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type";

const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc";
const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc";
@@ -122,15 +133,11 @@ const std::string ATTR_NAME_WEIGHTS = "value";
const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data";
const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt";
const std::string ATTR_NAME_DIM_ALIGN = "dim_align";
const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type";

const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id";

const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start";
const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size";
const std::string ATTR_MODEL_BATCH_NUM = "batch_num";
const std::string ATTR_NAME_STREAM_LABEL = "_stream_label";
const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event";
const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id";
const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start";
const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size";

// To be deleted
const std::string ATTR_TO_BE_DELETED = "to_be_deleted";
@@ -144,15 +151,13 @@ const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion";
const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num";
const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion";

const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag";

// Refinedet
const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion";
const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion";
const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion";
const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num";
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance";
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num";
const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion";
const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag";

// _Arg
const std::string ATTR_NAME_INDEX = "index";
@@ -242,6 +247,30 @@ const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean";
const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance";
const std::string BATCHNORM_ATTR_SCALE = "scale";
const std::string BATCHNORM_ATTR_BIAS = "bias";
const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format";
const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training";
const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion";

// huberloss
const std::string HUBER_LOSS_ATTR_DELTA = "delta";

// SSDRealDivTileMul
const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara";

// SSDSumMulRealDivMean
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices";
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis";
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para";
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum";

// ConcatFive2Four
// ConcatFour2Five
const std::string SSD_BOX_TYPE_NUM = "box_type_num";
const std::string SSD_CLASS_NUM = "class_num";
const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode";
const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size";
const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high";
const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width";

// Scale
const std::string SCALE_ATTR_SCALE = "scale";
@@ -346,6 +375,7 @@ const std::string SOFTMAX_ATTR_AXIS = "axis";

// Permute
const std::string PERMUTE_ATTR_ORDER = "order";
const std::string PERMUTE_ATTR_PERM = "perm";

// SSD Normalize
const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial";
@@ -373,6 +403,10 @@ const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num";
const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance";
const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num";

// RefinedetDetectionOutput
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num";
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance";

// PRelu
const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared";

@@ -386,11 +420,16 @@ const std::string POWER_ATTR_NAME_POWER = "power";
const std::string POWER_ATTR_NAME_SCALE = "scale";
const std::string POWER_ATTR_NAME_SHIFT = "shift";

// log
const std::string LOG_ATTR_NAME_SCALE = "scale";
const std::string LOG_ATTR_NAME_SHIFT = "shift";
const std::string LOG_ATTR_NAME_BASE = "base";
// Pack
const std::string PACK_ATTR_NAME_NUM = "N";

// Unpack
const std::string UNPACK_ATTR_NAME_NUM = "num";
const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_";
// Gathernd
const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices";
const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams";
@@ -400,6 +439,13 @@ const std::string ARGMAX_ATTR_NAME_TOPK = "topk";
const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size";
const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride";
const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval";
const std::string ARGMAX_ATTR_NAME_AXIS = "axis";
const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type";
const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims";

// upsample
const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h";
const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w";

// Relu
const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope";
@@ -416,6 +462,7 @@ const std::string SPLIT_ATTR_NAME_NUM_SPLIT = "num_split";
const std::string TVM_ATTR_NAME_MAGIC = "tvm_magic";
const std::string TVM_ATTR_NAME_BLOCKDIM = "tvm_blockdim";
const std::string TVM_ATTR_NAME_METADATA = "tvm_metadata";
const std::string TVM_ATTR_NAME_WORKSPACE_TYPE = "tvm_workspace_type";

// Squeeze
const std::string SQUEEZE_ATTR_AXIS = "axis";
@@ -438,6 +485,7 @@ const std::string ROIALIGN_ATTR_SPATIAL_SCALE = "spatial_scale";
const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio";
const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h";
const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w";
const std::string ROIALIGN_ATTR_NAME_TF = "roialign_tf";

// Generate_rpn_proposal
const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk";
@@ -536,19 +584,42 @@ const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape
const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape";

// Rnn
const std::string RNN_MODE_ = "rnn_";
const std::string CNN_RNN = "cnn_rnn";
const std::string RNN_TENSORFLOW = "rnn_tensorflow";
const std::string RNN_MODE_STATIC = "rnn_static";
const std::string MUTI_RNN = "multi_rnn";
const std::string CNN_RNN = "cnn_rnn";
const std::string RNN_MODE_ = "rnn_";

const std::string CELL_MODE = "mode";
const std::string LSTM_CELL = "lstm_cell";
const std::string GRU_CELL = "gru_cell";
const std::string RNN_HT = "ht";
const std::string RNN_XT_HT = "xt_ht";
const std::string RNN_BATCH_SIZE = "batch_size";
const std::string LSTM_CELL_CLIP = "lstm_cell_clip";
const std::string LSTM_PROJ_CLIP = "lstm_proj_clip";
const std::string LSTM_ACTIVATE = "lstm_activate";
const std::string LSTM_OUT_MAP = "lstm_out_map";
const std::string LSTM_OUT_MODE = "lstm_out_mode";
const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode";
const std::string LSTM_TIME_MAJOR = "lstm_time_major";
const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process";

// Upsample
const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale";

// PadV2
const std::string PADV2_ATTR_NAME_MODE = "mode";
const std::string PADV2_ATTR_NAME_PADS = "paddings";
const std::string PADV2_ATTR_NAME_T = "T";
const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format";
const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value";

// MirrorPad
const std::string MIRRORPAD_ATTR_NAME_MODE = "mode";
const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings";
const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format";
const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value";
// Filler
const std::string FILLER_TYPE = "filler_type";
const std::string FILLER_VALUE = "filler_value";
@@ -559,9 +630,6 @@ const std::string SHUFFLE_CHANNEL_GROUP = "group";
// TopKV2
const std::string TOPKV2_ATTR_K = "k";

const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size";
const std::string L2_NORMALIZE_ATTR_EPS = "eps";

// Calibaration
const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX";
const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX";
@@ -616,6 +684,8 @@ const std::string ATTR_MODEL_STREAM_NUM = "stream_num";

const std::string ATTR_MODEL_EVENT_NUM = "event_num";

const std::string ATTR_MODEL_LABEL_NUM = "label_num";

const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size";

const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size";
@@ -630,6 +700,8 @@ const std::string ATTR_MODEL_VAR_SIZE = "variable_size";

const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name";

const std::string ATTR_MODEL_CORE_TYPE = "core_type";

// Public attribute
const std::string ATTR_NAME_IMPLY_TYPE = "imply_type";

@@ -661,17 +733,145 @@ const std::string TARGET_TYPE_TINY = "TINY";

const std::string TARGET_TYPE_LITE = "LITE";

// l2_normalize
const std::string L2_NORMALIZE_ATTR_AXIS = "axis";
const std::string L2_NORMALIZE_ATTR_EPS = "eps";

const std::string POOL_PARAMA_ATTR_WINDOW = "window";
const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode";
const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode";
const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling";
const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt";
const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode";

// HCOM
const std::string HCOM_ATTR_ROOT_RANK = "root_rank";
const std::string HCOM_ATTR_RANK_SIZE = "rank_size";

const std::string HCOM_ATTR_REDUCE_TYPE = "reduction";
const std::string HCOM_ATTR_GROUP = "group";
const std::string HCOM_ATTR_SR_TAG = "sr_tag";
const std::string HCOM_ATTR_SRC_RANK = "src_rank";
const std::string HCOM_ATTR_DEST_RANK = "dest_rank";
const std::string HCOM_ATTR_FUSION = "fusion";
const std::string HCOM_ATTR_SHAPE = "shape";
const std::string HCOM_ATTR_DATA_TYPE = "dtype";

// SpaceToDepth/DepthToSpace
const std::string ATTR_NAME_BLOCK_SIZE = "block_size";

// SparseSoftmaxCrossEntropyWithLogits
const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels";

// MaxPoolGradWithArgmax
const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape";

// AvgPoolGrad
const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape";

// Pad
const std::string ATTR_PAD_FORMAT = "attr_pad_format";

// Varible
const std::string VAR_ATTR_FORMAT = "_var_format";
const std::string VAR_ATTR_NAME = "var_name";
const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ";
const std::string VAR_ATTR_4D_FORMAT = "4D";
const std::string VAR_ATTR_5D_FORMAT = "5D";
const std::string VAR_ATTR_DATA_TYPE = "data_format";
const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name";
const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index";
const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index";
const std::string VAR_ATTR_SHAPE = "shape";
const std::string HALF_VAR_NAME_END = "_fp16";
const std::string VAR_ATTR_INITED = "var_is_inited";

const std::string VAR_ATTR_CONTAINER = "container";
const std::string VAR_ATTR_SHARED_NAME = "shared_name";
const std::string VAR_ATTR_DTYPE = "dtype";

const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name";
const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save";
const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore";
const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast";
const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name";
const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index";

// Assign
const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape";

// space2bacth batch2space
const std::string BATCH_SPACE_ATTR_BLOCK = "block";
const std::string BATCH_SPACE_ATTR_PADDING = "padding";

// depth_to_space space_to_depth
const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size";

// FakeQuantWithMinMaxVars
const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max";
const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min";

// mobilenet_ssd_conv_fusion
const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion";
const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion";
const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num";

// lsh project
const std::string LSH_PROJ_TYPE = "lsh_project_type";

// log time stamp
const std::string LOG_TIME_STAMP_LOGID = "logid";
const std::string LOG_TIME_STAMP_NOTIFY = "notify";

// ShapeN
const std::string SHAPEN_ATTR_N = "N";
const std::string SHAPEN_ATTR_IN_TYPE = "in_type";
const std::string SHAPEN_ATTR_OUT_TYPE = "dtype";

// GatherV2 attr def
const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis";
const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices";
const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams";

// Reshape attr def
const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape";
const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape";

// axis attr def
const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op";

const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse";

const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format";
const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype";

// For constant folding
const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding";

const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input";

const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output";

const std::string ATTR_NAME_REFERENCE = "reference";

const std::string ATTR_NAME_NOTASK = "_no_task";

const std::string ATTR_NAME_OUTPUT_REUSE_INPUT = "_output_reuse_input";

const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX = "_reuse_input_on_dim_index";

const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT = "_no_padding_continuous_input";

const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT = "_no_padding_continuous_output";

const std::string ATTR_NAME_ATOMIC_INDEX = "atomic_index";

// Used for mark the active label list stream of activated node
const std::string ATTR_NAME_ACTIVE_LABEL_LIST = "_active_label_list";

// Used for l2cache, true: the memory of all inputs is used for the last time.
const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE = "is_end_of_inputmem_lifecycle";

// Multi batch
const std::string ATTR_NAME_PRED_VALUE = "_pred_value";
const std::string ATTR_NAME_BATCH_NUM = "_batch_num";
@@ -682,6 +882,8 @@ const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition";
const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream";
const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list";
const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value";
const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop";
const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node";

const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label";
const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag";
@@ -691,6 +893,9 @@ const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG = "_cyclic_dependence_flag";

const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node";

// Function Op
const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index";

// Used for mark the active node is for loop, type:bool
const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active";

@@ -702,6 +907,20 @@ const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace";

const std::string MODEL_ATTR_SESSION_ID = "session_id";

// l1 fusion and other fusion in future
const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id";
const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key";
const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op";
const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type";
const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST = "_input_memory_type";
const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST = "_output_memory_type";
const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR = "_l1_fusion_extend_content";
const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE = "_tensor_actual_size";
const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1_fuison";
const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion";
const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split";
const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed";

// Atomic addr clean attrs
const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index";
const std::string ATOMIC_ATTR_OUTPUT_INDEX = "atomic_output_index";
@@ -722,6 +941,9 @@ const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node";
// For inserted op
const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge";

// For compress weight
const std::string ATTR_NAME_COMPRESS_WEIGHT = "_is_compress_weight";

// For data dump
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES = "_datadump_original_op_names";
const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP = "_datadump_is_multiop";
@@ -732,24 +954,17 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_ou
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format";
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type";

// Variable
const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name";
const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name";
const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index";
const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast";
const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore";

// HCOM
const std::string HCOM_ATTR_ROOT_RANK = "root_rank";
const std::string HCOM_ATTR_RANK_SIZE = "rank_size";
const std::string HCOM_ATTR_SHAPE = "shape";
const std::string HCOM_ATTR_DATA_TYPE = "dtype";
// functional ops attr
const std::string ATTR_NAME_TCOND = "Tcond";
const std::string ATTR_NAME_TIN = "Tin";
const std::string ATTR_NAME_TOUT = "Tout";
const std::string ATTR_NAME_THEN_BRANCH = "then_branch";
const std::string ATTR_NAME_ELSE_BRANCH = "else_branch";

const std::string HCOM_ATTR_REDUCE_TYPE = "reduction";
// used for label switch
const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index";
const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list";

const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype";
const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype";

// Dynamic stitch
const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_";
} // namespace ge

+ 42
- 7
src/common/graph/ge_attr_value.cc View File

@@ -22,7 +22,7 @@
#include "graph/model_serialize.h"
#include "proto/ge_ir.pb.h"
#include "detail/model_serialize_imp.h"
#include "graph/debug/ge_attr_define.h"
#include "debug/ge_attr_define.h"
#include "debug/ge_log.h"
#include "debug/ge_util.h"

@@ -53,7 +53,7 @@ string GeAttrValue::NamedAttrs::GetName() const {

GeAttrValue GeAttrValue::NamedAttrs::GetItem(const string &key) const {
GeAttrValue value;
(void)GetAttr(key, value);
GetAttr(key, value);
return value;
}

@@ -1081,6 +1081,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstA
if (!GetListInt(std::move(obj), name, int64_list)) {
return false;
}

for (size_t i = 0; i < int64_list.size(); ++i) {
if (int64_list[i] > INT32_MAX) {
GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]);
@@ -1098,6 +1099,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstA
if (!GetListInt(std::move(obj), name, int64_list)) {
return false;
}

for (size_t i = 0; i < int64_list.size(); ++i) {
if (int64_list[i] > UINT32_MAX) {
GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]);
@@ -1215,6 +1217,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(
GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed");
op_desc->extAttrs_ = org_op_desc->extAttrs_;

if (op_desc->HasAttr("_input_name_idx_key")) {
if (op_desc->DelAttr("_input_name_idx_key") != SUCCESS) {
GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_key failed.");
}
}

if (op_desc->HasAttr("_input_name_idx_value")) {
if (op_desc->DelAttr("_input_name_idx_value") != SUCCESS) {
GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_value failed.");
}
}

if (op_desc->HasAttr("_opt_input")) {
if (op_desc->DelAttr("_opt_input") != SUCCESS) {
GELOGE(GRAPH_FAILED, "DelAttr _opt_input failed.");
}
}
return op_desc;
}

@@ -1237,11 +1256,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c

op_desc->extAttrs_ = org_op_desc->extAttrs_;

op_desc->input_name_idx_.insert(org_op_desc->input_name_idx_.begin(), org_op_desc->input_name_idx_.end());
op_desc->optional_input_names_.insert(org_op_desc->optional_input_names_.begin(),
org_op_desc->optional_input_names_.end());
op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end());

op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end());

op_desc->infer_func_ = org_op_desc->infer_func_;
@@ -1250,4 +1264,25 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c

return op_desc;
}
std::string AttrUtils::GetAllAttrsStr(AttrUtils::ConstAttrHolderAdapter &&obj) {
auto holder = obj.get();
if (holder == nullptr) {
return "";
}
auto attrs_map = holder->GetAttrMap();
if (attrs_map.GetProtoMsg() == nullptr) {
return "";
}

std::map<std::string, std::string> ordered_attrs;
for (auto &attr : *(attrs_map.GetProtoMsg())) {
ordered_attrs[attr.first] = attr.second.SerializeAsString();
}

std::stringstream ss;
for (auto &attr : ordered_attrs) {
ss << attr.first << ":" << attr.second << ";";
}
return ss.str();
}
} // namespace ge

+ 31
- 3
src/common/graph/ge_tensor.cc View File

@@ -163,6 +163,34 @@ int64_t GeShape::GetShapeSize() const {
return res;
}

///
/// @brief Check is unknown shape
/// @return bool
/// ///
bool GeShape::IsUnknownShape() const {
auto proto_msg = shape_def_.GetProtoMsg();
if (proto_msg != nullptr) {
for (auto i : proto_msg->dim()) {
if (i < 0) {
return true;
}
}
}
return false;
}

///
/// @brief Check is a scalar
/// @return bool
///
bool GeShape::IsScalar() const {
auto proto_msg = shape_def_.GetProtoMsg();
if (proto_msg != nullptr) {
return proto_msg->dim().empty();
}
return false;
}

const string TENSOR_UTILS_SIZE = "size";
const string TENSOR_UTILS_WEIGHT_SIZE = "weight_size";
const string TENSOR_UTILS_REUSE_INPUT = "reuse_input";
@@ -639,14 +667,14 @@ GeTensor &GeTensor::operator=(const GeTensor &other) {
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc,
uint32_t &size) {
int64_t &size) {
auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
GE_CHECK_NOTNULL(tensor_descriptor_msg);
size = static_cast<uint32_t>(tensor_descriptor_msg->size());
size = static_cast<int64_t>(tensor_descriptor_msg->size());
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, uint32_t size) {
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, int64_t size) {
auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg();
if (tensor_descriptor_msg != nullptr) {
tensor_descriptor_msg->set_size(size);


+ 13
- 4
src/common/graph/model.cc View File

@@ -49,6 +49,7 @@ void Model::Init() {
(void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0);
(void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI);
version_ = 0;
@@ -77,9 +78,9 @@ void Model::SetGraph(const ge::Graph &graph) { graph_ = graph; }

Graph Model::GetGraph() const { return graph_; }

graphStatus Model::Save(Buffer &buffer) const {
graphStatus Model::Save(Buffer &buffer, bool is_dump) const {
ModelSerialize serialize;
buffer = serialize.SerializeModel(*this);
buffer = serialize.SerializeModel(*this, is_dump);
return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED;
}

@@ -113,7 +114,7 @@ graphStatus Model::SaveToFile(const string &file_name) const {
}
int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS);
if (fd < 0) {
GELOGE(GRAPH_FAILED, "open file failed, file path [%s] ", real_path);
GELOGE(GRAPH_FAILED, "open file failed, file path [%s], %s ", real_path, strerror(errno));
return GRAPH_FAILED;
}
bool ret = ge_proto.SerializeToFileDescriptor(fd);
@@ -129,6 +130,10 @@ graphStatus Model::SaveToFile(const string &file_name) const {
GELOGE(GRAPH_FAILED, "close file descriptor fail.");
return GRAPH_FAILED;
}
if (!ret) {
GELOGE(GRAPH_FAILED, "function [SerializeToFileDescriptor] failed");
return GRAPH_FAILED;
}
}
return GRAPH_SUCCESS;
}
@@ -152,7 +157,7 @@ graphStatus Model::LoadFromFile(const string &file_name) {
}
int fd = open(real_path, O_RDONLY);
if (fd < 0) {
GELOGE(GRAPH_FAILED, "open file failed");
GELOGE(GRAPH_FAILED, "open file failed, %s", strerror(errno));
return GRAPH_FAILED;
}

@@ -170,6 +175,10 @@ graphStatus Model::LoadFromFile(const string &file_name) {
GELOGE(GRAPH_FAILED, "close file descriptor fail.");
return GRAPH_FAILED;
}
if (!ret) {
GELOGE(GRAPH_FAILED, "function [ParseFromFileDescriptor] failed");
return GRAPH_FAILED;
}
return Load(model_def);
}



+ 23
- 14
src/common/graph/model_serialize.cc View File

@@ -15,10 +15,8 @@
*/

#include "graph/model_serialize.h"

#include <google/protobuf/text_format.h>
#include <iostream>

#include "debug/ge_attr_define.h"
#include "debug/ge_log.h"
#include "debug/ge_util.h"
@@ -26,6 +24,7 @@
#include "graph/detail/model_serialize_imp.h"
#include "proto/ge_ir.pb.h"
#include "utils/graph_utils.h"
#include "debug/ge_op_types.h"

using std::string;

@@ -84,20 +83,29 @@ bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_
return true;
}

bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) {
bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) {
if (op_desc == nullptr || op_def_proto == nullptr) {
GELOGE(GRAPH_FAILED, "Input Para Invalid");
return false;
}
if (op_desc->op_def_.GetProtoMsg() != nullptr) {
*op_def_proto = *op_desc->op_def_.GetProtoMsg();
// Delete unnecessary attr
if (is_dump) {
auto attr = op_def_proto->mutable_attr();
attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF);
attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF);
attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF);
GE_IF_BOOL_EXEC((op_def_proto->type() == CONSTANT || op_def_proto->type() == CONSTANTOP),
attr->erase(ATTR_NAME_WEIGHTS));
}
op_def_proto->clear_input_desc();
op_def_proto->clear_output_desc();
// Input descs
if (op_desc->GetInputsSize() > 0) {
auto size = static_cast<uint32_t>(op_desc->GetInputsSize());
if (op_desc->GetAllInputsSize() > 0) {
auto size = static_cast<uint32_t>(op_desc->GetAllInputsSize());
for (uint32_t i = 0; i < size; i++) {
auto tensor_desc = op_desc->GetInputDescPtr(i);
auto tensor_desc = op_desc->GetInputDescPtrDfault(i);
if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) {
*op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg());
}
@@ -117,12 +125,12 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op
return true;
}

bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto) {
bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) {
if (node == nullptr || op_def_proto == nullptr) {
GELOGE(GRAPH_FAILED, "Input Para Node Invalid");
return false;
}
if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto)) {
if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, is_dump)) {
GELOGE(GRAPH_FAILED, "Serialize OpDesc failed");
return false;
}
@@ -134,7 +142,8 @@ bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph,
proto::GraphDef *graph_proto) {
proto::GraphDef *graph_proto,
bool is_dump) {
if (graph == nullptr || graph_proto == nullptr) {
GELOGE(GRAPH_FAILED, "Input para Invalid");
return false;
@@ -156,7 +165,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::Serialize
*graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg();
}
for (const auto &node : graph->GetDirectNode()) {
if (!SerializeNode(node, graph_proto->add_op())) {
if (!SerializeNode(node, graph_proto->add_op(), is_dump)) {
if (node->GetOpDesc() != nullptr) {
GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str());
}
@@ -166,7 +175,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::Serialize
return true;
}

bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto) {
bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto, bool is_dump) {
if (model_proto == nullptr) {
GELOGE(GRAPH_FAILED, "model_proto para Invalid");
return false;
@@ -183,7 +192,7 @@ bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *mode
GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr");
return false;
}
if (!SerializeGraph(compute_graph, model_proto->add_graph())) {
if (!SerializeGraph(compute_graph, model_proto->add_graph(), is_dump)) {
GELOGE(GRAPH_FAILED, "SerializeGraph fail");
return false;
}
@@ -390,10 +399,10 @@ bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf::
return true;
}

Buffer ModelSerialize::SerializeModel(const Model &model) {
Buffer ModelSerialize::SerializeModel(const Model &model, bool is_dump) {
proto::ModelDef model_def;
ModelSerializeImp imp;
if (!imp.SerializeModel(model, &model_def)) {
if (!imp.SerializeModel(model, &model_def, is_dump)) {
return Buffer();
}
#if !defined(__ANDROID__) && !defined(ANDROID)


+ 4
- 3
src/common/graph/node.cc View File

@@ -401,7 +401,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::Get
vec.push_back(in_anchor);
}
}
// Push back in_control_anchor_
// Push back in_control_anchor_
if ((in_control_anchor_->GetPeerOutControlAnchors().size() > 0) ||
(in_control_anchor_->GetPeerOutDataAnchors().size() > 0)) {
auto in_anchor = Anchor::DynamicAnchorCast<Anchor>(in_control_anchor_);
@@ -512,7 +512,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetIn

auto peer_out_anchors = in_control_anchor_->GetPeerOutDataAnchors();
for (const auto &out_anchor : peer_out_anchors) {
GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, " in_control_anchor_ peer out data anchors is nullptr");
GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "in_control_anchor_ peer out data anchors is nullptr");
auto node = out_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
@@ -521,7 +521,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetIn
auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors();
for (const auto &out_control_anchor : peer_out_control_anchors) {
GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue,
" in_control_anchor_ peer out control anchors is nullptr");
"in_control_anchor_ peer out control anchors is nullptr");
auto node = out_control_anchor->GetOwnerNode();
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr");
vec.push_back(node);
@@ -785,6 +785,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::UpdateOpDesc(co
GE_CHK_BOOL_EXEC(op_->GetInputsSize() == op_desc->GetInputsSize(), return GRAPH_PARAM_INVALID,
"Inputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetInputsSize(),
op_desc->GetInputsSize());

GE_CHK_BOOL_EXEC(op_->GetOutputsSize() == op_desc->GetOutputsSize(), return GRAPH_PARAM_INVALID,
"Outputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetOutputsSize(),
op_desc->GetOutputsSize());


+ 158
- 35
src/common/graph/op_desc.cc View File

@@ -61,6 +61,12 @@ const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes";

const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const";

const std::string ATTR_NAME_OPT_INPUT = "_opt_input";

const std::string ATTR_NAME_INPUT_NAME_IDX_KEY = "_input_name_idx_key";

const std::string ATTR_NAME_INPUT_NAME_IDX_VALUE = "_input_name_idx_value";

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() {
op_def_.InitDefault();
if (op_def_.GetProtoMsg() != nullptr) {
@@ -202,7 +208,8 @@ graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_d
}

graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) {
if (input_name_idx_.find(name) != input_name_idx_.end()) {
auto input_name_idx = GetAllInputName();
if (input_name_idx.find(name) != input_name_idx.end()) {
GELOGI("input %s is exist, update it", name.c_str());
graphStatus ret = UpdateInputDesc(name, input_desc);
return ret;
@@ -214,15 +221,17 @@ graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &inp
return GRAPH_FAILED;
}
inputs_desc_.push_back(in_desc);
(void)input_name_idx_.insert(make_pair(name, index));
(void)input_name_idx.insert(make_pair(name, index));
SetAllInputName(input_name_idx);
return GRAPH_SUCCESS;
}
}

graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) {
auto input_name_idx = GetAllInputName();
for (unsigned int i = 0; i < num; i++) {
string input_name = name + std::to_string(i);
GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED,
GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED,
"Add input tensor_desc is existed. name[%s]", input_name.c_str());

std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc());
@@ -234,12 +243,13 @@ graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int n
(void)inputs_desc_.insert(inputs_desc_.begin(), in_desc);

// Update index in input_name_idx
for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) {
for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) {
it->second += 1;
}

(void)input_name_idx_.insert(make_pair(input_name, 0));
(void)input_name_idx.insert(make_pair(input_name, 0));
}
SetAllInputName(input_name_idx);

return GRAPH_SUCCESS;
}
@@ -270,10 +280,19 @@ graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int

graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) {
if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED;
(void)optional_input_names_.insert(name);
vector<string> optional_input_names;
(void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names);
optional_input_names.push_back(name);
(void)AttrUtils::SetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names);
return GRAPH_SUCCESS;
}

std::vector<string> OpDesc::GetAllOptionalInputName() const {
vector<string> optional_input_names;
(void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names);
return optional_input_names;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) {
GE_CHK_BOOL_RET_STATUS((index < inputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index);
@@ -288,11 +307,12 @@ OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) {
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const {
return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") &&
IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") &&
IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") &&
IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") &&
IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_"));
return (
IsEqual(this->GetAllInputName(), r_op_desc.GetAllInputName(), "OpDesc.GetAllInputName()") &&
IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") &&
IsEqual(this->GetAllOptionalInputName(), r_op_desc.GetAllOptionalInputName(), "OpDesc.GetAllOptionalInputName()") &&
IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") &&
IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_"));
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const {
@@ -366,8 +386,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpD
}

graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) {
auto it = input_name_idx_.find(name);
if (it == input_name_idx_.end()) {
auto input_name_idx = GetAllInputName();
auto it = input_name_idx.find(name);
if (it == input_name_idx.end()) {
GELOGW("Cann't find the input desc. name[%s]", name.c_str());
return GRAPH_FAILED;
}
@@ -387,8 +408,9 @@ graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &
}

bool OpDesc::InputIsSet(const string &name) const {
auto it = input_name_idx_.find(name);
if (it != input_name_idx_.end()) {
auto input_name_idx = GetAllInputName();
auto it = input_name_idx.find(name);
if (it != input_name_idx.end()) {
GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false);
auto tensor_desc = inputs_desc_[it->second];
GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false);
@@ -406,18 +428,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc
}

GeTensorDesc OpDesc::GetInputDesc(const string &name) const {
auto it = input_name_idx_.find(name);
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc());
auto input_name_idx = GetAllInputName();
auto it = input_name_idx.find(name);
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), GeTensorDesc());
GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc());
return *(inputs_desc_[it->second].get());
}

GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const {
auto input_name_idx = GetAllInputName();
vector<string> names;
if (input_name_idx_.empty()) {
if (input_name_idx.empty()) {
return OpDesc::Vistor<string>(shared_from_this(), names);
}
for (std::pair<string, uint32_t> input : input_name_idx_) {
for (std::pair<string, uint32_t> input : input_name_idx) {
names.push_back(input.first);
}
return OpDesc::Vistor<string>(shared_from_this(), names);
@@ -483,6 +507,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetInputsSize() co
return size;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetAllInputsSize() const { return inputs_desc_.size(); }

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) {
int index = static_cast<int>(outputs_desc_.size());
return AddOutputDesc("__output" + std::to_string(index), output_desc);
@@ -548,6 +574,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOu
return outputs_desc_[index];
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const {
return static_cast<uint32_t>(outputs_desc_.size());
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllOutputsDesc() const {
vector<GeTensorDesc> temp{};
for (const auto &it : outputs_desc_) {
@@ -580,6 +610,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetI
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr
OpDesc::GetInputDescPtrDfault(uint32_t index) const {
GE_CHK_BOOL_RET_STATUS_NOLOG((index) < (uint32_t)(inputs_desc_.size()), nullptr);
return inputs_desc_[(int32_t)index];
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const {
auto input_name_idx = GetAllInputName();
auto it = input_name_idx.find(name);
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), shared_ptr<const GeTensorDesc>());
return inputs_desc_[it->second];
}

graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) {
if (is_push_back) {
for (unsigned int i = 0; i < num; i++) {
@@ -603,12 +646,45 @@ graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int
}

bool OpDesc::IsOptionalInput(const string &name) const {
return optional_input_names_.find(name) != optional_input_names_.end();
vector<string> optional_input_names;
(void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names);
for (auto &item : optional_input_names) {
if (item == name) {
return true;
}
}
return false;
}

bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); }

std::map<string, uint32_t> OpDesc::GetAllInputName() { return input_name_idx_; }
std::map<string, uint32_t> OpDesc::GetAllInputName() const {
std::map<string, uint32_t> input_name_idx;
std::vector<string> key;
std::vector<uint32_t> value;
(void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key);
(void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value);

if (key.size() != value.size()) {
GE_LOGE("twe vector size is different. key_size: %zu, value_size: %zu.", key.size(), value.size());
} else {
for (uint32_t i = 0; i < key.size(); ++i) {
input_name_idx.insert(std::pair<string, uint32_t>(key.at(i), value.at(i)));
}
}
return input_name_idx;
}

void OpDesc::SetAllInputName(const std::map<string, uint32_t> &input_name_idx) {
std::vector<string> key;
std::vector<uint32_t> value;
for (auto &item : input_name_idx) {
key.emplace_back(item.first);
value.emplace_back(item.second);
}
(void)AttrUtils::SetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key);
(void)AttrUtils::SetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value);
}

std::map<string, uint32_t> OpDesc::GetAllOutputName() { return output_name_idx_; }

@@ -619,6 +695,7 @@ bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) {
auto factory_map_size = input_name_idx.size();
// It indicates that some inputs have no optionalname.
// The redundant optionalname of factory needs to be deleted and then assigned
auto all_input_name_idx = GetAllInputName();
if (input_map_size < factory_map_size) {
GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size,
factory_map_size);
@@ -631,22 +708,23 @@ bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) {
}
if (input_name_idx.size() == input_map_size) {
GELOGI("UpdateInputName");
input_name_idx_ = input_name_idx;
all_input_name_idx = input_name_idx;
} else {
ret = false;
GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size());
}
} else if (input_map_size == factory_map_size) {
input_name_idx_ = input_name_idx;
all_input_name_idx = input_name_idx;
} else {
ret = false;
GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size);
}
SetAllInputName(all_input_name_idx);
return ret;
}

bool OpDesc::UpdateOutputName(std::map<string, uint32_t> output_name_idx) {
size_t output_map_size = GetAllOutputsDesc().size();
size_t output_map_size = GetAllOutputsDescSize();
size_t factory_map_size = output_name_idx.size();
if (output_map_size < factory_map_size) {
GELOGI("UpdateOutputName org outputname map size: %zu, factory outputname map size: %zu", output_map_size,
@@ -754,17 +832,17 @@ graphStatus OpDesc::OpVerify() {
}

graphStatus OpDesc::CommonVerify() const {
for (string iname : GetAllInputNames()) {
for (const string &iname : GetAllInputNames()) {
// Checking shape of all inputs
vector<int64_t> ishape = GetInputDesc(iname).GetShape().GetDims();
vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims();
for (int64_t dim : ishape) {
GE_CHK_BOOL_RET_STATUS(dim >= -1, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.",
iname.c_str());
}
}
// Check all attributes defined
const auto all_attributes = GetAllAttrs();
for (const auto name : GetAllAttrNames()) {
const auto &all_attributes = GetAllAttrs();
for (const auto &name : GetAllAttrNames()) {
GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED,
"operator attribute %s is empty.", name.c_str());
}
@@ -773,19 +851,21 @@ graphStatus OpDesc::CommonVerify() const {
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const {
auto it = input_name_idx_.begin();
for (; it != input_name_idx_.end(); ++it) {
auto input_name_idx = GetAllInputName();
auto it = input_name_idx.begin();
for (; it != input_name_idx.end(); ++it) {
if (it->second == index) {
break;
}
}
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), "");
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), "");
return it->first;
}

int OpDesc::GetInputIndexByName(const string &name) const {
auto it_find = input_name_idx_.find(name);
GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1);
auto input_name_idx = GetAllInputName();
auto it_find = input_name_idx.find(name);
GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx.end(), -1);
return static_cast<int>(it_find->second);
}

@@ -1065,10 +1145,12 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<bool> OpDesc::GetIsInputCo

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name,
const int &index) {
if (input_name_idx_.find(name) != input_name_idx_.end()) {
auto input_name_idx = GetAllInputName();
if (input_name_idx.find(name) != input_name_idx.end()) {
GELOGI("Restore input name index is existed. name[%s]", name.c_str());
}
(void)input_name_idx_.insert(make_pair(name, index));
(void)input_name_idx.insert(make_pair(name, index));
SetAllInputName(input_name_idx);
return GRAPH_SUCCESS;
}

@@ -1104,4 +1186,45 @@ graphStatus OpDesc::CallInferFormatFunc(Operator &op) {
}
return (graphStatus)infer_format_func_(op);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetSubgraphInstanceName(uint32_t index) const {
if (static_cast<size_t>(index) >= subgraph_instance_names_.size()) {
return "";
}
return subgraph_instance_names_.at(index);
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector<std::string> &OpDesc::GetSubgraphInstanceNames()
const {
return subgraph_instance_names_;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::AddSubgraphInstanceName(std::string name) {
subgraph_instance_names_.emplace_back(std::move(name));
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) {
for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) {
if (*iter == name) {
subgraph_instance_names_.erase(iter);
return;
}
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) {
auto iter = subgraph_names_to_index_.find(name);
if (iter != subgraph_names_to_index_.end()) {
GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second);
return GRAPH_FAILED;
}
auto size = subgraph_names_to_index_.size();
subgraph_names_to_index_[name] = size;
return GRAPH_SUCCESS;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map<std::string, uint32_t> &OpDesc::GetSubgraphNameIndexes()
const {
return subgraph_names_to_index_;
}
} // namespace ge

+ 1
- 2
src/common/graph/op_imp.cc View File

@@ -20,8 +20,7 @@
#include "debug/ge_log.h"
#include "debug/ge_util.h"

using std::function;
using std::vector;
using namespace std;

namespace ge {



+ 31
- 22
src/common/graph/operator.cc View File

@@ -15,13 +15,12 @@
*/

#include "external/graph/operator.h"

#include <stdint.h>
#include <algorithm>
#include <mutex>
#include <queue>
#include <set>
#include "array_ops.h"
#include "debug/ge_log.h"
#include "debug/ge_op_types.h"
#include "debug/ge_util.h"
@@ -33,7 +32,6 @@
#include "graph/ge_tensor.h"
#include "graph/node.h"
#include "graph/op_desc.h"
#include "graph/operator_factory.h"
#include "graph/usr_types.h"
#include "utils/graph_utils.h"
#include "utils/op_desc_utils.h"
@@ -48,10 +46,6 @@ using std::string;
using std::to_string;
using std::vector;

namespace {
const char *const kValue = "value";
} // namespace

namespace ge {
class OpIO {
public:
@@ -148,6 +142,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> {
for (int i = static_cast<int>(is_input_const.size()); i <= dst_index; ++i) {
is_input_const.push_back(false);
}

is_input_const[dst_index] = is_const;
op_desc_->SetIsInputConst(is_input_const);

@@ -179,8 +174,8 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> {
GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(),
op_desc_->GetName().c_str());
auto out_op_impl = out_handler->GetOwner();
GE_CHK_BOOL_EXEC(out_op_impl && out_op_impl->GetOpDescImpl(), return, "out_handler invalid. name[%s]",
dst_name.c_str());
GE_CHK_BOOL_EXEC(out_op_impl != nullptr && out_op_impl->GetOpDescImpl() != nullptr, return,
"out_handler invalid. name[%s]", dst_name.c_str());
bool is_const = false;
if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) {
is_const = true;
@@ -193,7 +188,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> {
op_desc_->SetIsInputConst(is_input_const);

OpIO in_handler(dst_name, dst_index, shared_from_this());
GE_CHK_BOOL_EXEC(!!out_op_impl, return, "Get out_handler's impl failed.");
GE_CHK_BOOL_EXEC(out_op_impl != nullptr, return, "Get out_handler's impl failed.");

out_op_impl->UpdateLinkMapImpl(src_name, in_handler);
auto src_output_desc = out_op_impl->GetOutputDesc(src_name);
@@ -210,7 +205,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> {

void AddControlInputImp(const ge::Operator &src_oprt) {
if (src_oprt.operator_impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "Src operator impl is nullptr");
GELOGE(FAILED, "Src operator impl is nullptr");
return;
}
for (auto &input : control_input_link_) {
@@ -520,9 +515,9 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co
if (peer_node_ptr->GetOpDesc() != nullptr) {
const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType();
if (op_descType == CONSTANTOP) {
return const_op.GetAttr(kValue, data);
return const_op.GetAttr(op::Constant::name_attr_value(), data);
} else if (op_descType == CONSTANT) {
return const_op.GetAttr(kValue, data);
return const_op.GetAttr(op::Const::name_attr_value(), data);
}
}
} else {
@@ -542,9 +537,9 @@ graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data)
Operator const_op(out_handle.GetOwner());
const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType();
if (op_desc_impl_type == CONSTANTOP) {
return const_op.GetAttr(kValue, data);
return const_op.GetAttr(op::Constant::name_attr_value(), data);
} else if (op_desc_impl_type == CONSTANT) {
return const_op.GetAttr(kValue, data);
return const_op.GetAttr(op::Const::name_attr_value(), data);
}
}
return GRAPH_FAILED;
@@ -709,6 +704,7 @@ void Operator::InputRegister(const string &name) {
void Operator::OptionalInputRegister(const string &name) {
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
// [No need to verify return value]
(void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name,
GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED));
}
@@ -716,24 +712,28 @@ void Operator::OptionalInputRegister(const string &name) {
void Operator::InferFuncRegister(const std::function<graphStatus(Operator &)> &func) {
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
// [No need to verify return value]
(void)operator_impl_->GetOpDescImpl()->AddInferFunc(func);
}

void Operator::InferFormatFuncRegister(const std::function<graphStatus(Operator &)> &func) {
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
// [No need to verify return value]
(void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func);
}

void Operator::VerifierFuncRegister(const std::function<graphStatus(Operator &)> &func) {
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
// [No need to verify return value]
(void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func);
}

void Operator::OutputRegister(const string &name) {
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
// [No need to verify return value]
(void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc());
}

@@ -757,7 +757,8 @@ int Operator::GetDynamicInputNum(const string &name) const {
void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) {
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr.");
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr.");
(void)AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num);
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return,
"Set %s int failed", name.c_str());
(void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back);
}

@@ -765,7 +766,8 @@ int Operator::GetDynamicOutputNum(const string &name) const {
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr.");
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr.");
int num = 0;
(void)AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num);
GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return num,
"Get %s int failed", name.c_str());
return num;
}

@@ -1141,7 +1143,9 @@ class GraphBuilderImpl {
GELOGW("Input operator should be Data, Variable operator or operator that has output but no input.");
}
}

GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr,
"User Input do not include operator such as \
Data, Variable operator or operator that has output but no input.");
auto ret = WalkAllOperators(vec_inputs);
GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed.");

@@ -1163,7 +1167,8 @@ class GraphBuilderImpl {
que.pop();
for (const auto &op_impl : vec_tem) {
GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.")
GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue)
GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue,
"This node %s has created.", op_impl->GetName().c_str())
auto node_ptr = graph_->AddNode(op_impl->op_desc_);
GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed.");
all_nodes_info_.insert(std::make_pair(op_impl, node_ptr));
@@ -1202,10 +1207,13 @@ class GraphBuilderImpl {
for (const auto &node_info : all_nodes_info_) {
auto src_op_impl_ptr = node_info.first;
auto src_node_ptr = node_info.second;

GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue);
auto out_links = src_op_impl_ptr->output_links_;
GE_CHK_BOOL_EXEC(src_op_impl_ptr->op_desc_ != nullptr, return GRAPH_FAILED,
"Src operator impl's op_desc is null.");
auto &op_desc = src_op_impl_ptr->op_desc_;

GE_IF_BOOL_EXEC(op_desc == nullptr, continue);
for (const auto &out : out_links) {
auto src_idx = op_desc->GetOutputIndexByName(out.first);
GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed");
@@ -1216,7 +1224,9 @@ class GraphBuilderImpl {
for (const auto &dst_opio : out.second) {
auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner());
GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed.");

GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue);

auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex());
GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed.");

@@ -1260,8 +1270,7 @@ inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) {
ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector<ge::Operator> &inputs) {
auto graph_builder_impl = GraphBuilderImpl(name);
ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs);
GE_IF_BOOL_EXEC(compute_graph == nullptr, return compute_graph);

GE_CHK_BOOL_EXEC(compute_graph != nullptr, return compute_graph, "Computer graph is nullptr");
compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo());
if (HasSameNameNode(compute_graph)) {
GELOGW("Compute do not allow has same name nodes.");


+ 2
- 4
src/common/graph/opsproto/opsproto_manager.cc View File

@@ -15,13 +15,11 @@
*/

#include "graph/opsproto_manager.h"

#include <algorithm>
#include <cstdlib>
#include <algorithm>
#include <functional>
#include <iostream>
#include <sstream>

#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/debug/ge_log.h"
@@ -155,7 +153,7 @@ void OpsProtoManager::LoadOpsProtoPluginSo(std::string &path) {

// Load .so file
for (auto elem : file_list) {
void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL | RTLD_NODELETE);
void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (handle == nullptr) {
GELOGW("OpsProtoManager dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror());
continue;


+ 1
- 1
src/common/graph/option/ge_context.cc View File

@@ -15,7 +15,6 @@
*/

#include "./ge_context.h"

#include "./ge_global_options.h"
#include "./ge_local_context.h"
#include "framework/common/debug/ge_log.h"
@@ -87,4 +86,5 @@ uint32_t GEContext::DeviceId() { return device_id_; }
uint64_t GEContext::TraceId() { return trace_id_; }

void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; }

} // namespace ge

+ 144
- 6
src/common/graph/shape_refiner.cc View File

@@ -22,6 +22,7 @@
#include <utility>
#include <vector>

#include "graph/utils/graph_utils.h"
#include "debug/ge_log.h"
#include "debug/ge_op_types.h"
#include "external/graph/operator.h"
@@ -34,6 +35,122 @@
#include "utils/type_utils.h"

namespace ge {
namespace {
constexpr const char *kRefIndex = "parent_node_index";
graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) {
auto op_desc = node->GetOpDesc();
auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
if (sub_graph_names.empty()) {
return GRAPH_SUCCESS;
}

auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
for (const auto &name : sub_graph_names) {
auto sub_graph = root_graph->GetSubgraph(name);
if (sub_graph == nullptr) {
GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
return GRAPH_FAILED;
}
for (const auto &node_sub : sub_graph->GetDirectNode()) {
if (node_sub->GetType() != DATA) {
continue;
}
int ref_i;
auto data_opdesc = node_sub->GetOpDesc();
if (data_opdesc == nullptr) {
GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
node->GetName().c_str());
return GRAPH_FAILED;
}
if (!AttrUtils::GetInt(node_sub->GetOpDesc(), kRefIndex, ref_i)) {
GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(),
node->GetName().c_str());
return GRAPH_FAILED;
}
auto input_desc = op_desc->MutableInputDesc(ref_i);
if (input_desc == nullptr) {
GE_LOGE(
"The ref index(%d) on the data %s on the sub graph %s "
"parent node %s are incompatible, inputs num %u",
ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize());
return GRAPH_FAILED;
}
auto ret = data_opdesc->UpdateInputDesc(0, *input_desc);
if (ret != GRAPH_SUCCESS) {
GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s",
node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
return ret;
}
ret = data_opdesc->UpdateOutputDesc(0, *input_desc);
if (ret != GRAPH_SUCCESS) {
GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s",
node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
return ret;
}
}
}
return GRAPH_SUCCESS;
}
graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
auto op_desc = node->GetOpDesc();
auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
if (sub_graph_names.empty()) {
return GRAPH_SUCCESS;
}

auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
for (const auto &name : sub_graph_names) {
auto sub_graph = root_graph->GetSubgraph(name);
if (sub_graph == nullptr) {
GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
return GRAPH_FAILED;
}
NodePtr netoutput = nullptr;
auto sub_nodes = sub_graph->GetDirectNode();
for (size_t i = sub_nodes.size(); i > 0; --i) {
auto sub_node = sub_nodes.at(i - 1);
if (sub_node->GetType() == NETOUTPUT) {
netoutput = sub_node;
break;
}
}
if (netoutput == nullptr) {
GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str());
return GRAPH_FAILED;
}
auto netoutput_opdesc = netoutput->GetOpDesc();
if (netoutput_opdesc == nullptr) {
GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(),
node->GetName().c_str());
return GRAPH_FAILED;
}
for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) {
auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx());
if (edge_desc == nullptr) {
GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", name.c_str(),
node->GetName().c_str(), edge_anchor->GetIdx());
return GRAPH_FAILED;
}
int ref_i;
if (!AttrUtils::GetInt(edge_desc, kRefIndex, ref_i)) {
// if there is no ref index on the TensorDesc, it means the output data will be ignored outer.
continue;
}
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i));
if (output_desc == nullptr) {
GE_LOGE(
"The ref index(%d) on the input %d of netoutput %s on the sub graph %s "
"parent node %s are incompatible, outputs num %u",
ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(),
node->GetAllOutDataAnchorsSize());
return GRAPH_FAILED;
}
op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc);
}
}
return GRAPH_SUCCESS;
}
} // namespace
void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "node is null");
@@ -42,7 +159,7 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str
ge::OpDescPtr op_desc = node->GetOpDesc();
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return );
std::string str;
if (!op_desc->GetAllInputsDescPtr().empty()) {
if (op_desc->GetInputsSize() != 0) {
std::string input_desc_str = "input shape: ";
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
input_desc_str += "[";
@@ -56,7 +173,7 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str
str += input_desc_str;
}

if (!op_desc->GetAllOutputsDescPtr().empty()) {
if (op_desc->GetAllOutputsDescSize() != 0) {
std::string output_desc_str = "output shape: ";
for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
if (output_desc == nullptr) {
@@ -76,13 +193,24 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str
}

graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) {
return InferShapeAndType(node, op, true);
}
graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) {
GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED);
auto op_desc = node->GetOpDesc();
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED);
const auto &op_type = op_desc->GetType();

graphStatus ret;
if (before_subgraph) {
ret = UpdateSubGraphDataNodes(node);
if (ret != GRAPH_SUCCESS) {
return ret;
}
}

// Get infer func and execute
graphStatus ret = op_desc->CallInferFunc(op);
ret = op_desc->CallInferFunc(op);
if (ret == GRAPH_PARAM_INVALID) {
// Op ir no infer func, try to get infer func from operator factory
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType());
@@ -113,7 +241,14 @@ graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &
ret = op_desc->CallInferFunc(op);
GELOGI("op CallInferFunc second. ret: %u", ret);
}
return ret;
if (ret != GRAPH_SUCCESS) {
return ret;
}

if (!before_subgraph) {
return UpdateParentNodeOutTensor(node);
}
return GRAPH_SUCCESS;
}

InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map,
@@ -179,8 +314,11 @@ InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, Inf
namespace {
std::unordered_map<NodePtr, InferenceContextPtr> context_map;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) {
return InferShapeAndType(node, true);
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node,
bool before_subgraph) {
GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED);
if (node->Verify() != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str());
@@ -199,7 +337,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh

Operator op = OpDescUtils::CreateOperatorFromNode(node);
op.SetInferenceContext(inference_context);
graphStatus status = InferShapeAndType(node, op);
graphStatus status = InferShapeAndType(node, op, before_subgraph);
if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) {
(void)ge::NodeUtils::UpdatePeerNodeInputDesc(node);
} else {


+ 5
- 3
src/common/graph/tensor.cc View File

@@ -353,6 +353,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size)
}
}
}

impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size);
}

@@ -516,13 +517,14 @@ graphStatus Tensor::IsValid() {
GELOGW("mul overflow: %lu, %u", shape_size, type_length);
} else {
if (shape_size * type_length != data_size) {
// [Just log] Constructor
GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length,
data_size, TypeUtils::DataTypeToSerialString(data_type).c_str());
return GRAPH_FAILED;
}
}
}
}

return GRAPH_SUCCESS;
}

@@ -539,7 +541,7 @@ GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_des
tensor_desc.GetDataType());
ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims()));
ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat());
auto size = static_cast<uint32_t>(tensor_desc.GetSize());
auto size = tensor_desc.GetSize();
TensorUtils::SetSize(ge_tensor_desc, size);

auto real_dim_cnt = static_cast<uint32_t>(tensor_desc.GetRealDimCnt());
@@ -552,7 +554,7 @@ TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_
ge_tensor_desc.GetDataType());
tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims()));
tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat());
uint32_t size = 0;
int64_t size = 0;
(void)TensorUtils::GetSize(ge_tensor_desc, size);
tensor_desc.SetSize(size);



+ 163
- 59
src/common/graph/utils/ge_ir_utils.cc View File

@@ -15,18 +15,21 @@
*/

#include "graph/utils/ge_ir_utils.h"

#include <utility>

#include "framework/common/debug/ge_log.h"

namespace {
const char *const kControlAnchorIndex = ":-1";
const char *const kNodeTypeForSubgraph = "subgraph";
const char *const kPrefixForInputDesc = "input_desc_attr_";
const char *const kPrefixForOutputDesc = "output_desc_attr_";
const char *const kDumpGEGraph = "DUMP_GE_GRAPH";
const int8_t kMaxRecursionDepth = 10;
const char *const kDumpGeGraph = std::getenv(kDumpGEGraph);
const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, nullptr, 10) : ge::OnnxUtils::NO_DUMP;
const int64_t kInputPrefixLength = 5;
const int64_t kOutputPrefixLength = 6;
using AttrDefPair = ::google::protobuf::MapPair<std::string, ge::proto::AttrDef>;
} // namespace

namespace ge {
@@ -198,7 +201,7 @@ void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_A
void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
::google::protobuf::RepeatedField<bool> data) {
if (node_proto == nullptr) {
GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str());
GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str());
return;
}
if (!data.empty()) {
@@ -320,7 +323,16 @@ void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const
auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
"input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset);
const auto &tensor_desc_map = tensor_descriptor->attr();
std::string suffix = ":" + std::to_string(i);
AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForInputDesc, suffix);
} else {
GELOGW("Tensor descriptor is nullptr");
continue;
}
} else {
GELOGW("Input desc is nullptr");
continue;
}
}
}
@@ -360,16 +372,25 @@ void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const
auto real_dim_cnt = tensor_descriptor->real_dim_cnt();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
"output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt);
const auto &tensor_desc_map = tensor_descriptor->attr();
std::string suffix = ":" + std::to_string(i);
AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForOutputDesc, suffix);
} else {
GELOGW("Tensor descriptor is nullptr");
continue;
}
} else {
GELOGW("Output desc is nullptr");
continue;
}
}
}
}

void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto) {
GE_CHK_BOOL_EXEC(op_def != nullptr, return, "Opdef is nullptr");
const auto &op_def_attr_map = op_def->attr();
for (const auto &item : op_def_attr_map) {
void OnnxUtils::AddAttrProtoForAttrsFromAttrMap(
const ::google::protobuf::Map<std::string, ::ge::proto::AttrDef> &attr_map, onnx::NodeProto *node_proto,
const std::string &prefix, const std::string &suffix) {
for (const auto &item : attr_map) {
auto attr_name = item.first;
auto attr_def = item.second;
auto attr_type = attr_def.value_case();
@@ -377,36 +398,40 @@ void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, on
const auto &tensor_def = attr_def.t();
const auto &tensor_desc = tensor_def.desc();
auto data_type = ge::proto::DataType_Name(tensor_desc.dtype());
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_desc_dtype:", &data_type);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_dtype" + suffix,
&data_type);
auto dims = tensor_desc.shape().dim();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name + "_desc_shape:", dims);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + "_desc_shape" + suffix,
dims);
auto layout = tensor_desc.layout();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_desc_layout:", &layout);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_layout" + suffix,
&layout);
auto device_type = tensor_desc.device_type();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
attr_name + "_desc_device_type:", &device_type);
prefix + attr_name + "_desc_device_type" + suffix, &device_type);
if (kDumpLevel == DUMP_ALL) {
auto data = tensor_def.data();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_data", &data);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_data" + suffix,
&data);
}
}
if (attr_type == ge::proto::AttrDef::kS) {
if (kDumpLevel == DUMP_ALL) {
auto str_value = attr_def.s();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name, &str_value);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value);
}
}
if (attr_type == ge::proto::AttrDef::kI) {
auto int_value = attr_def.i();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, attr_name, &int_value);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value);
}
if (attr_type == ge::proto::AttrDef::kF) {
auto float_value = attr_def.f();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, attr_name, &float_value);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value);
}
if (attr_type == ge::proto::AttrDef::kB) {
auto int_value = static_cast<int64_t>(attr_def.b());
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, attr_name, &int_value);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value);
}
if (attr_type == ge::proto::AttrDef::kList) {
const auto &list_value = attr_def.list();
@@ -415,21 +440,21 @@ void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, on
ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) {
if (kDumpLevel == DUMP_ALL) {
const auto &strings = list_value.s();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, attr_name, strings);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings);
}
}
if (list_value_type ==
ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) {
const auto &floats = list_value.f();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, attr_name, floats);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats);
}
if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) {
const auto &ints = list_value.i();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name, ints);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints);
}
if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) {
const auto &bools = list_value.b();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name, bools);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools);
}
}
}
@@ -481,8 +506,15 @@ void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes);
const auto &is_input_const = op_def->is_input_const();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const);
AddAttrProtoForAttrsFromOpDef(op_def, node_proto);
const auto &op_def_attr_map = op_def->attr();
AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto);
} else {
GELOGE(FAILED, "Opdef is nullptr");
return;
}
} else {
GELOGE(FAILED, "Opdesc is nullptr");
return;
}
}

@@ -526,15 +558,13 @@ bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto)
node_proto->clear_input();
// 1. Add input by in data edge
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
if (in_data_anchor != nullptr) {
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) {
node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" +
std::to_string(peer_out_anchor->GetIdx()));
} else {
// Add "" input
node_proto->add_input("");
}
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) {
node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" +
std::to_string(peer_out_anchor->GetIdx()));
} else {
// Add "" input
node_proto->add_input("");
}
}

@@ -547,6 +577,9 @@ bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto)
node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex);
}
}
} else {
GELOGE(FAILED, "Incontrol anchor is nullptr");
return false;
}

// 3. Add output for Netron visual support
@@ -584,7 +617,7 @@ void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_T
}
const auto &op_desc = node->GetOpDesc();
if (op_desc != nullptr) {
auto size_out = op_desc->GetOutputsSize();
uint32_t size_out = static_cast<uint32_t>(op_desc->GetOutputsSize());
if (size_out > 0) {
for (uint32_t i = 0; i < size_out; i++) {
const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i);
@@ -598,7 +631,13 @@ void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_T
auto dim = shape->add_dim();
dim->set_dim_value(d);
}
} else {
GELOGW("Shape is nullptr");
continue;
}
} else {
GELOGW("Ge tensor is nullptr");
continue;
}
}
}
@@ -666,7 +705,7 @@ bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelPr
}

// For subgraphs: a subgraph is represented by a node
for (const auto &sub_compute_graph : compute_graph->sub_graph_) {
for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) {
if (sub_compute_graph != nullptr) {
auto node_proto = graph_proto->add_node();
if (node_proto == nullptr) {
@@ -679,6 +718,10 @@ bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelPr
attr->set_name("graph");
attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
auto sub_graph_proto = attr->mutable_g();
if (sub_graph_proto == nullptr) {
GELOGW("Sub graph proto is nullptr");
continue;
}
if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) {
GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str());
continue;
@@ -831,56 +874,116 @@ void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t
value = attr_proto.i();
}

void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_input_output_desc, int32_t index,
OpDescPtr &op_desc) {
if (op_desc == nullptr || op_desc->MutableInputDesc(static_cast<uint32_t>(index)) == nullptr) {
GELOGE(GRAPH_FAILED, "op_desc or op_desc->MutableInputDesc(index) is nullptr");
void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_input_desc, int32_t index,
OpDescPtr &op_desc) {
if (op_desc->MutableInputDesc(static_cast<uint32_t>(index)) == nullptr) {
GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableInputDesc(static_cast<uint32_t>(index)) is nullptr",
op_desc->GetName().c_str(), attr_name_for_input_desc.c_str());
return;
}
if (attr_name_for_input_output_desc == "input_desc_dtype") {
if (attr_name_for_input_desc == "input_desc_dtype") {
auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s());
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetDataType(data_type);
} else if (attr_name_for_input_output_desc == "input_desc_shape") {
} else if (attr_name_for_input_desc == "input_desc_shape") {
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
GeShape ge_shape(ints);
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape);
} else if (attr_name_for_input_output_desc == "input_desc_layout") {
} else if (attr_name_for_input_desc == "input_desc_layout") {
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetFormat(data_format);
} else if (attr_name_for_input_output_desc == "input_desc_origin_shape") {
} else if (attr_name_for_input_desc == "input_desc_origin_shape") {
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
GeShape ge_shape(ints);
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape);
} else if (attr_name_for_input_output_desc == "input_desc_origin_layout") {
} else if (attr_name_for_input_desc == "input_desc_origin_layout") {
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format);
} else if (attr_name_for_input_output_desc == "output_desc_dtype") {
} else if (attr_name_for_input_desc == "input_desc_size") {
int64_t input_size = 0;
auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
DecodeAttribute(attr_proto, input_size);
tensor_descriptor->set_size(input_size);
} else if (attr_name_for_input_desc == "input_desc_data_offset") {
auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
int64_t offset = 0;
DecodeAttribute(attr_proto, offset);
tensor_descriptor->set_data_offset(offset);
} else {
return;
}
}

void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_output_desc, int32_t index,
OpDescPtr &op_desc) {
if (op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) == nullptr) {
GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) is nullptr",
op_desc->GetName().c_str(), attr_name_for_output_desc.c_str());
return;
}
if (attr_name_for_output_desc == "output_desc_dtype") {
auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s());
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetDataType(data_type);
} else if (attr_name_for_input_output_desc == "output_desc_shape") {
} else if (attr_name_for_output_desc == "output_desc_shape") {
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
GeShape ge_shape(ints);
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape);
} else if (attr_name_for_input_output_desc == "output_desc_layout") {
} else if (attr_name_for_output_desc == "output_desc_layout") {
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetFormat(data_format);
} else if (attr_name_for_input_output_desc == "output_desc_origin_shape") {
} else if (attr_name_for_output_desc == "output_desc_origin_shape") {
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
GeShape ge_shape(ints);
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape);
} else if (attr_name_for_input_output_desc == "output_desc_origin_layout") {
} else if (attr_name_for_output_desc == "output_desc_origin_layout") {
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format);
} else if (attr_name_for_output_desc == "output_desc_size") {
int64_t output_size = 0;
auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
DecodeAttribute(attr_proto, output_size);
tensor_descriptor->set_size(output_size);
} else if (attr_name_for_output_desc == "output_desc_data_offset") {
auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
int64_t offset = 0;
DecodeAttribute(attr_proto, offset);
tensor_descriptor->set_data_offset(offset);
} else {
return;
}
}

void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_input_output_desc, int32_t index,
OpDescPtr &op_desc) {
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "op_desc is nullptr");
return;
}
if (attr_name_for_input_output_desc.substr(0, kInputPrefixLength) == "input") {
DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc);
} else if (attr_name_for_input_output_desc.substr(0, kOutputPrefixLength) == "output") {
DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc);
} else {
return;
}
}

void OnnxUtils::DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def) {
auto attr_map = op_def.mutable_attr();
const auto &attr_name = attr_proto.name();
ge::proto::AttrDef op_attr;
int64_t value = 0;
DecodeAttribute(attr_proto, value);
op_attr.set_i(value);
attr_map->insert(AttrDefPair(attr_name, op_attr));
}

void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) {
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr");
@@ -910,6 +1013,16 @@ void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_pr
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
op_desc->SetDstIndex(ints);
} else if (attr_name == "fusion_scope") {
DecodeNodeAttributeForOpDef(attr_proto, *op_desc->op_def_.GetProtoMsg());
} else if (attr_name == "input_i") {
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
op_desc->SetInputOffset(ints);
} else if (attr_name == "output_i") {
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
op_desc->SetOutputOffset(ints);
} else {
return;
}
@@ -939,20 +1052,14 @@ bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_
auto size_in = attr.i();
for (int64_t i = 0; i < size_in; i++) {
GeTensorDesc ge_tensor_desc;
if (op_desc->AddInputDesc(ge_tensor_desc) != GRAPH_SUCCESS) {
GELOGW("Add inputdesc failed");
continue;
}
GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed.");
}
}
if (attr.name() == "output_desc_nums") {
auto size_out = attr.i();
for (int64_t i = 0; i < size_out; i++) {
GeTensorDesc ge_tensor_desc;
if (op_desc->AddInputDesc(ge_tensor_desc) != GRAPH_SUCCESS) {
GELOGW("add inputdesc failed");
continue;
}
GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed.");
}
}
}
@@ -970,10 +1077,7 @@ bool OnnxUtils::DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_p
}

graph = ComGraphMakeShared<ge::ComputeGraph>(graph_proto.name());
if (graph == nullptr) {
GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed");
return false;
}
GE_CHK_BOOL_EXEC(graph != nullptr, return false, "ComputeGraph make shared failed");
/// 1. Decode all nodes first, node should include input
/// and output nodes and nodes which represent sub graphs
std::map<std::string, NodePtr> node_map;


+ 14
- 0
src/common/graph/utils/ge_ir_utils.h View File

@@ -131,6 +131,10 @@ class OnnxUtils {

static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc);

static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map<std::string, ge::proto::AttrDef> &attr_map,
onnx::NodeProto *node_proto, const std::string &prefix = "",
const std::string &suffix = "");

static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto);

static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type);
@@ -172,10 +176,20 @@ class OnnxUtils {

static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value);

static void DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_output_desc, int32_t index,
OpDescPtr &op_desc);

static void DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_input_desc, int32_t index,
OpDescPtr &op_desc);

static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_input_output_desc, int32_t index,
OpDescPtr &op_desc);

static void DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def);

static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc);

static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr);


+ 879
- 65
src/common/graph/utils/graph_utils.cc
File diff suppressed because it is too large
View File


+ 44
- 0
src/common/graph/utils/node_utils.cc View File

@@ -15,10 +15,12 @@
*/

#include "utils/node_utils.h"
#include "graph/utils/graph_utils.h"
#include "debug/ge_op_types.h"
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/anchor.h"
#include "graph/debug/ge_attr_define.h"
#include "utils/tensor_utils.h"
#include "utils/type_utils.h"

@@ -109,6 +111,7 @@ graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_pt
graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) {
GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED,
"node or in_data_anchor is nullptr");

bool find_flag = false;
uint32_t index = 0;
vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end();
@@ -358,4 +361,45 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const
input_desc->SetShape(shape);
return GRAPH_SUCCESS;
}
std::string NodeUtils::GetNodeType(const Node &node) {
if (node.GetType() != FRAMEWORKOP) {
return node.GetType();
}
std::string type;
(void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
return type;
}
ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) {
auto op_desc = node.GetOpDesc();
if (op_desc == nullptr) {
return nullptr;
}
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
if (root_graph == nullptr) {
return nullptr;
}
return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index));
}

graphStatus NodeUtils::AddSubgraph(Node &node, const ComputeGraphPtr &subgraph) {
if (subgraph == nullptr) {
GE_LOGE("Failed to add subgraph to node %s, null subgraph", node.GetName().c_str());
return GRAPH_PARAM_INVALID;
}
auto op_desc = node.GetOpDesc();
if (op_desc == nullptr) {
return GRAPH_PARAM_INVALID;
}
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
if (root_graph == nullptr) {
GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str());
return GRAPH_PARAM_INVALID;
}
op_desc->AddSubgraphInstanceName(subgraph->GetName());
subgraph->SetParentNode(node.shared_from_this());
subgraph->SetParentGraph(node.GetOwnerComputeGraph());
root_graph->AddSubgraph(subgraph);

return GRAPH_SUCCESS;
}
} // namespace ge

+ 79
- 4
src/common/graph/utils/op_desc_utils.cc View File

@@ -15,9 +15,7 @@
*/

#include "utils/op_desc_utils.h"

#include <algorithm>

#include "debug/ge_attr_define.h"
#include "debug/ge_op_types.h"
#include "debug/ge_util.h"
@@ -209,6 +207,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData(
const vector<ge::NodePtr> &input_nodes) {
vector<ConstGeTensorPtr> ret;

for (const auto &input_node : input_nodes) {
auto temp_weight = MutableWeights(input_node->GetOpDesc());
if (temp_weight == nullptr) {
@@ -379,7 +378,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUt
if (NodeUtils::IsAnchorStatusSet(*node)) {
for (const auto &in_anchor : node->GetAllInDataAnchors()) {
if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) {
(void)ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
}
}
} else {
@@ -389,7 +388,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUt
continue;
}
if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) {
(void)ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx()));
}
}
}
@@ -572,4 +571,80 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWei
}
return GRAPH_SUCCESS;
}

///
/// @brief Add input
/// @param [in] name
/// @return OpDescBuilder
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) {
inputs_.emplace_back(name);
return *this;
}

///
/// @brief Add dynamic input
/// @param [in] name
/// @param [in] num
/// @return OpDescBuilder
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name,
uint32_t num) {
for (uint32_t i = 0; i < num; i++) {
inputs_.emplace_back(name + std::to_string(i));
}
return *this;
}

///
/// @brief Add output
/// @param [in] name
/// @return OpDescBuilder
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) {
outputs_.emplace_back(name);
return *this;
}

///
/// @brief Add dynamic output
/// @param [in] name
/// @param [in] num
/// @return OpDescBuilder
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name,
uint32_t num) {
for (uint32_t i = 0; i < num; i++) {
outputs_.emplace_back(name + std::to_string(i));
}
return *this;
}

///
/// @brief Build op_desc
/// @return OpDescPtr
///
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() {
OpDescPtr op_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name_, type_));
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "OpDesc is nullptr");
return nullptr;
}

for (auto &input : inputs_) {
if (op_desc->AddInputDesc(input, GeTensorDesc()) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Add input_desc failed.");
return nullptr;
}
}

for (auto &output : outputs_) {
if (op_desc->AddOutputDesc(output, GeTensorDesc()) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Add output_desc failed.");
return nullptr;
}
}

return op_desc;
}
} // namespace ge

+ 16
- 9
src/common/graph/utils/tensor_utils.cc View File

@@ -15,7 +15,6 @@
*/

#include "graph/utils/tensor_utils.h"

#include <cmath>

#include "debug/ge_log.h"
@@ -276,6 +275,14 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format
break;
case FORMAT_FRACTAL_NZ:
case FORMAT_FRACTAL_ZZ:
case FORMAT_NDHWC:
case FORMAT_NCDHW:
case FORMAT_DHWCN:
case FORMAT_DHWNC:
case FORMAT_FRACTAL_Z_3D:
case FORMAT_FRACTAL_Z_3D_TRANSPOSE:
case FORMAT_NDC1HWC0:
case FORMAT_FRACTAL_Z_C04:
graph_status = CalcElementCntByDims(dims, element_cnt);
break;
default:
@@ -351,21 +358,21 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTens
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp) {
TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) {
graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp);
if (graph_status != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
// 64-byte alignment, if size is 0, align to 32 bytes
if (size_temp > (UINT32_MAX - kNum2 * kDataMemAlignSize)) {
GELOGW("The updated mem size %u is bigger than UINT32_MAX", size_temp);
if (size_temp > (INT64_MAX - kNum2 * kDataMemAlignSize)) {
GELOGW("The updated mem size %ld is bigger than INT64_MAX", size_temp);
} else {
size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize;
}
return GRAPH_SUCCESS;
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp) {
TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) {
GeShape output_shape = desc_temp.GetShape();
Format format = desc_temp.GetFormat();
DataType data_type = desc_temp.GetDataType();
@@ -376,13 +383,13 @@ TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_
return GRAPH_FAILED;
}

if ((output_mem_size > UINT32_MAX) || (output_mem_size < 0)) {
GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %u]",
output_mem_size, UINT32_MAX);
if (output_mem_size < 0) {
GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %ld]",
output_mem_size, INT64_MAX);
return GRAPH_FAILED;
}

size_temp = static_cast<uint32_t>(output_mem_size);
size_temp = output_mem_size;
return GRAPH_SUCCESS;
}
} // namespace ge

+ 161
- 156
src/common/graph/utils/type_utils.cc View File

@@ -19,43 +19,45 @@

namespace ge {
static const std::map<Format, std::string> kFormatToStringMap = {
{FORMAT_NCHW, "NCHW"},
{FORMAT_NHWC, "NHWC"},
{FORMAT_ND, "ND"},
{FORMAT_NC1HWC0, "NC1HWC0"},
{FORMAT_FRACTAL_Z, "FRACTAL_Z"},
{FORMAT_NC1C0HWPAD, "NC1C0HWPAD"},
{FORMAT_NHWC1C0, "NHWC1C0"},
{FORMAT_FSR_NCHW, "FSR_NCHW"},
{FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"},
{FORMAT_C1HWNC0, "C1HWNC0"},
{FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"},
{FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"},
{FORMAT_NC1HWC0_C04, "NC1HWC0_C04"},
{FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"},
{FORMAT_CHWN, "CHWN"},
{FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"},
{FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"},
{FORMAT_BN_WEIGHT, "BN_WEIGHT"},
{FORMAT_FILTER_HWCK, "FILTER_HWCK"},
{FORMAT_HWCN, "HWCN"},
{FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"},
{FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"},
{FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"},
{FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"},
{FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"},
{FORMAT_MD, "MD"},
{FORMAT_NDHWC, "NDHWC"},
{FORMAT_NCDHW, "NCDHW"},
{FORMAT_DHWCK, "DHWCK"},
{FORMAT_NDC1HWC0, "NDC1HWC0"},
{FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"},
{FORMAT_C1HWNCoC0, "C1HWNCoC0"},
{FORMAT_FRACTAL_NZ, "FRACTAL_NZ"},
{FORMAT_CN, "CN"},
{FORMAT_NC, "NC"},
{FORMAT_RESERVED, "FORMAT_RESERVED"},
{FORMAT_ALL, "ALL"}};
{FORMAT_NCHW, "NCHW"},
{FORMAT_NHWC, "NHWC"},
{FORMAT_ND, "ND"},
{FORMAT_NC1HWC0, "NC1HWC0"},
{FORMAT_FRACTAL_Z, "FRACTAL_Z"},
{FORMAT_NC1C0HWPAD, "NC1C0HWPAD"},
{FORMAT_NHWC1C0, "NHWC1C0"},
{FORMAT_FSR_NCHW, "FSR_NCHW"},
{FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"},
{FORMAT_C1HWNC0, "C1HWNC0"},
{FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"},
{FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"},
{FORMAT_NC1HWC0_C04, "NC1HWC0_C04"},
{FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"},
{FORMAT_CHWN, "CHWN"},
{FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"},
{FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"},
{FORMAT_BN_WEIGHT, "BN_WEIGHT"},
{FORMAT_FILTER_HWCK, "FILTER_HWCK"},
{FORMAT_HWCN, "HWCN"},
{FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"},
{FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"},
{FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"},
{FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"},
{FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"},
{FORMAT_MD, "MD"},
{FORMAT_NDHWC, "NDHWC"},
{FORMAT_NCDHW, "NCDHW"},
{FORMAT_DHWCN, "DHWCN"},
{FORMAT_DHWNC, "DHWNC"},
{FORMAT_NDC1HWC0, "NDC1HWC0"},
{FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"},
{FORMAT_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"},
{FORMAT_C1HWNCoC0, "C1HWNCoC0"},
{FORMAT_FRACTAL_NZ, "FRACTAL_NZ"},
{FORMAT_CN, "CN"},
{FORMAT_NC, "NC"},
{FORMAT_RESERVED, "FORMAT_RESERVED"},
{FORMAT_ALL, "ALL"}};

static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0",
"FRACTAL_Z",
@@ -73,137 +75,140 @@ static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0",
"FRACTAL_ZZ",
"FRACTAL_NZ",
"NDC1HWC0",
"FORMAT_FRACTAL_Z_3D"};
"FORMAT_FRACTAL_Z_3D",
"FORMAT_FRACTAL_Z_3D_TRANSPOSE"};

static const std::map<std::string, Format> kDataFormatMap = {
{"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"ND", FORMAT_ND}};
{"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}};

static const std::map<std::string, Format> kStringToFormatMap = {
{"NCHW", FORMAT_NCHW},
{"NHWC", FORMAT_NHWC},
{"ND", FORMAT_ND},
{"NC1HWC0", FORMAT_NC1HWC0},
{"FRACTAL_Z", FORMAT_FRACTAL_Z},
{"NC1C0HWPAD", FORMAT_NC1C0HWPAD},
{"NHWC1C0", FORMAT_NHWC1C0},
{"FSR_NCHW", FORMAT_FSR_NCHW},
{"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV},
{"C1HWNC0", FORMAT_C1HWNC0},
{"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE},
{"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS},
{"NC1HWC0_C04", FORMAT_NC1HWC0_C04},
{"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04},
{"CHWN", FORMAT_CHWN},
{"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS},
{"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0},
{"BN_WEIGHT", FORMAT_BN_WEIGHT},
{"FILTER_HWCK", FORMAT_FILTER_HWCK},
{"HWCN", FORMAT_HWCN},
{"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS},
{"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS},
{"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE},
{"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT},
{"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS},
{"MD", FORMAT_MD},
{"C1HWNCoC0", FORMAT_C1HWNCoC0},
{"FRACTAL_NZ", FORMAT_FRACTAL_NZ},
{"NDHWC", FORMAT_NDHWC},
{"NCDHW", FORMAT_NCDHW},
{"DHWCK", FORMAT_DHWCK},
{"NDC1HWC0", FORMAT_NDC1HWC0},
{"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D},
{"CN", FORMAT_CN},
{"NC", FORMAT_NC},
{"FORMAT_RESERVED", FORMAT_RESERVED},
{"ALL", FORMAT_ALL}};
{"NCHW", FORMAT_NCHW},
{"NHWC", FORMAT_NHWC},
{"ND", FORMAT_ND},
{"NC1HWC0", FORMAT_NC1HWC0},
{"FRACTAL_Z", FORMAT_FRACTAL_Z},
{"NC1C0HWPAD", FORMAT_NC1C0HWPAD},
{"NHWC1C0", FORMAT_NHWC1C0},
{"FSR_NCHW", FORMAT_FSR_NCHW},
{"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV},
{"C1HWNC0", FORMAT_C1HWNC0},
{"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE},
{"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS},
{"NC1HWC0_C04", FORMAT_NC1HWC0_C04},
{"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04},
{"CHWN", FORMAT_CHWN},
{"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS},
{"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0},
{"BN_WEIGHT", FORMAT_BN_WEIGHT},
{"FILTER_HWCK", FORMAT_FILTER_HWCK},
{"HWCN", FORMAT_HWCN},
{"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS},
{"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS},
{"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE},
{"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT},
{"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS},
{"MD", FORMAT_MD},
{"C1HWNCoC0", FORMAT_C1HWNCoC0},
{"FRACTAL_NZ", FORMAT_FRACTAL_NZ},
{"NDHWC", FORMAT_NDHWC},
{"NCDHW", FORMAT_NCDHW},
{"DHWCN", FORMAT_DHWCN},
{"DHWNC", FORMAT_DHWNC},
{"NDC1HWC0", FORMAT_NDC1HWC0},
{"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D},
{"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE},
{"CN", FORMAT_CN},
{"NC", FORMAT_NC},
{"FORMAT_RESERVED", FORMAT_RESERVED},
{"ALL", FORMAT_ALL}};

static const std::map<DataType, std::string> kDataTypeToStringMap = {
{DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set.
{DT_FLOAT, "DT_FLOAT"}, // float type
{DT_FLOAT16, "DT_FLOAT16"}, // fp16 type
{DT_INT8, "DT_INT8"}, // int8 type
{DT_INT16, "DT_INT16"}, // int16 type
{DT_UINT16, "DT_UINT16"}, // uint16 type
{DT_UINT8, "DT_UINT8"}, // uint8 type
{DT_INT32, "DT_INT32"}, // uint32 type
{DT_INT64, "DT_INT64"}, // int64 type
{DT_UINT32, "DT_UINT32"}, // unsigned int32
{DT_UINT64, "DT_UINT64"}, // unsigned int64
{DT_BOOL, "DT_BOOL"}, // bool type
{DT_DOUBLE, "DT_DOUBLE"}, // double type
{DT_DUAL, "DT_DUAL"}, // dual output type
{DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type
{DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type
{DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type
{DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type
{DT_QINT8, "DT_QINT8"}, // qint8 type
{DT_QINT16, "DT_QINT16"}, // qint16 type
{DT_QINT32, "DT_QINT32"}, // qint32 type
{DT_QUINT8, "DT_QUINT8"}, // quint8 type
{DT_QUINT16, "DT_QUINT16"}, // quint16 type
{DT_RESOURCE, "DT_RESOURCE"}, // resource type
{DT_STRING_REF, "DT_STRING_REF"}, // string ref type
{DT_STRING, "DT_STRING"}, // string type
{DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set.
{DT_FLOAT, "DT_FLOAT"}, // float type
{DT_FLOAT16, "DT_FLOAT16"}, // fp16 type
{DT_INT8, "DT_INT8"}, // int8 type
{DT_INT16, "DT_INT16"}, // int16 type
{DT_UINT16, "DT_UINT16"}, // uint16 type
{DT_UINT8, "DT_UINT8"}, // uint8 type
{DT_INT32, "DT_INT32"}, // uint32 type
{DT_INT64, "DT_INT64"}, // int64 type
{DT_UINT32, "DT_UINT32"}, // unsigned int32
{DT_UINT64, "DT_UINT64"}, // unsigned int64
{DT_BOOL, "DT_BOOL"}, // bool type
{DT_DOUBLE, "DT_DOUBLE"}, // double type
{DT_DUAL, "DT_DUAL"}, // dual output type
{DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type
{DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type
{DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type
{DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type
{DT_QINT8, "DT_QINT8"}, // qint8 type
{DT_QINT16, "DT_QINT16"}, // qint16 type
{DT_QINT32, "DT_QINT32"}, // qint32 type
{DT_QUINT8, "DT_QUINT8"}, // quint8 type
{DT_QUINT16, "DT_QUINT16"}, // quint16 type
{DT_RESOURCE, "DT_RESOURCE"}, // resource type
{DT_STRING_REF, "DT_STRING_REF"}, // string ref type
{DT_STRING, "DT_STRING"}, // string type
};

static const std::map<std::string, DataType> kStringTodataTypeMap = {
{"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set.
{"DT_FLOAT", DT_FLOAT}, // float type
{
"DT_FLOAT16",
DT_FLOAT16,
}, // fp16 type
{"DT_INT8", DT_INT8}, // int8 type
{"DT_INT16", DT_INT16}, // int16 type
{"DT_UINT16", DT_UINT16}, // uint16 type
{"DT_UINT8", DT_UINT8}, // uint8 type
{"DT_INT32", DT_INT32}, // uint32 type
{"DT_INT64", DT_INT64}, // int64 type
{"DT_UINT32", DT_UINT32}, // unsigned int32
{"DT_UINT64", DT_UINT64}, // unsigned int64
{"DT_BOOL", DT_BOOL}, // bool type
{"DT_DOUBLE", DT_DOUBLE}, // double type
{"DT_DUAL", DT_DUAL}, // dual output type
{"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type
{"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type
{"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type
{"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type
{"DT_QINT8", DT_QINT8}, // qint8 type
{"DT_QINT16", DT_QINT16}, // qint16 type
{"DT_QINT32", DT_QINT32}, // qint32 type
{"DT_QUINT8", DT_QUINT8}, // quint8 type
{"DT_QUINT16", DT_QUINT16}, // quint16 type
{"DT_RESOURCE", DT_RESOURCE}, // resource type
{"DT_STRING_REF", DT_STRING_REF}, // string ref type
{"DT_STRING", DT_STRING}, // string type
{"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set.
{"DT_FLOAT", DT_FLOAT}, // float type
{
"DT_FLOAT16",
DT_FLOAT16,
}, // fp16 type
{"DT_INT8", DT_INT8}, // int8 type
{"DT_INT16", DT_INT16}, // int16 type
{"DT_UINT16", DT_UINT16}, // uint16 type
{"DT_UINT8", DT_UINT8}, // uint8 type
{"DT_INT32", DT_INT32}, // uint32 type
{"DT_INT64", DT_INT64}, // int64 type
{"DT_UINT32", DT_UINT32}, // unsigned int32
{"DT_UINT64", DT_UINT64}, // unsigned int64
{"DT_BOOL", DT_BOOL}, // bool type
{"DT_DOUBLE", DT_DOUBLE}, // double type
{"DT_DUAL", DT_DUAL}, // dual output type
{"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type
{"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type
{"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type
{"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type
{"DT_QINT8", DT_QINT8}, // qint8 type
{"DT_QINT16", DT_QINT16}, // qint16 type
{"DT_QINT32", DT_QINT32}, // qint32 type
{"DT_QUINT8", DT_QUINT8}, // quint8 type
{"DT_QUINT16", DT_QUINT16}, // quint16 type
{"DT_RESOURCE", DT_RESOURCE}, // resource type
{"DT_STRING_REF", DT_STRING_REF}, // string ref type
{"DT_STRING", DT_STRING}, // string type
};

static const std::map<ge::DataType, uint32_t> kDataTypeToLength = {
{DT_BOOL, sizeof(bool)},
{DT_INT64, sizeof(int64_t)},
{DT_UINT64, sizeof(int64_t)},
{DT_FLOAT, sizeof(float)},
{DT_INT32, sizeof(int32_t)},
{DT_UINT32, sizeof(int32_t)},
{DT_INT8, sizeof(char)},
{DT_UINT8, sizeof(char)},
{DT_INT16, sizeof(int16_t)},
{DT_UINT16, sizeof(int16_t)},
{DT_FLOAT16, sizeof(int16_t)},
{DT_DOUBLE, sizeof(double)},
{DT_DUAL, sizeof(float) + sizeof(int8_t)},
{DT_DUAL_SUB_INT8, sizeof(int8_t)},
{DT_DUAL_SUB_UINT8, sizeof(uint8_t)},
{DT_COMPLEX64, sizeof(int64_t)},
{DT_COMPLEX128, sizeof(int64_t) * 2},
{DT_QINT8, sizeof(int8_t)},
{DT_QINT16, sizeof(int16_t)},
{DT_QINT32, sizeof(int32_t)},
{DT_QUINT8, sizeof(uint8_t)},
{DT_QUINT16, sizeof(uint16_t)},
{DT_STRING_REF, sizeof(uint64_t) * 2},
{DT_STRING, sizeof(uint64_t)},
{DT_RESOURCE, sizeof(uint64_t)},
{DT_BOOL, sizeof(bool)},
{DT_INT64, sizeof(int64_t)},
{DT_UINT64, sizeof(int64_t)},
{DT_FLOAT, sizeof(float)},
{DT_INT32, sizeof(int32_t)},
{DT_UINT32, sizeof(int32_t)},
{DT_INT8, sizeof(char)},
{DT_UINT8, sizeof(char)},
{DT_INT16, sizeof(int16_t)},
{DT_UINT16, sizeof(int16_t)},
{DT_FLOAT16, sizeof(int16_t)},
{DT_DOUBLE, sizeof(double)},
{DT_DUAL, sizeof(float) + sizeof(int8_t)},
{DT_DUAL_SUB_INT8, sizeof(int8_t)},
{DT_DUAL_SUB_UINT8, sizeof(uint8_t)},
{DT_COMPLEX64, sizeof(int64_t)},
{DT_COMPLEX128, sizeof(int64_t) * 2},
{DT_QINT8, sizeof(int8_t)},
{DT_QINT16, sizeof(int16_t)},
{DT_QINT32, sizeof(int32_t)},
{DT_QUINT8, sizeof(uint8_t)},
{DT_QUINT16, sizeof(uint16_t)},
{DT_STRING_REF, sizeof(uint64_t) * 2},
{DT_STRING, sizeof(uint64_t)},
{DT_RESOURCE, sizeof(uint64_t)},
};

bool TypeUtils::IsDataTypeValid(DataType dt) {


+ 40
- 33
src/ge/CMakeLists.txt View File

@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================

# libge.so & libge_train.so
# libge_compiler.so & libge_train.so
# will later be integrated into libgraph_runner.so, works for both training and inference
# compiling proto files generates some warnings, use no-unused-variable to suppress them
set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}")
@@ -49,7 +49,7 @@ include_directories(${CMAKE_BINARY_DIR}/proto/ge)

######### libge_train.so #############
# need to remove dependencies on pb files later
file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"common/formats/format_transfers/*.cc"
"common/formats/formats.cc"
"common/formats/utils/formats_trans_utils.cc"
@@ -57,20 +57,24 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"common/ge/plugin_manager.cc"
"common/profiling/profiling_manager.cc"
"engine_manager/dnnengine_manager.cc"
"ge_local_engine/engine/host_cpu_engine.cc"
"generator/ge_generator.cc"
"generator/generator_api.cc"
"graph/build/graph_build.cc"
"graph/build/graph_builder.cc"
"graph/build/label_allocator.cc"
"graph/build/logical_stream_allocator.cc"
"graph/build/model_builder.cc"
"graph/build/optimize_stream_graph.cc"
"graph/build/run_context.cc"
"graph/build/stream_allocator.cc"
"graph/build/stream_graph_optimizer.cc"
"graph/build/task_generator.cc"
"graph/common/bcast.cc"
"graph/common/omg_util.cc"
"graph/common/transop_util.cc"
"graph/execute/graph_execute.cc"
"graph/label/*.cc"
"graph/load/graph_loader.cc"
"graph/load/new_model_manager/cpu_queue_schedule.cc"
"graph/load/new_model_manager/data_dumper.cc"
"graph/load/new_model_manager/data_inputer.cc"
"graph/load/new_model_manager/davinci_model.cc"
@@ -92,10 +96,12 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc"
"graph/load/new_model_manager/task_info/stream_active_task_info.cc"
"graph/load/new_model_manager/task_info/stream_switch_task_info.cc"
"graph/load/new_model_manager/task_info/stream_switchn_task_info.cc"
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc"
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc"
"graph/load/new_model_manager/task_info/task_info.cc"
"graph/load/new_model_manager/tbe_handle_store.cc"
"graph/load/output/output.cc"
"graph/manager/custom/custom_op.cc"
"graph/manager/graph_context.cc"
"graph/manager/graph_manager.cc"
"graph/manager/graph_manager_utils.cc"
@@ -105,12 +111,9 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/manager/trans_var_data_utils.cc"
"graph/manager/util/debug.cc"
"graph/manager/util/hcom_util.cc"
"graph/manager/util/node_searcher/need_rebuild_node_searcher.cc"
"graph/manager/util/rt_context_util.cc"
"graph/manager/util/variable_accelerate_ctrl.cc"
"graph/optimize/graph_functiondef.cc"
"graph/optimize/graph_optimize.cc"
"graph/optimize/graph_optimizer.cc"
"graph/optimize/optimizer/allreduce_fusion_pass.cc"
"graph/optimize/summary_optimize.cc"
"graph/partition/engine_place.cc"
@@ -120,7 +123,9 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/passes/assert_pass.cc"
"graph/passes/atomic_addr_clean_pass.cc"
"graph/passes/base_pass.cc"
"graph/passes/cast_remove_pass.cc"
"graph/passes/cast_translate_pass.cc"
"graph/passes/common_subexpression_elimination_pass.cc"
"graph/passes/compile_nodes_pass.cc"
"graph/passes/constant_folding_pass.cc"
"graph/passes/constant_fuse_same_pass.cc"
@@ -159,12 +164,14 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/passes/folding_kernel/shape_kernel.cc"
"graph/passes/folding_kernel/shape_n_kernel.cc"
"graph/passes/folding_kernel/size_kernel.cc"
"graph/passes/folding_kernel/slice_d_kernel.cc"
"graph/passes/folding_kernel/slice_kernel.cc"
"graph/passes/folding_kernel/squeeze_kernel.cc"
"graph/passes/folding_kernel/ssd_prior_box_kernel.cc"
"graph/passes/folding_kernel/strided_slice_kernel.cc"
"graph/passes/folding_kernel/sub_kernel.cc"
"graph/passes/folding_kernel/transdata_kernel.cc"
"graph/passes/folding_kernel/unpack_kernel.cc"
"graph/passes/folding_pass.cc"
"graph/passes/get_original_format_pass.cc"
"graph/passes/guarantee_const_pass.cc"
@@ -179,7 +186,6 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/passes/multi_batch_pass.cc"
"graph/passes/net_output_pass.cc"
"graph/passes/next_iteration_pass.cc"
"graph/passes/no_reshape_op_remove_pass.cc"
"graph/passes/no_use_reshape_remove_pass.cc"
"graph/passes/pass_manager.cc"
"graph/passes/pass_utils.cc"
@@ -188,6 +194,7 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/passes/prevent_gradient_pass.cc"
"graph/passes/print_op_pass.cc"
"graph/passes/prune_pass.cc"
"graph/passes/replace_with_empty_const_pass.cc"
"graph/passes/reshape_remove_pass.cc"
"graph/passes/resource_pair_add_control_pass.cc"
"graph/passes/resource_pair_remove_control_pass.cc"
@@ -206,14 +213,12 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/passes/transpose_transdata_pass.cc"
"graph/passes/unused_const_pass.cc"
"graph/passes/unused_op_remove_pass.cc"
"graph/passes/update_net_output_pass.cc"
"graph/passes/var_is_initialized_op_pass.cc"
"graph/passes/variable_format_pass.cc"
"graph/passes/variable_op_pass.cc"
"graph/passes/variable_prepare_op_pass.cc"
"graph/passes/variable_ref_delete_op_pass.cc"
"graph/preprocess/graph_preprocess.cc"
"graph/preprocess/insert_op/base_insert_op.cc"
"graph/preprocess/insert_op/ge_aipp_op.cc"
"graph/preprocess/insert_op/util_insert_aipp_op.cc"
"graph/preprocess/multi_batch_copy_graph.cc"
@@ -223,13 +228,8 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"opskernel_manager/ops_kernel_manager.cc"
"session/inner_session.cc"
"session/session_manager.cc"
"single_op/single_op.cc"
"single_op/single_op_manager.cc"
"single_op/single_op_model.cc"
"single_op/stream_resource.cc"
"single_op/task/build_task_utils.cc"
"single_op/task/op_task.cc"
"single_op/task/tbe_task_builder.cc"
"single_op/*.cc"
"single_op/task/*.cc"
)


@@ -261,9 +261,9 @@ target_link_libraries(ge_train
rt
dl)

######### libge.so #############
######### libge_compiler.so #############
# need to remove dependencies on pb files later
file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"common/formats/format_transfers/*.cc"
"common/formats/formats.cc"
"common/formats/utils/formats_trans_utils.cc"
@@ -271,20 +271,24 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"common/ge/plugin_manager.cc"
"common/profiling/profiling_manager.cc"
"engine_manager/dnnengine_manager.cc"
"ge_local_engine/engine/host_cpu_engine.cc"
"generator/ge_generator.cc"
"generator/generator_api.cc"
"graph/build/graph_build.cc"
"graph/build/graph_builder.cc"
"graph/build/label_allocator.cc"
"graph/build/logical_stream_allocator.cc"
"graph/build/model_builder.cc"
"graph/build/optimize_stream_graph.cc"
"graph/build/run_context.cc"
"graph/build/stream_allocator.cc"
"graph/build/stream_graph_optimizer.cc"
"graph/build/task_generator.cc"
"graph/common/bcast.cc"
"graph/common/omg_util.cc"
"graph/common/transop_util.cc"
"graph/execute/graph_execute.cc"
"graph/label/*.cc"
"graph/load/graph_loader.cc"
"graph/load/new_model_manager/cpu_queue_schedule.cc"
"graph/load/new_model_manager/data_dumper.cc"
"graph/load/new_model_manager/data_inputer.cc"
"graph/load/new_model_manager/davinci_model.cc"
@@ -305,10 +309,12 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc"
"graph/load/new_model_manager/task_info/stream_active_task_info.cc"
"graph/load/new_model_manager/task_info/stream_switch_task_info.cc"
"graph/load/new_model_manager/task_info/stream_switchn_task_info.cc"
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc"
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc"
"graph/load/new_model_manager/task_info/task_info.cc"
"graph/load/new_model_manager/tbe_handle_store.cc"
"graph/load/output/output.cc"
"graph/manager/custom/custom_op.cc"
"graph/manager/graph_context.cc"
"graph/manager/graph_manager.cc"
"graph/manager/graph_manager_utils.cc"
@@ -317,13 +323,9 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/manager/model_manager/event_manager.cc"
"graph/manager/trans_var_data_utils.cc"
"graph/manager/util/debug.cc"
"graph/manager/util/node_searcher/need_rebuild_node_searcher.cc"
"graph/manager/util/rt_context_util.cc"
"graph/manager/util/variable_accelerate_ctrl.cc"
"graph/optimize/graph_functiondef.cc"
"graph/optimize/graph_optimize.cc"
"graph/optimize/graph_optimizer.cc"
"graph/optimize/optimizer/allreduce_fusion_inference_pass.cc"
"graph/optimize/summary_optimize.cc"
"graph/partition/engine_place.cc"
"graph/partition/graph_partition.cc"
@@ -332,7 +334,9 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/passes/assert_pass.cc"
"graph/passes/atomic_addr_clean_pass.cc"
"graph/passes/base_pass.cc"
"graph/passes/cast_remove_pass.cc"
"graph/passes/cast_translate_pass.cc"
"graph/passes/common_subexpression_elimination_pass.cc"
"graph/passes/compile_nodes_pass.cc"
"graph/passes/constant_folding_pass.cc"
"graph/passes/constant_fuse_same_pass.cc"
@@ -371,12 +375,14 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/passes/folding_kernel/shape_kernel.cc"
"graph/passes/folding_kernel/shape_n_kernel.cc"
"graph/passes/folding_kernel/size_kernel.cc"
"graph/passes/folding_kernel/slice_d_kernel.cc"
"graph/passes/folding_kernel/slice_kernel.cc"
"graph/passes/folding_kernel/squeeze_kernel.cc"
"graph/passes/folding_kernel/ssd_prior_box_kernel.cc"
"graph/passes/folding_kernel/strided_slice_kernel.cc"
"graph/passes/folding_kernel/sub_kernel.cc"
"graph/passes/folding_kernel/transdata_kernel.cc"
"graph/passes/folding_kernel/unpack_kernel.cc"
"graph/passes/folding_pass.cc"
"graph/passes/get_original_format_pass.cc"
"graph/passes/guarantee_const_pass.cc"
@@ -391,7 +397,6 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/passes/multi_batch_pass.cc"
"graph/passes/net_output_pass.cc"
"graph/passes/next_iteration_pass.cc"
"graph/passes/no_reshape_op_remove_pass.cc"
"graph/passes/no_use_reshape_remove_pass.cc"
"graph/passes/pass_manager.cc"
"graph/passes/pass_utils.cc"
@@ -400,6 +405,7 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/passes/prevent_gradient_pass.cc"
"graph/passes/print_op_pass.cc"
"graph/passes/prune_pass.cc"
"graph/passes/replace_with_empty_const_pass.cc"
"graph/passes/reshape_remove_pass.cc"
"graph/passes/resource_pair_add_control_pass.cc"
"graph/passes/resource_pair_remove_control_pass.cc"
@@ -418,14 +424,12 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/passes/transpose_transdata_pass.cc"
"graph/passes/unused_const_pass.cc"
"graph/passes/unused_op_remove_pass.cc"
"graph/passes/update_net_output_pass.cc"
"graph/passes/var_is_initialized_op_pass.cc"
"graph/passes/variable_format_pass.cc"
"graph/passes/variable_op_pass.cc"
"graph/passes/variable_prepare_op_pass.cc"
"graph/passes/variable_ref_delete_op_pass.cc"
"graph/preprocess/graph_preprocess.cc"
"graph/preprocess/insert_op/base_insert_op.cc"
"graph/preprocess/insert_op/ge_aipp_op.cc"
"graph/preprocess/insert_op/util_insert_aipp_op.cc"
"graph/preprocess/multi_batch_copy_graph.cc"
@@ -442,16 +446,19 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"single_op/task/build_task_utils.cc"
"single_op/task/op_task.cc"
"single_op/task/tbe_task_builder.cc"
##########################################
# "ir_build/ge_ir_build.cc"
# "offline/atc_ir_common.cc"
)

add_library(ge SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS})
target_compile_definitions(ge PRIVATE
add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS})
target_compile_definitions(ge_compiler PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
DAVINCI_SUPPORT_PROFILING
REUSE_MEMORY=1
FMK_HOST_INFER
PLATFORM_CLOUD)
target_link_libraries(ge
target_link_libraries(ge_compiler
graph
ge_common
"-Wl,--whole-archive"


+ 1
- 1
src/ge/client/CMakeLists.txt View File

@@ -80,7 +80,7 @@ target_compile_definitions(ge_client_train PRIVATE
PLATFORM_CLOUD)
target_link_libraries(ge_client
graph
ge
ge_compiler
ge_common
${PROTOBUF_LIBRARY}
${register}


+ 25
- 17
src/ge/client/ge_api.cc View File

@@ -61,14 +61,14 @@ Status CheckDumpAndReuseMemory(const std::map<string, string> &options) {
const int kDecimal = 10;
auto dump_op_env = std::getenv("DUMP_OP");
int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0;
auto disable_reuse_memory_iter = options.find("ge.exec.disableReuseMemory");
if (disable_reuse_memory_iter != options.end()) {
if (disable_reuse_memory_iter->second == "0") {
auto disableReuseMemoryIter = options.find("ge.exec.disableReuseMemory");
if (disableReuseMemoryIter != options.end()) {
if (disableReuseMemoryIter->second == "0") {
GELOGD("ge.exec.disableReuseMemory=0, reuse memory is open");
if (dump_op_flag) {
GELOGW("Will dump incorrect op data with GE Option ge.exec.disableReuseMemory=0");
}
} else if (disable_reuse_memory_iter->second == "1") {
} else if (disableReuseMemoryIter->second == "1") {
GELOGD("ge.exec.disableReuseMemory=1, reuse memory is close");
} else {
GELOGE(PARAM_INVALID, "CheckDumpAndReuseMemory ge.exec.disableReuseMemory is valid");
@@ -128,22 +128,29 @@ Status GEInitialize(const std::map<string, string> &options) {
OpsProtoManager *manager = OpsProtoManager::Instance();
std::map<string, string> option_tmp;
option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
GE_TIMESTAMP_START(GEInitialize);
bool is_proto_init = manager->Initialize(option_tmp);
GE_TIMESTAMP_END(GEInitialize, "GEInitialize::ManagerInitialize");
if (!is_proto_init) {
GELOGE(GE_CLI_INIT_FAILED, "geInitialize failed, ops proto path is invalid.");
return FAILED;
}

// check options is valid
GE_TIMESTAMP_START(CheckOptionsValid);
if (CheckOptionsValid(options) != SUCCESS) {
return FAILED;
}
GE_TIMESTAMP_END(CheckOptionsValid, "GEInitialize::CheckOptionsValid");

GE_TIMESTAMP_START(InitPreparation);
SaveDdkVersion(options);
GE_TIMESTAMP_END(InitPreparation, "GEInitialize::InitPreparation");
// call Initialize
GELOGT(TRACE_RUNNING, "Initializing environment");
GE_TIMESTAMP_START(GELibInitialize);
Status ret = ge::GELib::Initialize(options);
GE_TIMESTAMP_END(GELibInitialize, "GEInitialize::GELibInitialize");
if (ret != SUCCESS) {
GELOGE(GE_CLI_INIT_FAILED, "geInitialize failed, error code = %u", ret);
return FAILED;
@@ -170,17 +177,20 @@ Status GEFinalize() {

std::lock_guard<std::mutex> lock(kGeReleaseMutex);
// call Finalize
Status ret = SUCCESS;
Status middle_ret;
GELOGT(TRACE_RUNNING, "Finalizing environment");
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GEFinalize Failed: GE not initialized");
return GE_CLI_GE_NOT_INITIALIZED;
}
Status ret = instance_ptr->Finalize();
GELOGI("GEFinalize finalize gelib ret=%u", ret);
if (ret != SUCCESS) {
GELOGE(ret, "GEFinalize Failed");
return FAILED;
std::shared_ptr<GELib> instancePtr = ge::GELib::GetInstance();
if (instancePtr == nullptr || !instancePtr->InitFlag()) {
GELOGW("GEFinalize Failed: GE not initialized.");
ret = GE_CLI_GE_NOT_INITIALIZED;
}
if (ret != GE_CLI_GE_NOT_INITIALIZED) {
middle_ret = instancePtr->Finalize();
GELOGI("GEFinalize finalize gelib ret=%u", middle_ret);
if (middle_ret != SUCCESS) {
ret = middle_ret;
}
}

if (kGeInitialized && ret == SUCCESS) {
@@ -379,8 +389,6 @@ Status Session::RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, s
}

Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback) {
GELOGW(
"The callback function will not be checked. Please ensure that the implementation of the function is trusted.");
return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback);
}



Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save