Merge pull request !22 from yanghaoran/mastertags/v0.3.0-alpha
@@ -18,7 +18,6 @@ | |||||
#define INC_COMMON_BLOCKING_QUEUE_H_ | #define INC_COMMON_BLOCKING_QUEUE_H_ | ||||
#include <stdint.h> | #include <stdint.h> | ||||
#include <condition_variable> | #include <condition_variable> | ||||
#include <list> | #include <list> | ||||
#include <mutex> | #include <mutex> | ||||
@@ -87,7 +86,7 @@ class BlockingQueue { | |||||
is_stoped_ = false; | is_stoped_ = false; | ||||
} | } | ||||
// if the queue stop , the function to release the unprocessed items will be call | |||||
// if the queue is stoped ,need call this function to release the unprocessed items | |||||
std::list<T> GetRemainItems() { | std::list<T> GetRemainItems() { | ||||
std::unique_lock<std::mutex> lock(mutex_); | std::unique_lock<std::mutex> lock(mutex_); | ||||
@@ -19,10 +19,10 @@ | |||||
#include <stdint.h> | #include <stdint.h> | ||||
/// | |||||
/// @ingroup dnn | |||||
/// @brief struct define of dynamic aipp batch parameter. | |||||
/// | |||||
/** | |||||
* @ingroup dnn | |||||
* @brief struct define of dynamic aipp batch parameter. | |||||
*/ | |||||
typedef struct tagAippDynamicBatchPara { | typedef struct tagAippDynamicBatchPara { | ||||
int8_t cropSwitch; // crop switch | int8_t cropSwitch; // crop switch | ||||
int8_t scfSwitch; // resize switch | int8_t scfSwitch; // resize switch | ||||
@@ -66,10 +66,10 @@ typedef struct tagAippDynamicBatchPara { | |||||
int8_t reserve1[16]; // 32B assign, for ub copy | int8_t reserve1[16]; // 32B assign, for ub copy | ||||
} kAippDynamicBatchPara; | } kAippDynamicBatchPara; | ||||
/// | |||||
/// @ingroup dnn | |||||
/// @brief struct definition of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte | |||||
/// | |||||
/** | |||||
* @ingroup dnn | |||||
* @brief struct define of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte | |||||
*/ | |||||
typedef struct tagAippDynamicPara { | typedef struct tagAippDynamicPara { | ||||
uint8_t inputFormat; // input format:YUV420SP_U8/XRGB8888_U8/RGB888_U8 | uint8_t inputFormat; // input format:YUV420SP_U8/XRGB8888_U8/RGB888_U8 | ||||
int8_t cscSwitch; // csc switch | int8_t cscSwitch; // csc switch | ||||
@@ -61,19 +61,19 @@ typedef enum tagHiAiNpuModuleId { | |||||
HIAI_DP = 23, | HIAI_DP = 23, | ||||
} HiAiNpuModuleId; | } HiAiNpuModuleId; | ||||
// bit 31-bit30 to be hiai local | |||||
/* bit 31-bit30 to be hiai local */ | |||||
#define HIAI_NPULOCAL_MASK 0xC0000000 | #define HIAI_NPULOCAL_MASK 0xC0000000 | ||||
#define SHIFT_LOCAL_MASK 30 | #define SHIFT_LOCAL_MASK 30 | ||||
#define HIAI_NPULOCAL_VAL_MASK 0x3 | #define HIAI_NPULOCAL_VAL_MASK 0x3 | ||||
// bit 29 -bit28 to be hiai aicpu code type | |||||
/* bit 29 -bit28 to be hiai aicpu code type */ | |||||
#define HIAI_CODE_TYPE_MASK 0x30000000 | #define HIAI_CODE_TYPE_MASK 0x30000000 | ||||
#define SHIFT_CODE_MASK 28 | #define SHIFT_CODE_MASK 28 | ||||
#define HIAI_CODE_TYPE_VAL_MASK 0x3 | #define HIAI_CODE_TYPE_VAL_MASK 0x3 | ||||
// bit 27 -bit25 to be hiai error level | |||||
/* bit 27 -bit25 to be hiai error level */ | |||||
#define HIAI_ERROR_LEVEL_MASK 0x0E000000 | #define HIAI_ERROR_LEVEL_MASK 0x0E000000 | ||||
#define SHIFT_ERROR_LVL_MASK 25 | #define SHIFT_ERROR_LVL_MASK 25 | ||||
#define HIAI_ERROR_LEVEL_VAL_MASK 0x7 | #define HIAI_ERROR_LEVEL_VAL_MASK 0x7 | ||||
// bit 24 -bit17 to be hiai mod | |||||
/* bit 24 -bit17 to be hiai mod */ | |||||
#define HIAI_MODE_ID_MASK 0x01FE0000 | #define HIAI_MODE_ID_MASK 0x01FE0000 | ||||
#define SHIFT_MODE_MASK 17 | #define SHIFT_MODE_MASK 17 | ||||
#define HIAI_MODE_ID_VAL_MASK 0xFF | #define HIAI_MODE_ID_VAL_MASK 0xFF | ||||
@@ -19,13 +19,12 @@ | |||||
#include <runtime/rt.h> | #include <runtime/rt.h> | ||||
#include <stdint.h> | #include <stdint.h> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
using std::string; | using std::string; | ||||
namespace ge { | namespace ge { | ||||
// DAVINCI_TRAIN/DAVINCI_CLOUD is not needed when GETaskKernelHcclInfo needed | |||||
// when need to eliminate GETaskKernelHcclInfo, so not need DAVINCI_TRAIN/DAVINCI_CLOUD | |||||
struct GETaskKernelHcclInfo { | struct GETaskKernelHcclInfo { | ||||
string hccl_type; | string hccl_type; | ||||
void *inputDataAddr; | void *inputDataAddr; | ||||
@@ -21,7 +21,6 @@ | |||||
#include <map> | #include <map> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "./ge_task_info.h" | #include "./ge_task_info.h" | ||||
#include "./ops_kernel_info_types.h" | #include "./ops_kernel_info_types.h" | ||||
#include "cce/aicpu_engine_struct.h" | #include "cce/aicpu_engine_struct.h" | ||||
@@ -29,7 +28,6 @@ | |||||
#include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
#include "proto/task.pb.h" | #include "proto/task.pb.h" | ||||
using std::map; | using std::map; | ||||
using std::string; | using std::string; | ||||
using std::to_string; | using std::to_string; | ||||
@@ -47,7 +45,7 @@ class OpsKernelInfoStore { | |||||
// initialize opsKernelInfoStore | // initialize opsKernelInfoStore | ||||
virtual Status Initialize(const map<string, string> &options) = 0; | virtual Status Initialize(const map<string, string> &options) = 0; | ||||
// finalize opsKernelInfoStore | |||||
// close opsKernelInfoStore | |||||
virtual Status Finalize() = 0; | virtual Status Finalize() = 0; | ||||
virtual Status CreateSession(const std::map<std::string, std::string> &session_options) { return SUCCESS; } | virtual Status CreateSession(const std::map<std::string, std::string> &session_options) { return SUCCESS; } | ||||
@@ -57,18 +55,20 @@ class OpsKernelInfoStore { | |||||
// get all opsKernelInfo | // get all opsKernelInfo | ||||
virtual void GetAllOpsKernelInfo(map<string, OpInfo> &infos) const = 0; | virtual void GetAllOpsKernelInfo(map<string, OpInfo> &infos) const = 0; | ||||
// check whether opsKernelInfoStore is supported based on the operator attribute | |||||
// whether the opsKernelInfoStore is supported based on the operator attribute | |||||
virtual bool CheckSupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason) const = 0; | virtual bool CheckSupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason) const = 0; | ||||
virtual bool CheckAccuracySupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason, | virtual bool CheckAccuracySupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason, | ||||
bool realQuery = false) const { | bool realQuery = false) const { | ||||
return CheckSupported(opDescPtr, un_supported_reason); | return CheckSupported(opDescPtr, un_supported_reason); | ||||
} | } | ||||
// opsFlag opsFlag[0] indicates constant folding is supported or not | |||||
virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag){}; | |||||
// requirement of memory allocation | |||||
// memory allocation requirement | |||||
virtual Status CalcOpRunningParam(Node &node) = 0; | virtual Status CalcOpRunningParam(Node &node) = 0; | ||||
// generate task for op | |||||
// generate task for op。 | |||||
virtual Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) = 0; | virtual Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) = 0; | ||||
// only call fe engine interface to compile single op | // only call fe engine interface to compile single op | ||||
@@ -77,10 +77,10 @@ class OpsKernelInfoStore { | |||||
// load task for op | // load task for op | ||||
virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; } | virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; } | ||||
// only to call aicpu interface for generating task struct | |||||
// only call aicpu interface to generate task struct | |||||
virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } | virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } | ||||
// only to call aicpu interface for generating task struct | |||||
// only call aicpu interface to generate task struct | |||||
virtual Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } | virtual Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -37,6 +37,7 @@ struct RunContext { | |||||
ge::Buffer weightsBuffer; | ge::Buffer weightsBuffer; | ||||
std::vector<rtStream_t> graphStreamList; // all streams of graph, order by ge stream id(0,1,...) | std::vector<rtStream_t> graphStreamList; // all streams of graph, order by ge stream id(0,1,...) | ||||
std::vector<rtEvent_t> graphEventList; // all events of graph, order by ge event id(0,1,...) | std::vector<rtEvent_t> graphEventList; // all events of graph, order by ge event id(0,1,...) | ||||
std::vector<rtLabel_t> graphLabelList; // all labels of graph, order by ge label id(0,1,...) | |||||
}; | }; | ||||
struct Task { | struct Task { | ||||
@@ -19,7 +19,6 @@ | |||||
#include <map> | #include <map> | ||||
#include <string> | #include <string> | ||||
#include "./graph_optimizer_types.h" | #include "./graph_optimizer_types.h" | ||||
#include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
#include "common/opskernel/ops_kernel_info_types.h" | #include "common/opskernel/ops_kernel_info_types.h" | ||||
@@ -39,19 +38,19 @@ class GraphOptimizer { | |||||
// close graphOptimizer | // close graphOptimizer | ||||
virtual Status Finalize() = 0; | virtual Status Finalize() = 0; | ||||
// optimize original graph for FE quant optimization | |||||
// optimize original graph for FE quant optimize | |||||
virtual Status OptimizeGraphPrepare(ComputeGraph &graph) { return SUCCESS; } | virtual Status OptimizeGraphPrepare(ComputeGraph &graph) { return SUCCESS; } | ||||
// optimize original graph used in the graph preparation stage | |||||
// optimize original graph, using in graph preparation stage | |||||
virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; | virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; | ||||
// optimize fused graph | // optimize fused graph | ||||
virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; | virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; | ||||
// optimize the whole graph which will be used after graph merged | |||||
// optimize whole graph, using after graph merged stage | |||||
virtual Status OptimizeWholeGraph(ComputeGraph &graph) = 0; | virtual Status OptimizeWholeGraph(ComputeGraph &graph) = 0; | ||||
// get attributes of graph optimizer | |||||
// get attribute of graph optimizer | |||||
virtual Status GetAttributes(GraphOptimizerAttribute &attrs) const = 0; | virtual Status GetAttributes(GraphOptimizerAttribute &attrs) const = 0; | ||||
// optimize streamed Graph | // optimize streamed Graph | ||||
@@ -19,8 +19,6 @@ | |||||
#include <stdint.h> | #include <stdint.h> | ||||
#include <string> | #include <string> | ||||
using std::string; | |||||
namespace ge { | namespace ge { | ||||
enum OPTIMIZER_SCOPE { | enum OPTIMIZER_SCOPE { | ||||
UNIT = 0, | UNIT = 0, | ||||
@@ -28,7 +26,7 @@ enum OPTIMIZER_SCOPE { | |||||
}; | }; | ||||
struct GraphOptimizerAttribute { | struct GraphOptimizerAttribute { | ||||
string engineName; | |||||
std::string engineName; | |||||
OPTIMIZER_SCOPE scope; | OPTIMIZER_SCOPE scope; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -20,6 +20,7 @@ | |||||
#include <cstdint> | #include <cstdint> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include <set> | |||||
namespace ge { | namespace ge { | ||||
// Option key: graph run mode | // Option key: graph run mode | ||||
@@ -38,9 +39,11 @@ const char *const GE_AICPU_FLAG = "ge.aicpuFlag"; | |||||
const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; | const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; | ||||
const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | ||||
const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | ||||
const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | |||||
// Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | // Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | ||||
const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | ||||
const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | ||||
const char *const OPTION_EXEC_DISABLE_REUSED_MEMORY = "ge.exec.disableReuseMemory"; | |||||
// Option key: memory init | // Option key: memory init | ||||
const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; | const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; | ||||
@@ -141,19 +144,43 @@ const std::string STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; | |||||
// congigure outputDatatype to setting net output type | // congigure outputDatatype to setting net output type | ||||
const std::string OUTPUT_DATATYPE = "ge.outputDatatype"; | const std::string OUTPUT_DATATYPE = "ge.outputDatatype"; | ||||
// congigure opSelectImplmode to setting op select implmode | |||||
const std::string kOpSelectImplmode = "ge.opSelectImplmode"; | |||||
// configure whether to enable hcom parallel by session constructor options param, | // configure whether to enable hcom parallel by session constructor options param, | ||||
// its value should be "0" or "1", default value is "0" | // its value should be "0" or "1", default value is "0" | ||||
const std::string HCOM_PARALLEL = "ge.hcomParallel"; | const std::string HCOM_PARALLEL = "ge.hcomParallel"; | ||||
// configure whether to use dynamic batch size | |||||
const char *const kDynamicBatchSize = "ge.dynamicBatchSize"; | |||||
// configure whether to use dynamic image size | |||||
const char *const kDynamicImageSize = "ge.dynamicImageSize"; | |||||
// Configure auto tune mode, this option only take effect while AUTO_TUNE_FLAG is Y, | // Configure auto tune mode, this option only take effect while AUTO_TUNE_FLAG is Y, | ||||
// example: GA|RL, support configure multiple, split by | | // example: GA|RL, support configure multiple, split by | | ||||
const std::string AUTO_TUNE_MODE = "ge.autoTuneMode"; | const std::string AUTO_TUNE_MODE = "ge.autoTuneMode"; | ||||
// Configure soc version , example: "Ascend310" | |||||
const std::string SOC_VERSION = "ge.socVersion"; | |||||
// Configure core type "VectorEngine", default value is "AIcoreEngine" | // Configure core type "VectorEngine", default value is "AIcoreEngine" | ||||
const std::string CORE_TYPE = "ge.engineType"; | const std::string CORE_TYPE = "ge.engineType"; | ||||
// Configure soc version , example: "Ascend310" | |||||
const std::string SOC_VERSION = "ge.socVersion"; | |||||
// Configure AICORE NUM | |||||
const std::string AICORE_NUM = "ge.aicoreNum"; | |||||
// Configure L1FUSION | |||||
const std::string L1_FUSION = "ge.l1Fusion"; | |||||
// Configure Small Channel flag | |||||
const std::string ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; | |||||
// Configure Compress Weight flag | |||||
const std::string ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight"; | |||||
// Configure fusion switch file path | |||||
const std::string FUSION_SWITCH_FILE = "ge.fusionSwitchFile"; | |||||
// Save original model | // Save original model | ||||
const std::string SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; | const std::string SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; | ||||
@@ -194,6 +221,28 @@ struct TensorInfo { | |||||
DataDesc data; // tensor data | DataDesc data; // tensor data | ||||
ShapeDesc shapeInfo; // tensor shape | ShapeDesc shapeInfo; // tensor shape | ||||
}; | }; | ||||
// for ir build | |||||
namespace ir_option { | |||||
static const char *const INPUT_FORMAT = "input_format"; | |||||
static const char *const INPUT_SHAPE = "input_shape"; | |||||
static const char *const OP_NAME_MAP = "op_name_map"; | |||||
static const char *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize; | |||||
static const char *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize; | |||||
static const char *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str(); | |||||
static const char *const PRECISION_MODE = ge::PRECISION_MODE.c_str(); | |||||
static const char *const EXEC_DISABLE_REUSED_MEMORY = ge::OPTION_EXEC_DISABLE_REUSED_MEMORY; | |||||
static const char *const HEAD_STREAM = ge::HEAD_STREAM.c_str(); | |||||
static const char *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str(); | |||||
static const char *const CORE_TYPE = ge::CORE_TYPE.c_str(); | |||||
static const char *const SOC_VERSION = ge::SOC_VERSION.c_str(); | |||||
// for interface: aclgrphBuildModel | |||||
const std::set<std::string> ir_builder_suppported_options = { | |||||
INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, DYNAMIC_BATCH_SIZE, | |||||
DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||||
AUTO_TUNE_MODE}; | |||||
// for interface: aclgrphBuildInitialize | |||||
const std::set<std::string> global_options = {HEAD_STREAM, CORE_TYPE, SOC_VERSION}; | |||||
} // namespace ir_option | |||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_EXTERNAL_GE_GE_API_TYPES_H_ | #endif // INC_EXTERNAL_GE_GE_API_TYPES_H_ |
@@ -0,0 +1,75 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_EXTERNAL_GE_IR_BUILD_H_ | |||||
#define INC_EXTERNAL_GE_IR_BUILD_H_ | |||||
#include <string> | |||||
#include <map> | |||||
#include <memory> | |||||
#include "graph/graph.h" | |||||
#include "graph/ge_error_codes.h" | |||||
namespace ge { | |||||
struct ModelBufferData { | |||||
std::shared_ptr<uint8_t> data = nullptr; | |||||
uint64_t length; | |||||
}; | |||||
/** | |||||
* @ingroup AscendCL | |||||
* @brief build model.Notice the model is stored in buffer | |||||
* | |||||
* @param global_options[IN] global init params for build | |||||
* @retval GRAPH_SUCCESS The function is successfully executed. | |||||
* @retval OtherValues Failure | |||||
*/ | |||||
graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options); | |||||
/** | |||||
* @ingroup AscendCL | |||||
* @brief build model.Notice the model is stored in buffer | |||||
* | |||||
*/ | |||||
void aclgrphBuildFinalize(); | |||||
/** | |||||
* @ingroup AscendCL | |||||
* @brief build model.Notice the model is stored in buffer | |||||
* | |||||
* @param graph[IN] the graph ready to build | |||||
* @param options[IN] options used for build | |||||
* @param model[OUT] builded model | |||||
* @retval GRAPH_SUCCESS The function is successfully executed. | |||||
* @retval OtherValues Failure | |||||
*/ | |||||
graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options, | |||||
ModelBufferData &model); | |||||
/** | |||||
* @ingroup AscendCL | |||||
* @brief save model buffer to file | |||||
* | |||||
* @param output_file[IN] the file path to be saved | |||||
* @param model[IN] model buffer data | |||||
* @retval GRAPH_SUCCESS The function is successfully executed. | |||||
* @retval OtherValues Failure | |||||
*/ | |||||
graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model); | |||||
}; // namespace ge | |||||
#endif |
@@ -22,7 +22,7 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "external/graph/ge_error_codes.h" | |||||
#include "./ge_error_codes.h" | |||||
using std::make_shared; | using std::make_shared; | ||||
using std::map; | using std::map; | ||||
@@ -22,7 +22,7 @@ | |||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include "external/graph/operator.h" | |||||
#include "./operator.h" | |||||
namespace ge { | namespace ge { | ||||
class GraphImpl; | class GraphImpl; | ||||
@@ -21,8 +21,8 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "external/graph/tensor.h" | |||||
#include "external/graph/types.h" | |||||
#include "./tensor.h" | |||||
#include "./types.h" | |||||
namespace ge { | namespace ge { | ||||
class InferenceContext; | class InferenceContext; | ||||
@@ -69,7 +69,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { | |||||
static std::unique_ptr<InferenceContext> Create(); | static std::unique_ptr<InferenceContext> Create(); | ||||
private: | private: | ||||
InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||||
explicit InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||||
std::shared_ptr<InferenceContextImpl> inference_context_impl_; | std::shared_ptr<InferenceContextImpl> inference_context_impl_; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -23,9 +23,9 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "external/graph/ge_error_codes.h" | |||||
#include "external/graph/inference_context.h" | |||||
#include "external/graph/tensor.h" | |||||
#include "./ge_error_codes.h" | |||||
#include "./inference_context.h" | |||||
#include "./tensor.h" | |||||
#ifndef USER_GE_LOGI | #ifndef USER_GE_LOGI | ||||
#define USER_GE_LOGI(...) | #define USER_GE_LOGI(...) | ||||
@@ -22,8 +22,8 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "external/graph//operator.h" | |||||
#include "external/graph/ge_error_codes.h" | |||||
#include "./operator.h" | |||||
#include "./ge_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
using OpCreator = std::function<Operator(const std::string &)>; | using OpCreator = std::function<Operator(const std::string &)>; | ||||
@@ -22,10 +22,10 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "external/graph/operator.h" | |||||
#include "external/graph/operator_factory.h" | |||||
#include "external/graph/tensor.h" | |||||
#include "external/graph/types.h" | |||||
#include "./operator.h" | |||||
#include "./operator_factory.h" | |||||
#include "./tensor.h" | |||||
#include "./types.h" | |||||
namespace ge { | namespace ge { | ||||
using std::function; | using std::function; | ||||
@@ -60,7 +60,7 @@ class OpReg { | |||||
\ | \ | ||||
private: \ | private: \ | ||||
void __##x() { \ | void __##x() { \ | ||||
OpReg() | |||||
OpReg() | |||||
#define ATTR(x, Type, ...) \ | #define ATTR(x, Type, ...) \ | ||||
N(); \ | N(); \ | ||||
@@ -86,7 +86,7 @@ class OpReg { | |||||
void __attr_##x() { \ | void __attr_##x() { \ | ||||
Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \ | Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \ | ||||
string attr_name(#x); \ | string attr_name(#x); \ | ||||
(void)OpReg() | |||||
(void)OpReg() | |||||
#define REQUIRED_ATTR(x, Type) \ | #define REQUIRED_ATTR(x, Type) \ | ||||
N(); \ | N(); \ | ||||
@@ -112,7 +112,7 @@ class OpReg { | |||||
void __required_attr_##x() { \ | void __required_attr_##x() { \ | ||||
Operator::RequiredAttrRegister(#x); \ | Operator::RequiredAttrRegister(#x); \ | ||||
string attr_name(#x); \ | string attr_name(#x); \ | ||||
(void)OpReg() | |||||
(void)OpReg() | |||||
#define INPUT(x, t) \ | #define INPUT(x, t) \ | ||||
N(); \ | N(); \ | ||||
@@ -137,7 +137,7 @@ class OpReg { | |||||
private: \ | private: \ | ||||
void __input_##x() { \ | void __input_##x() { \ | ||||
Operator::InputRegister(#x); \ | Operator::InputRegister(#x); \ | ||||
(void)OpReg() | |||||
(void)OpReg() | |||||
#define OPTIONAL_INPUT(x, t) \ | #define OPTIONAL_INPUT(x, t) \ | ||||
N(); \ | N(); \ | ||||
@@ -162,7 +162,7 @@ class OpReg { | |||||
private: \ | private: \ | ||||
void __optional_input_##x() { \ | void __optional_input_##x() { \ | ||||
Operator::OptionalInputRegister(#x); \ | Operator::OptionalInputRegister(#x); \ | ||||
(void)OpReg() | |||||
(void)OpReg() | |||||
#define OUTPUT(x, t) \ | #define OUTPUT(x, t) \ | ||||
N(); \ | N(); \ | ||||
@@ -179,7 +179,7 @@ class OpReg { | |||||
private: \ | private: \ | ||||
void __out_##x() { \ | void __out_##x() { \ | ||||
Operator::OutputRegister(#x); \ | Operator::OutputRegister(#x); \ | ||||
(void)OpReg() | |||||
(void)OpReg() | |||||
#define DYNAMIC_INPUT(x, t) \ | #define DYNAMIC_INPUT(x, t) \ | ||||
N(); \ | N(); \ | ||||
@@ -206,7 +206,7 @@ class OpReg { | |||||
\ | \ | ||||
private: \ | private: \ | ||||
void __dy_input_##x() { \ | void __dy_input_##x() { \ | ||||
(void)OpReg() | |||||
(void)OpReg() | |||||
#define DYNAMIC_OUTPUT(x, t) \ | #define DYNAMIC_OUTPUT(x, t) \ | ||||
N(); \ | N(); \ | ||||
@@ -227,18 +227,18 @@ class OpReg { | |||||
\ | \ | ||||
private: \ | private: \ | ||||
void __dy_output_##x() { \ | void __dy_output_##x() { \ | ||||
(void)OpReg() | |||||
(void)OpReg() | |||||
#define PASTE(g_register, y) g_register##y | #define PASTE(g_register, y) g_register##y | ||||
#define __OP_END_IMPL__(x, y) \ | |||||
N(); \ | |||||
} \ | |||||
static_assert( \ | |||||
std::is_same<x, _THIS_TYPE>::value, \ | |||||
"The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ | |||||
} \ | |||||
; \ | |||||
static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \ | |||||
#define __OP_END_IMPL__(x, y) \ | |||||
N(); \ | |||||
} \ | |||||
static_assert( \ | |||||
std::is_same<x, _THIS_TYPE>::value, \ | |||||
"The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ | |||||
} \ | |||||
; \ | |||||
static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \ | |||||
} | } | ||||
#define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__) | #define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__) | ||||
@@ -286,7 +286,7 @@ class OpReg { | |||||
// Common shape inferencer | // Common shape inferencer | ||||
#define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \ | #define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \ | ||||
[](Operator op)->graphStatus { \ | |||||
[](Operator op) -> graphStatus { \ | |||||
auto x_shape = op.GetInputDesc(in_name).GetShape().GetDims(); \ | auto x_shape = op.GetInputDesc(in_name).GetShape().GetDims(); \ | ||||
auto x_type = op.GetInputDesc(in_name).GetDataType(); \ | auto x_type = op.GetInputDesc(in_name).GetDataType(); \ | ||||
TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ | TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ | ||||
@@ -300,7 +300,7 @@ graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape, | |||||
const function<void(const vector<int64_t> &y_shape)> &set_out_shape); | const function<void(const vector<int64_t> &y_shape)> &set_out_shape); | ||||
#define BROADCAST_INFER(in1_name, in2_name, out_name) \ | #define BROADCAST_INFER(in1_name, in2_name, out_name) \ | ||||
[](Operator op)->graphStatus { \ | |||||
[](Operator op) -> graphStatus { \ | |||||
return BroadCastInfer([&]() { return op.GetInputDesc(in1_name).GetShape().GetDims(); }, \ | return BroadCastInfer([&]() { return op.GetInputDesc(in1_name).GetShape().GetDims(); }, \ | ||||
[&]() { return op.GetInputDesc(in2_name).GetShape().GetDims(); }, \ | [&]() { return op.GetInputDesc(in2_name).GetShape().GetDims(); }, \ | ||||
[&](const vector<int64_t> &y_shape) { \ | [&](const vector<int64_t> &y_shape) { \ | ||||
@@ -22,8 +22,8 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "external/graph/ge_error_codes.h" | |||||
#include "external/graph/types.h" | |||||
#include "./ge_error_codes.h" | |||||
#include "./types.h" | |||||
namespace ge { | namespace ge { | ||||
class ShapeImpl; | class ShapeImpl; | ||||
@@ -133,11 +133,13 @@ enum Format { | |||||
FORMAT_FRACTAL_ZZ, | FORMAT_FRACTAL_ZZ, | ||||
FORMAT_FRACTAL_NZ, | FORMAT_FRACTAL_NZ, | ||||
FORMAT_NCDHW, | FORMAT_NCDHW, | ||||
FORMAT_DHWCK, // 3D filter input tensor format | |||||
FORMAT_DHWCN, // 3D filter input tensor format | |||||
FORMAT_NDC1HWC0, | FORMAT_NDC1HWC0, | ||||
FORMAT_FRACTAL_Z_3D, | FORMAT_FRACTAL_Z_3D, | ||||
FORMAT_CN, | FORMAT_CN, | ||||
FORMAT_NC, | FORMAT_NC, | ||||
FORMAT_DHWNC, | |||||
FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format | |||||
FORMAT_RESERVED, | FORMAT_RESERVED, | ||||
FORMAT_ALL | FORMAT_ALL | ||||
}; | }; | ||||
@@ -47,6 +47,12 @@ class Tensor; | |||||
class TBEPluginManager; | class TBEPluginManager; | ||||
} // namespace ge | } // namespace ge | ||||
namespace google { | |||||
namespace protobuf { | |||||
class Message; | |||||
} | |||||
} // namespace google | |||||
namespace domi { | namespace domi { | ||||
Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); | Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); | ||||
Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | ||||
@@ -56,6 +62,8 @@ using google::protobuf::Message; | |||||
class OpRegistrationDataImpl; | class OpRegistrationDataImpl; | ||||
using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>; | using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>; | ||||
using FusionParseParamFunc = | |||||
std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>; | |||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | ||||
public: | public: | ||||
@@ -71,15 +79,20 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||||
OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn); | OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn); | ||||
OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); | |||||
OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | ||||
OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); | OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); | ||||
OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type); | |||||
domi::ImplyType GetImplyType() const; | domi::ImplyType GetImplyType() const; | ||||
std::string GetOmOptype() const; | std::string GetOmOptype() const; | ||||
std::set<std::string> GetOriginOpTypeSet() const; | std::set<std::string> GetOriginOpTypeSet() const; | ||||
domi::FrameworkType GetFrameworkType() const; | domi::FrameworkType GetFrameworkType() const; | ||||
ParseParamFunc GetParseParamFn() const; | ParseParamFunc GetParseParamFn() const; | ||||
FusionParseParamFunc GetFusionParseParamFn() const; | |||||
private: | private: | ||||
std::shared_ptr<OpRegistrationDataImpl> impl_; | std::shared_ptr<OpRegistrationDataImpl> impl_; | ||||
@@ -103,5 +116,27 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { | |||||
namespace ge { | namespace ge { | ||||
using OpRegistrationData = domi::OpRegistrationData; | using OpRegistrationData = domi::OpRegistrationData; | ||||
using OpReceiver = domi::OpReceiver; | using OpReceiver = domi::OpReceiver; | ||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { | |||||
public: | |||||
HostCpuOp() = default; | |||||
virtual ~HostCpuOp() = default; | |||||
virtual graphStatus Compute(Operator &op, const std::map<std::string, const Tensor> &inputs, | |||||
std::map<std::string, Tensor> &outputs) = 0; | |||||
}; | |||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { | |||||
public: | |||||
HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); | |||||
}; | |||||
#define REGISTER_HOST_CPU_OP_BUILDER(name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) | |||||
#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) | |||||
#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ | |||||
static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr __attribute__((unused)) = \ | |||||
::ge::HostCpuOpRegistrar(name, []() -> ::ge::HostCpuOp * { return new (std::nothrow) op(); }) | |||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ |
@@ -22,7 +22,7 @@ | |||||
#define DECLARE_ERRORNO(sysid, modid, name, value) \ | #define DECLARE_ERRORNO(sysid, modid, name, value) \ | ||||
const domi::Status name = \ | const domi::Status name = \ | ||||
((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value)); | |||||
((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value)); | |||||
#define DECLARE_ERRORNO_COMMON(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_COMMON, name, value) | #define DECLARE_ERRORNO_COMMON(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_COMMON, name, value) | ||||
@@ -33,6 +33,7 @@ using Status = uint32_t; | |||||
DECLARE_ERRORNO(0, 0, SUCCESS, 0); | DECLARE_ERRORNO(0, 0, SUCCESS, 0); | ||||
DECLARE_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFFFFFF); | DECLARE_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFFFFFF); | ||||
DECLARE_ERRORNO_COMMON(PARAM_INVALID, 1); // 50331649 | DECLARE_ERRORNO_COMMON(PARAM_INVALID, 1); // 50331649 | ||||
DECLARE_ERRORNO(SYSID_FWK, 1, SCOPE_NOT_CHANGED, 201); | |||||
} // namespace domi | } // namespace domi | ||||
#endif // INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ | #endif // INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ |
@@ -48,6 +48,10 @@ typedef enum tagDomiTensorFormat { | |||||
DOMI_TENSOR_BN_WEIGHT, | DOMI_TENSOR_BN_WEIGHT, | ||||
DOMI_TENSOR_CHWN, // Android NN Depth CONV | DOMI_TENSOR_CHWN, // Android NN Depth CONV | ||||
DOMI_TENSOR_FILTER_HWCK, // filter input tensor format | DOMI_TENSOR_FILTER_HWCK, // filter input tensor format | ||||
DOMI_TENSOR_NDHWC, | |||||
DOMI_TENSOR_NCDHW, | |||||
DOMI_TENSOR_DHWCN, // 3D filter input tensor format | |||||
DOMI_TENSOR_DHWNC, | |||||
DOMI_TENSOR_RESERVED | DOMI_TENSOR_RESERVED | ||||
} domiTensorFormat_t; | } domiTensorFormat_t; | ||||
} // namespace domi | } // namespace domi | ||||
@@ -18,11 +18,13 @@ | |||||
#define INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ | #define INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ | ||||
#include <cstdint> | #include <cstdint> | ||||
#include <unistd.h> | |||||
#include <sys/syscall.h> | |||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "toolchain/slog.h" | #include "toolchain/slog.h" | ||||
#define GE_MODULE_NAME GE | |||||
#define GE_MODULE_NAME static_cast<int>(GE) | |||||
// trace status of log | // trace status of log | ||||
enum TraceStatus { TRACE_INIT = 0, TRACE_RUNNING, TRACE_WAITING, TRACE_STOP }; | enum TraceStatus { TRACE_INIT = 0, TRACE_RUNNING, TRACE_WAITING, TRACE_STOP }; | ||||
@@ -35,15 +37,20 @@ enum TraceStatus { TRACE_INIT = 0, TRACE_RUNNING, TRACE_WAITING, TRACE_STOP }; | |||||
#define GELOGO(...) GE_LOG_OPLOG(GE_MODULE_NAME, __VA_ARGS__) | #define GELOGO(...) GE_LOG_OPLOG(GE_MODULE_NAME, __VA_ARGS__) | ||||
#define GELOGT(VALUE, ...) GE_LOG_TRACE(GE_MODULE_NAME, VALUE, __VA_ARGS__) | #define GELOGT(VALUE, ...) GE_LOG_TRACE(GE_MODULE_NAME, VALUE, __VA_ARGS__) | ||||
inline bool IsLogEnable(int module_name, int log_level) noexcept { | |||||
int32_t enable_event = 0; | |||||
int32_t dlog_level = dlog_getlevel(module_name, &enable_event); | |||||
if (dlog_level <= log_level) { | |||||
inline bool IsLogEnable(int module_name, int log_level) { | |||||
int32_t enable = CheckLogLevel(module_name, log_level); | |||||
// 1:enable, 0:disable | |||||
if (enable == 1) { | |||||
return true; | return true; | ||||
} | } | ||||
return false; | return false; | ||||
} | } | ||||
inline pid_t GetTid() { | |||||
thread_local static pid_t tid = syscall(__NR_gettid); | |||||
return tid; | |||||
} | |||||
#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() | #define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() | ||||
#define GE_TIMESTAMP_END(stage, stage_name) \ | #define GE_TIMESTAMP_END(stage, stage_name) \ | ||||
@@ -68,29 +75,35 @@ inline bool IsLogEnable(int module_name, int log_level) noexcept { | |||||
GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ | GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ | ||||
call_num_of##stage) | call_num_of##stage) | ||||
#define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ | |||||
dlog_error(static_cast<int>(MOD_NAME), "%s: ErrorNo: %d(%s) " fmt, __FUNCTION__, ERROR_CODE, \ | |||||
#define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ | |||||
dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GetTid(), __FUNCTION__, ERROR_CODE, \ | |||||
((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__) | ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__) | ||||
#define GE_LOG_WARN(MOD_NAME, fmt, ...) \ | |||||
if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_WARN)) \ | |||||
dlog_warn(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define GE_LOG_INFO(MOD_NAME, fmt, ...) \ | |||||
if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_INFO)) \ | |||||
dlog_info(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define GE_LOG_DEBUG(MOD_NAME, fmt, ...) \ | |||||
if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_DEBUG)) \ | |||||
dlog_debug(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define GE_LOG_EVENT(MOD_NAME, fmt, ...) dlog_event(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define GE_LOG_WARN(MOD_NAME, fmt, ...) \ | |||||
if (IsLogEnable(MOD_NAME, DLOG_WARN)) dlog_warn(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||||
#define GE_LOG_INFO(MOD_NAME, fmt, ...) \ | |||||
if (IsLogEnable(MOD_NAME, DLOG_INFO)) dlog_info(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||||
#define GE_LOG_DEBUG(MOD_NAME, fmt, ...) \ | |||||
if (IsLogEnable(MOD_NAME, DLOG_DEBUG)) dlog_debug(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||||
#define GE_LOG_EVENT(MOD_NAME, fmt, ...) dlog_event(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||||
#define GE_LOG_OPLOG(MOD_NAME, fmt, ...) \ | #define GE_LOG_OPLOG(MOD_NAME, fmt, ...) \ | ||||
Dlog(static_cast<int>(MOD_NAME), DLOG_OPLOG, "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define GE_LOG_TRACE(MOD_NAME, value, fmt, ...) \ | |||||
do { \ | |||||
TraceStatus stat = value; \ | |||||
const char *const TraceStatStr[] = {"INIT", "RUNNING", "WAITING", "STOP"}; \ | |||||
int idx = static_cast<int>(stat); \ | |||||
char *k = const_cast<char *>("status"); \ | |||||
char *v = const_cast<char *>(TraceStatStr[idx]); \ | |||||
KeyValue kv = {k, v}; \ | |||||
DlogWithKV(static_cast<int>(MOD_NAME), DLOG_TRACE, &kv, 1, "%s:" fmt, __FUNCTION__, ##__VA_ARGS__); \ | |||||
Dlog(MOD_NAME, DLOG_OPLOG, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||||
#define GE_LOG_TRACE(MOD_NAME, value, fmt, ...) \ | |||||
do { \ | |||||
TraceStatus stat = value; \ | |||||
const char *const TraceStatStr[] = {"INIT", "RUNNING", "WAITING", "STOP"}; \ | |||||
int idx = static_cast<int>(stat); \ | |||||
char *k = const_cast<char *>("status"); \ | |||||
char *v = const_cast<char *>(TraceStatStr[idx]); \ | |||||
KeyValue kv = {k, v}; \ | |||||
DlogWithKV(static_cast<int>(MOD_NAME), DLOG_TRACE, &kv, 1, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__); \ | |||||
} while (0) | } while (0) | ||||
// print memory when it is greater than 1KB. | |||||
#define GE_PRINT_DYNAMIC_MEMORY(FUNC, PURPOSE, SIZE) \ | |||||
do { \ | |||||
if ((SIZE) > 1024) { \ | |||||
GELOGI("MallocMemory, func=%s, size=%zu, purpose=%s", (#FUNC), static_cast<size_t>(SIZE), (PURPOSE)); \ | |||||
} \ | |||||
} while (0); | |||||
#endif // INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ | #endif // INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ |
@@ -29,7 +29,18 @@ | |||||
using cce::CC_STATUS_SUCCESS; | using cce::CC_STATUS_SUCCESS; | ||||
using cce::ccStatus_t; | using cce::ccStatus_t; | ||||
#define GE_LOGE(...) DAV_LOGE("GE", __VA_ARGS__) | |||||
#if !defined(__ANDROID__) && !defined(ANDROID) | |||||
#define DOMI_LOGE(...) DAV_LOGE("DOMI", __VA_ARGS__) | |||||
#else | |||||
#include <android/log.h> | |||||
#if defined(BUILD_VERSION_PERF) | |||||
#define DOMI_LOGE(fmt, ...) | |||||
#else | |||||
// The Android system has strict log control. Do not modify the log. | |||||
#define DOMI_LOGE(fmt, ...) \ | |||||
__android_log_print(ANDROID_LOG_ERROR, "NPU_FMK", "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
#endif | |||||
#endif | |||||
// ge marco | // ge marco | ||||
#define GE_LOGI_IF(condition, ...) \ | #define GE_LOGI_IF(condition, ...) \ | ||||
@@ -44,7 +55,7 @@ using cce::ccStatus_t; | |||||
#define GE_LOGE_IF(condition, ...) \ | #define GE_LOGE_IF(condition, ...) \ | ||||
if ((condition)) { \ | if ((condition)) { \ | ||||
GE_LOGE(__VA_ARGS__); \ | |||||
DOMI_LOGE(__VA_ARGS__); \ | |||||
} | } | ||||
// If expr is not SUCCESS, print the log and return the same value | // If expr is not SUCCESS, print the log and return the same value | ||||
@@ -52,7 +63,7 @@ using cce::ccStatus_t; | |||||
do { \ | do { \ | ||||
const ge::Status _status = (expr); \ | const ge::Status _status = (expr); \ | ||||
if (_status != ge::SUCCESS) { \ | if (_status != ge::SUCCESS) { \ | ||||
GE_LOGE(__VA_ARGS__); \ | |||||
DOMI_LOGE(__VA_ARGS__); \ | |||||
return _status; \ | return _status; \ | ||||
} \ | } \ | ||||
} while (0); | } while (0); | ||||
@@ -62,7 +73,7 @@ using cce::ccStatus_t; | |||||
do { \ | do { \ | ||||
const ge::Status _status = (expr); \ | const ge::Status _status = (expr); \ | ||||
if (_status != ge::SUCCESS) { \ | if (_status != ge::SUCCESS) { \ | ||||
GE_LOGE(__VA_ARGS__); \ | |||||
DOMI_LOGE(__VA_ARGS__); \ | |||||
} \ | } \ | ||||
} while (0); | } while (0); | ||||
@@ -75,6 +86,15 @@ using cce::ccStatus_t; | |||||
} \ | } \ | ||||
} while (0); | } while (0); | ||||
// If expr is not GRAPH_SUCCESS, print the log and return FAILED | |||||
#define GE_CHK_GRAPH_STATUS_RET(expr, ...) \ | |||||
do { \ | |||||
if ((expr) != ge::GRAPH_SUCCESS) { \ | |||||
DOMI_LOGE(__VA_ARGS__); \ | |||||
return FAILED; \ | |||||
} \ | |||||
} while (0); | |||||
// If expr is not SUCCESS, print the log and execute a custom statement | // If expr is not SUCCESS, print the log and execute a custom statement | ||||
#define GE_CHK_STATUS_EXEC(expr, exec_expr, ...) \ | #define GE_CHK_STATUS_EXEC(expr, exec_expr, ...) \ | ||||
do { \ | do { \ | ||||
@@ -91,25 +111,11 @@ using cce::ccStatus_t; | |||||
(void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ | (void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ | ||||
(void)msg.append( \ | (void)msg.append( \ | ||||
ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | ||||
GE_LOGE("%s", msg.c_str()); \ | |||||
DOMI_LOGE("%s", msg.c_str()); \ | |||||
return _status; \ | return _status; \ | ||||
} \ | } \ | ||||
} while (0); | } while (0); | ||||
// If expr is not true, print the Info log and return the specified status | |||||
#define GE_CHK_BOOL_RET_STATUS_LOGI(expr, _status, ...) \ | |||||
do { \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
std::string msg; \ | |||||
(void)msg.append(StringUtils::FormatString(__VA_ARGS__)); \ | |||||
(void)msg.append( \ | |||||
StringUtils::FormatString(" Check result false, status: 0x%X %s", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||||
GELOGI("%s", msg.c_str()); \ | |||||
return _status; \ | |||||
} \ | |||||
} while (0); | |||||
// If expr is not true, print the log and return the specified status | // If expr is not true, print the log and return the specified status | ||||
#define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ | #define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ | ||||
do { \ | do { \ | ||||
@@ -124,7 +130,7 @@ using cce::ccStatus_t; | |||||
{ \ | { \ | ||||
bool b = (expr); \ | bool b = (expr); \ | ||||
if (!b) { \ | if (!b) { \ | ||||
GE_LOGE(__VA_ARGS__); \ | |||||
DOMI_LOGE(__VA_ARGS__); \ | |||||
exec_expr; \ | exec_expr; \ | ||||
} \ | } \ | ||||
}; | }; | ||||
@@ -163,7 +169,7 @@ using cce::ccStatus_t; | |||||
{ \ | { \ | ||||
bool b = (expr); \ | bool b = (expr); \ | ||||
if (b) { \ | if (b) { \ | ||||
GE_LOGE(__VA_ARGS__); \ | |||||
DOMI_LOGE(__VA_ARGS__); \ | |||||
exec_expr; \ | exec_expr; \ | ||||
} \ | } \ | ||||
}; | }; | ||||
@@ -182,7 +188,7 @@ using cce::ccStatus_t; | |||||
{ \ | { \ | ||||
bool b = (expr); \ | bool b = (expr); \ | ||||
if (b) { \ | if (b) { \ | ||||
GE_LOGE(__VA_ARGS__); \ | |||||
DOMI_LOGE(__VA_ARGS__); \ | |||||
exec_expr; \ | exec_expr; \ | ||||
return; \ | return; \ | ||||
} \ | } \ | ||||
@@ -193,7 +199,7 @@ using cce::ccStatus_t; | |||||
{ \ | { \ | ||||
bool b = (expr); \ | bool b = (expr); \ | ||||
if (b) { \ | if (b) { \ | ||||
GE_LOGE(__VA_ARGS__); \ | |||||
DOMI_LOGE(__VA_ARGS__); \ | |||||
exec_expr; \ | exec_expr; \ | ||||
return _status; \ | return _status; \ | ||||
} \ | } \ | ||||
@@ -210,62 +216,42 @@ using cce::ccStatus_t; | |||||
// -----------------runtime related macro definitions------------------------------- | // -----------------runtime related macro definitions------------------------------- | ||||
// If expr is not RT_ERROR_NONE, print the log | // If expr is not RT_ERROR_NONE, print the log | ||||
#define GE_CHK_RT(expr) \ | |||||
do { \ | |||||
rtError_t _rt_ret = (expr); \ | |||||
if (_rt_ret != RT_ERROR_NONE) { \ | |||||
GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
} \ | |||||
#define GE_CHK_RT(expr) \ | |||||
do { \ | |||||
rtError_t _rt_ret = (expr); \ | |||||
if (_rt_ret != RT_ERROR_NONE) { \ | |||||
DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
} \ | |||||
} while (0); | } while (0); | ||||
// If expr is not RT_ERROR_NONE, print the log and execute the exec_expr expression | // If expr is not RT_ERROR_NONE, print the log and execute the exec_expr expression | ||||
#define GE_CHK_RT_EXEC(expr, exec_expr) \ | |||||
{ \ | |||||
rtError_t _rt_ret = (expr); \ | |||||
if (_rt_ret != RT_ERROR_NONE) { \ | |||||
GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
exec_expr; \ | |||||
} \ | |||||
#define GE_CHK_RT_EXEC(expr, exec_expr) \ | |||||
{ \ | |||||
rtError_t _rt_ret = (expr); \ | |||||
if (_rt_ret != RT_ERROR_NONE) { \ | |||||
DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
exec_expr; \ | |||||
} \ | |||||
} | } | ||||
// If expr is not RT_ERROR_NONE, print the log and return | // If expr is not RT_ERROR_NONE, print the log and return | ||||
#define GE_CHK_RT_RET(expr) \ | |||||
do { \ | |||||
rtError_t _rt_ret = (expr); \ | |||||
if (_rt_ret != RT_ERROR_NONE) { \ | |||||
GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
return ge::RT_FAILED; \ | |||||
} \ | |||||
#define GE_CHK_RT_RET(expr) \ | |||||
do { \ | |||||
rtError_t _rt_ret = (expr); \ | |||||
if (_rt_ret != RT_ERROR_NONE) { \ | |||||
DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
return ge::RT_FAILED; \ | |||||
} \ | |||||
} while (0); | } while (0); | ||||
// ------------------------cce related macro definitions---------------------------- | // ------------------------cce related macro definitions---------------------------- | ||||
// If expr is not CC_STATUS_SUCCESS, print the log | // If expr is not CC_STATUS_SUCCESS, print the log | ||||
#define GE_CHK_CCE(expr) \ | |||||
do { \ | |||||
ccStatus_t _cc_ret = (expr); \ | |||||
if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||||
GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
} \ | |||||
} while (0); | |||||
// If expr is not CC_STATUS_SUCCESS, print the log and execute the exec_expr expression | |||||
#define GE_CHK_CCE_EXEC(expr, exec_expr) \ | |||||
do { \ | |||||
ccStatus_t _cc_ret = (expr); \ | |||||
if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||||
GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
exec_expr; \ | |||||
} \ | |||||
} while (0); | |||||
// If expr is not CC_STATUS_SUCCESS, print the log and return | |||||
#define GE_CHK_CCE_RET(expr) \ | |||||
do { \ | |||||
ccStatus_t _cc_ret = (expr); \ | |||||
if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||||
GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
return ge::CCE_FAILED; \ | |||||
} \ | |||||
#define GE_CHK_CCE(expr) \ | |||||
do { \ | |||||
ccStatus_t _cc_ret = (expr); \ | |||||
if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||||
DOMI_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
} \ | |||||
} while (0); | } while (0); | ||||
// If expr is true, execute exec_expr without printing logs | // If expr is true, execute exec_expr without printing logs | ||||
@@ -281,37 +267,8 @@ using cce::ccStatus_t; | |||||
try { \ | try { \ | ||||
exec_expr0; \ | exec_expr0; \ | ||||
} catch (const std::bad_alloc &) { \ | } catch (const std::bad_alloc &) { \ | ||||
GE_LOGE("Make shared failed"); \ | |||||
DOMI_LOGE("Make shared failed"); \ | |||||
exec_expr1; \ | exec_expr1; \ | ||||
} | } | ||||
#define GE_CHECK_INT32_MUL_OVERFLOW(a, b, ...) \ | |||||
do { \ | |||||
if ((a) > 0) { \ | |||||
if ((b) > 0) { \ | |||||
if ((a) > (INT32_MAX / (b))) { \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
} else { \ | |||||
if ((b) < (INT32_MIN / (a))) { \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
} \ | |||||
} else { \ | |||||
if ((b) > 0) { \ | |||||
if ((a) < (INT32_MAX / (b))) { \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
} else { \ | |||||
if (((a) != 0) && ((b) < (INT32_MAX / (a)))) { \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
} \ | |||||
} \ | |||||
} while (0); | |||||
#endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | #endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ |
@@ -152,7 +152,6 @@ GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_INVALID, 11, "Get computeGraph by g | |||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_TINY_CAL_CHECK_FAILED, 15, "Check tiny calibration failed."); // 1343242255 | |||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 | GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 | GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 | ||||
@@ -204,15 +203,16 @@ GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_GET_GRAPH_REBUILD_FAILED, 60, | |||||
GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_SET_GRAPH_FINISH_REBUILD_GRAPH_FAILED, 61, | GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_SET_GRAPH_FINISH_REBUILD_GRAPH_FAILED, 61, | ||||
"Failed set graph finish rebuild in node searcher."); // 1343242301 | "Failed set graph finish rebuild in node searcher."); // 1343242301 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_OP_PASS_FAILED, 62, "Failed to run variable pass."); // 1343242302 | GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_OP_PASS_FAILED, 62, "Failed to run variable pass."); // 1343242302 | ||||
// Optimize errocode | |||||
GE_ERRORNO_GRAPH(TO_BE_DELETED, 200, "The node of the graph to be deleted."); | |||||
GE_ERRORNO_GRAPH(NOT_CHANGED, 201, "NThe node of the graph not changed."); | |||||
// Engine_manager module error code definition | // Engine_manager module error code definition | ||||
GE_ERRORNO_ENGINE(GE_ENG_INIT_FAILED, 0, "Failed to initialize engine."); // 1343246336 | GE_ERRORNO_ENGINE(GE_ENG_INIT_FAILED, 0, "Failed to initialize engine."); // 1343246336 | ||||
GE_ERRORNO_ENGINE(GE_ENG_FINALIZE_FAILED, 1, "Engine finalize failed."); // 1343246337 | GE_ERRORNO_ENGINE(GE_ENG_FINALIZE_FAILED, 1, "Engine finalize failed."); // 1343246337 | ||||
GE_ERRORNO_ENGINE(GE_ENG_MEMTYPE_ERROR, 2, "Memory type HBM is necessary when engine is in device"); // 1343246338 | GE_ERRORNO_ENGINE(GE_ENG_MEMTYPE_ERROR, 2, "Memory type HBM is necessary when engine is in device"); // 1343246338 | ||||
// Optimize errocode | |||||
GE_ERRORNO_GRAPH(TO_BE_DELETED, 63, "The node of the graph to be deleted."); // 1343242303 | |||||
GE_ERRORNO_GRAPH(NOT_CHANGED, 64, "The node of the graph no changed."); // 1343242304 | |||||
// Ops module error code definition | // Ops module error code definition | ||||
GE_ERRORNO_OPS(GE_OPS_KERNEL_STORE_INIT_FAILED, 0, "Failed to initialize OpsKernelInfoStore."); // 1343250432 | GE_ERRORNO_OPS(GE_OPS_KERNEL_STORE_INIT_FAILED, 0, "Failed to initialize OpsKernelInfoStore."); // 1343250432 | ||||
GE_ERRORNO_OPS(GE_OPS_GRAPH_OPTIMIZER_INIT_FAILED, 1, "Failed to initialize GraphOptimizer."); // 1343250433 | GE_ERRORNO_OPS(GE_OPS_GRAPH_OPTIMIZER_INIT_FAILED, 1, "Failed to initialize GraphOptimizer."); // 1343250433 | ||||
@@ -24,8 +24,7 @@ | |||||
#include "common/fmk_error_codes.h" | #include "common/fmk_error_codes.h" | ||||
#include "ge/ge_api_error_codes.h" | #include "ge/ge_api_error_codes.h" | ||||
using std::string; | |||||
#include "external/graph/types.h" | |||||
namespace ge { | namespace ge { | ||||
enum RuntimeType { HOST = 0, DEVICE = 1 }; | enum RuntimeType { HOST = 0, DEVICE = 1 }; | ||||
@@ -56,7 +55,7 @@ struct DataBuffer { | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief External inputdata | |||||
/// @brief External input data | |||||
/// | /// | ||||
struct InputData { | struct InputData { | ||||
uint32_t index; // Index of input data | uint32_t index; // Index of input data | ||||
@@ -65,13 +64,14 @@ struct InputData { | |||||
uint32_t model_id; // Model ID required for data processing | uint32_t model_id; // Model ID required for data processing | ||||
uint64_t request_id = 0; // Request ID | uint64_t request_id = 0; // Request ID | ||||
std::vector<DataBuffer> blobs; // Actual input data, currently only supports one input | std::vector<DataBuffer> blobs; // Actual input data, currently only supports one input | ||||
bool is_dynamic_batch = false; // Whether is dynamic batch size scene, default:false | |||||
std::string batch_label; // Gear used for current inference in dynamic batch scene | |||||
}; | }; | ||||
// The definition of output result structure | |||||
/// Output result structure definition | |||||
struct OutputData { | struct OutputData { | ||||
uint32_t index; // Index of input data | uint32_t index; // Index of input data | ||||
uint32_t model_id; // The model ID corresponding to the processing result | uint32_t model_id; // The model ID corresponding to the processing result | ||||
/// Output data cache, arranged in sequence of output operators. | /// Output data cache, arranged in sequence of output operators. | ||||
/// If the operator has multiple outputs, | /// If the operator has multiple outputs, | ||||
/// the data buffer order of the operator is the same as that defined in the | /// the data buffer order of the operator is the same as that defined in the | ||||
@@ -142,11 +142,31 @@ struct Options { | |||||
bool deployMode; | bool deployMode; | ||||
bool isAICPUMode; | bool isAICPUMode; | ||||
bool enable_atomic; | bool enable_atomic; | ||||
string podName; | |||||
std::string podName; | |||||
int64_t rankId; | int64_t rankId; | ||||
string rankTableFile; | |||||
std::string rankTableFile; | |||||
int32_t ge_hccl_flag = 0; | int32_t ge_hccl_flag = 0; | ||||
int32_t physical_device_id; | int32_t physical_device_id; | ||||
}; | }; | ||||
// Profiling info of task | |||||
struct TaskDescInfo { | |||||
std::string op_name; | |||||
uint32_t block_dim; | |||||
uint32_t task_id; | |||||
uint32_t stream_id; | |||||
}; | |||||
// Profiling info of graph | |||||
struct ComputeGraphDescInfo { | |||||
std::string op_name; | |||||
std::string op_type; | |||||
std::vector<Format> input_format; | |||||
std::vector<std::vector<int64_t>> input_shape; | |||||
std::vector<DataType> input_data_type; | |||||
std::vector<Format> output_format; | |||||
std::vector<std::vector<int64_t>> output_shape; | |||||
std::vector<DataType> output_data_type; | |||||
}; | |||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_FRAMEWORK_COMMON_GE_TYPES_H_ | #endif // INC_FRAMEWORK_COMMON_GE_TYPES_H_ |
@@ -19,7 +19,6 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <memory> | |||||
#include "common/fmk_types.h" | #include "common/fmk_types.h" | ||||
#include "common/helper/om_file_helper.h" | #include "common/helper/om_file_helper.h" | ||||
@@ -33,36 +32,41 @@ class ModelHelper { | |||||
ModelHelper() = default; | ModelHelper() = default; | ||||
~ModelHelper(); | ~ModelHelper(); | ||||
Status SaveToOmModel(const GeModelPtr &ge_model, const SaveParam &save_param, const std::string &output_file); | |||||
Status SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::string &output_file); | |||||
Status LoadModel(const ge::ModelData &model_data); | |||||
Status SaveToOmModel(const GeModelPtr& ge_model, const SaveParam& save_param, const std::string& output_file, | |||||
ge::ModelBufferData& model); | |||||
Status SaveOriginalGraphToOmModel(const ge::Graph& graph, const std::string& output_file); | |||||
Status LoadModel(const ge::ModelData& model_data); | |||||
Status GetModelBufferData(ge::ModelBufferData& model); | |||||
ModelFileHeader *GetFileHeader() { return file_header_; } | |||||
ModelFileHeader* GetFileHeader() { return file_header_; } | |||||
GeModelPtr GetGeModel(); | GeModelPtr GetGeModel(); | ||||
void SetSaveMode(bool val) { is_offline_ = val; } | |||||
bool GetSaveMode(void) const { return is_offline_; } | |||||
static Status TransModelToGeModel(const ModelPtr &model, GeModelPtr &ge_model); | |||||
static Status TransGeModelToModel(const GeModelPtr &geModelPtr, ModelPtr &modelPtr); | |||||
static Status TransModelToGeModel(const ModelPtr& model, GeModelPtr& ge_model); | |||||
static Status TransGeModelToModel(const GeModelPtr& geModelPtr, ModelPtr& modelPtr); | |||||
private: | private: | ||||
bool is_assign_model_ = false; | bool is_assign_model_ = false; | ||||
ModelFileHeader *file_header_ = nullptr; | |||||
bool is_offline_ = true; | |||||
ModelFileHeader* file_header_ = nullptr; | |||||
// Encrypted model need delete temp model and unencrypted model need not delete model | // Encrypted model need delete temp model and unencrypted model need not delete model | ||||
uint8_t *model_addr_tmp_ = nullptr; | |||||
uint8_t* model_addr_tmp_ = nullptr; | |||||
uint32_t model_len_tmp_ = 0; | uint32_t model_len_tmp_ = 0; | ||||
GeModelPtr model_; | GeModelPtr model_; | ||||
ModelHelper(const ModelHelper &); | |||||
ModelHelper &operator=(const ModelHelper &); | |||||
Status GenerateGeModel(OmFileLoadHelper &om_load_helper); | |||||
Status LoadModelData(OmFileLoadHelper &om_load_helper); | |||||
void SetModelToGeModel(ge::Model &model); | |||||
Status LoadWeights(OmFileLoadHelper &om_load_helper); | |||||
Status LoadTask(OmFileLoadHelper &om_load_helper); | |||||
Status LoadTBEKernelStore(OmFileLoadHelper &om_load_helper); | |||||
ModelHelper(const ModelHelper&); | |||||
ModelHelper& operator=(const ModelHelper&); | |||||
Status GenerateGeModel(OmFileLoadHelper& om_load_helper); | |||||
Status LoadModelData(OmFileLoadHelper& om_load_helper); | |||||
void SetModelToGeModel(ge::Model& model); | |||||
Status LoadWeights(OmFileLoadHelper& om_load_helper); | |||||
Status LoadTask(OmFileLoadHelper& om_load_helper); | |||||
Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); | |||||
Status ReleaseLocalModelData() noexcept; | Status ReleaseLocalModelData() noexcept; | ||||
Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, ModelPartitionType type, | |||||
const uint8_t *data, size_t size); | |||||
Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, ModelPartitionType type, | |||||
const uint8_t* data, size_t size); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ | #endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ |
@@ -20,10 +20,12 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "external/ge/ge_ir_build.h" | |||||
#include "framework/common/fmk_types.h" | #include "framework/common/fmk_types.h" | ||||
#include "framework/common/ge_types.h" | |||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "framework/common/ge_types.h" | |||||
using ProcParam = struct PROC_PARAM; | |||||
using std::string; | using std::string; | ||||
using std::vector; | using std::vector; | ||||
@@ -80,9 +82,10 @@ class OmFileSaveHelper { | |||||
const std::vector<ModelPartition> &GetModelPartitions() const; | const std::vector<ModelPartition> &GetModelPartitions() const; | ||||
Status SaveModel(const SaveParam &save_param, const char *target_file); | |||||
Status SaveModel(const SaveParam &save_param, const char *target_file, ge::ModelBufferData &model, | |||||
bool is_offline = true); | |||||
Status SaveModelToFile(const char *output_file); | |||||
Status SaveModelToFile(const char *output_file, ge::ModelBufferData &model, bool is_offline = true); | |||||
ModelFileHeader model_header_; | ModelFileHeader model_header_; | ||||
OmFileContext context_; | OmFileContext context_; | ||||
@@ -120,4 +120,4 @@ class L2CacheOptimize { | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ | |||||
#endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ |
@@ -649,6 +649,8 @@ extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_M | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_EVENT_NUM; | extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_EVENT_NUM; | ||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_LABEL_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_MEMORY_SIZE; | extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_MEMORY_SIZE; | ||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_WEIGHT_SIZE; | extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_WEIGHT_SIZE; | ||||
@@ -801,6 +803,8 @@ extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_N | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | ||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | ||||
// For constant folding | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||||
} // namespace domi | } // namespace domi | ||||
#endif // INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ | #endif // INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ |
@@ -17,11 +17,12 @@ | |||||
#ifndef INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | #ifndef INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | ||||
#define INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | #define INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | ||||
#include <google/protobuf/map.h> | |||||
#include <unordered_map> | #include <unordered_map> | ||||
#include <string> | #include <string> | ||||
#include <google/protobuf/map.h> | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "proto/om.pb.h" | #include "proto/om.pb.h" | ||||
using domi::AttrDef; | using domi::AttrDef; | ||||
@@ -18,7 +18,6 @@ | |||||
#define INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ | #define INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ | ||||
#include <cce/dnn.h> | #include <cce/dnn.h> | ||||
#include <memory> | #include <memory> | ||||
#include <vector> | #include <vector> | ||||
@@ -56,6 +55,15 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_TR | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_DATA_INPUT; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_DATA_INPUT; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_PRED_INPUT; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_PRED_INPUT; | ||||
// FunctionOp | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t IF_COND_INPUT; | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_START_INPUT; | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_LIMIT_INPUT; | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DELTA_INPUT; | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DATA_INPUT; | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int NORMAL_TENSOR_SIZE; | |||||
class OpUtils { | class OpUtils { | ||||
public: | public: | ||||
/// | /// | ||||
@@ -164,15 +172,23 @@ class OpUtils { | |||||
/// | /// | ||||
static Status ConvertAippParams(const GeAttrValue::NamedAttrs &aipp_attr, domi::AippOpParams *aipp_params); | static Status ConvertAippParams(const GeAttrValue::NamedAttrs &aipp_attr, domi::AippOpParams *aipp_params); | ||||
static Status TransferDim(const std::vector<int64_t> &dim, std::vector<int64_t> &dim_vector); | static Status TransferDim(const std::vector<int64_t> &dim, std::vector<int64_t> &dim_vector); | ||||
static void SliceData(std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output, int64_t begin, | |||||
int64_t out_dim, int64_t stride); | |||||
template <typename T> | |||||
static void SliceData(const std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output, | |||||
int64_t begin, int64_t out_dim, int64_t stride); | |||||
template <typename T> | |||||
static Status SetDataByDataType(size_t out_size, const std::vector<char *> &chunk_input, | |||||
const std::vector<char *> &chunk_output, GeTensor *output); | |||||
template <typename T> | |||||
static Status SetOutputSliceDataByDataType(void *data, int64_t data_size, const std::vector<int64_t> &input_dims, | |||||
const std::vector<int64_t> &begin, const std::vector<int64_t> &output_dims, | |||||
ge::GeTensor *output, const std::vector<int64_t> &stride); | |||||
static Status SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector<int64_t> &input_dims, | static Status SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector<int64_t> &input_dims, | ||||
std::vector<int64_t> &begin, std::vector<int64_t> &output_dims, ge::GeTensor *output, | std::vector<int64_t> &begin, std::vector<int64_t> &output_dims, ge::GeTensor *output, | ||||
std::vector<int64_t> &stride); | std::vector<int64_t> &stride); | ||||
/// | /// | ||||
/// @ingroup domi_omg | /// @ingroup domi_omg | ||||
/// @brief Convert the convolution weight data from [h, w, c, k] to [k, c, h, w] | |||||
/// @brief Convert the convolutional weight data from [h, w, c, k] to [k, c, h, w] | |||||
/// @param [in] input Weight data in HWCK format | /// @param [in] input Weight data in HWCK format | ||||
/// @param [in] H value of H dimension | /// @param [in] H value of H dimension | ||||
/// @param [in] W value of W dimension | /// @param [in] W value of W dimension | ||||
@@ -183,7 +199,7 @@ class OpUtils { | |||||
static void TransDataHWCK2KCHW(const void *input, int64_t H, int64_t W, int64_t C, int64_t K, void **output); | static void TransDataHWCK2KCHW(const void *input, int64_t H, int64_t W, int64_t C, int64_t K, void **output); | ||||
/// | /// | ||||
/// @ingroup domi_omg | /// @ingroup domi_omg | ||||
/// @brief Converts the convolution weight data from [k, c, h, w] to [h, w, c, k]. | |||||
/// @brief Converts the convolutional weight data from [k, c, h, w] to [h, w, c, k]. | |||||
/// @param [in] input Weight data in HWCK format | /// @param [in] input Weight data in HWCK format | ||||
/// @param [in] K value of K dimension | /// @param [in] K value of K dimension | ||||
/// @param [in] C value of C dimension | /// @param [in] C value of C dimension | ||||
@@ -222,7 +238,6 @@ using CceTensorDescriptorPtr = std::shared_ptr<CceTensorDescriptor>; | |||||
class CceTensorDescriptor { | class CceTensorDescriptor { | ||||
public: | public: | ||||
explicit CceTensorDescriptor(ccTensorDescriptor_t cc_tensor); | explicit CceTensorDescriptor(ccTensorDescriptor_t cc_tensor); | ||||
CceTensorDescriptor(const CceTensorDescriptor &) = delete; | CceTensorDescriptor(const CceTensorDescriptor &) = delete; | ||||
CceTensorDescriptor &operator=(const CceTensorDescriptor &) = delete; | CceTensorDescriptor &operator=(const CceTensorDescriptor &) = delete; | ||||
@@ -22,7 +22,7 @@ | |||||
#include <math.h> | #include <math.h> | ||||
#include <stdint.h> | #include <stdint.h> | ||||
namespace domi { | |||||
namespace ge { | |||||
// general | // general | ||||
const float DEFAULT_ALPHA_VALUE = 1.0; | const float DEFAULT_ALPHA_VALUE = 1.0; | ||||
const float DEFAULT_BETA_VALUE = 0.0; | const float DEFAULT_BETA_VALUE = 0.0; | ||||
@@ -421,5 +421,5 @@ const uint32_t MULTI_SHAPE_INPUT_NUM = 2; | |||||
// Shufflechannel | // Shufflechannel | ||||
const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; | const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; | ||||
} // namespace domi | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ | #endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ |
@@ -25,7 +25,7 @@ | |||||
/// MAKE_GUARD([&] { Release Resource 1 }) | /// MAKE_GUARD([&] { Release Resource 1 }) | ||||
/// Acquire Resource 2 | /// Acquire Resource 2 | ||||
// MAKE_GUARD([&] { Release Resource 2 }) | // MAKE_GUARD([&] { Release Resource 2 }) | ||||
#define GE_MAKE_GUARD(var, callback) ge::ScopeGuard make_guard_##var(callback) | |||||
#define GE_MAKE_GUARD(var, callback) ScopeGuard make_guard_##var(callback) | |||||
#define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() | #define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() | ||||
namespace ge { | namespace ge { | ||||
@@ -156,6 +156,7 @@ REGISTER_OPTYPE_DECLARE(GATHER, "Gather"); | |||||
REGISTER_OPTYPE_DECLARE(REALDIV, "RealDiv"); | REGISTER_OPTYPE_DECLARE(REALDIV, "RealDiv"); | ||||
REGISTER_OPTYPE_DECLARE(PACK, "Pack"); | REGISTER_OPTYPE_DECLARE(PACK, "Pack"); | ||||
REGISTER_OPTYPE_DECLARE(SLICE, "Slice"); | REGISTER_OPTYPE_DECLARE(SLICE, "Slice"); | ||||
REGISTER_OPTYPE_DECLARE(SLICED, "SliceD"); | |||||
REGISTER_OPTYPE_DECLARE(FLOORDIV, "FloorDiv"); | REGISTER_OPTYPE_DECLARE(FLOORDIV, "FloorDiv"); | ||||
REGISTER_OPTYPE_DECLARE(SQUEEZE, "Squeeze"); | REGISTER_OPTYPE_DECLARE(SQUEEZE, "Squeeze"); | ||||
REGISTER_OPTYPE_DECLARE(STRIDEDSLICE, "StridedSlice"); | REGISTER_OPTYPE_DECLARE(STRIDEDSLICE, "StridedSlice"); | ||||
@@ -188,6 +189,19 @@ REGISTER_OPTYPE_DECLARE(REFNEXTITERATION, "RefNextIteration"); | |||||
REGISTER_OPTYPE_DECLARE(EXIT, "Exit"); | REGISTER_OPTYPE_DECLARE(EXIT, "Exit"); | ||||
REGISTER_OPTYPE_DECLARE(REFEXIT, "RefExit"); | REGISTER_OPTYPE_DECLARE(REFEXIT, "RefExit"); | ||||
REGISTER_OPTYPE_DECLARE(CONTROLTRIGGER, "ControlTrigger"); | REGISTER_OPTYPE_DECLARE(CONTROLTRIGGER, "ControlTrigger"); | ||||
REGISTER_OPTYPE_DECLARE(SYMBOLICGRADIENT, "SymbolicGradient"); | |||||
REGISTER_OPTYPE_DECLARE(REMOTECALL, "RemoteCall"); | |||||
REGISTER_OPTYPE_DECLARE(_IF, "_If"); | |||||
REGISTER_OPTYPE_DECLARE(STATELESSIF, "StatelessIf"); | |||||
REGISTER_OPTYPE_DECLARE(IF, "If"); | |||||
REGISTER_OPTYPE_DECLARE(CASE, "Case"); | |||||
REGISTER_OPTYPE_DECLARE(_WHILE, "_While"); | |||||
REGISTER_OPTYPE_DECLARE(WHILE, "While"); | |||||
REGISTER_OPTYPE_DECLARE(STATELESSWHILE, "StatelessWhile"); | |||||
REGISTER_OPTYPE_DECLARE(FOR, "For"); | |||||
REGISTER_OPTYPE_DECLARE(PARTITIONEDCALL, "PartitionedCall"); | |||||
REGISTER_OPTYPE_DECLARE(STATEFULPARTITIONEDCALL, "StatefulPartitionedCall"); | |||||
REGISTER_OPTYPE_DECLARE(FAKEPARAM, "FakeParam"); | |||||
REGISTER_OPTYPE_DECLARE(TRANSPOSE, "Transpose"); | REGISTER_OPTYPE_DECLARE(TRANSPOSE, "Transpose"); | ||||
REGISTER_OPTYPE_DECLARE(TRANSPOSED, "TransposeD"); | REGISTER_OPTYPE_DECLARE(TRANSPOSED, "TransposeD"); | ||||
REGISTER_OPTYPE_DECLARE(CAST, "Cast"); | REGISTER_OPTYPE_DECLARE(CAST, "Cast"); | ||||
@@ -424,6 +438,12 @@ REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | |||||
REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | ||||
REGISTER_OPTYPE_DECLARE(SEND, "Send"); | REGISTER_OPTYPE_DECLARE(SEND, "Send"); | ||||
REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | ||||
REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | |||||
REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | |||||
REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); | |||||
REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | |||||
REGISTER_OPTYPE_DECLARE(ATOMICADDRCLEAN, "AtomicAddrClean"); | REGISTER_OPTYPE_DECLARE(ATOMICADDRCLEAN, "AtomicAddrClean"); | ||||
REGISTER_OPTYPE_DECLARE(ABS_GRAD, "AbsGrad"); | REGISTER_OPTYPE_DECLARE(ABS_GRAD, "AbsGrad"); | ||||
@@ -1032,14 +1052,11 @@ struct BasicInfo { | |||||
uint32_t workspace_size; // workspace | uint32_t workspace_size; // workspace | ||||
uint32_t total_size; // total memory size | uint32_t total_size; // total memory size | ||||
}; | }; | ||||
#pragma pack() // Cancels single-byte alignment | #pragma pack() // Cancels single-byte alignment | ||||
} // namespace ge | } // namespace ge | ||||
namespace domi { | namespace domi { | ||||
/// @brief Data structure definition related to task sinking | /// @brief Data structure definition related to task sinking | ||||
/// Build model | |||||
enum BuildMode { | enum BuildMode { | ||||
GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | ||||
GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | ||||
@@ -30,6 +30,14 @@ | |||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "mmpa/mmpa_api.h" | #include "mmpa/mmpa_api.h" | ||||
#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||||
do { \ | |||||
if (size <= 0) { \ | |||||
DOMI_LOGE(param[#size] is not a positive number); \ | |||||
return PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | |||||
#define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ | #define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ | ||||
{ \ | { \ | ||||
bool b = (expr); \ | bool b = (expr); \ | ||||
@@ -50,21 +58,6 @@ | |||||
if (var) GE_CHK_RT(rtStreamDestroy(var)); \ | if (var) GE_CHK_RT(rtStreamDestroy(var)); \ | ||||
}); | }); | ||||
#define GE_MAKE_GUARD_RTEVENT(var) \ | |||||
GE_MAKE_GUARD(var, [&] { \ | |||||
if (var) GE_CHK_RT(rtEventDestroy(var)); \ | |||||
}); | |||||
#define GE_MAKE_GUARD_TENSOR(var) \ | |||||
GE_MAKE_GUARD(var, [&] { \ | |||||
if (var) GE_CHK_CCE(ccDestroyTensorDescriptor(&var)); \ | |||||
}); | |||||
#define GE_MAKE_GUARD_FILTER_DESC(var) \ | |||||
GE_MAKE_GUARD(var, [&] { \ | |||||
if (var) GE_CHK_CCE(ccDestroyFilterDescriptor(&var)); \ | |||||
}); | |||||
// For propagating errors when calling a function. | // For propagating errors when calling a function. | ||||
#define GE_RETURN_IF_ERROR(expr) \ | #define GE_RETURN_IF_ERROR(expr) \ | ||||
do { \ | do { \ | ||||
@@ -76,7 +69,7 @@ | |||||
do { \ | do { \ | ||||
const ::ge::Status _status = (expr); \ | const ::ge::Status _status = (expr); \ | ||||
if (_status) { \ | if (_status) { \ | ||||
GE_LOGE(__VA_ARGS__); \ | |||||
DOMI_LOGE(__VA_ARGS__); \ | |||||
return _status; \ | return _status; \ | ||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
@@ -85,7 +78,7 @@ | |||||
#define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ | #define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ | ||||
do { \ | do { \ | ||||
if (condition) { \ | if (condition) { \ | ||||
GE_LOGE(__VA_ARGS__); \ | |||||
DOMI_LOGE(__VA_ARGS__); \ | |||||
return ge::FAILED; \ | return ge::FAILED; \ | ||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
@@ -95,7 +88,7 @@ | |||||
do { \ | do { \ | ||||
bool _condition = (condition); \ | bool _condition = (condition); \ | ||||
if (!_condition) { \ | if (!_condition) { \ | ||||
GE_LOGE(__VA_ARGS__); \ | |||||
DOMI_LOGE(__VA_ARGS__); \ | |||||
return ge::FAILED; \ | return ge::FAILED; \ | ||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
@@ -104,7 +97,7 @@ | |||||
#define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ | #define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ | ||||
do { \ | do { \ | ||||
if (condition) { \ | if (condition) { \ | ||||
GE_LOGE(__VA_ARGS__); \ | |||||
DOMI_LOGE(__VA_ARGS__); \ | |||||
return ge::PARAM_INVALID; \ | return ge::PARAM_INVALID; \ | ||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
@@ -114,111 +107,90 @@ | |||||
do { \ | do { \ | ||||
bool _condition = (condition); \ | bool _condition = (condition); \ | ||||
if (!_condition) { \ | if (!_condition) { \ | ||||
GE_LOGE(__VA_ARGS__); \ | |||||
DOMI_LOGE(__VA_ARGS__); \ | |||||
return ge::PARAM_INVALID; \ | return ge::PARAM_INVALID; \ | ||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
// Check if the parameter is null. If yes, return PARAM_INVALID and record the error | // Check if the parameter is null. If yes, return PARAM_INVALID and record the error | ||||
#define GE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GE_LOGE(param[#val] must not be null.); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE(param[#val] must not be null.); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the parameter is null. If yes, return PARAM_INVALID and record the error | |||||
#define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GE_LOGE(param[#val] must not be null.); \ | |||||
return; \ | |||||
} \ | |||||
// Check if the parameter is null. If yes, just return and record the error | |||||
#define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE(param[#val] must not be null.); \ | |||||
return; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check whether the parameter is null. If so, execute the exec_expr expression and record the error log | // Check whether the parameter is null. If so, execute the exec_expr expression and record the error log | ||||
#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GE_LOGE(param[#val] must not be null.); \ | |||||
exec_expr; \ | |||||
} \ | |||||
#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE(param[#val] must not be null.); \ | |||||
exec_expr; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check whether the parameter is null. If yes, return directly and record the error log | // Check whether the parameter is null. If yes, return directly and record the error log | ||||
#define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GE_LOGE(param[#val] must not be null.); \ | |||||
return; \ | |||||
} \ | |||||
#define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE(param[#val] must not be null.); \ | |||||
return; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the parameter is null. If yes, return false and record the error log | // Check if the parameter is null. If yes, return false and record the error log | ||||
#define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GE_LOGE(param[#val] must not be null.); \ | |||||
return false; \ | |||||
} \ | |||||
#define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE(param[#val] must not be null.); \ | |||||
return false; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the parameter is out of bounds | // Check if the parameter is out of bounds | ||||
#define GE_CHECK_SIZE(size) \ | |||||
do { \ | |||||
if (size == 0) { \ | |||||
GE_LOGE(param[#size] is out of range); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_SIZE(size) \ | |||||
do { \ | |||||
if (size == 0) { \ | |||||
DOMI_LOGE(param[#size] is out of range); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Macros that define the size variable | |||||
#define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \ | |||||
uint32_t _var_name; \ | |||||
do { \ | |||||
uint32_t _expr_size = (_expr); \ | |||||
uint32_t _sizeof_size = (_sizeof); \ | |||||
if (_expr_size > (0xffffffff) / _sizeof_size) { \ | |||||
GE_LOGE(byte size : #_var_name is out of range); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
_var_name = _sizeof_size * _expr_size; \ | |||||
} while (0); | |||||
// Check if the container is empty | // Check if the container is empty | ||||
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||||
do { \ | |||||
if (vector.empty()) { \ | |||||
GE_LOGE(param[#vector] is empty !); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
} while (0) | |||||
#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||||
do { \ | |||||
if (size <= 0) { \ | |||||
GE_LOGE(param[#size] is not a positive number); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||||
do { \ | |||||
if (vector.empty()) { \ | |||||
DOMI_LOGE(param[#vector] is empty !); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the value on the left is greater than or equal to the value on the right | // Check if the value on the left is greater than or equal to the value on the right | ||||
#define GE_CHECK_GE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs < rhs) { \ | |||||
GE_LOGE(param[#lhs] is less than[#rhs]); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_GE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs < rhs) { \ | |||||
DOMI_LOGE(param[#lhs] is less than[#rhs]); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the value on the left is less than or equal to the value on the right | // Check if the value on the left is less than or equal to the value on the right | ||||
#define GE_CHECK_LE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs > rhs) { \ | |||||
GE_LOGE(param[#lhs] is greater than[#rhs]); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_LE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs > rhs) { \ | |||||
DOMI_LOGE(param[#lhs] is greater than[#rhs]); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
#define GE_DELETE_NEW_SINGLE(var) \ | #define GE_DELETE_NEW_SINGLE(var) \ | ||||
@@ -52,10 +52,10 @@ | |||||
#define DLOG_DECLARE(level) \ | #define DLOG_DECLARE(level) \ | ||||
void Log_##level(const char *mod_name, const char *func, const char *file, int line, const char *format, ...) | void Log_##level(const char *mod_name, const char *func, const char *file, int line, const char *format, ...) | ||||
namespace ge { | |||||
namespace domi { | |||||
DLOG_DECLARE(INFO); | DLOG_DECLARE(INFO); | ||||
DLOG_DECLARE(WARNING); | DLOG_DECLARE(WARNING); | ||||
DLOG_DECLARE(ERROR); | DLOG_DECLARE(ERROR); | ||||
} // namespace ge | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_DLOG_LOG_H_ | #endif // INC_FRAMEWORK_DLOG_LOG_H_ |
@@ -38,7 +38,7 @@ struct DNNEngineAttribute { | |||||
std::vector<std::string> mem_type; | std::vector<std::string> mem_type; | ||||
uint32_t compute_cost; | uint32_t compute_cost; | ||||
enum RuntimeType runtime_type; // HOST, DEVICE | enum RuntimeType runtime_type; // HOST, DEVICE | ||||
// set this attribute if the inputformat of engine must be specific, otherwise set FORMAT_RESERVED | |||||
// If engine input format must be specific, set this attribute, else set FORMAT_RESERVED | |||||
Format engine_input_format; | Format engine_input_format; | ||||
Format engine_output_format; | Format engine_output_format; | ||||
}; | }; | ||||
@@ -26,6 +26,7 @@ | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/tensor.h" | #include "graph/tensor.h" | ||||
#include "runtime/base.h" | #include "runtime/base.h" | ||||
#include "common/dynamic_aipp.h" | |||||
namespace ge { | namespace ge { | ||||
class ModelListenerAdapter; | class ModelListenerAdapter; | ||||
@@ -33,12 +34,15 @@ class ModelListenerAdapter; | |||||
class SingleOp; | class SingleOp; | ||||
struct RunModelData { | struct RunModelData { | ||||
uint32_t index; // Data index | |||||
uint32_t model_id; // Model id | |||||
std::vector<DataBuffer> blobs; // All input/output data buffer | |||||
uint32_t timestamp; // Data creation time | |||||
uint32_t timeout; // Processing timeout | |||||
uint64_t request_id = 0; // Request ID | |||||
uint32_t index; // Data index | |||||
uint32_t modelId; | |||||
std::vector<DataBuffer> blobs; // All input/output data buffer | |||||
uint32_t timestamp; // Data creation time | |||||
uint32_t timeout; // Processing timeout | |||||
uint64_t request_id = 0; // Request ID | |||||
uint64_t dynamic_batch_size = 0; // Dynamic batch size scene, set dynamic size, not supported by default:0 | |||||
uint64_t dynamic_image_height = 0; // Dynamic image size scene, set image height, not supported by default:0 | |||||
uint64_t dynamic_image_width = 0; // Dynamic image size scene, set image width, not supported by default:0 | |||||
}; | }; | ||||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | ||||
@@ -46,12 +50,13 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||||
GeExecutor(); | GeExecutor(); | ||||
~GeExecutor() = default; | ~GeExecutor() = default; | ||||
ge::Status Initialize(); | ge::Status Initialize(); | ||||
ge::Status Finalize(); | |||||
// Load model | // Load model | ||||
ge::Status LoadModelOffline(uint32_t &model_id, const std::string &path, const std::string &key, int32_t priority, | ge::Status LoadModelOffline(uint32_t &model_id, const std::string &path, const std::string &key, int32_t priority, | ||||
std::shared_ptr<ge::ModelListener> listener); | std::shared_ptr<ge::ModelListener> listener); | ||||
ge::Status UnloadModel(uint32_t model_id); | |||||
ge::Status UnloadModel(uint32_t modelId); | |||||
ge::Status RunModel(const ge::RunModelData &input_data, ge::RunModelData &output_data); | ge::Status RunModel(const ge::RunModelData &input_data, ge::RunModelData &output_data); | ||||
@@ -59,6 +64,52 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||||
ge::Status GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | ge::Status GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | ||||
std::vector<ge::TensorDesc> &output_desc); | std::vector<ge::TensorDesc> &output_desc); | ||||
/// | |||||
/// @ingroup ge | |||||
/// @brief Set dynamic batch size | |||||
/// @param [in] model_id: model id allocate from manager | |||||
/// @param [in] dynamic_input_addr: dynamic input addr created by user | |||||
/// @param [in] length: length of dynamic input addr | |||||
/// @param [in] batch_size: batch size entered by user in dynamic multi-batch scenario | |||||
/// @return execute result | |||||
/// | |||||
ge::Status SetDynamicBatchSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t batch_size); | |||||
/// | |||||
/// @ingroup ge | |||||
/// @brief Set dynamic image info | |||||
/// @param [in] model_id: model id allocate from manager | |||||
/// @param [in] dynamic_input_addr: dynamic input addr created by user | |||||
/// @param [in] length: length of dynamic input addr | |||||
/// @param [in] image_height: image height entered by user in dynamic multi-resolution scenario | |||||
/// @param [in] image_width: image width entered by user in dynamic multi-resolution scenario | |||||
/// @return execute result | |||||
/// | |||||
ge::Status SetDynamicImageSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t image_height, | |||||
uint64_t image_width); | |||||
/// | |||||
/// @ingroup ge | |||||
/// @brief Get dynamic batch_info | |||||
/// @param [in] model_id | |||||
/// @param [out] batch_info | |||||
/// @return execute result | |||||
/// | |||||
ge::Status GetDynamicBatchInfo(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info); | |||||
/// | |||||
/// @ingroup ge | |||||
/// @brief Set dynamic image info | |||||
/// @param [in] model_id: model id allocate from manager | |||||
/// @param [in] dynamic_input_addr: dynamic input addr created by user | |||||
/// @param [in] length: length of dynamic input addr | |||||
/// @param [in] aippBatchPara: kAippDynamicBatchPara vector by user in dynamic aipp | |||||
/// @param [in] aippParms: kAippDynamicPara by user in dynamic aipp | |||||
/// @return execute result | |||||
/// | |||||
ge::Status SetDynamicAippData(uint32_t model_id, void *dynamic_input_addr, uint64_t length, | |||||
const std::vector<kAippDynamicBatchPara> &aippBatchPara, | |||||
const kAippDynamicPara &aippParms); | |||||
ge::Status GetModelDescInfoForZeroCopy(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | ge::Status GetModelDescInfoForZeroCopy(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | ||||
std::vector<ge::TensorDesc> &output_desc); | std::vector<ge::TensorDesc> &output_desc); | ||||
@@ -147,7 +198,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||||
/// | /// | ||||
ge::Status GetMemAndWeightSize(const void *model_data, size_t model_size, size_t &mem_size, size_t &weight_size); | ge::Status GetMemAndWeightSize(const void *model_data, size_t model_size, size_t &mem_size, size_t &weight_size); | ||||
static ge::Status LoadSingleOp(const std::string &model_name, const ge::ModelData &model_data, void *stream, | |||||
static ge::Status LoadSingleOp(const std::string &modelName, const ge::ModelData &modelData, void *stream, | |||||
SingleOp **single_op); | SingleOp **single_op); | ||||
static ge::Status ExecuteAsync(SingleOp *executor, const std::vector<DataBuffer> &inputs, | static ge::Status ExecuteAsync(SingleOp *executor, const std::vector<DataBuffer> &inputs, | ||||
@@ -156,8 +207,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||||
static ge::Status ReleaseSingleOpResource(void *stream); | static ge::Status ReleaseSingleOpResource(void *stream); | ||||
private: | private: | ||||
static bool is_init_; | |||||
std::vector<std::shared_ptr<ModelListenerAdapter>> listener_adapters_; | |||||
static bool isInit_; | |||||
}; | }; | ||||
ge::Status ModelInfoParser(const ge::ModelData &model, ge::ModelInfo &model_info); | ge::Status ModelInfoParser(const ge::ModelData &model, ge::ModelInfo &model_info); | ||||
@@ -21,7 +21,7 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "ge/ge_ir_build.h" | |||||
#include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
@@ -45,6 +45,8 @@ class GeGenerator { | |||||
Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix, | Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix, | ||||
const std::vector<GeTensor> &inputs = std::vector<GeTensor>()); | const std::vector<GeTensor> &inputs = std::vector<GeTensor>()); | ||||
Status GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ge::ModelBufferData &model); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief: Build single OP in Model. | /// @brief: Build single OP in Model. | ||||
@@ -58,6 +60,8 @@ class GeGenerator { | |||||
const std::vector<GeTensor> &outputs, const std::string &model_file_name); | const std::vector<GeTensor> &outputs, const std::string &model_file_name); | ||||
private: | private: | ||||
Status GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | |||||
ge::ModelBufferData &model, bool is_offline = true); | |||||
class Impl; | class Impl; | ||||
std::shared_ptr<Impl> impl_; | std::shared_ptr<Impl> impl_; | ||||
@@ -24,7 +24,6 @@ extern "C" { | |||||
#endif | #endif | ||||
typedef uint32_t Status_t; | typedef uint32_t Status_t; | ||||
using Status_t = uint32_t; | |||||
typedef void *OpAttr_t; | typedef void *OpAttr_t; | ||||
typedef void *OpTensor_t; | typedef void *OpTensor_t; | ||||
@@ -23,7 +23,7 @@ | |||||
#include "graph/node.h" | #include "graph/node.h" | ||||
namespace ge { | namespace ge { | ||||
const int64_t kMemAlignSize = 512; | |||||
const int64_t MEM_ALIGN_SIZE = 512; | |||||
class MemoryAssigner { | class MemoryAssigner { | ||||
public: | public: | ||||
explicit MemoryAssigner(ge::ComputeGraphPtr compute_graph) : compute_graph_(std::move(compute_graph)) {} | explicit MemoryAssigner(ge::ComputeGraphPtr compute_graph) : compute_graph_(std::move(compute_graph)) {} | ||||
@@ -39,4 +39,4 @@ class MemoryAssigner { | |||||
ge::ComputeGraphPtr compute_graph_; | ge::ComputeGraphPtr compute_graph_; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_ | |||||
#endif // INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_ |
@@ -31,7 +31,6 @@ | |||||
using domi::DOMI_TENSOR_ND; | using domi::DOMI_TENSOR_ND; | ||||
using domi::DOMI_TENSOR_RESERVED; | using domi::DOMI_TENSOR_RESERVED; | ||||
using domi::domiTensorFormat_t; | using domi::domiTensorFormat_t; | ||||
using domi::FMK_TYPE_RESERVED; | |||||
using domi::FrameworkType; | using domi::FrameworkType; | ||||
using std::map; | using std::map; | ||||
using std::string; | using std::string; | ||||
@@ -44,10 +43,10 @@ namespace ge { | |||||
* @brief run model | * @brief run model | ||||
*/ | */ | ||||
enum RunMode { | enum RunMode { | ||||
kGeOmModel = 0, // generate offline model file | |||||
kModelToJson = 1, // convert to JSON file | |||||
kOnlyPreCheck = 3, // only for pre-check | |||||
kPbtxtToJson = 5 // pbtxt to json | |||||
GEN_OM_MODEL = 0, // generate offline model file | |||||
MODEL_TO_JSON = 1, // convert to JSON file | |||||
ONLY_PRE_CHECK = 3, // only for pre-check | |||||
PBTXT_TO_JSON = 5 // pbtxt to json | |||||
}; | }; | ||||
/// | /// | ||||
@@ -56,10 +55,10 @@ enum RunMode { | |||||
/// | /// | ||||
enum HighPrecisionMode { | enum HighPrecisionMode { | ||||
// the FP16 high-precision function is disabled in common mode | // the FP16 high-precision function is disabled in common mode | ||||
kHighPrecisonDefault = 0, | |||||
HIGH_PRECISION_DEFAULT = 0, | |||||
// high-precision mode, in which FP16 high-precision mode (Convolution/FullConnect/AvgPooling are involved) is enable | |||||
kHighPrecisionFP16 = 1 | |||||
// high-precision mode, enabling FP16 high-precision mode (Convolution/FullConnect/AvgPooling are involved) | |||||
HIGH_PRECISION_FP16 = 1 | |||||
}; | }; | ||||
/// | /// | ||||
@@ -99,21 +98,23 @@ struct OmgContext { | |||||
// preferential format used by the entire network | // preferential format used by the entire network | ||||
domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | ||||
domi::FrameworkType type = domi::FMK_TYPE_RESERVED; | domi::FrameworkType type = domi::FMK_TYPE_RESERVED; | ||||
RunMode run_mode = kOnlyPreCheck; | |||||
RunMode run_mode = ONLY_PRE_CHECK; | |||||
bool train_flag = false; | bool train_flag = false; | ||||
// whether to use FP16 high precision | // whether to use FP16 high precision | ||||
int32_t fp16_high_precision = kHighPrecisonDefault; | |||||
int32_t fp16_high_precision = HIGH_PRECISION_DEFAULT; | |||||
std::string output_type; | std::string output_type; | ||||
// Save the name of the entire network: Some special operators are used to determine a network. Some operators in the | // Save the name of the entire network: Some special operators are used to determine a network. Some operators in the | ||||
// network require special processing based on the specific network. | |||||
// e.g:faster-rcnn, the FirstStageProcessor module is determined as the Faster-R-CNN network based on the scope | |||||
// fusion. Then, the conv+reshape operators in the FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The | |||||
// convolution kernel rearrangement reshape operator needs to be deleted for the convolution kernel. | |||||
// network require special processing based on the specific network. e.g:faster-rcnn, the FirstStageProcessor module | |||||
// is determined as the Faster-R-CNN network based on the scope fusion. Then, the conv+reshape operators in the | |||||
// FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The convolution kernel rearrangement reshape | |||||
// operator needs to be deleted for the convolution kernel. | |||||
std::string net_name; | std::string net_name; | ||||
// whether to enable dynamic batch | |||||
bool enable_l2dynamic = false; | |||||
// Whether to use dynamic batch size or dynamic image size | |||||
bool is_dynamic_input = false; | |||||
std::string dynamic_batch_size; | |||||
std::string dynamic_image_size; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -32,15 +32,7 @@ class PlatformVersionManager { | |||||
PlatformVersionManager() = delete; | PlatformVersionManager() = delete; | ||||
~PlatformVersionManager() = delete; | ~PlatformVersionManager() = delete; | ||||
static Status GetPlatformVersion(std::string &ver) { | static Status GetPlatformVersion(std::string &ver) { | ||||
#if defined PLATFORM_PHOENIX | |||||
ver = "3.51.z"; | |||||
#elif defined PLATFORM_ORLANDO | |||||
ver = "3.31.z"; | |||||
#elif defined PLATFORM_MINI | |||||
ver = "1.11.z"; | ver = "1.11.z"; | ||||
#elif defined PLATFORM_CLOUD | |||||
ver = "1.61.z"; | |||||
#endif | |||||
std::vector<std::string> version_splits = StringUtils::Split(ver, '.'); | std::vector<std::string> version_splits = StringUtils::Split(ver, '.'); | ||||
GE_IF_BOOL_EXEC(version_splits.size() < 3, GELOGW("Read platform version error!"); return FAILED;); | GE_IF_BOOL_EXEC(version_splits.size() < 3, GELOGW("Read platform version error!"); return FAILED;); | ||||
@@ -20,13 +20,17 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
#include "graph/range_vistor.h" | #include "graph/range_vistor.h" | ||||
#include "graph/types.h" | #include "graph/types.h" | ||||
namespace ge { | namespace ge { | ||||
enum AnchorStatus { ANCHOR_SUSPEND = 0, ANCHOR_CONST = 1, ANCHOR_DATA = 2, ANCHOR_RESERVED = 3 }; | |||||
enum AnchorStatus { | |||||
ANCHOR_SUSPEND = 0, // dat null | |||||
ANCHOR_CONST = 1, | |||||
ANCHOR_DATA = 2, // Effective | |||||
ANCHOR_RESERVED = 3 | |||||
}; | |||||
using std::string; | using std::string; | ||||
using std::vector; | using std::vector; | ||||
@@ -81,17 +85,19 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable | |||||
virtual ~Anchor() = default; | virtual ~Anchor() = default; | ||||
protected: | protected: | ||||
// Whether the two anchors are equal | |||||
// Whether the two anchor is equal | |||||
virtual bool Equal(AnchorPtr anchor) const = 0; | virtual bool Equal(AnchorPtr anchor) const = 0; | ||||
virtual bool IsTypeOf(TYPE type) const; | virtual bool IsTypeOf(TYPE type) const; | ||||
public: | public: | ||||
// Get all peer anchors connected to current anchor | // Get all peer anchors connected to current anchor | ||||
Vistor<AnchorPtr> GetPeerAnchors() const; | Vistor<AnchorPtr> GetPeerAnchors() const; | ||||
// Get the first peer anchor | |||||
// Get peer anchor size | |||||
size_t GetPeerAnchorsSize() const; | |||||
// Get first peer anchor | |||||
AnchorPtr GetFirstPeerAnchor() const; | AnchorPtr GetFirstPeerAnchor() const; | ||||
// Get the node which is the owner of the anchor | |||||
// Get the anchor belong to which node | |||||
NodePtr GetOwnerNode() const; | NodePtr GetOwnerNode() const; | ||||
// Remove all links with the anchor | // Remove all links with the anchor | ||||
@@ -100,22 +106,22 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable | |||||
// Remove link with the given anchor | // Remove link with the given anchor | ||||
graphStatus Unlink(const AnchorPtr &peer); | graphStatus Unlink(const AnchorPtr &peer); | ||||
// Replace the peeranchor with the new peeranchor | |||||
// Replace peer with new peers | |||||
graphStatus ReplacePeer(const AnchorPtr &oldPeer, const AnchorPtr &firstPeer, const AnchorPtr &secondPeer); | graphStatus ReplacePeer(const AnchorPtr &oldPeer, const AnchorPtr &firstPeer, const AnchorPtr &secondPeer); | ||||
// Judge if the anchor is linked with the given anchor | // Judge if the anchor is linked with the given anchor | ||||
bool IsLinkedWith(const AnchorPtr &peer); | bool IsLinkedWith(const AnchorPtr &peer); | ||||
// Get the anchor index of the node | |||||
// Get anchor index of the node | |||||
int GetIdx() const; | int GetIdx() const; | ||||
// Set the anchor index of the node | |||||
// set anchor index of the node | |||||
void SetIdx(int index); | void SetIdx(int index); | ||||
protected: | protected: | ||||
// All peer anchors connected to current anchor | // All peer anchors connected to current anchor | ||||
vector<std::weak_ptr<Anchor>> peer_anchors_; | vector<std::weak_ptr<Anchor>> peer_anchors_; | ||||
// The owner nodes of the anchor | |||||
// The owner node of anchor | |||||
std::weak_ptr<Node> owner_node_; | std::weak_ptr<Node> owner_node_; | ||||
// The index of current anchor | // The index of current anchor | ||||
int idx_; | int idx_; | ||||
@@ -167,7 +173,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchor : public DataA | |||||
virtual ~InDataAnchor() = default; | virtual ~InDataAnchor() = default; | ||||
// Get source out data anchor | |||||
// Get source out data anchor | |||||
OutDataAnchorPtr GetPeerOutAnchor() const; | OutDataAnchorPtr GetPeerOutAnchor() const; | ||||
// Build connection from OutDataAnchor to InDataAnchor | // Build connection from OutDataAnchor to InDataAnchor | ||||
@@ -19,10 +19,10 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
namespace ge { | namespace ge { | ||||
class GeAttrValue; | class GeAttrValue; | ||||
class _GeSerializable { | class _GeSerializable { | ||||
public: | public: | ||||
@@ -107,7 +107,6 @@ class _GeSerializable { | |||||
static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) { return GRAPH_SUCCESS; } | static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) { return GRAPH_SUCCESS; } | ||||
}; | }; | ||||
#define _GE_FI(a) #a, a | #define _GE_FI(a) #a, a | ||||
#define _GE_MAP_FIELDS1(a1) _GE_FI(a1) | #define _GE_MAP_FIELDS1(a1) _GE_FI(a1) | ||||
#define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2) | #define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2) | ||||
@@ -130,23 +129,23 @@ class _GeSerializable { | |||||
#define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \ | #define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \ | ||||
_GE_FI(a1) \ | _GE_FI(a1) \ | ||||
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | ||||
_GE_FI(a11) | |||||
_GE_FI(a11) | |||||
#define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \ | #define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \ | ||||
_GE_FI(a1) \ | _GE_FI(a1) \ | ||||
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | ||||
_GE_FI(a11), _GE_FI(a12) | |||||
_GE_FI(a11), _GE_FI(a12) | |||||
#define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \ | #define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \ | ||||
_GE_FI(a1) \ | _GE_FI(a1) \ | ||||
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | ||||
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13) | |||||
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13) | |||||
#define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \ | #define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \ | ||||
_GE_FI(a1) \ | _GE_FI(a1) \ | ||||
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | ||||
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14) | |||||
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14) | |||||
#define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \ | #define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \ | ||||
_GE_FI(a1) \ | _GE_FI(a1) \ | ||||
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | ||||
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15) | |||||
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15) | |||||
#define _GE_PRIVATE_ARGS_GLUE(x, y) x y | #define _GE_PRIVATE_ARGS_GLUE(x, y) x y | ||||
@@ -17,12 +17,11 @@ | |||||
#ifndef INC_GRAPH_BUFFER_H_ | #ifndef INC_GRAPH_BUFFER_H_ | ||||
#define INC_GRAPH_BUFFER_H_ | #define INC_GRAPH_BUFFER_H_ | ||||
#include <graph/types.h> | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
#include "graph/types.h" | |||||
namespace ge { | namespace ge { | ||||
#ifdef HOST_VISIBILITY | #ifdef HOST_VISIBILITY | ||||
@@ -72,7 +71,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer { | |||||
GeIrProtoHelper<proto::AttrDef> data_; | GeIrProtoHelper<proto::AttrDef> data_; | ||||
std::string *buffer_ = nullptr; | std::string *buffer_ = nullptr; | ||||
// Create buffer from protobuf obj | |||||
// Create from protobuf obj | |||||
Buffer(const ProtoMsgOwner &protoOnwer, proto::AttrDef *buffer); | Buffer(const ProtoMsgOwner &protoOnwer, proto::AttrDef *buffer); | ||||
Buffer(const ProtoMsgOwner &protoOnwer, std::string *buffer); | Buffer(const ProtoMsgOwner &protoOnwer, std::string *buffer); | ||||
@@ -17,7 +17,6 @@ | |||||
#ifndef INC_GRAPH_COMPUTE_GRAPH_H_ | #ifndef INC_GRAPH_COMPUTE_GRAPH_H_ | ||||
#define INC_GRAPH_COMPUTE_GRAPH_H_ | #define INC_GRAPH_COMPUTE_GRAPH_H_ | ||||
#include <deque> | |||||
#include <map> | #include <map> | ||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
@@ -63,7 +62,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
using Vistor = RangeVistor<T, std::shared_ptr<ConstComputeGraph>>; | using Vistor = RangeVistor<T, std::shared_ptr<ConstComputeGraph>>; | ||||
explicit ComputeGraph(const std::string &name); | explicit ComputeGraph(const std::string &name); | ||||
virtual ~ComputeGraph(); | |||||
~ComputeGraph() override; | |||||
std::string GetName() const; | std::string GetName() const; | ||||
void SetName(const std::string &name); | void SetName(const std::string &name); | ||||
@@ -81,7 +80,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
Vistor<NodePtr> GetOutputNodes() const; | Vistor<NodePtr> GetOutputNodes() const; | ||||
NodePtr FindNode(const std::string &name) const; | NodePtr FindNode(const std::string &name) const; | ||||
// Add node | |||||
// AddNode with NodePtr | |||||
NodePtr AddNode(NodePtr node); | NodePtr AddNode(NodePtr node); | ||||
NodePtr AddNode(OpDescPtr op); | NodePtr AddNode(OpDescPtr op); | ||||
NodePtr AddNodeFront(NodePtr node); | NodePtr AddNodeFront(NodePtr node); | ||||
@@ -94,9 +93,40 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
graphStatus RemoveOutputNode(const NodePtr &node); | graphStatus RemoveOutputNode(const NodePtr &node); | ||||
graphStatus RemoveConstInput(const NodePtr &node); | graphStatus RemoveConstInput(const NodePtr &node); | ||||
/// Add a subgraph to this graph. The subgraph must has a parent graph and parent node, | |||||
/// which means the member functions `SetParentGraph` and `SetParentNode` of the subgraph | |||||
/// must be called before add it to the root graph. and subgraph->GetParentNode()->GetOwnerGraph() | |||||
/// must equal to subgraph->GetOwnerGraph(). | |||||
/// The subgraphs can only be added to a *root graph*. A root graph is a graph without any parent graph. | |||||
/// The subgraph's name SHOULD(not must) be the same as the parameter `name` | |||||
graphStatus AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph); | |||||
graphStatus AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph); | |||||
void RemoveSubgraph(const std::string &name); | |||||
void RemoveSubgraph(const std::shared_ptr<ComputeGraph> &subgraph); | |||||
std::shared_ptr<ComputeGraph> GetSubgraph(const std::string &name) const; | |||||
std::vector<std::shared_ptr<ComputeGraph>> GetAllSubgraphs() const; | |||||
// obsolete | |||||
std::shared_ptr<ComputeGraph> AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph); | std::shared_ptr<ComputeGraph> AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph); | ||||
// obsolete | |||||
graphStatus RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph); | graphStatus RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph); | ||||
/// | |||||
/// @brief Update input-mapping | |||||
/// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input | |||||
/// @return graphStatus | |||||
/// | |||||
graphStatus UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping); | |||||
/// | |||||
/// @brief Update output-mapping | |||||
/// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output | |||||
/// @return graphStatus | |||||
/// | |||||
graphStatus UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping); | |||||
graphStatus TopologicalSorting(); | graphStatus TopologicalSorting(); | ||||
bool IsValid() const; | bool IsValid() const; | ||||
void Dump() const; | void Dump() const; | ||||
@@ -127,6 +157,11 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
} | } | ||||
} | } | ||||
shared_ptr<ComputeGraph> GetParentGraph(); | |||||
void SetParentGraph(const shared_ptr<ComputeGraph> &parent); | |||||
shared_ptr<Node> GetParentNode(); | |||||
void SetParentNode(const shared_ptr<Node> &parent); | |||||
const std::map<std::string, std::vector<int32_t>> &GetGraphOutNodes() const { return out_nodes_map_; } | const std::map<std::string, std::vector<int32_t>> &GetGraphOutNodes() const { return out_nodes_map_; } | ||||
void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; } | void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; } | ||||
@@ -138,8 +173,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
uint32_t GetInputSize() const { return input_size_; } | uint32_t GetInputSize() const { return input_size_; } | ||||
/// | /// | ||||
/// Set iteration needed. | |||||
/// If set is true, it means this graph need run iteration some | |||||
/// Set is need train iteration. | |||||
/// If set true, it means this graph need to be run iteration some | |||||
/// times(according variant "npu_runconfig/iterations_per_loop"). | /// times(according variant "npu_runconfig/iterations_per_loop"). | ||||
/// @param need_iteration is need iteration | /// @param need_iteration is need iteration | ||||
/// | /// | ||||
@@ -150,7 +185,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
const std::string GetOutput(); | const std::string GetOutput(); | ||||
/// | /// | ||||
/// Get need_iteration. | |||||
/// Get is need train iteration. | |||||
/// @return is need iteration | /// @return is need iteration | ||||
/// | /// | ||||
bool GetNeedIteration() const { return need_iteration_; } | bool GetNeedIteration() const { return need_iteration_; } | ||||
@@ -201,6 +236,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
std::deque<NodePtr> &stack); | std::deque<NodePtr> &stack); | ||||
graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | ||||
std::map<string, NodePtr> &breadth_node_map); | std::map<string, NodePtr> &breadth_node_map); | ||||
graphStatus TopologicalSortingSubgraph(); | |||||
graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum); | graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum); | ||||
size_t GetInEdgeSize(const NodePtr &node); | size_t GetInEdgeSize(const NodePtr &node); | ||||
size_t GetOutEdgeSize(const NodePtr &node); | size_t GetOutEdgeSize(const NodePtr &node); | ||||
@@ -210,31 +246,38 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
bool VectorInputNodePtrIsEqual(const std::vector<NodePtr> &r_node_ptr_vector, | bool VectorInputNodePtrIsEqual(const std::vector<NodePtr> &r_node_ptr_vector, | ||||
const std::vector<NodePtr> &l_node_ptr_vector) const; | const std::vector<NodePtr> &l_node_ptr_vector) const; | ||||
ProtoAttrMapHelper attrs_; | |||||
friend class ModelSerializeImp; | friend class ModelSerializeImp; | ||||
friend class GraphDebugImp; | friend class GraphDebugImp; | ||||
friend class OnnxUtils; | friend class OnnxUtils; | ||||
std::string name_; | |||||
uint32_t graph_id_ = 0; | |||||
ProtoAttrMapHelper attrs_; | |||||
std::vector<NodePtr> nodes_; | std::vector<NodePtr> nodes_; | ||||
std::map<OperatorImplPtr, NodePtr> all_nodes_infos_; | |||||
std::vector<NodePtr> target_nodes_info_; | |||||
std::vector<NodePtr> input_nodes_; | std::vector<NodePtr> input_nodes_; | ||||
std::vector<std::string> inputs_order_; | |||||
uint32_t input_size_ = 1; | |||||
std::map<std::string, std::vector<int32_t>> out_nodes_map_; | |||||
uint32_t output_size_ = 1; | |||||
std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_; | |||||
std::vector<std::shared_ptr<ComputeGraph>> sub_graph_; | std::vector<std::shared_ptr<ComputeGraph>> sub_graph_; | ||||
std::string name_; | |||||
std::map<std::string, std::shared_ptr<ComputeGraph>> names_to_subgraph_; | |||||
std::weak_ptr<ComputeGraph> parent_graph_; | |||||
std::weak_ptr<Node> parent_node_; | |||||
// the members followed should not in the ComputeGraph class | |||||
bool is_valid_flag_; | bool is_valid_flag_; | ||||
bool is_summary_graph_ = false; | bool is_summary_graph_ = false; | ||||
// Indicates whether it is need iteration | // Indicates whether it is need iteration | ||||
bool need_iteration_ = false; | bool need_iteration_ = false; | ||||
std::map<std::vector<std::string>, std::vector<std::string>> params_share_map_; | std::map<std::vector<std::string>, std::vector<std::string>> params_share_map_; | ||||
std::map<std::string, std::vector<int32_t>> out_nodes_map_; | |||||
// TaskIdx -> op_name Map | // TaskIdx -> op_name Map | ||||
std::map<uint32_t, std::string> op_name_map_; | std::map<uint32_t, std::string> op_name_map_; | ||||
std::vector<std::string> inputs_order_; | |||||
uint32_t output_size_ = 1; | |||||
uint32_t input_size_ = 1; | |||||
std::map<OperatorImplPtr, NodePtr> all_nodes_infos_; | |||||
std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_; | |||||
std::vector<NodePtr> target_nodes_info_; | |||||
uint64_t session_id_ = 0; | uint64_t session_id_ = 0; | ||||
uint32_t graph_id_ = 0; | |||||
ge::Format data_format_ = ge::FORMAT_ND; | ge::Format data_format_ = ge::FORMAT_ND; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -18,7 +18,6 @@ | |||||
#define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | #define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | ||||
#include <string> | #include <string> | ||||
#include "graph/types.h" | #include "graph/types.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -59,6 +58,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; | ||||
@@ -75,8 +76,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; | ||||
// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string | |||||
// ATTR_NAME_WEIGHTS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; | ||||
@@ -124,6 +124,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; | ||||
@@ -141,10 +148,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | ||||
@@ -166,15 +178,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; | ||||
// _Arg | // _Arg | ||||
@@ -263,7 +275,29 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNOR | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; | |||||
// Huberloss | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; | |||||
// SSDRealDivTileMul | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; | |||||
// SSDSumMulRealDivMean | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; | |||||
/// ConcatFive2Four | |||||
/// ConcatFour2Five | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; | |||||
// Scale | // Scale | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; | ||||
@@ -300,7 +334,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_AT | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | ||||
// Roipooling | // Roipooling | ||||
@@ -313,6 +346,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLI | |||||
// DetectionOutput | // DetectionOutput | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | ||||
@@ -371,6 +405,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ | |||||
// Permute | // Permute | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; | |||||
// SSD Normalize | // SSD Normalize | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | ||||
@@ -411,9 +446,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_AT | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; | ||||
// Log | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; | |||||
// Pack | // Pack | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; | ||||
// Dynamic stitch | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||||
// Unpack | // Unpack | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; | ||||
// Gathernd | // Gathernd | ||||
@@ -422,8 +463,16 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND | |||||
// Argmax | // Argmax | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; | |||||
// Upsample | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; | |||||
// Relu | // Relu | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; | ||||
@@ -439,6 +488,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_AT | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_MAGIC; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_MAGIC; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_BLOCKDIM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_BLOCKDIM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_METADATA; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_METADATA; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_WORKSPACE_TYPE; | |||||
// Squeeze | // Squeeze | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_AXIS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_AXIS; | ||||
@@ -461,6 +511,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_TF; | |||||
// Generate_rpn_proposal | // Generate_rpn_proposal | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; | ||||
@@ -493,6 +544,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_AT | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; | ||||
static const std::string NOT_NET_OUTPUT = "not_net_output"; | |||||
// ENTER | // ENTER | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; | ||||
@@ -518,6 +570,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_B | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; | ||||
// RetinaNet | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; | |||||
// MatMul | // MatMul | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; | ||||
@@ -566,10 +621,30 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; | |||||
// Upsample | // Upsample | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; | ||||
// PadV2 | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; | |||||
// MirrorPad | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; | |||||
// Filler | // Filler | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; | ||||
@@ -590,36 +665,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; | ||||
@@ -634,6 +679,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; | ||||
@@ -642,12 +689,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||||
// Public attribute | // Public attribute | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; | ||||
@@ -685,11 +726,178 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; | ||||
// Used for operators that do not generate task | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOTASK; | |||||
// Used for operators that output reuse input | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_REUSE_INPUT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||||
// L2_normalize | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; | |||||
// HCOM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; | |||||
// Log time stamp | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; | |||||
// SpaceToDepth/DepthToSpace | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; | |||||
// SparseSoftmaxCrossEntropyWithLogits | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; | |||||
// MaxPoolGradWithArgmax | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; | |||||
// AvgPoolGrad | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; | |||||
// Varible | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||||
// Assign | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; | |||||
// ShapeN | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; | |||||
// Space2bacth batch2space | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; | |||||
// Depth_to_space space_to_depth | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||||
// FakeQuantWithMinMaxVars | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; | |||||
// Mobilenet_ssd_conv_fusion | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; | |||||
// Lsh project | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; | |||||
// Control flow | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; | |||||
// GatherV2 attr def | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; | |||||
// Reshape attr def | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; | |||||
// Axis attr def | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; | |||||
// The node link with SparseSoftmaxCrossEntropyWithLogits | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | |||||
// For constant folding | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||||
// Used for mark the active label list to find stream of activated node | // Used for mark the active label list to find stream of activated node | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE; | |||||
// Multi batch | // Multi batch | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM; | ||||
@@ -697,7 +905,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
// Control flow | // Control flow | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | ||||
@@ -709,6 +916,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEXT_ITERATION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEXT_ITERATION; | ||||
// Function Op | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; | |||||
// Used for mark the active node is for loop, type:bool | // Used for mark the active node is for loop, type:bool | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; | ||||
@@ -742,6 +952,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INS | |||||
// For inserted op | // For inserted op | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; | ||||
// For compress weight | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMPRESS_WEIGHT; | |||||
// For data dump | // For data dump | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP; | ||||
@@ -752,6 +965,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; | ||||
// used for l1 fusion and other fusion in future | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | |||||
// used for label switch | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | |||||
// Varible | // Varible | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | ||||
@@ -20,10 +20,8 @@ | |||||
#include <atomic> | #include <atomic> | ||||
#include <memory> | #include <memory> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/attr_value_serializable.h" | #include "graph/attr_value_serializable.h" | ||||
#include "graph/buffer.h" | #include "graph/buffer.h" | ||||
namespace ge { | namespace ge { | ||||
#define DEF_TYPE_DEC(type, name) \ | #define DEF_TYPE_DEC(type, name) \ | ||||
inline void set_##name(const type &value) { name = value; } \ | inline void set_##name(const type &value) { name = value; } \ | ||||
@@ -49,10 +47,9 @@ namespace ge { | |||||
inline void add_##name(type value) { name.push_back(value); } \ | inline void add_##name(type value) { name.push_back(value); } \ | ||||
inline std::vector<type> *mutable_##name() { return &name; } | inline std::vector<type> *mutable_##name() { return &name; } | ||||
#define DEF_TYPE_BYTES_DEC(name) \ | |||||
inline void clear_##name() { name.ClearBuffer(); } \ | |||||
inline void set_##name(const void *value, size_t size) { \ | |||||
name = Buffer::CopyFrom((const uint8_t *)(value), size); } \ | |||||
#define DEF_TYPE_BYTES_DEC(name) \ | |||||
inline void clear_##name() { name.ClearBuffer(); } \ | |||||
inline void set_##name(const void *value, size_t size) { name = Buffer::CopyFrom((const uint8_t *)(value), size); } \ | |||||
inline Buffer *mutable_##name() { return &name; } | inline Buffer *mutable_##name() { return &name; } | ||||
struct CompressInfo { | struct CompressInfo { | ||||
@@ -23,7 +23,6 @@ | |||||
#include <unordered_set> | #include <unordered_set> | ||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/detail/any_map.h" | #include "graph/detail/any_map.h" | ||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
#include "graph/types.h" | #include "graph/types.h" | ||||
@@ -96,7 +95,7 @@ class GeIrProtoHelper { | |||||
} | } | ||||
} | } | ||||
// protoMsg_ is part of protoOwner_ and they have the same runtime | |||||
// protoMsg_ is part of protoOwner_, they have the same runtime | |||||
ProtoMsgOwner protoOwner_ = nullptr; | ProtoMsgOwner protoOwner_ = nullptr; | ||||
ProtoType *protoMsg_ = nullptr; | ProtoType *protoMsg_ = nullptr; | ||||
friend class GeIrProtoHelper<typename std::conditional< | friend class GeIrProtoHelper<typename std::conditional< | ||||
@@ -21,9 +21,7 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/anchor.h" | #include "graph/anchor.h" | ||||
#include "graph/model.h" | |||||
#include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
@@ -48,15 +46,15 @@ struct NodeNameNodeReq { | |||||
class ModelSerializeImp { | class ModelSerializeImp { | ||||
public: | public: | ||||
bool SerializeModel(const Model &model, proto::ModelDef *modeProto); | |||||
bool SerializeModel(const Model &model, proto::ModelDef *modeProto, bool is_dump = false); | |||||
bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto); | |||||
bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto, bool is_dump = false); | |||||
bool SerializeEdge(const NodePtr &node, proto::OpDef *opDefProto); | bool SerializeEdge(const NodePtr &node, proto::OpDef *opDefProto); | ||||
bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto); | |||||
bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto, bool is_dump = false); | |||||
bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto); | |||||
bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto, bool is_dump = false); | |||||
bool SerializeTensor(const ConstGeTensorPtr &tensor, proto::TensorDef *tensorProto); | bool SerializeTensor(const ConstGeTensorPtr &tensor, proto::TensorDef *tensorProto); | ||||
@@ -23,7 +23,6 @@ | |||||
#include <string> | #include <string> | ||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/buffer.h" | #include "graph/buffer.h" | ||||
#include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
@@ -139,15 +138,14 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||||
template <typename vector_type> | template <typename vector_type> | ||||
// To cols | // To cols | ||||
using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE, | |||||
int>::type; | |||||
using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE, int>::type; | |||||
template <typename one_type> | template <typename one_type> | ||||
using enable_if_one_type_valid_t = typename std::enable_if<IsAttrTypeEnable<one_type>::VALUE, int>::type; | using enable_if_one_type_valid_t = typename std::enable_if<IsAttrTypeEnable<one_type>::VALUE, int>::type; | ||||
template <typename val_type> | template <typename val_type> | ||||
using enable_if_type_valid_t = | using enable_if_type_valid_t = | ||||
typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type; | |||||
typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type; | |||||
template <typename seriliable_type> | template <typename seriliable_type> | ||||
using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable; | using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable; | ||||
@@ -18,7 +18,6 @@ | |||||
#define INC_GRAPH_GE_CONTEXT_H_ | #define INC_GRAPH_GE_CONTEXT_H_ | ||||
#include <string> | #include <string> | ||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -42,4 +41,4 @@ class GEContext { | |||||
GEContext &GetContext(); | GEContext &GetContext(); | ||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_GRAPH_GE_CONTEXT_H_ | |||||
#endif // INC_GRAPH_GE_CONTEXT_H_ |
@@ -20,7 +20,6 @@ | |||||
#include <map> | #include <map> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
using std::map; | using std::map; | ||||
@@ -42,5 +41,4 @@ class GEThreadLocalContext { | |||||
GEThreadLocalContext &GetThreadLocalContext(); | GEThreadLocalContext &GetThreadLocalContext(); | ||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_ | #endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_ |
@@ -21,12 +21,10 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
#include "graph/buffer.h" | #include "graph/buffer.h" | ||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
#include "graph/types.h" | #include "graph/types.h" | ||||
namespace ge { | namespace ge { | ||||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | ||||
public: | public: | ||||
@@ -43,6 +41,18 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | |||||
int64_t GetShapeSize() const; | int64_t GetShapeSize() const; | ||||
std::string ToString() const; | std::string ToString() const; | ||||
/// | |||||
/// @brief Check is unknown shape | |||||
/// @return bool | |||||
/// | |||||
bool IsUnknownShape() const; | |||||
/// | |||||
/// @brief Check is a scalar | |||||
/// @return bool | |||||
/// | |||||
bool IsScalar() const; | |||||
GeShape(const GeShape &other); | GeShape(const GeShape &other); | ||||
GeShape(GeShape &&other); | GeShape(GeShape &&other); | ||||
GeShape &operator=(const GeShape &other); | GeShape &operator=(const GeShape &other); | ||||
@@ -51,7 +61,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | |||||
private: | private: | ||||
GeIrProtoHelper<proto::ShapeDef> shape_def_; | GeIrProtoHelper<proto::ShapeDef> shape_def_; | ||||
friend class GeTensorDesc; | friend class GeTensorDesc; | ||||
// Create geshape from proto obj | |||||
// Create from proto obj | |||||
GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg); | GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg); | ||||
void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; } | void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; } | ||||
@@ -112,7 +122,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH | |||||
void Init(); | void Init(); | ||||
// Create getensordesc from proto obj | |||||
// Create from proto obj | |||||
GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg); | GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg); | ||||
friend class GeTensor; | friend class GeTensor; | ||||
friend class GeAttrValueImp; | friend class GeAttrValueImp; | ||||
@@ -159,10 +169,10 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { | |||||
friend class GeAttrValueImp; | friend class GeAttrValueImp; | ||||
friend class ModelSerializeImp; | friend class ModelSerializeImp; | ||||
friend class OnnxUtils; | friend class OnnxUtils; | ||||
// Create getensor from proto obj | |||||
// Create from proto obj | |||||
GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg); | GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg); | ||||
GeIrProtoHelper<proto::TensorDef> tensor_def_; | GeIrProtoHelper<proto::TensorDef> tensor_def_; | ||||
// Reference from tensorDef_, cab not use it directly | |||||
// Reference from tensorDef_, do not direct use | |||||
mutable GeTensorDesc __desc_; | mutable GeTensorDesc __desc_; | ||||
GeTensorDesc &DescReference() const; | GeTensorDesc &DescReference() const; | ||||
}; | }; | ||||
@@ -21,7 +21,6 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
#include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
@@ -62,7 +61,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { | |||||
using AttrHolder::HasAttr; | using AttrHolder::HasAttr; | ||||
using AttrHolder::SetAttr; | using AttrHolder::SetAttr; | ||||
graphStatus Save(Buffer &buffer) const; | |||||
graphStatus Save(Buffer &buffer, bool is_dump = false) const; | |||||
graphStatus SaveToFile(const string &file_name) const; | graphStatus SaveToFile(const string &file_name) const; | ||||
// Model will be rewrite | // Model will be rewrite | ||||
@@ -19,7 +19,6 @@ | |||||
#include <map> | #include <map> | ||||
#include <string> | #include <string> | ||||
#include "graph/buffer.h" | #include "graph/buffer.h" | ||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include "graph/model.h" | #include "graph/model.h" | ||||
@@ -27,7 +26,7 @@ | |||||
namespace ge { | namespace ge { | ||||
class ModelSerialize { | class ModelSerialize { | ||||
public: | public: | ||||
Buffer SerializeModel(const Model &model); | |||||
Buffer SerializeModel(const Model &model, bool is_dump = false); | |||||
Model UnserializeModel(const uint8_t *data, size_t len); | Model UnserializeModel(const uint8_t *data, size_t len); | ||||
Model UnserializeModel(ge::proto::ModelDef &model_def); | Model UnserializeModel(ge::proto::ModelDef &model_def); | ||||
@@ -113,25 +113,25 @@ class Node : public std::enable_shared_from_this<Node> { | |||||
bool IsAllInNodesSeen(std::unordered_set<Node *> &nodes_seen) const; | bool IsAllInNodesSeen(std::unordered_set<Node *> &nodes_seen) const; | ||||
// All inData nodes | |||||
// All in Data nodes | |||||
Vistor<NodePtr> GetInDataNodes() const; | Vistor<NodePtr> GetInDataNodes() const; | ||||
// All inControl nodes | |||||
// All in Control nodes | |||||
Vistor<NodePtr> GetInControlNodes() const; | Vistor<NodePtr> GetInControlNodes() const; | ||||
// GetInAllNodes = InDataNodes + InControlNodes | // GetInAllNodes = InDataNodes + InControlNodes | ||||
Vistor<NodePtr> GetInAllNodes() const; | Vistor<NodePtr> GetInAllNodes() const; | ||||
// All outData nodes | |||||
// All out Data nodes | |||||
Vistor<NodePtr> GetOutDataNodes() const; | Vistor<NodePtr> GetOutDataNodes() const; | ||||
uint32_t GetOutDataNodesSize() const; | uint32_t GetOutDataNodesSize() const; | ||||
// All outControl nodes | |||||
// All out Control nodes | |||||
Vistor<NodePtr> GetOutControlNodes() const; | Vistor<NodePtr> GetOutControlNodes() const; | ||||
// GetOutAllNodes = OutDataNodes + InControlNodes | // GetOutAllNodes = OutDataNodes + InControlNodes | ||||
Vistor<NodePtr> GetOutAllNodes() const; | Vistor<NodePtr> GetOutAllNodes() const; | ||||
// Get all indata nodes and its outanchor | |||||
// Get all in data nodes and its out-anchor | |||||
Vistor<std::pair<NodePtr, OutDataAnchorPtr>> GetInDataNodesAndAnchors() const; | Vistor<std::pair<NodePtr, OutDataAnchorPtr>> GetInDataNodesAndAnchors() const; | ||||
// Get all outdata nodes and its inanchor | |||||
// Get all out data nodes and its in-anchor | |||||
Vistor<std::pair<NodePtr, InDataAnchorPtr>> GetOutDataNodesAndAnchors() const; | Vistor<std::pair<NodePtr, InDataAnchorPtr>> GetOutDataNodesAndAnchors() const; | ||||
graphStatus InferShapeAndType() const; | graphStatus InferShapeAndType() const; | ||||
@@ -176,7 +176,7 @@ class Node : public std::enable_shared_from_this<Node> { | |||||
void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } | void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } | ||||
NodePtr GetOrigNode(void) { return orig_node_; } | |||||
NodePtr GetOrigNode() { return orig_node_; } | |||||
private: | private: | ||||
bool NodeMembersAreEqual(const Node &r_node) const; | bool NodeMembersAreEqual(const Node &r_node) const; | ||||
@@ -23,7 +23,6 @@ | |||||
#include <string> | #include <string> | ||||
#include <unordered_set> | #include <unordered_set> | ||||
#include <vector> | #include <vector> | ||||
#include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
#include "graph/range_vistor.h" | #include "graph/range_vistor.h" | ||||
@@ -108,6 +107,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
size_t GetInputsSize() const; | size_t GetInputsSize() const; | ||||
size_t GetAllInputsSize() const; | |||||
graphStatus AddOutputDesc(const GeTensorDesc &output_desc); | graphStatus AddOutputDesc(const GeTensorDesc &output_desc); | ||||
graphStatus AddOutputDesc(const string &name, const GeTensorDesc &output_desc); | graphStatus AddOutputDesc(const string &name, const GeTensorDesc &output_desc); | ||||
@@ -122,6 +123,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
GeTensorDescPtr MutableOutputDesc(uint32_t index) const; | GeTensorDescPtr MutableOutputDesc(uint32_t index) const; | ||||
uint32_t GetAllOutputsDescSize() const; | |||||
Vistor<GeTensorDesc> GetAllOutputsDesc() const; | Vistor<GeTensorDesc> GetAllOutputsDesc() const; | ||||
Vistor<GeTensorDescPtr> GetAllOutputsDescPtr() const; | Vistor<GeTensorDescPtr> GetAllOutputsDescPtr() const; | ||||
@@ -132,6 +135,10 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
ConstGeTensorDescPtr GetInputDescPtr(uint32_t index) const; | ConstGeTensorDescPtr GetInputDescPtr(uint32_t index) const; | ||||
ConstGeTensorDescPtr GetInputDescPtrDfault(uint32_t index) const; | |||||
ConstGeTensorDescPtr GetInputDescPtr(const string &name) const; | |||||
graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); | graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); | ||||
graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); | graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); | ||||
@@ -140,7 +147,11 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
bool IsOptionalInput(uint32_t index) const; | bool IsOptionalInput(uint32_t index) const; | ||||
std::map<string, uint32_t> GetAllInputName(); | |||||
std::map<string, uint32_t> GetAllInputName() const; | |||||
void SetAllInputName(const std::map<string, uint32_t> &input_name_idx); | |||||
std::vector<string> GetAllOptionalInputName() const; | |||||
std::map<string, uint32_t> GetAllOutputName(); | std::map<string, uint32_t> GetAllOutputName(); | ||||
@@ -225,6 +236,14 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
std::string GetOpEngineName() const; | std::string GetOpEngineName() const; | ||||
graphStatus AddSubgraphName(const std::string &name); | |||||
const std::map<std::string, uint32_t> &GetSubgraphNameIndexes() const; | |||||
std::string GetSubgraphInstanceName(uint32_t index) const; | |||||
const std::vector<std::string> &GetSubgraphInstanceNames() const; | |||||
void AddSubgraphInstanceName(std::string name); | |||||
void RemoveSubgraphInstanceName(const std::string &name); | |||||
protected: | protected: | ||||
ProtoAttrMapHelper MutableAttrMap() override; | ProtoAttrMapHelper MutableAttrMap() override; | ||||
ConstProtoAttrMapHelper GetAttrMap() const override; | ConstProtoAttrMapHelper GetAttrMap() const override; | ||||
@@ -236,9 +255,9 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const; | bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const; | ||||
GeIrProtoHelper<ge::proto::OpDef> op_def_; | GeIrProtoHelper<ge::proto::OpDef> op_def_; | ||||
std::vector<std::string> subgraph_instance_names_; | |||||
std::map<std::string, uint32_t> subgraph_names_to_index_; | |||||
vector<GeTensorDescPtr> inputs_desc_{}; | vector<GeTensorDescPtr> inputs_desc_{}; | ||||
map<string, uint32_t> input_name_idx_{}; | |||||
std::unordered_set<string> optional_input_names_{}; | |||||
vector<GeTensorDescPtr> outputs_desc_{}; | vector<GeTensorDescPtr> outputs_desc_{}; | ||||
map<string, uint32_t> output_name_idx_{}; | map<string, uint32_t> output_name_idx_{}; | ||||
std::function<graphStatus(Operator &)> infer_func_ = nullptr; | std::function<graphStatus(Operator &)> infer_func_ = nullptr; | ||||
@@ -21,7 +21,6 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/operator_factory.h" | #include "graph/operator_factory.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -47,7 +46,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactoryImpl { | |||||
static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func); | static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func); | ||||
private: | |||||
static shared_ptr<std::map<string, OpCreator>> operator_creators_; | static shared_ptr<std::map<string, OpCreator>> operator_creators_; | ||||
static shared_ptr<std::map<string, InferShapeFunc>> operator_infershape_funcs_; | static shared_ptr<std::map<string, InferShapeFunc>> operator_infershape_funcs_; | ||||
static shared_ptr<std::map<string, InferFormatFunc>> operator_inferformat_funcs_; | static shared_ptr<std::map<string, InferFormatFunc>> operator_inferformat_funcs_; | ||||
@@ -18,8 +18,8 @@ | |||||
#define INC_GRAPH_SHAPE_REFINER_H_ | #define INC_GRAPH_SHAPE_REFINER_H_ | ||||
#include <string> | #include <string> | ||||
#include "external/graph/inference_context.h" | #include "external/graph/inference_context.h" | ||||
#include "external/graph/ge_error_codes.h" | #include "external/graph/ge_error_codes.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
@@ -27,8 +27,10 @@ namespace ge { | |||||
// ShapeRefiner performs shape inference for compute graphs | // ShapeRefiner performs shape inference for compute graphs | ||||
class ShapeRefiner { | class ShapeRefiner { | ||||
public: | public: | ||||
static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); | |||||
static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph); | |||||
static graphStatus InferShapeAndType(const NodePtr &node, bool before_subgraph); | |||||
static graphStatus InferShapeAndType(const NodePtr &node); | static graphStatus InferShapeAndType(const NodePtr &node); | ||||
static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); | |||||
private: | private: | ||||
static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); | static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); | ||||
@@ -14,8 +14,8 @@ | |||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#ifndef INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||||
#define INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||||
#ifndef INC_GRAPH_USR_TYPES_H_ | |||||
#define INC_GRAPH_USR_TYPES_H_ | |||||
#include <atomic> | #include <atomic> | ||||
#include <memory> | #include <memory> | ||||
@@ -130,4 +130,4 @@ struct UsrQuantizeFactorParams { | |||||
#undef USR_TYPE_BYTES_DEC | #undef USR_TYPE_BYTES_DEC | ||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||||
#endif // INC_GRAPH_USR_TYPES_H_ |
@@ -99,8 +99,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||||
static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); | static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); | ||||
static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer); | static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer); | ||||
// Value will be moved | // Value will be moved | ||||
static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, | |||||
vector<Buffer> &listBuffer); | |||||
static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer); | |||||
static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer); | static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer); | ||||
static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector<vector<int64_t>> &value); | static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector<vector<int64_t>> &value); | ||||
@@ -116,6 +115,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||||
static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc); | static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc); | ||||
static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj); | |||||
class AttrHolderAdapter { | class AttrHolderAdapter { | ||||
public: | public: | ||||
AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {} | AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {} | ||||
@@ -137,6 +137,18 @@ class GraphUtils { | |||||
static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, | static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, | ||||
const std::vector<OpDescPtr> &vec_op_desc); | const std::vector<OpDescPtr> &vec_op_desc); | ||||
/// | |||||
/// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst | |||||
/// @param [in] src | |||||
/// @param [in] dsts | |||||
/// @param [in] insert_node | |||||
/// @param [in] input_index | |||||
/// @param [in] output_index | |||||
/// @return graphStatus | |||||
/// | |||||
static graphStatus InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||||
const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); | |||||
static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); | static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); | ||||
static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node); | static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node); | ||||
@@ -145,16 +157,12 @@ class GraphUtils { | |||||
static void RecordOriginalNames(std::vector<std::string> names_tmp, const ge::NodePtr &node); | static void RecordOriginalNames(std::vector<std::string> names_tmp, const ge::NodePtr &node); | ||||
static bool CheckIsTrainGraph(const ge::ComputeGraphPtr &compute_graph); | |||||
static bool MatchDumpStr(const std::string &suffix); | static bool MatchDumpStr(const std::string &suffix); | ||||
static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false); | static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false); | ||||
static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); | static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); | ||||
static bool CheckGlobalStepNode(const ge::NodePtr &node); | |||||
static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos); | static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos); | ||||
static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); | static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); | ||||
@@ -252,6 +260,315 @@ class GraphUtils { | |||||
/// @return success: GRAPH_SUCESS | /// @return success: GRAPH_SUCESS | ||||
/// | /// | ||||
static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | ||||
static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | |||||
}; | |||||
class ComputeGraphBuilder { | |||||
public: | |||||
ComputeGraphBuilder() : owner_graph_(nullptr) {} | |||||
ComputeGraphBuilder(const ComputeGraphBuilder &) = delete; | |||||
ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete; | |||||
ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete; | |||||
ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete; | |||||
~ComputeGraphBuilder() = default; | |||||
/// | |||||
/// @brief Add node to graph | |||||
/// @param [in] op_desc | |||||
/// @return ComputeGraphBuilder | |||||
/// | |||||
virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc); | |||||
/// | |||||
/// @brief Add data-link among nodes in graph | |||||
/// @param [in] src_name | |||||
/// @param [in] out_anchor_ind | |||||
/// @param [in] dst_name | |||||
/// @param [in] in_anchor_ind | |||||
/// @return ComputeGraphBuilder | |||||
/// | |||||
virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, | |||||
const std::string &dst_name, uint32_t in_anchor_ind); | |||||
/// | |||||
/// @brief Add ctrl-link among nodes in graph | |||||
/// @param [in] src_name | |||||
/// @param [in] dst_name | |||||
/// @return ComputeGraphBuilder | |||||
/// | |||||
virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name); | |||||
/// | |||||
/// @brief Build graph | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return ComputeGraphPtr | |||||
/// | |||||
virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0; | |||||
/// @brief Get node with name | |||||
/// @param [in] name | |||||
/// @return NodePtr | |||||
/// | |||||
NodePtr GetNode(const std::string &name); | |||||
protected: | |||||
/// | |||||
/// @brief Build nodes | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return void | |||||
/// | |||||
void BuildNodes(graphStatus &error_code, std::string &error_msg); | |||||
/// | |||||
/// @brief Build data-links | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return void | |||||
/// | |||||
void BuildDataLinks(graphStatus &error_code, std::string &error_msg); | |||||
/// | |||||
/// @brief Build ctrl-links | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return void | |||||
/// | |||||
void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg); | |||||
ComputeGraphPtr owner_graph_; | |||||
// node_name -> node | |||||
std::map<std::string, NodePtr> node_names_; | |||||
std::vector<OpDescPtr> nodes_; | |||||
// <src_node_name, out_anchor_ind> -> <dst_node_name, in_anchor_ind> | |||||
std::vector<std::pair<std::pair<std::string, uint32_t>, std::pair<std::string, uint32_t>>> data_links_; | |||||
// src_node_name -> dst_node_name | |||||
std::vector<std::pair<std::string, std::string>> ctrl_links_; | |||||
}; | |||||
class CompleteGraphBuilder : public ComputeGraphBuilder { | |||||
public: | |||||
explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {} | |||||
CompleteGraphBuilder(const CompleteGraphBuilder &) = delete; | |||||
CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete; | |||||
CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete; | |||||
CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete; | |||||
~CompleteGraphBuilder() = default; | |||||
/// | |||||
/// @brief Add node to graph | |||||
/// @param [in] op_desc | |||||
/// @return CompleteGraphBuilder | |||||
/// | |||||
CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override; | |||||
/// | |||||
/// @brief Add data-link among nodes in graph | |||||
/// @param [in] src_name | |||||
/// @param [in] out_anchor_ind | |||||
/// @param [in] dst_name | |||||
/// @param [in] in_anchor_ind | |||||
/// @return CompleteGraphBuilder | |||||
/// | |||||
CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, | |||||
uint32_t in_anchor_ind) override; | |||||
/// | |||||
/// @brief Add ctrl-link among nodes in graph | |||||
/// @param [in] src_name | |||||
/// @param [in] dst_name | |||||
/// @return CompleteGraphBuilder | |||||
/// | |||||
CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; | |||||
/// | |||||
/// @brief Set index_th input anchor for graph | |||||
/// @param [in] index | |||||
/// @param [in] node_names | |||||
/// @param [in] anchor_inds | |||||
/// @return CompleteGraphBuilder | |||||
/// | |||||
CompleteGraphBuilder &SetInput(uint32_t index, const std::vector<std::string> &node_names, | |||||
const std::vector<uint32_t> &anchor_inds); | |||||
/// | |||||
/// @brief Set index_th input of graph as useless | |||||
/// @param [in] index | |||||
/// @return CompleteGraphBuilder | |||||
/// | |||||
CompleteGraphBuilder &SetUselessInput(uint32_t index); | |||||
/// | |||||
/// @brief Add output anchor for graph | |||||
/// @param [in] owner_node_name | |||||
/// @param [in] anchor_ind | |||||
/// @return CompleteGraphBuilder | |||||
/// | |||||
CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind); | |||||
/// | |||||
/// @brief Set parent-node of graph | |||||
/// @param [in] parent_node | |||||
/// @return CompleteGraphBuilder | |||||
/// | |||||
CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node); | |||||
/// | |||||
/// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node | |||||
/// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node | |||||
/// @return CompleteGraphBuilder | |||||
/// | |||||
CompleteGraphBuilder &SetInputMapping(const std::map<uint32_t, uint32_t> &input_mapping); | |||||
/// | |||||
/// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind | |||||
/// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node | |||||
/// @return CompleteGraphBuilder | |||||
/// | |||||
CompleteGraphBuilder &SetOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping); | |||||
/// | |||||
/// @brief Build graph | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return ComputeGraphPtr | |||||
/// | |||||
ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; | |||||
private: | |||||
/// | |||||
/// @brief Build inputs | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return void | |||||
/// | |||||
void BuildInputs(graphStatus &error_code, std::string &error_msg); | |||||
/// | |||||
/// @brief Add data node | |||||
/// @param [in] index | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return void | |||||
/// | |||||
NodePtr AddDateNode(uint32_t index, graphStatus &error_code, std::string &error_msg); | |||||
/// | |||||
/// @brief Build outputs | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return void | |||||
/// | |||||
void BuildOutputs(graphStatus &error_code, std::string &error_msg); | |||||
/// | |||||
/// @brief Add NetOutput node | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return NodePtr | |||||
/// | |||||
NodePtr AddNetOutputNode(graphStatus &error_code, std::string &error_msg); | |||||
/// | |||||
/// @brief Add input/output tensor for NetOutput node | |||||
/// @param [in] out_nodes_info | |||||
/// @param [out] net_output_desc | |||||
/// @return graphStatus | |||||
/// | |||||
graphStatus BuildInOutForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||||
OpDescPtr &net_output_desc); | |||||
/// | |||||
/// @brief Add edge for NetOutput node | |||||
/// @param [in] out_nodes_info | |||||
/// @param [out] net_output_node | |||||
/// @return graphStatus | |||||
/// | |||||
graphStatus AddEdgeForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||||
const NodePtr &net_output_node); | |||||
std::string name_; | |||||
NodePtr parent_node_; | |||||
std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_; | |||||
std::vector<std::pair<std::string, uint32_t>> graph_outputs_; | |||||
// index_of_graph_input -> in_anchor_index_of_parent_node | |||||
std::map<uint32_t, uint32_t> input_mapping_; | |||||
// index_of_graph_output -> out_anchor_index_of_parent_node | |||||
std::map<uint32_t, uint32_t> output_mapping_; | |||||
}; | |||||
class PartialGraphBuilder : public ComputeGraphBuilder { | |||||
public: | |||||
PartialGraphBuilder() = default; | |||||
PartialGraphBuilder(const PartialGraphBuilder &) = delete; | |||||
PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete; | |||||
PartialGraphBuilder(const PartialGraphBuilder &&) = delete; | |||||
PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete; | |||||
~PartialGraphBuilder() = default; | |||||
/// | |||||
/// @brief Add node to graph | |||||
/// @param [in] op_desc | |||||
/// @return PartialGraphBuilder | |||||
/// | |||||
PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override; | |||||
/// | |||||
/// @brief Add data-link among nodes in graph | |||||
/// @param [in] src_name | |||||
/// @param [in] out_anchor_ind | |||||
/// @param [in] dst_name | |||||
/// @param [in] in_anchor_ind | |||||
/// @return PartialGraphBuilder | |||||
/// | |||||
PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, | |||||
uint32_t in_anchor_ind) override; | |||||
/// | |||||
/// @brief Add ctrl-link among nodes in graph | |||||
/// @param [in] src_name | |||||
/// @param [in] dst_name | |||||
/// @return PartialGraphBuilder | |||||
/// | |||||
PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; | |||||
/// | |||||
/// @brief Set owner graph | |||||
/// @param [in] graph | |||||
/// @return PartialGraphBuilder | |||||
/// | |||||
PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph); | |||||
/// | |||||
/// @brief Add exist node | |||||
/// @param [in] node | |||||
/// @return PartialGraphBuilder | |||||
/// | |||||
PartialGraphBuilder &AddExistNode(const NodePtr &node); | |||||
/// | |||||
/// @brief Build multi nodes with links | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return ComputeGraphPtr | |||||
/// | |||||
ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; | |||||
private: | |||||
/// | |||||
/// @brief Build exist nodes | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return void | |||||
/// | |||||
void BuildExistNodes(graphStatus &error_code, std::string &error_msg); | |||||
std::vector<NodePtr> exist_nodes_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -56,6 +56,11 @@ class NodeUtils { | |||||
static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); | static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); | ||||
static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); | static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); | ||||
static std::string GetNodeType(const Node &node); | |||||
static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); | |||||
static graphStatus AddSubgraph(Node &node, const ComputeGraphPtr &subgraph); | |||||
private: | private: | ||||
static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | ||||
static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_; | static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_; | ||||
@@ -20,7 +20,6 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/def_types.h" | #include "graph/def_types.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
@@ -29,7 +28,6 @@ | |||||
namespace ge { | namespace ge { | ||||
class OpDesc; | class OpDesc; | ||||
using OpDescPtr = std::shared_ptr<OpDesc>; | using OpDescPtr = std::shared_ptr<OpDesc>; | ||||
class OpDescUtils { | class OpDescUtils { | ||||
@@ -39,55 +37,108 @@ class OpDescUtils { | |||||
OpDescUtils() = default; | OpDescUtils() = default; | ||||
~OpDescUtils() = default; | ~OpDescUtils() = default; | ||||
static bool HasQuantizeFactorParams(const OpDescPtr &op_desc); | |||||
static bool HasQuantizeFactorParams(const OpDesc &op_desc); | |||||
static graphStatus GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant); | |||||
static graphStatus GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant); | |||||
static graphStatus SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant); | |||||
static graphStatus SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant); | |||||
static vector<ge::NodePtr> GetConstInputNode(const ge::Node &node); | |||||
static vector<ConstGeTensorPtr> GetInputData(const vector<ge::NodePtr> &input_nodes); | |||||
static vector<ConstGeTensorPtr> GetWeights(const ge::Node &node); | |||||
static vector<ConstGeTensorPtr> GetWeights(const ge::ConstNodePtr &node); | |||||
static vector<GeTensorPtr> MutableWeights(const ge::Node &node); | |||||
static bool HasQuantizeFactorParams(const OpDescPtr& op_desc); | |||||
static bool HasQuantizeFactorParams(const OpDesc& op_desc); | |||||
static graphStatus GetQuantizeFactorParams(const OpDescPtr& op_desc, QuantizeFactorParams& quant); | |||||
static graphStatus GetQuantizeFactorParams(const OpDesc& op_desc, QuantizeFactorParams& quant); | |||||
static graphStatus SetQuantizeFactorParams(const OpDescPtr& op_desc, const QuantizeFactorParams& quant); | |||||
static graphStatus SetQuantizeFactorParams(OpDesc& op_desc, const QuantizeFactorParams& quant); | |||||
static vector<ge::NodePtr> GetConstInputNode(const ge::Node& node); | |||||
static vector<ConstGeTensorPtr> GetInputData(const vector<ge::NodePtr>& input_nodes); | |||||
static vector<ConstGeTensorPtr> GetWeights(const ge::Node& node); | |||||
static vector<ConstGeTensorPtr> GetWeights(const ge::ConstNodePtr& node); | |||||
static vector<GeTensorPtr> MutableWeights(const ge::Node& node); | |||||
static vector<GeTensorPtr> MutableWeights(const ge::NodePtr node); | static vector<GeTensorPtr> MutableWeights(const ge::NodePtr node); | ||||
static graphStatus SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights); | |||||
static graphStatus SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr> &weights); | |||||
static graphStatus SetWeights(ge::Node& node, const vector<ge::GeTensorPtr>& weights); | |||||
static graphStatus SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr>& weights); | |||||
static graphStatus ClearWeights(ge::NodePtr node); | static graphStatus ClearWeights(ge::NodePtr node); | ||||
static bool ClearInputDesc(ge::OpDescPtr op_desc, uint32_t index); | static bool ClearInputDesc(ge::OpDescPtr op_desc, uint32_t index); | ||||
static bool ClearInputDesc(const ge::NodePtr &node); | |||||
static bool ClearOutputDesc(const ge::OpDescPtr &op_desc, uint32_t index); | |||||
static bool ClearOutputDesc(const ge::NodePtr &node); | |||||
static vector<ge::NodePtr> GetConstInputs(const ge::Node &node); | |||||
static vector<ge::NodePtr> GetConstInputs(const ge::ConstNodePtr &node); | |||||
static size_t GetNonConstInputsSize(const ge::Node &node); | |||||
static bool ClearInputDesc(const ge::NodePtr& node); | |||||
static bool ClearOutputDesc(const ge::OpDescPtr& op_desc, uint32_t index); | |||||
static bool ClearOutputDesc(const ge::NodePtr& node); | |||||
static vector<ge::NodePtr> GetConstInputs(const ge::Node& node); | |||||
static vector<ge::NodePtr> GetConstInputs(const ge::ConstNodePtr& node); | |||||
static size_t GetNonConstInputsSize(const ge::Node& node); | |||||
static size_t GetNonConstInputsSize(ge::ConstNodePtr node); | static size_t GetNonConstInputsSize(ge::ConstNodePtr node); | ||||
// Index: Indicate the index of all non const inputs | |||||
static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const = 0); | |||||
static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const = 0); | |||||
static bool GetNonConstInputIndex(const ge::Node &node, size_t index_non_const, size_t &index); | |||||
static bool GetNonConstInputIndex(const ge::ConstNodePtr &node, size_t index_non_const, size_t &index); | |||||
// Index: Indicate the index of all inputs | |||||
static bool IsNonConstInput(const ge::Node &node, size_t index = 0); | |||||
static bool IsNonConstInput(const ge::ConstNodePtr &node, size_t index = 0); | |||||
static vector<ge::GeTensorDesc> GetNonConstTensorDesc(const ge::ConstNodePtr &node); | |||||
static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr); | |||||
// Index: Indicates the index of all non const inputs | |||||
static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node& node, size_t index_non_const = 0); | |||||
static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr& node, size_t index_non_const = 0); | |||||
static bool GetNonConstInputIndex(const ge::Node& node, size_t index_non_const, size_t& index); | |||||
static bool GetNonConstInputIndex(const ge::ConstNodePtr& node, size_t index_non_const, size_t& index); | |||||
// Index: Indicates the index of all inputs | |||||
static bool IsNonConstInput(const ge::Node& node, size_t index = 0); | |||||
static bool IsNonConstInput(const ge::ConstNodePtr& node, size_t index = 0); | |||||
static vector<ge::GeTensorDesc> GetNonConstTensorDesc(const ge::ConstNodePtr& node); | |||||
static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr& tensor_ptr); | |||||
static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc); | static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc); | ||||
static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr); | static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr); | ||||
static OpDescPtr GetOpDescFromOperator(const Operator &oprt); | |||||
static OpDescPtr GetOpDescFromOperator(const Operator& oprt); | |||||
static OpDescPtr CreateConstOp(const GeTensorPtr &tensor_ptr); | |||||
static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); | |||||
private: | private: | ||||
static GeTensorPtr MutableWeights(ge::OpDesc &op_desc); | |||||
static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); | |||||
static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); | static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); | ||||
static graphStatus SetWeights(ge::OpDesc &op_desc, const GeTensorPtr weight); | |||||
static graphStatus SetWeights(ge::OpDesc& op_desc, const GeTensorPtr weight); | |||||
static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight); | static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight); | ||||
}; | }; | ||||
class OpDescBuilder { | |||||
public: | |||||
OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {} | |||||
OpDescBuilder(const OpDescBuilder&) = delete; | |||||
OpDescBuilder& operator=(const OpDescBuilder&) = delete; | |||||
OpDescBuilder(const OpDescBuilder&&) = delete; | |||||
OpDescBuilder& operator=(const OpDescBuilder&&) = delete; | |||||
~OpDescBuilder() = default; | |||||
/// | |||||
/// @brief Add input | |||||
/// @param [in] name | |||||
/// @return OpDescBuilder | |||||
/// | |||||
OpDescBuilder& AddInput(const std::string& name); | |||||
/// | |||||
/// @brief Add dynamic input | |||||
/// @param [in] name | |||||
/// @param [in] num | |||||
/// @return OpDescBuilder | |||||
/// | |||||
OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num); | |||||
/// | |||||
/// @brief Add output | |||||
/// @param [in] name | |||||
/// @return OpDescBuilder | |||||
/// | |||||
OpDescBuilder& AddOutput(const std::string& name); | |||||
/// | |||||
/// @brief Add dynamic output | |||||
/// @param [in] name | |||||
/// @param [in] num | |||||
/// @return OpDescBuilder | |||||
/// | |||||
OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num); | |||||
/// | |||||
/// @brief Build op_desc | |||||
/// @return OpDescPtr | |||||
/// | |||||
OpDescPtr Build(); | |||||
private: | |||||
std::string name_; | |||||
std::string type_; | |||||
std::vector<std::string> inputs_; | |||||
std::vector<std::string> outputs_; | |||||
}; | |||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_ | #endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_ |
@@ -18,15 +18,14 @@ | |||||
#define INC_GRAPH_UTILS_TENSOR_UTILS_H_ | #define INC_GRAPH_UTILS_TENSOR_UTILS_H_ | ||||
#include <vector> | #include <vector> | ||||
#include "graph/def_types.h" | #include "graph/def_types.h" | ||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
namespace ge { | namespace ge { | ||||
class TensorUtils { | class TensorUtils { | ||||
public: | public: | ||||
static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, uint32_t &size); | |||||
static void SetSize(GeTensorDesc &tensorDesc, uint32_t size); | |||||
static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, int64_t &size); | |||||
static void SetSize(GeTensorDesc &tensorDesc, int64_t size); | |||||
static uint32_t GetWeightSize(const ConstGeTensorPtr &tensorPtr); | static uint32_t GetWeightSize(const ConstGeTensorPtr &tensorPtr); | ||||
static uint32_t GetWeightSize(const GeTensor &tensor); | static uint32_t GetWeightSize(const GeTensor &tensor); | ||||
static uint32_t GetWeightSize(const GeTensorDesc &tensorDesc); | static uint32_t GetWeightSize(const GeTensorDesc &tensorDesc); | ||||
@@ -62,16 +61,16 @@ class TensorUtils { | |||||
static void SetRC(GeTensorDesc &tensorDesc, uint32_t rc); | static void SetRC(GeTensorDesc &tensorDesc, uint32_t rc); | ||||
/// | /// | ||||
/// calculate mem size of the tensor. | |||||
/// calculate tensor mem size. | |||||
/// @param shape tensor shape | /// @param shape tensor shape | ||||
/// @param format tensor format | /// @param format tensor format | ||||
/// @param data_type tensor data type | /// @param data_type tensor data type | ||||
/// @param mem_size -1 means unknown shape,others means mem size | |||||
/// @return GRAPH_SUCCESS:success, others:failed | |||||
/// @param mem_size -1 means unknown shape,other means mem size | |||||
/// @return GRAPH_SUCCESS:success, other:failed | |||||
/// | /// | ||||
static ge::graphStatus CalcTensorMemSize(const GeShape &shape, Format format, DataType data_type, int64_t &mem_size); | static ge::graphStatus CalcTensorMemSize(const GeShape &shape, Format format, DataType data_type, int64_t &mem_size); | ||||
static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp); | |||||
static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp); | |||||
static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); | |||||
static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_ | #endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_ |
@@ -58,6 +58,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||||
include_directories(${GE_SOURCE_DIR}/inc/graph) | include_directories(${GE_SOURCE_DIR}/inc/graph) | ||||
include_directories(${GE_SOURCE_DIR}/inc/common) | include_directories(${GE_SOURCE_DIR}/inc/common) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | ||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
@@ -26,6 +26,8 @@ Anchor::Anchor(const NodePtr &owner_node, int idx) : owner_node_(owner_node), id | |||||
bool Anchor::IsTypeOf(TYPE type) const { return strcmp(Anchor::TypeOf<Anchor>(), type) == 0; } | bool Anchor::IsTypeOf(TYPE type) const { return strcmp(Anchor::TypeOf<Anchor>(), type) == 0; } | ||||
size_t Anchor::GetPeerAnchorsSize() const { return peer_anchors_.size(); } | |||||
Anchor::Vistor<AnchorPtr> Anchor::GetPeerAnchors() const { | Anchor::Vistor<AnchorPtr> Anchor::GetPeerAnchors() const { | ||||
vector<AnchorPtr> ret; | vector<AnchorPtr> ret; | ||||
for (const auto &anchor : peer_anchors_) { | for (const auto &anchor : peer_anchors_) { | ||||
@@ -32,8 +32,7 @@ Buffer::Buffer(const Buffer &other) { | |||||
buffer_ = other.buffer_; | buffer_ = other.buffer_; | ||||
} | } | ||||
// default | |||||
Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { | |||||
Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { // default | |||||
auto proto_msg = data_.GetProtoMsg(); | auto proto_msg = data_.GetProtoMsg(); | ||||
if (proto_msg != nullptr) { | if (proto_msg != nullptr) { | ||||
try { | try { | ||||
@@ -15,9 +15,7 @@ | |||||
*/ | */ | ||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include <deque> | #include <deque> | ||||
#include "./format_refiner.h" | #include "./format_refiner.h" | ||||
#include "./ge_context.h" | #include "./ge_context.h" | ||||
#include "debug/ge_attr_define.h" | #include "debug/ge_attr_define.h" | ||||
@@ -41,7 +39,7 @@ const size_t OUTPUT_PARAM_SIZE = 2; | |||||
} // namespace | } // namespace | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name) | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name) | ||||
: nodes_(), input_nodes_(), sub_graph_(), name_(name), is_valid_flag_(false), need_iteration_(false) { | |||||
: name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) { | |||||
attrs_.InitDefault(); | attrs_.InitDefault(); | ||||
} | } | ||||
ComputeGraph::~ComputeGraph() {} | ComputeGraph::~ComputeGraph() {} | ||||
@@ -154,7 +152,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNod | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual( | ||||
const ComputeGraph &r_graph) const { | const ComputeGraph &r_graph) const { | ||||
return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.sub_graph_.size()") && | |||||
return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.subgraphs_.size()") && | |||||
IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") && | IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") && | ||||
VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) && | VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) && | ||||
IsEqual(this->name_, r_graph.name_, "graph.name_") && | IsEqual(this->name_, r_graph.name_, "graph.name_") && | ||||
@@ -398,6 +396,165 @@ graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr<ComputeGraph> &su | |||||
} | } | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph) { | |||||
if (subgraph == nullptr) { | |||||
GE_LOGE("Try to add a null subgraph, name %s", name.c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
auto parent_graph = subgraph->GetParentGraph(); | |||||
if (parent_graph == nullptr) { | |||||
GE_LOGE("Try to add subgraph without parent graph, name %s", name.c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
auto parent_node = subgraph->GetParentNode(); | |||||
if (parent_node == nullptr) { | |||||
GE_LOGE("Try to add a subgraph without parent node, name %s", name.c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
if (parent_node->GetOwnerComputeGraph() != parent_graph) { | |||||
GE_LOGE( | |||||
"Try to add a subgraph which parent node's parent graph is not equal to " | |||||
"the subgraph's parent graph, subgraph name %s, parent node name %s", | |||||
subgraph->GetName().c_str(), parent_graph->GetName().c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
if (!this->parent_graph_.expired()) { | |||||
GE_LOGE("The subgraphs can only be added to the root graph"); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
if (name != subgraph->GetName()) { | |||||
GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str()); | |||||
} | |||||
sub_graph_.push_back(subgraph); | |||||
names_to_subgraph_[name] = subgraph; | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
ComputeGraph::AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph) { | |||||
if (subgraph == nullptr) { | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
return AddSubgraph(subgraph->GetName(), subgraph); | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(const std::string &name) { | |||||
auto iter = names_to_subgraph_.find(name); | |||||
if (iter == names_to_subgraph_.end()) { | |||||
return; | |||||
} | |||||
for (auto vec_iter = sub_graph_.begin(); vec_iter != sub_graph_.end(); ++vec_iter) { | |||||
if (*vec_iter == iter->second) { | |||||
sub_graph_.erase(vec_iter); | |||||
break; | |||||
} | |||||
} | |||||
names_to_subgraph_.erase(iter); | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph( | |||||
const std::shared_ptr<ComputeGraph> &subgraph) { | |||||
if (subgraph != nullptr) { | |||||
RemoveSubgraph(subgraph->GetName()); | |||||
} | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr<ComputeGraph> ComputeGraph::GetSubgraph( | |||||
const std::string &name) const { | |||||
auto iter = names_to_subgraph_.find(name); | |||||
return iter == names_to_subgraph_.end() ? nullptr : iter->second; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector<std::shared_ptr<ComputeGraph>> | |||||
ComputeGraph::GetAllSubgraphs() const { | |||||
return sub_graph_; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<ComputeGraph> ComputeGraph::GetParentGraph() { | |||||
return parent_graph_.lock(); | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentGraph( | |||||
const shared_ptr<ComputeGraph> &parent) { | |||||
parent_graph_ = parent; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<Node> ComputeGraph::GetParentNode() { | |||||
return parent_node_.lock(); | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode(const shared_ptr<Node> &parent) { | |||||
parent_node_ = parent; | |||||
} | |||||
/// | |||||
/// @brief Update input-mapping | |||||
/// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input | |||||
/// @return graphStatus | |||||
/// | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping) { | |||||
for (auto &input : input_nodes_) { | |||||
uint32_t cur_index = 0; | |||||
if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { | |||||
continue; | |||||
} | |||||
auto iter = input_mapping.find(cur_index); | |||||
if (iter == input_mapping.end()) { | |||||
continue; | |||||
} | |||||
if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||||
GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
/// | |||||
/// @brief Update output-mapping | |||||
/// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output | |||||
/// @return graphStatus | |||||
/// | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping) { | |||||
NodePtr net_output = FindNode(kNodeNameNetOutput); | |||||
if (net_output == nullptr) { | |||||
GE_LOGE("UpdateOutputMapping failed: node %s not exist in graph.", kNodeNameNetOutput); | |||||
return GRAPH_FAILED; | |||||
} | |||||
OpDescPtr op_desc = net_output->GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
GE_LOGE("UpdateOutputMapping failed: op_desc is NULL."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
size_t num = op_desc->GetInputsSize(); | |||||
for (size_t i = 0; i < num; i++) { | |||||
GeTensorDesc tensor = op_desc->GetInputDesc(i); | |||||
uint32_t cur_index = 0; | |||||
if (!ge::AttrUtils::GetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { | |||||
continue; | |||||
} | |||||
auto iter = output_mapping.find(cur_index); | |||||
if (iter == output_mapping.end()) { | |||||
continue; | |||||
} | |||||
if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||||
GE_LOGE("UpdateOutputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (op_desc->UpdateInputDesc(i, tensor) != GRAPH_SUCCESS) { | |||||
GE_LOGE("UpdateOutputMapping failed: update %u input_tensor failed.", i); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { | ||||
std::vector<NodePtr> node_vec = nodes_; | std::vector<NodePtr> node_vec = nodes_; | ||||
for (const auto &node : GetAllNodes()) { | for (const auto &node : GetAllNodes()) { | ||||
@@ -551,6 +708,23 @@ graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<No | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { | ||||
auto ret = TopologicalSortingSubgraph(); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Sub graph partition Failed"); | |||||
return ret; | |||||
} | |||||
// partition sub graph | |||||
for (const auto &sub_graph : GetAllSubgraphs()) { | |||||
ret = sub_graph->TopologicalSortingSubgraph(); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Sub graph topological sort Failed"); | |||||
return ret; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSortingSubgraph() { | |||||
std::vector<NodePtr> node_vec; | std::vector<NodePtr> node_vec; | ||||
std::map<NodePtr, uint32_t> map_in_edge_num; | std::map<NodePtr, uint32_t> map_in_edge_num; | ||||
bool use_BFS = false; | bool use_BFS = false; | ||||
@@ -598,6 +772,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog | |||||
node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null] | node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null] | ||||
nodes_.push_back(node); | nodes_.push_back(node); | ||||
} | } | ||||
is_valid_flag_ = true; | is_valid_flag_ = true; | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
@@ -614,7 +789,7 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||||
verify_isolated = true; | verify_isolated = true; | ||||
} | } | ||||
} | } | ||||
for (const auto &node : GetAllNodes()) { | |||||
for (const auto &node : GetDirectNode()) { | |||||
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | ||||
map_in_edge_num[node] = static_cast<uint32_t>(GetInEdgeSize(node)); | map_in_edge_num[node] = static_cast<uint32_t>(GetInEdgeSize(node)); | ||||
if (map_in_edge_num[node] == 0) { | if (map_in_edge_num[node] == 0) { | ||||
@@ -640,16 +815,16 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||||
/// 2. Compare two indices, if not match, swap the positions of two inputs | /// 2. Compare two indices, if not match, swap the positions of two inputs | ||||
/// *: Remind: stack is reverse-order | /// *: Remind: stack is reverse-order | ||||
for (size_t i = 0; i < stack.size(); ++i) { | for (size_t i = 0; i < stack.size(); ++i) { | ||||
// [stack: should not be null] | |||||
// If not found in 'inputs_order_', skip it | |||||
auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); | |||||
GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue); | |||||
auto inx_i = it_i - inputs_order_.begin(); | |||||
for (size_t j = i + 1; j < stack.size(); ++j) { | for (size_t j = i + 1; j < stack.size(); ++j) { | ||||
// If not found in 'inputs_order_', skip it | // If not found in 'inputs_order_', skip it | ||||
auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); | |||||
GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue); | |||||
auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName()); | auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName()); | ||||
GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue); | GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue); | ||||
// Compare index, swap them if it should be | // Compare index, swap them if it should be | ||||
auto inx_i = it_i - inputs_order_.begin(); | |||||
auto inx_j = it_j - inputs_order_.begin(); | auto inx_j = it_j - inputs_order_.begin(); | ||||
GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j])); | GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j])); | ||||
} | } | ||||
@@ -663,7 +838,7 @@ size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { | |||||
return in_edge_size; | return in_edge_size; | ||||
} | } | ||||
for (const auto &anchor : node->GetAllInDataAnchors()) { | for (const auto &anchor : node->GetAllInDataAnchors()) { | ||||
in_edge_size = in_edge_size + anchor->GetPeerAnchors().size(); | |||||
in_edge_size = in_edge_size + anchor->GetPeerAnchorsSize(); | |||||
// Break flow control data loop. | // Break flow control data loop. | ||||
OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor(); | OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor(); | ||||
if ((out_anchor != nullptr) && (out_anchor->GetOwnerNode() != nullptr)) { | if ((out_anchor != nullptr) && (out_anchor->GetOwnerNode() != nullptr)) { | ||||
@@ -680,10 +855,11 @@ size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { | |||||
} | } | ||||
} | } | ||||
if (node->GetInControlAnchor() != nullptr) { | if (node->GetInControlAnchor() != nullptr) { | ||||
in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchors().size(); | |||||
in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchorsSize(); | |||||
} | } | ||||
return in_edge_size; | return in_edge_size; | ||||
} | } | ||||
size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { | size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { | ||||
size_t out_edge_size = 0; | size_t out_edge_size = 0; | ||||
if (node == nullptr) { | if (node == nullptr) { | ||||
@@ -699,7 +875,7 @@ size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { | |||||
} | } | ||||
} | } | ||||
if (node->GetOutControlAnchor() != nullptr) { | if (node->GetOutControlAnchor() != nullptr) { | ||||
if (out_edge_size > (UINT32_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) { | |||||
if (out_edge_size > (UINT64_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) { | |||||
return 0; | return 0; | ||||
} | } | ||||
out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchors().size(); | out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchors().size(); | ||||
@@ -724,17 +900,18 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { | |||||
peer_in_anchor->GetOwnerNode()->GetName().c_str())); | peer_in_anchor->GetOwnerNode()->GetName().c_str())); | ||||
} | } | ||||
} | } | ||||
GE_IF_BOOL_EXEC(node->GetOutControlAnchor() == nullptr, GELOGE(GRAPH_FAILED, "Out control anchor is null"); | |||||
return ); | |||||
for (const auto &peer_in_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) { | |||||
GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||||
GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||||
peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||||
} | |||||
for (const auto &peer_in_anchor : node->GetOutControlAnchor()->GetPeerInDataAnchors()) { | |||||
GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||||
GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||||
peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||||
auto out_control_anchor = node->GetOutControlAnchor(); | |||||
if (out_control_anchor != nullptr) { | |||||
for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) { | |||||
GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||||
GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||||
peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||||
} | |||||
for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) { | |||||
GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||||
GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||||
peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -18,21 +18,9 @@ | |||||
#define COMMON_GRAPH_DEBUG_GE_LOG_H_ | #define COMMON_GRAPH_DEBUG_GE_LOG_H_ | ||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
#include "toolchain/slog.h" | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#define GE_MOD_ID GE | |||||
#ifdef _MSC_VER | |||||
#define FUNC_NAME __FUNCTION__ | |||||
#else | |||||
#define FUNC_NAME __PRETTY_FUNCTION__ | |||||
#endif | |||||
#define D_GE_LOGE(fmt, ...) \ | |||||
dlog_error(static_cast<int>(GE_MOD_ID), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define GE_LOGE(...) D_GE_LOGE(__VA_ARGS__) | |||||
#define GE_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) | |||||
#define GE_LOGI_IF(condition, ...) \ | #define GE_LOGI_IF(condition, ...) \ | ||||
if ((condition)) { \ | if ((condition)) { \ | ||||
@@ -44,15 +32,15 @@ | |||||
GELOGW(__VA_ARGS__); \ | GELOGW(__VA_ARGS__); \ | ||||
} | } | ||||
#define GE_LOGE_IF(condition, ...) \ | |||||
if ((condition)) { \ | |||||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
#define GE_LOGE_IF(condition, ...) \ | |||||
if ((condition)) { \ | |||||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
} | } | ||||
#define GE_CHK_STATUS_RET_NOLOG(expr) \ | #define GE_CHK_STATUS_RET_NOLOG(expr) \ | ||||
do { \ | do { \ | ||||
const ge::graphStatus _status = (expr); \ | const ge::graphStatus _status = (expr); \ | ||||
if (_status != ge::GRAPH_SUCCESS) { \ | |||||
if (ge::SUCCESS != _status) { \ | |||||
return _status; \ | return _status; \ | ||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
@@ -61,7 +49,7 @@ | |||||
do { \ | do { \ | ||||
bool b = (expr); \ | bool b = (expr); \ | ||||
if (!b) { \ | if (!b) { \ | ||||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
return _status; \ | return _status; \ | ||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
@@ -85,7 +73,7 @@ | |||||
do { \ | do { \ | ||||
const ge::graphStatus _status = (expr); \ | const ge::graphStatus _status = (expr); \ | ||||
if (_status) { \ | if (_status) { \ | ||||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
return _status; \ | return _status; \ | ||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
@@ -95,7 +83,7 @@ | |||||
{ \ | { \ | ||||
bool b = (expr); \ | bool b = (expr); \ | ||||
if (b) { \ | if (b) { \ | ||||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
exec_expr; \ | exec_expr; \ | ||||
} \ | } \ | ||||
} | } | ||||
@@ -119,63 +107,41 @@ | |||||
} while (0) | } while (0) | ||||
// If expr is not true, the log is printed and a custom statement is executed | // If expr is not true, the log is printed and a custom statement is executed | ||||
#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ | |||||
{ \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
exec_expr; \ | |||||
} \ | |||||
} | |||||
// If expr is not true, the log is printed and a custom statement is executed | |||||
#define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ | |||||
{ \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
GELOGI(__VA_ARGS__); \ | |||||
exec_expr; \ | |||||
} \ | |||||
#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ | |||||
{ \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
exec_expr; \ | |||||
} \ | |||||
} | } | ||||
// If expr is not true, the log is printed and a custom statement is executed | // If expr is not true, the log is printed and a custom statement is executed | ||||
#define GE_CHK_BOOL_EXEC_DEBUG(expr, exec_expr, ...) \ | |||||
{ \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
GELOGD(__VA_ARGS__); \ | |||||
exec_expr; \ | |||||
} \ | |||||
#define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ | |||||
{ \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
GELOGI(__VA_ARGS__); \ | |||||
exec_expr; \ | |||||
} \ | |||||
} | } | ||||
// If expr is not GRAPH_SUCCESS, print the log and return the same value | // If expr is not GRAPH_SUCCESS, print the log and return the same value | ||||
#define GE_CHK_STATUS_RET(expr, ...) \ | |||||
do { \ | |||||
const ge::graphStatus _status = (expr); \ | |||||
if (_status != ge::GRAPH_SUCCESS) { \ | |||||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
return _status; \ | |||||
} \ | |||||
#define GE_CHK_STATUS_RET(expr, ...) \ | |||||
do { \ | |||||
const ge::graphStatus _status = (expr); \ | |||||
if (ge::SUCCESS != _status) { \ | |||||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
return _status; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ | |||||
try { \ | |||||
exec_expr0; \ | |||||
} catch (...) { \ | |||||
GELOGE(ge::GRAPH_FAILED, "Make shared failed"); \ | |||||
exec_expr1; \ | |||||
#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ | |||||
try { \ | |||||
exec_expr0; \ | |||||
} catch (...) { \ | |||||
GELOGE(ge::FAILED, "Make shared failed"); \ | |||||
exec_expr1; \ | |||||
} | } | ||||
/// CCE related macro definition | |||||
/// If expr is not CC_STATUS_GRAPH_SUCCESS, print the log and return | |||||
#define GE_CHK_CCE_RET(expr) \ | |||||
do { \ | |||||
ccgraphStatus_t _cc_ret = (expr); \ | |||||
if (_cc_ret != CC_STATUS_GRAPH_SUCCESS) { \ | |||||
GELOGE(ge::GRAPH_FAILED, "Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
return ge::GRAPH_FAILED; \ | |||||
} \ | |||||
} while (0) | |||||
#endif // COMMON_GRAPH_DEBUG_GE_LOG_H_ | #endif // COMMON_GRAPH_DEBUG_GE_LOG_H_ | ||||
@@ -25,7 +25,6 @@ | |||||
#include <string> | #include <string> | ||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/debug/ge_log.h" | #include "graph/debug/ge_log.h" | ||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
@@ -15,12 +15,10 @@ | |||||
*/ | */ | ||||
#include "graph/debug/graph_debug.h" | #include "graph/debug/graph_debug.h" | ||||
#include <algorithm> | #include <algorithm> | ||||
#include <unordered_set> | #include <unordered_set> | ||||
#include <vector> | #include <vector> | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#define TAB " " | #define TAB " " | ||||
@@ -16,13 +16,11 @@ | |||||
#ifndef COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | #ifndef COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | ||||
#define COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | #define COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | ||||
#include <cstdint> | #include <cstdint> | ||||
#include <fstream> | #include <fstream> | ||||
#include <iostream> | #include <iostream> | ||||
#include <sstream> | #include <sstream> | ||||
#include <string> | #include <string> | ||||
#include "external/graph/graph.h" | #include "external/graph/graph.h" | ||||
#include "./ge_error_codes.h" | #include "./ge_error_codes.h" | ||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
@@ -15,9 +15,7 @@ | |||||
*/ | */ | ||||
#include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
#include <map> | #include <map> | ||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
@@ -14,14 +14,12 @@ | |||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#include "graph/format_refiner.h" | |||||
#include "format_refiner.h" | |||||
#include <deque> | #include <deque> | ||||
#include <iostream> | #include <iostream> | ||||
#include <set> | #include <set> | ||||
#include <unordered_map> | #include <unordered_map> | ||||
#include <unordered_set> | #include <unordered_set> | ||||
#include "./compute_graph.h" | #include "./compute_graph.h" | ||||
#include "./ge_error_codes.h" | #include "./ge_error_codes.h" | ||||
#include "./graph/ge_tensor.h" | #include "./graph/ge_tensor.h" | ||||
@@ -57,6 +55,7 @@ graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { | |||||
} | } | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points, | graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points, | ||||
std::vector<ge::NodePtr> &data_nodes, | std::vector<ge::NodePtr> &data_nodes, | ||||
std::unordered_map<ge::NodePtr, bool> &node_status) { | std::unordered_map<ge::NodePtr, bool> &node_status) { | ||||
@@ -82,10 +81,10 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||||
// consider special node save process | // consider special node save process | ||||
// get all input desc format | // get all input desc format | ||||
bool node_is_all_nd = false; | bool node_is_all_nd = false; | ||||
for (uint32_t i = 0; i < static_cast<uint32_t>(op_desc->GetInputsSize()); i++) { | |||||
auto input_desc = op_desc->GetInputDesc(i); | |||||
auto input_size = static_cast<uint32_t>(op_desc->GetInputsSize()); | |||||
for (uint32_t i = 0; i < input_size; i++) { | |||||
// Operator pre-set format but not origin format | // Operator pre-set format but not origin format | ||||
auto input_format = input_desc.GetFormat(); | |||||
auto input_format = op_desc->MutableInputDesc(i)->GetFormat(); | |||||
// Pre-save data node and default infer fail | // Pre-save data node and default infer fail | ||||
if (node_ptr->GetType() == DATA) { | if (node_ptr->GetType() == DATA) { | ||||
data_nodes.push_back(node_ptr); | data_nodes.push_back(node_ptr); | ||||
@@ -95,9 +94,9 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||||
} | } | ||||
} | } | ||||
// Get all output desc format | // Get all output desc format | ||||
for (uint32_t i = 0; i < static_cast<uint32_t>(op_desc->GetOutputsSize()); i++) { | |||||
GeTensorDesc output_desc = op_desc->GetOutputDesc(i); | |||||
auto output_format = output_desc.GetFormat(); | |||||
auto output_size = static_cast<uint32_t>(op_desc->GetOutputsSize()); | |||||
for (uint32_t i = 0; i < output_size; i++) { | |||||
auto output_format = op_desc->MutableOutputDesc(i)->GetFormat(); | |||||
if (output_format != FORMAT_ND && output_format != FORMAT_RESERVED) { | if (output_format != FORMAT_ND && output_format != FORMAT_RESERVED) { | ||||
node_is_all_nd = true; | node_is_all_nd = true; | ||||
} | } | ||||
@@ -145,7 +144,8 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||||
for (const auto &in_anchor : node->GetAllInDataAnchors()) { | for (const auto &in_anchor : node->GetAllInDataAnchors()) { | ||||
GELOGD("Node is [%s] [B]", (node->GetName()).c_str()); | GELOGD("Node is [%s] [B]", (node->GetName()).c_str()); | ||||
auto in_data_anchor_idx = in_anchor->GetIdx(); | auto in_data_anchor_idx = in_anchor->GetIdx(); | ||||
auto to_be_set_format = (node->GetOpDesc()->GetInputDesc(in_data_anchor_idx)).GetOriginFormat(); | |||||
auto to_be_set_format = | |||||
node->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx))->GetOriginFormat(); | |||||
if (to_be_set_format == FORMAT_ND) { | if (to_be_set_format == FORMAT_ND) { | ||||
GELOGD("Node [%s] [B], format is ND", (node->GetName()).c_str()); | GELOGD("Node [%s] [B], format is ND", (node->GetName()).c_str()); | ||||
continue; | continue; | ||||
@@ -162,7 +162,7 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||||
} | } | ||||
// Check format whether have been set | // Check format whether have been set | ||||
int idx = peer_out_data_anchor->GetIdx(); | int idx = peer_out_data_anchor->GetIdx(); | ||||
auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(idx); | |||||
auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(idx)); | |||||
if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | ||||
auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | ||||
if (dim_num == 0) { | if (dim_num == 0) { | ||||
@@ -182,7 +182,7 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||||
ge_tensor_desc.SetOriginFormat(to_be_set_format); | ge_tensor_desc.SetOriginFormat(to_be_set_format); | ||||
ge_tensor_desc.SetFormat(to_be_set_format); | ge_tensor_desc.SetFormat(to_be_set_format); | ||||
(void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(idx, ge_tensor_desc); | |||||
(void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||||
// Call operator infer format api (forward) to get out format | // Call operator infer format api (forward) to get out format | ||||
GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); | GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); | ||||
@@ -205,7 +205,8 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||||
GELOGD("Node is [%s] [F]", (node->GetName()).c_str()); | GELOGD("Node is [%s] [F]", (node->GetName()).c_str()); | ||||
GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); | ||||
auto out_data_anchor_idx = out_data_anchor->GetIdx(); | auto out_data_anchor_idx = out_data_anchor->GetIdx(); | ||||
auto to_be_set_format = (node->GetOpDesc()->GetOutputDesc(out_data_anchor_idx)).GetOriginFormat(); | |||||
auto to_be_set_format = | |||||
node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(out_data_anchor_idx))->GetOriginFormat(); | |||||
if (to_be_set_format == FORMAT_ND) { | if (to_be_set_format == FORMAT_ND) { | ||||
GELOGD("Node [%s] format is ND.[F]", (node->GetName()).c_str()); | GELOGD("Node [%s] format is ND.[F]", (node->GetName()).c_str()); | ||||
continue; | continue; | ||||
@@ -222,7 +223,7 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||||
} | } | ||||
// Check format whether have been set | // Check format whether have been set | ||||
int idx = peer_in_data_anchor->GetIdx(); | int idx = peer_in_data_anchor->GetIdx(); | ||||
auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(idx); | |||||
auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(idx)); | |||||
if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | ||||
auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | ||||
if (dim_num == 0) { | if (dim_num == 0) { | ||||
@@ -285,9 +286,9 @@ void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer = | |||||
graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format, | graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format, | ||||
std::unordered_map<ge::NodePtr, bool> &node_status) { | std::unordered_map<ge::NodePtr, bool> &node_status) { | ||||
bool is_internal_format = TypeUtils::IsInternalFormat(data_format); | bool is_internal_format = TypeUtils::IsInternalFormat(data_format); | ||||
bool need_process = ((!is_first_infer) && (is_internal_format == false) && (data_format != FORMAT_ND)); | |||||
bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND); | |||||
if (!need_process) { | if (!need_process) { | ||||
GELOGI("no necessary to do DataNodeFormatProcess.IsFirstInfer: %d, data_format:%s", is_first_infer, | |||||
GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer, | |||||
TypeUtils::FormatToSerialString(data_format).c_str()); | TypeUtils::FormatToSerialString(data_format).c_str()); | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
@@ -378,9 +379,9 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) | |||||
/// Notice: ignore 5D formats | /// Notice: ignore 5D formats | ||||
auto data_format = graph->GetDataFormat(); | auto data_format = graph->GetDataFormat(); | ||||
status = DataNodeFormatProcess(data_nodes, data_format, node_status); | status = DataNodeFormatProcess(data_nodes, data_format, node_status); | ||||
// Set infer flag to false | // Set infer flag to false | ||||
SetInferOrigineFormatFlag(false); | SetInferOrigineFormatFlag(false); | ||||
return status; | return status; | ||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -42,6 +42,8 @@ const std::string ATTR_NAME_BIAS = "bias"; | |||||
const std::string ATTR_NAME_BIAS_TERM = "bias_term"; | const std::string ATTR_NAME_BIAS_TERM = "bias_term"; | ||||
const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; | |||||
const std::string ATTR_NAME_PAD = "pad"; | const std::string ATTR_NAME_PAD = "pad"; | ||||
const std::string ATTR_NAME_PADS = "pad"; | const std::string ATTR_NAME_PADS = "pad"; | ||||
@@ -83,6 +85,7 @@ const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; | |||||
const std::string ATTR_NAME_AXIS = "axis"; | const std::string ATTR_NAME_AXIS = "axis"; | ||||
const std::string ATTR_NAME_BROADCAST = "broadcast"; | const std::string ATTR_NAME_BROADCAST = "broadcast"; | ||||
const std::string ATTR_NAME_OUTPUT = "output"; | |||||
const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; | const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; | ||||
const std::string ATTR_NAME_TIDX = "t_idx"; | const std::string ATTR_NAME_TIDX = "t_idx"; | ||||
@@ -103,6 +106,13 @@ const std::string ATTR_NAME_TSHAPE = "Tshape"; | |||||
const std::string ATTR_NAME_NAN_OPT = "nan_opt"; | const std::string ATTR_NAME_NAN_OPT = "nan_opt"; | ||||
const std::string ATTR_NAME_AIPP = "aipp"; | const std::string ATTR_NAME_AIPP = "aipp"; | ||||
const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; | |||||
const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||||
const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; | |||||
const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; | |||||
const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; | |||||
const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; | const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; | ||||
const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; | const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; | ||||
@@ -111,6 +121,7 @@ const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; | |||||
const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; | const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; | ||||
const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; | const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; | ||||
const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; | const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; | ||||
const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||||
const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; | const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; | ||||
const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | ||||
@@ -122,15 +133,11 @@ const std::string ATTR_NAME_WEIGHTS = "value"; | |||||
const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; | const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; | ||||
const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; | const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; | ||||
const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; | const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; | ||||
const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||||
const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||||
const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; | |||||
const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; | |||||
const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; | |||||
const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; | const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; | ||||
const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; | const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; | ||||
const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; | |||||
const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; | |||||
const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; | |||||
// To be deleted | // To be deleted | ||||
const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; | const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; | ||||
@@ -144,15 +151,13 @@ const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; | |||||
const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | ||||
const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | ||||
const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||||
// Refinedet | // Refinedet | ||||
const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; | const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; | ||||
const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||||
const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; | const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; | ||||
const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | ||||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||||
const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||||
const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||||
// _Arg | // _Arg | ||||
const std::string ATTR_NAME_INDEX = "index"; | const std::string ATTR_NAME_INDEX = "index"; | ||||
@@ -242,6 +247,30 @@ const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; | |||||
const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; | const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; | ||||
const std::string BATCHNORM_ATTR_SCALE = "scale"; | const std::string BATCHNORM_ATTR_SCALE = "scale"; | ||||
const std::string BATCHNORM_ATTR_BIAS = "bias"; | const std::string BATCHNORM_ATTR_BIAS = "bias"; | ||||
const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; | |||||
const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; | |||||
const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; | |||||
// huberloss | |||||
const std::string HUBER_LOSS_ATTR_DELTA = "delta"; | |||||
// SSDRealDivTileMul | |||||
const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; | |||||
// SSDSumMulRealDivMean | |||||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; | |||||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; | |||||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; | |||||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; | |||||
// ConcatFive2Four | |||||
// ConcatFour2Five | |||||
const std::string SSD_BOX_TYPE_NUM = "box_type_num"; | |||||
const std::string SSD_CLASS_NUM = "class_num"; | |||||
const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; | |||||
const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; | |||||
const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; | |||||
const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; | |||||
// Scale | // Scale | ||||
const std::string SCALE_ATTR_SCALE = "scale"; | const std::string SCALE_ATTR_SCALE = "scale"; | ||||
@@ -346,6 +375,7 @@ const std::string SOFTMAX_ATTR_AXIS = "axis"; | |||||
// Permute | // Permute | ||||
const std::string PERMUTE_ATTR_ORDER = "order"; | const std::string PERMUTE_ATTR_ORDER = "order"; | ||||
const std::string PERMUTE_ATTR_PERM = "perm"; | |||||
// SSD Normalize | // SSD Normalize | ||||
const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; | const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; | ||||
@@ -373,6 +403,10 @@ const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; | |||||
const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; | const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; | ||||
const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | ||||
// RefinedetDetectionOutput | |||||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||||
// PRelu | // PRelu | ||||
const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; | const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; | ||||
@@ -386,11 +420,16 @@ const std::string POWER_ATTR_NAME_POWER = "power"; | |||||
const std::string POWER_ATTR_NAME_SCALE = "scale"; | const std::string POWER_ATTR_NAME_SCALE = "scale"; | ||||
const std::string POWER_ATTR_NAME_SHIFT = "shift"; | const std::string POWER_ATTR_NAME_SHIFT = "shift"; | ||||
// log | |||||
const std::string LOG_ATTR_NAME_SCALE = "scale"; | |||||
const std::string LOG_ATTR_NAME_SHIFT = "shift"; | |||||
const std::string LOG_ATTR_NAME_BASE = "base"; | |||||
// Pack | // Pack | ||||
const std::string PACK_ATTR_NAME_NUM = "N"; | const std::string PACK_ATTR_NAME_NUM = "N"; | ||||
// Unpack | // Unpack | ||||
const std::string UNPACK_ATTR_NAME_NUM = "num"; | const std::string UNPACK_ATTR_NAME_NUM = "num"; | ||||
const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; | |||||
// Gathernd | // Gathernd | ||||
const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; | const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; | ||||
const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; | const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; | ||||
@@ -400,6 +439,13 @@ const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; | |||||
const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; | const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; | ||||
const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; | const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; | ||||
const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; | const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; | ||||
const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; | |||||
const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; | |||||
const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; | |||||
// upsample | |||||
const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; | |||||
const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; | |||||
// Relu | // Relu | ||||
const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; | const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; | ||||
@@ -416,6 +462,7 @@ const std::string SPLIT_ATTR_NAME_NUM_SPLIT = "num_split"; | |||||
const std::string TVM_ATTR_NAME_MAGIC = "tvm_magic"; | const std::string TVM_ATTR_NAME_MAGIC = "tvm_magic"; | ||||
const std::string TVM_ATTR_NAME_BLOCKDIM = "tvm_blockdim"; | const std::string TVM_ATTR_NAME_BLOCKDIM = "tvm_blockdim"; | ||||
const std::string TVM_ATTR_NAME_METADATA = "tvm_metadata"; | const std::string TVM_ATTR_NAME_METADATA = "tvm_metadata"; | ||||
const std::string TVM_ATTR_NAME_WORKSPACE_TYPE = "tvm_workspace_type"; | |||||
// Squeeze | // Squeeze | ||||
const std::string SQUEEZE_ATTR_AXIS = "axis"; | const std::string SQUEEZE_ATTR_AXIS = "axis"; | ||||
@@ -438,6 +485,7 @@ const std::string ROIALIGN_ATTR_SPATIAL_SCALE = "spatial_scale"; | |||||
const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; | const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; | ||||
const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; | const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; | ||||
const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; | const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; | ||||
const std::string ROIALIGN_ATTR_NAME_TF = "roialign_tf"; | |||||
// Generate_rpn_proposal | // Generate_rpn_proposal | ||||
const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; | const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; | ||||
@@ -536,19 +584,42 @@ const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape | |||||
const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; | const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; | ||||
// Rnn | // Rnn | ||||
const std::string RNN_MODE_ = "rnn_"; | |||||
const std::string CNN_RNN = "cnn_rnn"; | |||||
const std::string RNN_TENSORFLOW = "rnn_tensorflow"; | |||||
const std::string RNN_MODE_STATIC = "rnn_static"; | |||||
const std::string MUTI_RNN = "multi_rnn"; | const std::string MUTI_RNN = "multi_rnn"; | ||||
const std::string CNN_RNN = "cnn_rnn"; | |||||
const std::string RNN_MODE_ = "rnn_"; | |||||
const std::string CELL_MODE = "mode"; | const std::string CELL_MODE = "mode"; | ||||
const std::string LSTM_CELL = "lstm_cell"; | const std::string LSTM_CELL = "lstm_cell"; | ||||
const std::string GRU_CELL = "gru_cell"; | const std::string GRU_CELL = "gru_cell"; | ||||
const std::string RNN_HT = "ht"; | const std::string RNN_HT = "ht"; | ||||
const std::string RNN_XT_HT = "xt_ht"; | const std::string RNN_XT_HT = "xt_ht"; | ||||
const std::string RNN_BATCH_SIZE = "batch_size"; | const std::string RNN_BATCH_SIZE = "batch_size"; | ||||
const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; | |||||
const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; | |||||
const std::string LSTM_ACTIVATE = "lstm_activate"; | |||||
const std::string LSTM_OUT_MAP = "lstm_out_map"; | |||||
const std::string LSTM_OUT_MODE = "lstm_out_mode"; | |||||
const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; | |||||
const std::string LSTM_TIME_MAJOR = "lstm_time_major"; | |||||
const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; | |||||
// Upsample | // Upsample | ||||
const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; | const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; | ||||
// PadV2 | |||||
const std::string PADV2_ATTR_NAME_MODE = "mode"; | |||||
const std::string PADV2_ATTR_NAME_PADS = "paddings"; | |||||
const std::string PADV2_ATTR_NAME_T = "T"; | |||||
const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||||
const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; | |||||
// MirrorPad | |||||
const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; | |||||
const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; | |||||
const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||||
const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; | |||||
// Filler | // Filler | ||||
const std::string FILLER_TYPE = "filler_type"; | const std::string FILLER_TYPE = "filler_type"; | ||||
const std::string FILLER_VALUE = "filler_value"; | const std::string FILLER_VALUE = "filler_value"; | ||||
@@ -559,9 +630,6 @@ const std::string SHUFFLE_CHANNEL_GROUP = "group"; | |||||
// TopKV2 | // TopKV2 | ||||
const std::string TOPKV2_ATTR_K = "k"; | const std::string TOPKV2_ATTR_K = "k"; | ||||
const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||||
const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||||
// Calibaration | // Calibaration | ||||
const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; | const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; | ||||
const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; | const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; | ||||
@@ -616,6 +684,8 @@ const std::string ATTR_MODEL_STREAM_NUM = "stream_num"; | |||||
const std::string ATTR_MODEL_EVENT_NUM = "event_num"; | const std::string ATTR_MODEL_EVENT_NUM = "event_num"; | ||||
const std::string ATTR_MODEL_LABEL_NUM = "label_num"; | |||||
const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; | const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; | ||||
const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; | const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; | ||||
@@ -630,6 +700,8 @@ const std::string ATTR_MODEL_VAR_SIZE = "variable_size"; | |||||
const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; | const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; | ||||
const std::string ATTR_MODEL_CORE_TYPE = "core_type"; | |||||
// Public attribute | // Public attribute | ||||
const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; | const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; | ||||
@@ -661,17 +733,145 @@ const std::string TARGET_TYPE_TINY = "TINY"; | |||||
const std::string TARGET_TYPE_LITE = "LITE"; | const std::string TARGET_TYPE_LITE = "LITE"; | ||||
// l2_normalize | |||||
const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; | |||||
const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||||
const std::string POOL_PARAMA_ATTR_WINDOW = "window"; | |||||
const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; | |||||
const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; | |||||
const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; | |||||
const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; | |||||
const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; | |||||
// HCOM | |||||
const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; | |||||
const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; | |||||
const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; | |||||
const std::string HCOM_ATTR_GROUP = "group"; | |||||
const std::string HCOM_ATTR_SR_TAG = "sr_tag"; | |||||
const std::string HCOM_ATTR_SRC_RANK = "src_rank"; | |||||
const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; | |||||
const std::string HCOM_ATTR_FUSION = "fusion"; | |||||
const std::string HCOM_ATTR_SHAPE = "shape"; | |||||
const std::string HCOM_ATTR_DATA_TYPE = "dtype"; | |||||
// SpaceToDepth/DepthToSpace | |||||
const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; | |||||
// SparseSoftmaxCrossEntropyWithLogits | |||||
const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; | |||||
// MaxPoolGradWithArgmax | |||||
const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; | |||||
// AvgPoolGrad | |||||
const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; | |||||
// Pad | |||||
const std::string ATTR_PAD_FORMAT = "attr_pad_format"; | |||||
// Varible | |||||
const std::string VAR_ATTR_FORMAT = "_var_format"; | |||||
const std::string VAR_ATTR_NAME = "var_name"; | |||||
const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; | |||||
const std::string VAR_ATTR_4D_FORMAT = "4D"; | |||||
const std::string VAR_ATTR_5D_FORMAT = "5D"; | |||||
const std::string VAR_ATTR_DATA_TYPE = "data_format"; | |||||
const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; | |||||
const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; | |||||
const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; | |||||
const std::string VAR_ATTR_SHAPE = "shape"; | |||||
const std::string HALF_VAR_NAME_END = "_fp16"; | |||||
const std::string VAR_ATTR_INITED = "var_is_inited"; | |||||
const std::string VAR_ATTR_CONTAINER = "container"; | |||||
const std::string VAR_ATTR_SHARED_NAME = "shared_name"; | |||||
const std::string VAR_ATTR_DTYPE = "dtype"; | |||||
const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; | |||||
const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; | |||||
const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; | |||||
const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; | |||||
const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; | |||||
const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; | |||||
// Assign | |||||
const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; | |||||
// space2bacth batch2space | |||||
const std::string BATCH_SPACE_ATTR_BLOCK = "block"; | |||||
const std::string BATCH_SPACE_ATTR_PADDING = "padding"; | |||||
// depth_to_space space_to_depth | |||||
const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||||
// FakeQuantWithMinMaxVars | |||||
const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; | |||||
const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; | |||||
// mobilenet_ssd_conv_fusion | |||||
const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; | |||||
const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; | |||||
const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; | |||||
// lsh project | |||||
const std::string LSH_PROJ_TYPE = "lsh_project_type"; | |||||
// log time stamp | |||||
const std::string LOG_TIME_STAMP_LOGID = "logid"; | |||||
const std::string LOG_TIME_STAMP_NOTIFY = "notify"; | |||||
// ShapeN | |||||
const std::string SHAPEN_ATTR_N = "N"; | |||||
const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; | |||||
const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; | |||||
// GatherV2 attr def | |||||
const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; | |||||
const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; | |||||
const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; | |||||
// Reshape attr def | |||||
const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; | |||||
const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; | |||||
// axis attr def | |||||
const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; | |||||
const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; | |||||
const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; | |||||
const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; | |||||
// For constant folding | |||||
const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; | |||||
const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; | const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; | ||||
const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; | const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; | ||||
const std::string ATTR_NAME_REFERENCE = "reference"; | const std::string ATTR_NAME_REFERENCE = "reference"; | ||||
const std::string ATTR_NAME_NOTASK = "_no_task"; | |||||
const std::string ATTR_NAME_OUTPUT_REUSE_INPUT = "_output_reuse_input"; | |||||
const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX = "_reuse_input_on_dim_index"; | |||||
const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT = "_no_padding_continuous_input"; | |||||
const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT = "_no_padding_continuous_output"; | |||||
const std::string ATTR_NAME_ATOMIC_INDEX = "atomic_index"; | const std::string ATTR_NAME_ATOMIC_INDEX = "atomic_index"; | ||||
// Used for mark the active label list stream of activated node | // Used for mark the active label list stream of activated node | ||||
const std::string ATTR_NAME_ACTIVE_LABEL_LIST = "_active_label_list"; | const std::string ATTR_NAME_ACTIVE_LABEL_LIST = "_active_label_list"; | ||||
// Used for l2cache, true: the memory of all inputs is used for the last time. | |||||
const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE = "is_end_of_inputmem_lifecycle"; | |||||
// Multi batch | // Multi batch | ||||
const std::string ATTR_NAME_PRED_VALUE = "_pred_value"; | const std::string ATTR_NAME_PRED_VALUE = "_pred_value"; | ||||
const std::string ATTR_NAME_BATCH_NUM = "_batch_num"; | const std::string ATTR_NAME_BATCH_NUM = "_batch_num"; | ||||
@@ -682,6 +882,8 @@ const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; | |||||
const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; | const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; | ||||
const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; | const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; | ||||
const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | ||||
const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; | |||||
const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; | |||||
const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | ||||
const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | ||||
@@ -691,6 +893,9 @@ const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG = "_cyclic_dependence_flag"; | |||||
const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; | const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; | ||||
// Function Op | |||||
const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; | |||||
// Used for mark the active node is for loop, type:bool | // Used for mark the active node is for loop, type:bool | ||||
const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; | const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; | ||||
@@ -702,6 +907,20 @@ const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; | |||||
const std::string MODEL_ATTR_SESSION_ID = "session_id"; | const std::string MODEL_ATTR_SESSION_ID = "session_id"; | ||||
// l1 fusion and other fusion in future | |||||
const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; | |||||
const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; | |||||
const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; | |||||
const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; | |||||
const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST = "_input_memory_type"; | |||||
const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST = "_output_memory_type"; | |||||
const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR = "_l1_fusion_extend_content"; | |||||
const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE = "_tensor_actual_size"; | |||||
const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1_fuison"; | |||||
const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; | |||||
const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; | |||||
const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; | |||||
// Atomic addr clean attrs | // Atomic addr clean attrs | ||||
const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; | const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; | ||||
const std::string ATOMIC_ATTR_OUTPUT_INDEX = "atomic_output_index"; | const std::string ATOMIC_ATTR_OUTPUT_INDEX = "atomic_output_index"; | ||||
@@ -722,6 +941,9 @@ const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; | |||||
// For inserted op | // For inserted op | ||||
const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; | const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; | ||||
// For compress weight | |||||
const std::string ATTR_NAME_COMPRESS_WEIGHT = "_is_compress_weight"; | |||||
// For data dump | // For data dump | ||||
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES = "_datadump_original_op_names"; | const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES = "_datadump_original_op_names"; | ||||
const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP = "_datadump_is_multiop"; | const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP = "_datadump_is_multiop"; | ||||
@@ -732,24 +954,17 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_ou | |||||
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | ||||
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | ||||
// Variable | |||||
const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; | |||||
const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; | |||||
const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; | |||||
const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; | |||||
const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; | |||||
// HCOM | |||||
const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; | |||||
const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; | |||||
const std::string HCOM_ATTR_SHAPE = "shape"; | |||||
const std::string HCOM_ATTR_DATA_TYPE = "dtype"; | |||||
// functional ops attr | |||||
const std::string ATTR_NAME_TCOND = "Tcond"; | |||||
const std::string ATTR_NAME_TIN = "Tin"; | |||||
const std::string ATTR_NAME_TOUT = "Tout"; | |||||
const std::string ATTR_NAME_THEN_BRANCH = "then_branch"; | |||||
const std::string ATTR_NAME_ELSE_BRANCH = "else_branch"; | |||||
const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; | |||||
// used for label switch | |||||
const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | |||||
const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; | |||||
const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | ||||
const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | ||||
// Dynamic stitch | |||||
const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; | |||||
} // namespace ge | } // namespace ge |
@@ -22,7 +22,7 @@ | |||||
#include "graph/model_serialize.h" | #include "graph/model_serialize.h" | ||||
#include "proto/ge_ir.pb.h" | #include "proto/ge_ir.pb.h" | ||||
#include "detail/model_serialize_imp.h" | #include "detail/model_serialize_imp.h" | ||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "debug/ge_attr_define.h" | |||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
@@ -53,7 +53,7 @@ string GeAttrValue::NamedAttrs::GetName() const { | |||||
GeAttrValue GeAttrValue::NamedAttrs::GetItem(const string &key) const { | GeAttrValue GeAttrValue::NamedAttrs::GetItem(const string &key) const { | ||||
GeAttrValue value; | GeAttrValue value; | ||||
(void)GetAttr(key, value); | |||||
GetAttr(key, value); | |||||
return value; | return value; | ||||
} | } | ||||
@@ -1081,6 +1081,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstA | |||||
if (!GetListInt(std::move(obj), name, int64_list)) { | if (!GetListInt(std::move(obj), name, int64_list)) { | ||||
return false; | return false; | ||||
} | } | ||||
for (size_t i = 0; i < int64_list.size(); ++i) { | for (size_t i = 0; i < int64_list.size(); ++i) { | ||||
if (int64_list[i] > INT32_MAX) { | if (int64_list[i] > INT32_MAX) { | ||||
GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]); | GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]); | ||||
@@ -1098,6 +1099,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstA | |||||
if (!GetListInt(std::move(obj), name, int64_list)) { | if (!GetListInt(std::move(obj), name, int64_list)) { | ||||
return false; | return false; | ||||
} | } | ||||
for (size_t i = 0; i < int64_list.size(); ++i) { | for (size_t i = 0; i < int64_list.size(); ++i) { | ||||
if (int64_list[i] > UINT32_MAX) { | if (int64_list[i] > UINT32_MAX) { | ||||
GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]); | GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]); | ||||
@@ -1215,6 +1217,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc( | |||||
GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); | GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); | ||||
op_desc->extAttrs_ = org_op_desc->extAttrs_; | op_desc->extAttrs_ = org_op_desc->extAttrs_; | ||||
if (op_desc->HasAttr("_input_name_idx_key")) { | |||||
if (op_desc->DelAttr("_input_name_idx_key") != SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_key failed."); | |||||
} | |||||
} | |||||
if (op_desc->HasAttr("_input_name_idx_value")) { | |||||
if (op_desc->DelAttr("_input_name_idx_value") != SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_value failed."); | |||||
} | |||||
} | |||||
if (op_desc->HasAttr("_opt_input")) { | |||||
if (op_desc->DelAttr("_opt_input") != SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "DelAttr _opt_input failed."); | |||||
} | |||||
} | |||||
return op_desc; | return op_desc; | ||||
} | } | ||||
@@ -1237,11 +1256,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c | |||||
op_desc->extAttrs_ = org_op_desc->extAttrs_; | op_desc->extAttrs_ = org_op_desc->extAttrs_; | ||||
op_desc->input_name_idx_.insert(org_op_desc->input_name_idx_.begin(), org_op_desc->input_name_idx_.end()); | |||||
op_desc->optional_input_names_.insert(org_op_desc->optional_input_names_.begin(), | |||||
org_op_desc->optional_input_names_.end()); | |||||
op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); | |||||
op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); | op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); | ||||
op_desc->infer_func_ = org_op_desc->infer_func_; | op_desc->infer_func_ = org_op_desc->infer_func_; | ||||
@@ -1250,4 +1264,25 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c | |||||
return op_desc; | return op_desc; | ||||
} | } | ||||
std::string AttrUtils::GetAllAttrsStr(AttrUtils::ConstAttrHolderAdapter &&obj) { | |||||
auto holder = obj.get(); | |||||
if (holder == nullptr) { | |||||
return ""; | |||||
} | |||||
auto attrs_map = holder->GetAttrMap(); | |||||
if (attrs_map.GetProtoMsg() == nullptr) { | |||||
return ""; | |||||
} | |||||
std::map<std::string, std::string> ordered_attrs; | |||||
for (auto &attr : *(attrs_map.GetProtoMsg())) { | |||||
ordered_attrs[attr.first] = attr.second.SerializeAsString(); | |||||
} | |||||
std::stringstream ss; | |||||
for (auto &attr : ordered_attrs) { | |||||
ss << attr.first << ":" << attr.second << ";"; | |||||
} | |||||
return ss.str(); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -163,6 +163,34 @@ int64_t GeShape::GetShapeSize() const { | |||||
return res; | return res; | ||||
} | } | ||||
/// | |||||
/// @brief Check is unknown shape | |||||
/// @return bool | |||||
/// /// | |||||
bool GeShape::IsUnknownShape() const { | |||||
auto proto_msg = shape_def_.GetProtoMsg(); | |||||
if (proto_msg != nullptr) { | |||||
for (auto i : proto_msg->dim()) { | |||||
if (i < 0) { | |||||
return true; | |||||
} | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
/// | |||||
/// @brief Check is a scalar | |||||
/// @return bool | |||||
/// | |||||
bool GeShape::IsScalar() const { | |||||
auto proto_msg = shape_def_.GetProtoMsg(); | |||||
if (proto_msg != nullptr) { | |||||
return proto_msg->dim().empty(); | |||||
} | |||||
return false; | |||||
} | |||||
const string TENSOR_UTILS_SIZE = "size"; | const string TENSOR_UTILS_SIZE = "size"; | ||||
const string TENSOR_UTILS_WEIGHT_SIZE = "weight_size"; | const string TENSOR_UTILS_WEIGHT_SIZE = "weight_size"; | ||||
const string TENSOR_UTILS_REUSE_INPUT = "reuse_input"; | const string TENSOR_UTILS_REUSE_INPUT = "reuse_input"; | ||||
@@ -639,14 +667,14 @@ GeTensor &GeTensor::operator=(const GeTensor &other) { | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc, | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc, | ||||
uint32_t &size) { | |||||
int64_t &size) { | |||||
auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); | auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); | ||||
GE_CHECK_NOTNULL(tensor_descriptor_msg); | GE_CHECK_NOTNULL(tensor_descriptor_msg); | ||||
size = static_cast<uint32_t>(tensor_descriptor_msg->size()); | |||||
size = static_cast<int64_t>(tensor_descriptor_msg->size()); | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, uint32_t size) { | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, int64_t size) { | |||||
auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); | auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); | ||||
if (tensor_descriptor_msg != nullptr) { | if (tensor_descriptor_msg != nullptr) { | ||||
tensor_descriptor_msg->set_size(size); | tensor_descriptor_msg->set_size(size); | ||||
@@ -49,6 +49,7 @@ void Model::Init() { | |||||
(void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0); | (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0); | ||||
(void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0); | (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0); | ||||
(void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0); | (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0); | ||||
(void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0); | |||||
(void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0); | (void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0); | ||||
(void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI); | (void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI); | ||||
version_ = 0; | version_ = 0; | ||||
@@ -77,9 +78,9 @@ void Model::SetGraph(const ge::Graph &graph) { graph_ = graph; } | |||||
Graph Model::GetGraph() const { return graph_; } | Graph Model::GetGraph() const { return graph_; } | ||||
graphStatus Model::Save(Buffer &buffer) const { | |||||
graphStatus Model::Save(Buffer &buffer, bool is_dump) const { | |||||
ModelSerialize serialize; | ModelSerialize serialize; | ||||
buffer = serialize.SerializeModel(*this); | |||||
buffer = serialize.SerializeModel(*this, is_dump); | |||||
return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED; | return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED; | ||||
} | } | ||||
@@ -113,7 +114,7 @@ graphStatus Model::SaveToFile(const string &file_name) const { | |||||
} | } | ||||
int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS); | int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS); | ||||
if (fd < 0) { | if (fd < 0) { | ||||
GELOGE(GRAPH_FAILED, "open file failed, file path [%s] ", real_path); | |||||
GELOGE(GRAPH_FAILED, "open file failed, file path [%s], %s ", real_path, strerror(errno)); | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
bool ret = ge_proto.SerializeToFileDescriptor(fd); | bool ret = ge_proto.SerializeToFileDescriptor(fd); | ||||
@@ -129,6 +130,10 @@ graphStatus Model::SaveToFile(const string &file_name) const { | |||||
GELOGE(GRAPH_FAILED, "close file descriptor fail."); | GELOGE(GRAPH_FAILED, "close file descriptor fail."); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
if (!ret) { | |||||
GELOGE(GRAPH_FAILED, "function [SerializeToFileDescriptor] failed"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | } | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
@@ -152,7 +157,7 @@ graphStatus Model::LoadFromFile(const string &file_name) { | |||||
} | } | ||||
int fd = open(real_path, O_RDONLY); | int fd = open(real_path, O_RDONLY); | ||||
if (fd < 0) { | if (fd < 0) { | ||||
GELOGE(GRAPH_FAILED, "open file failed"); | |||||
GELOGE(GRAPH_FAILED, "open file failed, %s", strerror(errno)); | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
@@ -170,6 +175,10 @@ graphStatus Model::LoadFromFile(const string &file_name) { | |||||
GELOGE(GRAPH_FAILED, "close file descriptor fail."); | GELOGE(GRAPH_FAILED, "close file descriptor fail."); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
if (!ret) { | |||||
GELOGE(GRAPH_FAILED, "function [ParseFromFileDescriptor] failed"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
return Load(model_def); | return Load(model_def); | ||||
} | } | ||||
@@ -15,10 +15,8 @@ | |||||
*/ | */ | ||||
#include "graph/model_serialize.h" | #include "graph/model_serialize.h" | ||||
#include <google/protobuf/text_format.h> | #include <google/protobuf/text_format.h> | ||||
#include <iostream> | #include <iostream> | ||||
#include "debug/ge_attr_define.h" | #include "debug/ge_attr_define.h" | ||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
@@ -26,6 +24,7 @@ | |||||
#include "graph/detail/model_serialize_imp.h" | #include "graph/detail/model_serialize_imp.h" | ||||
#include "proto/ge_ir.pb.h" | #include "proto/ge_ir.pb.h" | ||||
#include "utils/graph_utils.h" | #include "utils/graph_utils.h" | ||||
#include "debug/ge_op_types.h" | |||||
using std::string; | using std::string; | ||||
@@ -84,20 +83,29 @@ bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_ | |||||
return true; | return true; | ||||
} | } | ||||
bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) { | |||||
bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { | |||||
if (op_desc == nullptr || op_def_proto == nullptr) { | if (op_desc == nullptr || op_def_proto == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "Input Para Invalid"); | GELOGE(GRAPH_FAILED, "Input Para Invalid"); | ||||
return false; | return false; | ||||
} | } | ||||
if (op_desc->op_def_.GetProtoMsg() != nullptr) { | if (op_desc->op_def_.GetProtoMsg() != nullptr) { | ||||
*op_def_proto = *op_desc->op_def_.GetProtoMsg(); | *op_def_proto = *op_desc->op_def_.GetProtoMsg(); | ||||
// Delete unnecessary attr | |||||
if (is_dump) { | |||||
auto attr = op_def_proto->mutable_attr(); | |||||
attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF); | |||||
attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF); | |||||
attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF); | |||||
GE_IF_BOOL_EXEC((op_def_proto->type() == CONSTANT || op_def_proto->type() == CONSTANTOP), | |||||
attr->erase(ATTR_NAME_WEIGHTS)); | |||||
} | |||||
op_def_proto->clear_input_desc(); | op_def_proto->clear_input_desc(); | ||||
op_def_proto->clear_output_desc(); | op_def_proto->clear_output_desc(); | ||||
// Input descs | // Input descs | ||||
if (op_desc->GetInputsSize() > 0) { | |||||
auto size = static_cast<uint32_t>(op_desc->GetInputsSize()); | |||||
if (op_desc->GetAllInputsSize() > 0) { | |||||
auto size = static_cast<uint32_t>(op_desc->GetAllInputsSize()); | |||||
for (uint32_t i = 0; i < size; i++) { | for (uint32_t i = 0; i < size; i++) { | ||||
auto tensor_desc = op_desc->GetInputDescPtr(i); | |||||
auto tensor_desc = op_desc->GetInputDescPtrDfault(i); | |||||
if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { | if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { | ||||
*op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); | *op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); | ||||
} | } | ||||
@@ -117,12 +125,12 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op | |||||
return true; | return true; | ||||
} | } | ||||
bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto) { | |||||
bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) { | |||||
if (node == nullptr || op_def_proto == nullptr) { | if (node == nullptr || op_def_proto == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "Input Para Node Invalid"); | GELOGE(GRAPH_FAILED, "Input Para Node Invalid"); | ||||
return false; | return false; | ||||
} | } | ||||
if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto)) { | |||||
if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, is_dump)) { | |||||
GELOGE(GRAPH_FAILED, "Serialize OpDesc failed"); | GELOGE(GRAPH_FAILED, "Serialize OpDesc failed"); | ||||
return false; | return false; | ||||
} | } | ||||
@@ -134,7 +142,8 @@ bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_ | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph, | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph, | ||||
proto::GraphDef *graph_proto) { | |||||
proto::GraphDef *graph_proto, | |||||
bool is_dump) { | |||||
if (graph == nullptr || graph_proto == nullptr) { | if (graph == nullptr || graph_proto == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "Input para Invalid"); | GELOGE(GRAPH_FAILED, "Input para Invalid"); | ||||
return false; | return false; | ||||
@@ -156,7 +165,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::Serialize | |||||
*graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg(); | *graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg(); | ||||
} | } | ||||
for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
if (!SerializeNode(node, graph_proto->add_op())) { | |||||
if (!SerializeNode(node, graph_proto->add_op(), is_dump)) { | |||||
if (node->GetOpDesc() != nullptr) { | if (node->GetOpDesc() != nullptr) { | ||||
GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str()); | GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str()); | ||||
} | } | ||||
@@ -166,7 +175,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::Serialize | |||||
return true; | return true; | ||||
} | } | ||||
bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto) { | |||||
bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto, bool is_dump) { | |||||
if (model_proto == nullptr) { | if (model_proto == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "model_proto para Invalid"); | GELOGE(GRAPH_FAILED, "model_proto para Invalid"); | ||||
return false; | return false; | ||||
@@ -183,7 +192,7 @@ bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *mode | |||||
GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr"); | GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr"); | ||||
return false; | return false; | ||||
} | } | ||||
if (!SerializeGraph(compute_graph, model_proto->add_graph())) { | |||||
if (!SerializeGraph(compute_graph, model_proto->add_graph(), is_dump)) { | |||||
GELOGE(GRAPH_FAILED, "SerializeGraph fail"); | GELOGE(GRAPH_FAILED, "SerializeGraph fail"); | ||||
return false; | return false; | ||||
} | } | ||||
@@ -390,10 +399,10 @@ bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf:: | |||||
return true; | return true; | ||||
} | } | ||||
Buffer ModelSerialize::SerializeModel(const Model &model) { | |||||
Buffer ModelSerialize::SerializeModel(const Model &model, bool is_dump) { | |||||
proto::ModelDef model_def; | proto::ModelDef model_def; | ||||
ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
if (!imp.SerializeModel(model, &model_def)) { | |||||
if (!imp.SerializeModel(model, &model_def, is_dump)) { | |||||
return Buffer(); | return Buffer(); | ||||
} | } | ||||
#if !defined(__ANDROID__) && !defined(ANDROID) | #if !defined(__ANDROID__) && !defined(ANDROID) | ||||
@@ -401,7 +401,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::Get | |||||
vec.push_back(in_anchor); | vec.push_back(in_anchor); | ||||
} | } | ||||
} | } | ||||
// Push back in_control_anchor_ | |||||
// Push back in_control_anchor_ | |||||
if ((in_control_anchor_->GetPeerOutControlAnchors().size() > 0) || | if ((in_control_anchor_->GetPeerOutControlAnchors().size() > 0) || | ||||
(in_control_anchor_->GetPeerOutDataAnchors().size() > 0)) { | (in_control_anchor_->GetPeerOutDataAnchors().size() > 0)) { | ||||
auto in_anchor = Anchor::DynamicAnchorCast<Anchor>(in_control_anchor_); | auto in_anchor = Anchor::DynamicAnchorCast<Anchor>(in_control_anchor_); | ||||
@@ -512,7 +512,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetIn | |||||
auto peer_out_anchors = in_control_anchor_->GetPeerOutDataAnchors(); | auto peer_out_anchors = in_control_anchor_->GetPeerOutDataAnchors(); | ||||
for (const auto &out_anchor : peer_out_anchors) { | for (const auto &out_anchor : peer_out_anchors) { | ||||
GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, " in_control_anchor_ peer out data anchors is nullptr"); | |||||
GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "in_control_anchor_ peer out data anchors is nullptr"); | |||||
auto node = out_anchor->GetOwnerNode(); | auto node = out_anchor->GetOwnerNode(); | ||||
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | ||||
vec.push_back(node); | vec.push_back(node); | ||||
@@ -521,7 +521,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetIn | |||||
auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); | auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); | ||||
for (const auto &out_control_anchor : peer_out_control_anchors) { | for (const auto &out_control_anchor : peer_out_control_anchors) { | ||||
GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, | GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, | ||||
" in_control_anchor_ peer out control anchors is nullptr"); | |||||
"in_control_anchor_ peer out control anchors is nullptr"); | |||||
auto node = out_control_anchor->GetOwnerNode(); | auto node = out_control_anchor->GetOwnerNode(); | ||||
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | ||||
vec.push_back(node); | vec.push_back(node); | ||||
@@ -785,6 +785,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::UpdateOpDesc(co | |||||
GE_CHK_BOOL_EXEC(op_->GetInputsSize() == op_desc->GetInputsSize(), return GRAPH_PARAM_INVALID, | GE_CHK_BOOL_EXEC(op_->GetInputsSize() == op_desc->GetInputsSize(), return GRAPH_PARAM_INVALID, | ||||
"Inputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetInputsSize(), | "Inputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetInputsSize(), | ||||
op_desc->GetInputsSize()); | op_desc->GetInputsSize()); | ||||
GE_CHK_BOOL_EXEC(op_->GetOutputsSize() == op_desc->GetOutputsSize(), return GRAPH_PARAM_INVALID, | GE_CHK_BOOL_EXEC(op_->GetOutputsSize() == op_desc->GetOutputsSize(), return GRAPH_PARAM_INVALID, | ||||
"Outputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetOutputsSize(), | "Outputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetOutputsSize(), | ||||
op_desc->GetOutputsSize()); | op_desc->GetOutputsSize()); | ||||
@@ -61,6 +61,12 @@ const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes"; | |||||
const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; | const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; | ||||
const std::string ATTR_NAME_OPT_INPUT = "_opt_input"; | |||||
const std::string ATTR_NAME_INPUT_NAME_IDX_KEY = "_input_name_idx_key"; | |||||
const std::string ATTR_NAME_INPUT_NAME_IDX_VALUE = "_input_name_idx_value"; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() { | ||||
op_def_.InitDefault(); | op_def_.InitDefault(); | ||||
if (op_def_.GetProtoMsg() != nullptr) { | if (op_def_.GetProtoMsg() != nullptr) { | ||||
@@ -202,7 +208,8 @@ graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_d | |||||
} | } | ||||
graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { | graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { | ||||
if (input_name_idx_.find(name) != input_name_idx_.end()) { | |||||
auto input_name_idx = GetAllInputName(); | |||||
if (input_name_idx.find(name) != input_name_idx.end()) { | |||||
GELOGI("input %s is exist, update it", name.c_str()); | GELOGI("input %s is exist, update it", name.c_str()); | ||||
graphStatus ret = UpdateInputDesc(name, input_desc); | graphStatus ret = UpdateInputDesc(name, input_desc); | ||||
return ret; | return ret; | ||||
@@ -214,15 +221,17 @@ graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &inp | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
inputs_desc_.push_back(in_desc); | inputs_desc_.push_back(in_desc); | ||||
(void)input_name_idx_.insert(make_pair(name, index)); | |||||
(void)input_name_idx.insert(make_pair(name, index)); | |||||
SetAllInputName(input_name_idx); | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
} | } | ||||
graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { | graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { | ||||
auto input_name_idx = GetAllInputName(); | |||||
for (unsigned int i = 0; i < num; i++) { | for (unsigned int i = 0; i < num; i++) { | ||||
string input_name = name + std::to_string(i); | string input_name = name + std::to_string(i); | ||||
GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, | |||||
GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED, | |||||
"Add input tensor_desc is existed. name[%s]", input_name.c_str()); | "Add input tensor_desc is existed. name[%s]", input_name.c_str()); | ||||
std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc()); | std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc()); | ||||
@@ -234,12 +243,13 @@ graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int n | |||||
(void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); | (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); | ||||
// Update index in input_name_idx | // Update index in input_name_idx | ||||
for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { | |||||
for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) { | |||||
it->second += 1; | it->second += 1; | ||||
} | } | ||||
(void)input_name_idx_.insert(make_pair(input_name, 0)); | |||||
(void)input_name_idx.insert(make_pair(input_name, 0)); | |||||
} | } | ||||
SetAllInputName(input_name_idx); | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
@@ -270,10 +280,19 @@ graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int | |||||
graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { | graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { | ||||
if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED; | if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED; | ||||
(void)optional_input_names_.insert(name); | |||||
vector<string> optional_input_names; | |||||
(void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||||
optional_input_names.push_back(name); | |||||
(void)AttrUtils::SetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
std::vector<string> OpDesc::GetAllOptionalInputName() const { | |||||
vector<string> optional_input_names; | |||||
(void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||||
return optional_input_names; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | ||||
OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { | OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { | ||||
GE_CHK_BOOL_RET_STATUS((index < inputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); | GE_CHK_BOOL_RET_STATUS((index < inputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); | ||||
@@ -288,11 +307,12 @@ OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const { | ||||
return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") && | |||||
IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && | |||||
IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") && | |||||
IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && | |||||
IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); | |||||
return ( | |||||
IsEqual(this->GetAllInputName(), r_op_desc.GetAllInputName(), "OpDesc.GetAllInputName()") && | |||||
IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && | |||||
IsEqual(this->GetAllOptionalInputName(), r_op_desc.GetAllOptionalInputName(), "OpDesc.GetAllOptionalInputName()") && | |||||
IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && | |||||
IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const { | ||||
@@ -366,8 +386,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpD | |||||
} | } | ||||
graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { | graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { | ||||
auto it = input_name_idx_.find(name); | |||||
if (it == input_name_idx_.end()) { | |||||
auto input_name_idx = GetAllInputName(); | |||||
auto it = input_name_idx.find(name); | |||||
if (it == input_name_idx.end()) { | |||||
GELOGW("Cann't find the input desc. name[%s]", name.c_str()); | GELOGW("Cann't find the input desc. name[%s]", name.c_str()); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
@@ -387,8 +408,9 @@ graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc & | |||||
} | } | ||||
bool OpDesc::InputIsSet(const string &name) const { | bool OpDesc::InputIsSet(const string &name) const { | ||||
auto it = input_name_idx_.find(name); | |||||
if (it != input_name_idx_.end()) { | |||||
auto input_name_idx = GetAllInputName(); | |||||
auto it = input_name_idx.find(name); | |||||
if (it != input_name_idx.end()) { | |||||
GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false); | GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false); | ||||
auto tensor_desc = inputs_desc_[it->second]; | auto tensor_desc = inputs_desc_[it->second]; | ||||
GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false); | GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false); | ||||
@@ -406,18 +428,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc | |||||
} | } | ||||
GeTensorDesc OpDesc::GetInputDesc(const string &name) const { | GeTensorDesc OpDesc::GetInputDesc(const string &name) const { | ||||
auto it = input_name_idx_.find(name); | |||||
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc()); | |||||
auto input_name_idx = GetAllInputName(); | |||||
auto it = input_name_idx.find(name); | |||||
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), GeTensorDesc()); | |||||
GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc()); | GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc()); | ||||
return *(inputs_desc_[it->second].get()); | return *(inputs_desc_[it->second].get()); | ||||
} | } | ||||
GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const { | GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const { | ||||
auto input_name_idx = GetAllInputName(); | |||||
vector<string> names; | vector<string> names; | ||||
if (input_name_idx_.empty()) { | |||||
if (input_name_idx.empty()) { | |||||
return OpDesc::Vistor<string>(shared_from_this(), names); | return OpDesc::Vistor<string>(shared_from_this(), names); | ||||
} | } | ||||
for (std::pair<string, uint32_t> input : input_name_idx_) { | |||||
for (std::pair<string, uint32_t> input : input_name_idx) { | |||||
names.push_back(input.first); | names.push_back(input.first); | ||||
} | } | ||||
return OpDesc::Vistor<string>(shared_from_this(), names); | return OpDesc::Vistor<string>(shared_from_this(), names); | ||||
@@ -483,6 +507,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetInputsSize() co | |||||
return size; | return size; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetAllInputsSize() const { return inputs_desc_.size(); } | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) { | ||||
int index = static_cast<int>(outputs_desc_.size()); | int index = static_cast<int>(outputs_desc_.size()); | ||||
return AddOutputDesc("__output" + std::to_string(index), output_desc); | return AddOutputDesc("__output" + std::to_string(index), output_desc); | ||||
@@ -548,6 +574,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOu | |||||
return outputs_desc_[index]; | return outputs_desc_[index]; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { | |||||
return static_cast<uint32_t>(outputs_desc_.size()); | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllOutputsDesc() const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllOutputsDesc() const { | ||||
vector<GeTensorDesc> temp{}; | vector<GeTensorDesc> temp{}; | ||||
for (const auto &it : outputs_desc_) { | for (const auto &it : outputs_desc_) { | ||||
@@ -580,6 +610,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetI | |||||
} | } | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr | |||||
OpDesc::GetInputDescPtrDfault(uint32_t index) const { | |||||
GE_CHK_BOOL_RET_STATUS_NOLOG((index) < (uint32_t)(inputs_desc_.size()), nullptr); | |||||
return inputs_desc_[(int32_t)index]; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const { | |||||
auto input_name_idx = GetAllInputName(); | |||||
auto it = input_name_idx.find(name); | |||||
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), shared_ptr<const GeTensorDesc>()); | |||||
return inputs_desc_[it->second]; | |||||
} | |||||
graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) { | graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) { | ||||
if (is_push_back) { | if (is_push_back) { | ||||
for (unsigned int i = 0; i < num; i++) { | for (unsigned int i = 0; i < num; i++) { | ||||
@@ -603,12 +646,45 @@ graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int | |||||
} | } | ||||
bool OpDesc::IsOptionalInput(const string &name) const { | bool OpDesc::IsOptionalInput(const string &name) const { | ||||
return optional_input_names_.find(name) != optional_input_names_.end(); | |||||
vector<string> optional_input_names; | |||||
(void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||||
for (auto &item : optional_input_names) { | |||||
if (item == name) { | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | } | ||||
bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } | bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } | ||||
std::map<string, uint32_t> OpDesc::GetAllInputName() { return input_name_idx_; } | |||||
std::map<string, uint32_t> OpDesc::GetAllInputName() const { | |||||
std::map<string, uint32_t> input_name_idx; | |||||
std::vector<string> key; | |||||
std::vector<uint32_t> value; | |||||
(void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key); | |||||
(void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value); | |||||
if (key.size() != value.size()) { | |||||
GE_LOGE("twe vector size is different. key_size: %zu, value_size: %zu.", key.size(), value.size()); | |||||
} else { | |||||
for (uint32_t i = 0; i < key.size(); ++i) { | |||||
input_name_idx.insert(std::pair<string, uint32_t>(key.at(i), value.at(i))); | |||||
} | |||||
} | |||||
return input_name_idx; | |||||
} | |||||
void OpDesc::SetAllInputName(const std::map<string, uint32_t> &input_name_idx) { | |||||
std::vector<string> key; | |||||
std::vector<uint32_t> value; | |||||
for (auto &item : input_name_idx) { | |||||
key.emplace_back(item.first); | |||||
value.emplace_back(item.second); | |||||
} | |||||
(void)AttrUtils::SetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key); | |||||
(void)AttrUtils::SetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value); | |||||
} | |||||
std::map<string, uint32_t> OpDesc::GetAllOutputName() { return output_name_idx_; } | std::map<string, uint32_t> OpDesc::GetAllOutputName() { return output_name_idx_; } | ||||
@@ -619,6 +695,7 @@ bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) { | |||||
auto factory_map_size = input_name_idx.size(); | auto factory_map_size = input_name_idx.size(); | ||||
// It indicates that some inputs have no optionalname. | // It indicates that some inputs have no optionalname. | ||||
// The redundant optionalname of factory needs to be deleted and then assigned | // The redundant optionalname of factory needs to be deleted and then assigned | ||||
auto all_input_name_idx = GetAllInputName(); | |||||
if (input_map_size < factory_map_size) { | if (input_map_size < factory_map_size) { | ||||
GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size, | GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size, | ||||
factory_map_size); | factory_map_size); | ||||
@@ -631,22 +708,23 @@ bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) { | |||||
} | } | ||||
if (input_name_idx.size() == input_map_size) { | if (input_name_idx.size() == input_map_size) { | ||||
GELOGI("UpdateInputName"); | GELOGI("UpdateInputName"); | ||||
input_name_idx_ = input_name_idx; | |||||
all_input_name_idx = input_name_idx; | |||||
} else { | } else { | ||||
ret = false; | ret = false; | ||||
GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size()); | GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size()); | ||||
} | } | ||||
} else if (input_map_size == factory_map_size) { | } else if (input_map_size == factory_map_size) { | ||||
input_name_idx_ = input_name_idx; | |||||
all_input_name_idx = input_name_idx; | |||||
} else { | } else { | ||||
ret = false; | ret = false; | ||||
GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); | GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); | ||||
} | } | ||||
SetAllInputName(all_input_name_idx); | |||||
return ret; | return ret; | ||||
} | } | ||||
bool OpDesc::UpdateOutputName(std::map<string, uint32_t> output_name_idx) { | bool OpDesc::UpdateOutputName(std::map<string, uint32_t> output_name_idx) { | ||||
size_t output_map_size = GetAllOutputsDesc().size(); | |||||
size_t output_map_size = GetAllOutputsDescSize(); | |||||
size_t factory_map_size = output_name_idx.size(); | size_t factory_map_size = output_name_idx.size(); | ||||
if (output_map_size < factory_map_size) { | if (output_map_size < factory_map_size) { | ||||
GELOGI("UpdateOutputName org outputname map size: %zu, factory outputname map size: %zu", output_map_size, | GELOGI("UpdateOutputName org outputname map size: %zu, factory outputname map size: %zu", output_map_size, | ||||
@@ -754,17 +832,17 @@ graphStatus OpDesc::OpVerify() { | |||||
} | } | ||||
graphStatus OpDesc::CommonVerify() const { | graphStatus OpDesc::CommonVerify() const { | ||||
for (string iname : GetAllInputNames()) { | |||||
for (const string &iname : GetAllInputNames()) { | |||||
// Checking shape of all inputs | // Checking shape of all inputs | ||||
vector<int64_t> ishape = GetInputDesc(iname).GetShape().GetDims(); | |||||
vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims(); | |||||
for (int64_t dim : ishape) { | for (int64_t dim : ishape) { | ||||
GE_CHK_BOOL_RET_STATUS(dim >= -1, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", | GE_CHK_BOOL_RET_STATUS(dim >= -1, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", | ||||
iname.c_str()); | iname.c_str()); | ||||
} | } | ||||
} | } | ||||
// Check all attributes defined | // Check all attributes defined | ||||
const auto all_attributes = GetAllAttrs(); | |||||
for (const auto name : GetAllAttrNames()) { | |||||
const auto &all_attributes = GetAllAttrs(); | |||||
for (const auto &name : GetAllAttrNames()) { | |||||
GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, | GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, | ||||
"operator attribute %s is empty.", name.c_str()); | "operator attribute %s is empty.", name.c_str()); | ||||
} | } | ||||
@@ -773,19 +851,21 @@ graphStatus OpDesc::CommonVerify() const { | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const { | ||||
auto it = input_name_idx_.begin(); | |||||
for (; it != input_name_idx_.end(); ++it) { | |||||
auto input_name_idx = GetAllInputName(); | |||||
auto it = input_name_idx.begin(); | |||||
for (; it != input_name_idx.end(); ++it) { | |||||
if (it->second == index) { | if (it->second == index) { | ||||
break; | break; | ||||
} | } | ||||
} | } | ||||
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), ""); | |||||
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), ""); | |||||
return it->first; | return it->first; | ||||
} | } | ||||
int OpDesc::GetInputIndexByName(const string &name) const { | int OpDesc::GetInputIndexByName(const string &name) const { | ||||
auto it_find = input_name_idx_.find(name); | |||||
GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1); | |||||
auto input_name_idx = GetAllInputName(); | |||||
auto it_find = input_name_idx.find(name); | |||||
GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx.end(), -1); | |||||
return static_cast<int>(it_find->second); | return static_cast<int>(it_find->second); | ||||
} | } | ||||
@@ -1065,10 +1145,12 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<bool> OpDesc::GetIsInputCo | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name, | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name, | ||||
const int &index) { | const int &index) { | ||||
if (input_name_idx_.find(name) != input_name_idx_.end()) { | |||||
auto input_name_idx = GetAllInputName(); | |||||
if (input_name_idx.find(name) != input_name_idx.end()) { | |||||
GELOGI("Restore input name index is existed. name[%s]", name.c_str()); | GELOGI("Restore input name index is existed. name[%s]", name.c_str()); | ||||
} | } | ||||
(void)input_name_idx_.insert(make_pair(name, index)); | |||||
(void)input_name_idx.insert(make_pair(name, index)); | |||||
SetAllInputName(input_name_idx); | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
@@ -1104,4 +1186,45 @@ graphStatus OpDesc::CallInferFormatFunc(Operator &op) { | |||||
} | } | ||||
return (graphStatus)infer_format_func_(op); | return (graphStatus)infer_format_func_(op); | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetSubgraphInstanceName(uint32_t index) const { | |||||
if (static_cast<size_t>(index) >= subgraph_instance_names_.size()) { | |||||
return ""; | |||||
} | |||||
return subgraph_instance_names_.at(index); | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector<std::string> &OpDesc::GetSubgraphInstanceNames() | |||||
const { | |||||
return subgraph_instance_names_; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::AddSubgraphInstanceName(std::string name) { | |||||
subgraph_instance_names_.emplace_back(std::move(name)); | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { | |||||
for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { | |||||
if (*iter == name) { | |||||
subgraph_instance_names_.erase(iter); | |||||
return; | |||||
} | |||||
} | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { | |||||
auto iter = subgraph_names_to_index_.find(name); | |||||
if (iter != subgraph_names_to_index_.end()) { | |||||
GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second); | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto size = subgraph_names_to_index_.size(); | |||||
subgraph_names_to_index_[name] = size; | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map<std::string, uint32_t> &OpDesc::GetSubgraphNameIndexes() | |||||
const { | |||||
return subgraph_names_to_index_; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -20,8 +20,7 @@ | |||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
using std::function; | |||||
using std::vector; | |||||
using namespace std; | |||||
namespace ge { | namespace ge { | ||||
@@ -15,13 +15,12 @@ | |||||
*/ | */ | ||||
#include "external/graph/operator.h" | #include "external/graph/operator.h" | ||||
#include <stdint.h> | #include <stdint.h> | ||||
#include <algorithm> | #include <algorithm> | ||||
#include <mutex> | #include <mutex> | ||||
#include <queue> | #include <queue> | ||||
#include <set> | #include <set> | ||||
#include "array_ops.h" | |||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
#include "debug/ge_op_types.h" | #include "debug/ge_op_types.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
@@ -33,7 +32,6 @@ | |||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
#include "graph/operator_factory.h" | |||||
#include "graph/usr_types.h" | #include "graph/usr_types.h" | ||||
#include "utils/graph_utils.h" | #include "utils/graph_utils.h" | ||||
#include "utils/op_desc_utils.h" | #include "utils/op_desc_utils.h" | ||||
@@ -48,10 +46,6 @@ using std::string; | |||||
using std::to_string; | using std::to_string; | ||||
using std::vector; | using std::vector; | ||||
namespace { | |||||
const char *const kValue = "value"; | |||||
} // namespace | |||||
namespace ge { | namespace ge { | ||||
class OpIO { | class OpIO { | ||||
public: | public: | ||||
@@ -148,6 +142,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||||
for (int i = static_cast<int>(is_input_const.size()); i <= dst_index; ++i) { | for (int i = static_cast<int>(is_input_const.size()); i <= dst_index; ++i) { | ||||
is_input_const.push_back(false); | is_input_const.push_back(false); | ||||
} | } | ||||
is_input_const[dst_index] = is_const; | is_input_const[dst_index] = is_const; | ||||
op_desc_->SetIsInputConst(is_input_const); | op_desc_->SetIsInputConst(is_input_const); | ||||
@@ -179,8 +174,8 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||||
GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), | GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), | ||||
op_desc_->GetName().c_str()); | op_desc_->GetName().c_str()); | ||||
auto out_op_impl = out_handler->GetOwner(); | auto out_op_impl = out_handler->GetOwner(); | ||||
GE_CHK_BOOL_EXEC(out_op_impl && out_op_impl->GetOpDescImpl(), return, "out_handler invalid. name[%s]", | |||||
dst_name.c_str()); | |||||
GE_CHK_BOOL_EXEC(out_op_impl != nullptr && out_op_impl->GetOpDescImpl() != nullptr, return, | |||||
"out_handler invalid. name[%s]", dst_name.c_str()); | |||||
bool is_const = false; | bool is_const = false; | ||||
if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) { | if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) { | ||||
is_const = true; | is_const = true; | ||||
@@ -193,7 +188,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||||
op_desc_->SetIsInputConst(is_input_const); | op_desc_->SetIsInputConst(is_input_const); | ||||
OpIO in_handler(dst_name, dst_index, shared_from_this()); | OpIO in_handler(dst_name, dst_index, shared_from_this()); | ||||
GE_CHK_BOOL_EXEC(!!out_op_impl, return, "Get out_handler's impl failed."); | |||||
GE_CHK_BOOL_EXEC(out_op_impl != nullptr, return, "Get out_handler's impl failed."); | |||||
out_op_impl->UpdateLinkMapImpl(src_name, in_handler); | out_op_impl->UpdateLinkMapImpl(src_name, in_handler); | ||||
auto src_output_desc = out_op_impl->GetOutputDesc(src_name); | auto src_output_desc = out_op_impl->GetOutputDesc(src_name); | ||||
@@ -210,7 +205,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||||
void AddControlInputImp(const ge::Operator &src_oprt) { | void AddControlInputImp(const ge::Operator &src_oprt) { | ||||
if (src_oprt.operator_impl_ == nullptr) { | if (src_oprt.operator_impl_ == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "Src operator impl is nullptr"); | |||||
GELOGE(FAILED, "Src operator impl is nullptr"); | |||||
return; | return; | ||||
} | } | ||||
for (auto &input : control_input_link_) { | for (auto &input : control_input_link_) { | ||||
@@ -520,9 +515,9 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co | |||||
if (peer_node_ptr->GetOpDesc() != nullptr) { | if (peer_node_ptr->GetOpDesc() != nullptr) { | ||||
const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType(); | const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType(); | ||||
if (op_descType == CONSTANTOP) { | if (op_descType == CONSTANTOP) { | ||||
return const_op.GetAttr(kValue, data); | |||||
return const_op.GetAttr(op::Constant::name_attr_value(), data); | |||||
} else if (op_descType == CONSTANT) { | } else if (op_descType == CONSTANT) { | ||||
return const_op.GetAttr(kValue, data); | |||||
return const_op.GetAttr(op::Const::name_attr_value(), data); | |||||
} | } | ||||
} | } | ||||
} else { | } else { | ||||
@@ -542,9 +537,9 @@ graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) | |||||
Operator const_op(out_handle.GetOwner()); | Operator const_op(out_handle.GetOwner()); | ||||
const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType(); | const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType(); | ||||
if (op_desc_impl_type == CONSTANTOP) { | if (op_desc_impl_type == CONSTANTOP) { | ||||
return const_op.GetAttr(kValue, data); | |||||
return const_op.GetAttr(op::Constant::name_attr_value(), data); | |||||
} else if (op_desc_impl_type == CONSTANT) { | } else if (op_desc_impl_type == CONSTANT) { | ||||
return const_op.GetAttr(kValue, data); | |||||
return const_op.GetAttr(op::Const::name_attr_value(), data); | |||||
} | } | ||||
} | } | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
@@ -709,6 +704,7 @@ void Operator::InputRegister(const string &name) { | |||||
void Operator::OptionalInputRegister(const string &name) { | void Operator::OptionalInputRegister(const string &name) { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
// [No need to verify return value] | |||||
(void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name, | (void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name, | ||||
GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED)); | GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED)); | ||||
} | } | ||||
@@ -716,24 +712,28 @@ void Operator::OptionalInputRegister(const string &name) { | |||||
void Operator::InferFuncRegister(const std::function<graphStatus(Operator &)> &func) { | void Operator::InferFuncRegister(const std::function<graphStatus(Operator &)> &func) { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
// [No need to verify return value] | |||||
(void)operator_impl_->GetOpDescImpl()->AddInferFunc(func); | (void)operator_impl_->GetOpDescImpl()->AddInferFunc(func); | ||||
} | } | ||||
void Operator::InferFormatFuncRegister(const std::function<graphStatus(Operator &)> &func) { | void Operator::InferFormatFuncRegister(const std::function<graphStatus(Operator &)> &func) { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
// [No need to verify return value] | |||||
(void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func); | (void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func); | ||||
} | } | ||||
void Operator::VerifierFuncRegister(const std::function<graphStatus(Operator &)> &func) { | void Operator::VerifierFuncRegister(const std::function<graphStatus(Operator &)> &func) { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
// [No need to verify return value] | |||||
(void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func); | (void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func); | ||||
} | } | ||||
void Operator::OutputRegister(const string &name) { | void Operator::OutputRegister(const string &name) { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
// [No need to verify return value] | |||||
(void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc()); | (void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc()); | ||||
} | } | ||||
@@ -757,7 +757,8 @@ int Operator::GetDynamicInputNum(const string &name) const { | |||||
void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) { | void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
(void)AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num); | |||||
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return, | |||||
"Set %s int failed", name.c_str()); | |||||
(void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back); | (void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back); | ||||
} | } | ||||
@@ -765,7 +766,8 @@ int Operator::GetDynamicOutputNum(const string &name) const { | |||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); | ||||
int num = 0; | int num = 0; | ||||
(void)AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num); | |||||
GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return num, | |||||
"Get %s int failed", name.c_str()); | |||||
return num; | return num; | ||||
} | } | ||||
@@ -1141,7 +1143,9 @@ class GraphBuilderImpl { | |||||
GELOGW("Input operator should be Data, Variable operator or operator that has output but no input."); | GELOGW("Input operator should be Data, Variable operator or operator that has output but no input."); | ||||
} | } | ||||
} | } | ||||
GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr, | |||||
"User Input do not include operator such as \ | |||||
Data, Variable operator or operator that has output but no input."); | |||||
auto ret = WalkAllOperators(vec_inputs); | auto ret = WalkAllOperators(vec_inputs); | ||||
GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed."); | GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed."); | ||||
@@ -1163,7 +1167,8 @@ class GraphBuilderImpl { | |||||
que.pop(); | que.pop(); | ||||
for (const auto &op_impl : vec_tem) { | for (const auto &op_impl : vec_tem) { | ||||
GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.") | GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.") | ||||
GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue) | |||||
GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue, | |||||
"This node %s has created.", op_impl->GetName().c_str()) | |||||
auto node_ptr = graph_->AddNode(op_impl->op_desc_); | auto node_ptr = graph_->AddNode(op_impl->op_desc_); | ||||
GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed."); | GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed."); | ||||
all_nodes_info_.insert(std::make_pair(op_impl, node_ptr)); | all_nodes_info_.insert(std::make_pair(op_impl, node_ptr)); | ||||
@@ -1202,10 +1207,13 @@ class GraphBuilderImpl { | |||||
for (const auto &node_info : all_nodes_info_) { | for (const auto &node_info : all_nodes_info_) { | ||||
auto src_op_impl_ptr = node_info.first; | auto src_op_impl_ptr = node_info.first; | ||||
auto src_node_ptr = node_info.second; | auto src_node_ptr = node_info.second; | ||||
GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue); | GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue); | ||||
auto out_links = src_op_impl_ptr->output_links_; | auto out_links = src_op_impl_ptr->output_links_; | ||||
GE_CHK_BOOL_EXEC(src_op_impl_ptr->op_desc_ != nullptr, return GRAPH_FAILED, | |||||
"Src operator impl's op_desc is null."); | |||||
auto &op_desc = src_op_impl_ptr->op_desc_; | auto &op_desc = src_op_impl_ptr->op_desc_; | ||||
GE_IF_BOOL_EXEC(op_desc == nullptr, continue); | |||||
for (const auto &out : out_links) { | for (const auto &out : out_links) { | ||||
auto src_idx = op_desc->GetOutputIndexByName(out.first); | auto src_idx = op_desc->GetOutputIndexByName(out.first); | ||||
GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed"); | GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed"); | ||||
@@ -1216,7 +1224,9 @@ class GraphBuilderImpl { | |||||
for (const auto &dst_opio : out.second) { | for (const auto &dst_opio : out.second) { | ||||
auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner()); | auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner()); | ||||
GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed."); | GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed."); | ||||
GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); | GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); | ||||
auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex()); | auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex()); | ||||
GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed."); | GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed."); | ||||
@@ -1260,8 +1270,7 @@ inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) { | |||||
ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector<ge::Operator> &inputs) { | ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector<ge::Operator> &inputs) { | ||||
auto graph_builder_impl = GraphBuilderImpl(name); | auto graph_builder_impl = GraphBuilderImpl(name); | ||||
ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs); | ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs); | ||||
GE_IF_BOOL_EXEC(compute_graph == nullptr, return compute_graph); | |||||
GE_CHK_BOOL_EXEC(compute_graph != nullptr, return compute_graph, "Computer graph is nullptr"); | |||||
compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo()); | compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo()); | ||||
if (HasSameNameNode(compute_graph)) { | if (HasSameNameNode(compute_graph)) { | ||||
GELOGW("Compute do not allow has same name nodes."); | GELOGW("Compute do not allow has same name nodes."); | ||||
@@ -15,13 +15,11 @@ | |||||
*/ | */ | ||||
#include "graph/opsproto_manager.h" | #include "graph/opsproto_manager.h" | ||||
#include <algorithm> | |||||
#include <cstdlib> | #include <cstdlib> | ||||
#include <algorithm> | |||||
#include <functional> | #include <functional> | ||||
#include <iostream> | #include <iostream> | ||||
#include <sstream> | #include <sstream> | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/debug/ge_log.h" | #include "graph/debug/ge_log.h" | ||||
@@ -155,7 +153,7 @@ void OpsProtoManager::LoadOpsProtoPluginSo(std::string &path) { | |||||
// Load .so file | // Load .so file | ||||
for (auto elem : file_list) { | for (auto elem : file_list) { | ||||
void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL | RTLD_NODELETE); | |||||
void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||||
if (handle == nullptr) { | if (handle == nullptr) { | ||||
GELOGW("OpsProtoManager dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); | GELOGW("OpsProtoManager dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); | ||||
continue; | continue; | ||||
@@ -15,7 +15,6 @@ | |||||
*/ | */ | ||||
#include "./ge_context.h" | #include "./ge_context.h" | ||||
#include "./ge_global_options.h" | #include "./ge_global_options.h" | ||||
#include "./ge_local_context.h" | #include "./ge_local_context.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
@@ -87,4 +86,5 @@ uint32_t GEContext::DeviceId() { return device_id_; } | |||||
uint64_t GEContext::TraceId() { return trace_id_; } | uint64_t GEContext::TraceId() { return trace_id_; } | ||||
void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } | void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } | ||||
} // namespace ge | } // namespace ge |
@@ -22,6 +22,7 @@ | |||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/utils/graph_utils.h" | |||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
#include "debug/ge_op_types.h" | #include "debug/ge_op_types.h" | ||||
#include "external/graph/operator.h" | #include "external/graph/operator.h" | ||||
@@ -34,6 +35,122 @@ | |||||
#include "utils/type_utils.h" | #include "utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
namespace { | |||||
constexpr const char *kRefIndex = "parent_node_index"; | |||||
graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||||
if (sub_graph_names.empty()) { | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||||
for (const auto &name : sub_graph_names) { | |||||
auto sub_graph = root_graph->GetSubgraph(name); | |||||
if (sub_graph == nullptr) { | |||||
GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
for (const auto &node_sub : sub_graph->GetDirectNode()) { | |||||
if (node_sub->GetType() != DATA) { | |||||
continue; | |||||
} | |||||
int ref_i; | |||||
auto data_opdesc = node_sub->GetOpDesc(); | |||||
if (data_opdesc == nullptr) { | |||||
GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), | |||||
node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (!AttrUtils::GetInt(node_sub->GetOpDesc(), kRefIndex, ref_i)) { | |||||
GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), | |||||
node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto input_desc = op_desc->MutableInputDesc(ref_i); | |||||
if (input_desc == nullptr) { | |||||
GE_LOGE( | |||||
"The ref index(%d) on the data %s on the sub graph %s " | |||||
"parent node %s are incompatible, inputs num %u", | |||||
ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s", | |||||
node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
ret = data_opdesc->UpdateOutputDesc(0, *input_desc); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s", | |||||
node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||||
if (sub_graph_names.empty()) { | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||||
for (const auto &name : sub_graph_names) { | |||||
auto sub_graph = root_graph->GetSubgraph(name); | |||||
if (sub_graph == nullptr) { | |||||
GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
NodePtr netoutput = nullptr; | |||||
auto sub_nodes = sub_graph->GetDirectNode(); | |||||
for (size_t i = sub_nodes.size(); i > 0; --i) { | |||||
auto sub_node = sub_nodes.at(i - 1); | |||||
if (sub_node->GetType() == NETOUTPUT) { | |||||
netoutput = sub_node; | |||||
break; | |||||
} | |||||
} | |||||
if (netoutput == nullptr) { | |||||
GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto netoutput_opdesc = netoutput->GetOpDesc(); | |||||
if (netoutput_opdesc == nullptr) { | |||||
GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(), | |||||
node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) { | |||||
auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx()); | |||||
if (edge_desc == nullptr) { | |||||
GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", name.c_str(), | |||||
node->GetName().c_str(), edge_anchor->GetIdx()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
int ref_i; | |||||
if (!AttrUtils::GetInt(edge_desc, kRefIndex, ref_i)) { | |||||
// if there is no ref index on the TensorDesc, it means the output data will be ignored outer. | |||||
continue; | |||||
} | |||||
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i)); | |||||
if (output_desc == nullptr) { | |||||
GE_LOGE( | |||||
"The ref index(%d) on the input %d of netoutput %s on the sub graph %s " | |||||
"parent node %s are incompatible, outputs num %u", | |||||
ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(), | |||||
node->GetAllOutDataAnchorsSize()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc); | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} // namespace | |||||
void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { | void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { | ||||
if (node == nullptr) { | if (node == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "node is null"); | GELOGE(GRAPH_FAILED, "node is null"); | ||||
@@ -42,7 +159,7 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||||
ge::OpDescPtr op_desc = node->GetOpDesc(); | ge::OpDescPtr op_desc = node->GetOpDesc(); | ||||
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); | GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); | ||||
std::string str; | std::string str; | ||||
if (!op_desc->GetAllInputsDescPtr().empty()) { | |||||
if (op_desc->GetInputsSize() != 0) { | |||||
std::string input_desc_str = "input shape: "; | std::string input_desc_str = "input shape: "; | ||||
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | ||||
input_desc_str += "["; | input_desc_str += "["; | ||||
@@ -56,7 +173,7 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||||
str += input_desc_str; | str += input_desc_str; | ||||
} | } | ||||
if (!op_desc->GetAllOutputsDescPtr().empty()) { | |||||
if (op_desc->GetAllOutputsDescSize() != 0) { | |||||
std::string output_desc_str = "output shape: "; | std::string output_desc_str = "output shape: "; | ||||
for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { | for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { | ||||
if (output_desc == nullptr) { | if (output_desc == nullptr) { | ||||
@@ -76,13 +193,24 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||||
} | } | ||||
graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) { | graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) { | ||||
return InferShapeAndType(node, op, true); | |||||
} | |||||
graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) { | |||||
GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); | GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); | ||||
const auto &op_type = op_desc->GetType(); | const auto &op_type = op_desc->GetType(); | ||||
graphStatus ret; | |||||
if (before_subgraph) { | |||||
ret = UpdateSubGraphDataNodes(node); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
return ret; | |||||
} | |||||
} | |||||
// Get infer func and execute | // Get infer func and execute | ||||
graphStatus ret = op_desc->CallInferFunc(op); | |||||
ret = op_desc->CallInferFunc(op); | |||||
if (ret == GRAPH_PARAM_INVALID) { | if (ret == GRAPH_PARAM_INVALID) { | ||||
// Op ir no infer func, try to get infer func from operator factory | // Op ir no infer func, try to get infer func from operator factory | ||||
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); | auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); | ||||
@@ -113,7 +241,14 @@ graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator & | |||||
ret = op_desc->CallInferFunc(op); | ret = op_desc->CallInferFunc(op); | ||||
GELOGI("op CallInferFunc second. ret: %u", ret); | GELOGI("op CallInferFunc second. ret: %u", ret); | ||||
} | } | ||||
return ret; | |||||
if (ret != GRAPH_SUCCESS) { | |||||
return ret; | |||||
} | |||||
if (!before_subgraph) { | |||||
return UpdateParentNodeOutTensor(node); | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | } | ||||
InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map, | InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map, | ||||
@@ -179,8 +314,11 @@ InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, Inf | |||||
namespace { | namespace { | ||||
std::unordered_map<NodePtr, InferenceContextPtr> context_map; | std::unordered_map<NodePtr, InferenceContextPtr> context_map; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) { | ||||
return InferShapeAndType(node, true); | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node, | |||||
bool before_subgraph) { | |||||
GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | ||||
if (node->Verify() != GRAPH_SUCCESS) { | if (node->Verify() != GRAPH_SUCCESS) { | ||||
GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str()); | GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str()); | ||||
@@ -199,7 +337,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh | |||||
Operator op = OpDescUtils::CreateOperatorFromNode(node); | Operator op = OpDescUtils::CreateOperatorFromNode(node); | ||||
op.SetInferenceContext(inference_context); | op.SetInferenceContext(inference_context); | ||||
graphStatus status = InferShapeAndType(node, op); | |||||
graphStatus status = InferShapeAndType(node, op, before_subgraph); | |||||
if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { | if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { | ||||
(void)ge::NodeUtils::UpdatePeerNodeInputDesc(node); | (void)ge::NodeUtils::UpdatePeerNodeInputDesc(node); | ||||
} else { | } else { | ||||
@@ -353,6 +353,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) | |||||
} | } | ||||
} | } | ||||
} | } | ||||
impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size); | impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size); | ||||
} | } | ||||
@@ -516,13 +517,14 @@ graphStatus Tensor::IsValid() { | |||||
GELOGW("mul overflow: %lu, %u", shape_size, type_length); | GELOGW("mul overflow: %lu, %u", shape_size, type_length); | ||||
} else { | } else { | ||||
if (shape_size * type_length != data_size) { | if (shape_size * type_length != data_size) { | ||||
// [Just log] Constructor | |||||
GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, | GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, | ||||
data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); | data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
return GRAPH_FAILED; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
} | } | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
@@ -539,7 +541,7 @@ GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_des | |||||
tensor_desc.GetDataType()); | tensor_desc.GetDataType()); | ||||
ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); | ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); | ||||
ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); | ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); | ||||
auto size = static_cast<uint32_t>(tensor_desc.GetSize()); | |||||
auto size = tensor_desc.GetSize(); | |||||
TensorUtils::SetSize(ge_tensor_desc, size); | TensorUtils::SetSize(ge_tensor_desc, size); | ||||
auto real_dim_cnt = static_cast<uint32_t>(tensor_desc.GetRealDimCnt()); | auto real_dim_cnt = static_cast<uint32_t>(tensor_desc.GetRealDimCnt()); | ||||
@@ -552,7 +554,7 @@ TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_ | |||||
ge_tensor_desc.GetDataType()); | ge_tensor_desc.GetDataType()); | ||||
tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); | tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); | ||||
tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); | tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); | ||||
uint32_t size = 0; | |||||
int64_t size = 0; | |||||
(void)TensorUtils::GetSize(ge_tensor_desc, size); | (void)TensorUtils::GetSize(ge_tensor_desc, size); | ||||
tensor_desc.SetSize(size); | tensor_desc.SetSize(size); | ||||
@@ -15,18 +15,21 @@ | |||||
*/ | */ | ||||
#include "graph/utils/ge_ir_utils.h" | #include "graph/utils/ge_ir_utils.h" | ||||
#include <utility> | #include <utility> | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
namespace { | namespace { | ||||
const char *const kControlAnchorIndex = ":-1"; | const char *const kControlAnchorIndex = ":-1"; | ||||
const char *const kNodeTypeForSubgraph = "subgraph"; | const char *const kNodeTypeForSubgraph = "subgraph"; | ||||
const char *const kPrefixForInputDesc = "input_desc_attr_"; | |||||
const char *const kPrefixForOutputDesc = "output_desc_attr_"; | |||||
const char *const kDumpGEGraph = "DUMP_GE_GRAPH"; | const char *const kDumpGEGraph = "DUMP_GE_GRAPH"; | ||||
const int8_t kMaxRecursionDepth = 10; | const int8_t kMaxRecursionDepth = 10; | ||||
const char *const kDumpGeGraph = std::getenv(kDumpGEGraph); | const char *const kDumpGeGraph = std::getenv(kDumpGEGraph); | ||||
const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, nullptr, 10) : ge::OnnxUtils::NO_DUMP; | const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, nullptr, 10) : ge::OnnxUtils::NO_DUMP; | ||||
const int64_t kInputPrefixLength = 5; | |||||
const int64_t kOutputPrefixLength = 6; | |||||
using AttrDefPair = ::google::protobuf::MapPair<std::string, ge::proto::AttrDef>; | |||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
@@ -198,7 +201,7 @@ void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_A | |||||
void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, | void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, | ||||
::google::protobuf::RepeatedField<bool> data) { | ::google::protobuf::RepeatedField<bool> data) { | ||||
if (node_proto == nullptr) { | if (node_proto == nullptr) { | ||||
GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); | |||||
GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str()); | |||||
return; | return; | ||||
} | } | ||||
if (!data.empty()) { | if (!data.empty()) { | ||||
@@ -320,7 +323,16 @@ void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const | |||||
auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset(); | auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, | AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, | ||||
"input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset); | "input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset); | ||||
const auto &tensor_desc_map = tensor_descriptor->attr(); | |||||
std::string suffix = ":" + std::to_string(i); | |||||
AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForInputDesc, suffix); | |||||
} else { | |||||
GELOGW("Tensor descriptor is nullptr"); | |||||
continue; | |||||
} | } | ||||
} else { | |||||
GELOGW("Input desc is nullptr"); | |||||
continue; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -360,16 +372,25 @@ void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const | |||||
auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); | auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, | AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, | ||||
"output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); | "output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); | ||||
const auto &tensor_desc_map = tensor_descriptor->attr(); | |||||
std::string suffix = ":" + std::to_string(i); | |||||
AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForOutputDesc, suffix); | |||||
} else { | |||||
GELOGW("Tensor descriptor is nullptr"); | |||||
continue; | |||||
} | } | ||||
} else { | |||||
GELOGW("Output desc is nullptr"); | |||||
continue; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
} | } | ||||
void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto) { | |||||
GE_CHK_BOOL_EXEC(op_def != nullptr, return, "Opdef is nullptr"); | |||||
const auto &op_def_attr_map = op_def->attr(); | |||||
for (const auto &item : op_def_attr_map) { | |||||
void OnnxUtils::AddAttrProtoForAttrsFromAttrMap( | |||||
const ::google::protobuf::Map<std::string, ::ge::proto::AttrDef> &attr_map, onnx::NodeProto *node_proto, | |||||
const std::string &prefix, const std::string &suffix) { | |||||
for (const auto &item : attr_map) { | |||||
auto attr_name = item.first; | auto attr_name = item.first; | ||||
auto attr_def = item.second; | auto attr_def = item.second; | ||||
auto attr_type = attr_def.value_case(); | auto attr_type = attr_def.value_case(); | ||||
@@ -377,36 +398,40 @@ void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, on | |||||
const auto &tensor_def = attr_def.t(); | const auto &tensor_def = attr_def.t(); | ||||
const auto &tensor_desc = tensor_def.desc(); | const auto &tensor_desc = tensor_def.desc(); | ||||
auto data_type = ge::proto::DataType_Name(tensor_desc.dtype()); | auto data_type = ge::proto::DataType_Name(tensor_desc.dtype()); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_desc_dtype:", &data_type); | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_dtype" + suffix, | |||||
&data_type); | |||||
auto dims = tensor_desc.shape().dim(); | auto dims = tensor_desc.shape().dim(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name + "_desc_shape:", dims); | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + "_desc_shape" + suffix, | |||||
dims); | |||||
auto layout = tensor_desc.layout(); | auto layout = tensor_desc.layout(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_desc_layout:", &layout); | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_layout" + suffix, | |||||
&layout); | |||||
auto device_type = tensor_desc.device_type(); | auto device_type = tensor_desc.device_type(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, | AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, | ||||
attr_name + "_desc_device_type:", &device_type); | |||||
prefix + attr_name + "_desc_device_type" + suffix, &device_type); | |||||
if (kDumpLevel == DUMP_ALL) { | if (kDumpLevel == DUMP_ALL) { | ||||
auto data = tensor_def.data(); | auto data = tensor_def.data(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_data", &data); | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_data" + suffix, | |||||
&data); | |||||
} | } | ||||
} | } | ||||
if (attr_type == ge::proto::AttrDef::kS) { | if (attr_type == ge::proto::AttrDef::kS) { | ||||
if (kDumpLevel == DUMP_ALL) { | if (kDumpLevel == DUMP_ALL) { | ||||
auto str_value = attr_def.s(); | auto str_value = attr_def.s(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name, &str_value); | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value); | |||||
} | } | ||||
} | } | ||||
if (attr_type == ge::proto::AttrDef::kI) { | if (attr_type == ge::proto::AttrDef::kI) { | ||||
auto int_value = attr_def.i(); | auto int_value = attr_def.i(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, attr_name, &int_value); | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); | |||||
} | } | ||||
if (attr_type == ge::proto::AttrDef::kF) { | if (attr_type == ge::proto::AttrDef::kF) { | ||||
auto float_value = attr_def.f(); | auto float_value = attr_def.f(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, attr_name, &float_value); | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value); | |||||
} | } | ||||
if (attr_type == ge::proto::AttrDef::kB) { | if (attr_type == ge::proto::AttrDef::kB) { | ||||
auto int_value = static_cast<int64_t>(attr_def.b()); | auto int_value = static_cast<int64_t>(attr_def.b()); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, attr_name, &int_value); | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); | |||||
} | } | ||||
if (attr_type == ge::proto::AttrDef::kList) { | if (attr_type == ge::proto::AttrDef::kList) { | ||||
const auto &list_value = attr_def.list(); | const auto &list_value = attr_def.list(); | ||||
@@ -415,21 +440,21 @@ void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, on | |||||
ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) { | ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) { | ||||
if (kDumpLevel == DUMP_ALL) { | if (kDumpLevel == DUMP_ALL) { | ||||
const auto &strings = list_value.s(); | const auto &strings = list_value.s(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, attr_name, strings); | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings); | |||||
} | } | ||||
} | } | ||||
if (list_value_type == | if (list_value_type == | ||||
ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) { | ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) { | ||||
const auto &floats = list_value.f(); | const auto &floats = list_value.f(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, attr_name, floats); | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats); | |||||
} | } | ||||
if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) { | if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) { | ||||
const auto &ints = list_value.i(); | const auto &ints = list_value.i(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name, ints); | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints); | |||||
} | } | ||||
if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) { | if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) { | ||||
const auto &bools = list_value.b(); | const auto &bools = list_value.b(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name, bools); | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -481,8 +506,15 @@ void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes); | AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes); | ||||
const auto &is_input_const = op_def->is_input_const(); | const auto &is_input_const = op_def->is_input_const(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const); | AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const); | ||||
AddAttrProtoForAttrsFromOpDef(op_def, node_proto); | |||||
const auto &op_def_attr_map = op_def->attr(); | |||||
AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto); | |||||
} else { | |||||
GELOGE(FAILED, "Opdef is nullptr"); | |||||
return; | |||||
} | } | ||||
} else { | |||||
GELOGE(FAILED, "Opdesc is nullptr"); | |||||
return; | |||||
} | } | ||||
} | } | ||||
@@ -526,15 +558,13 @@ bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) | |||||
node_proto->clear_input(); | node_proto->clear_input(); | ||||
// 1. Add input by in data edge | // 1. Add input by in data edge | ||||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
if (in_data_anchor != nullptr) { | |||||
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) { | |||||
node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + | |||||
std::to_string(peer_out_anchor->GetIdx())); | |||||
} else { | |||||
// Add "" input | |||||
node_proto->add_input(""); | |||||
} | |||||
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) { | |||||
node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + | |||||
std::to_string(peer_out_anchor->GetIdx())); | |||||
} else { | |||||
// Add "" input | |||||
node_proto->add_input(""); | |||||
} | } | ||||
} | } | ||||
@@ -547,6 +577,9 @@ bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) | |||||
node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex); | node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex); | ||||
} | } | ||||
} | } | ||||
} else { | |||||
GELOGE(FAILED, "Incontrol anchor is nullptr"); | |||||
return false; | |||||
} | } | ||||
// 3. Add output for Netron visual support | // 3. Add output for Netron visual support | ||||
@@ -584,7 +617,7 @@ void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_T | |||||
} | } | ||||
const auto &op_desc = node->GetOpDesc(); | const auto &op_desc = node->GetOpDesc(); | ||||
if (op_desc != nullptr) { | if (op_desc != nullptr) { | ||||
auto size_out = op_desc->GetOutputsSize(); | |||||
uint32_t size_out = static_cast<uint32_t>(op_desc->GetOutputsSize()); | |||||
if (size_out > 0) { | if (size_out > 0) { | ||||
for (uint32_t i = 0; i < size_out; i++) { | for (uint32_t i = 0; i < size_out; i++) { | ||||
const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i); | const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i); | ||||
@@ -598,7 +631,13 @@ void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_T | |||||
auto dim = shape->add_dim(); | auto dim = shape->add_dim(); | ||||
dim->set_dim_value(d); | dim->set_dim_value(d); | ||||
} | } | ||||
} else { | |||||
GELOGW("Shape is nullptr"); | |||||
continue; | |||||
} | } | ||||
} else { | |||||
GELOGW("Ge tensor is nullptr"); | |||||
continue; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -666,7 +705,7 @@ bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelPr | |||||
} | } | ||||
// For subgraphs: a subgraph is represented by a node | // For subgraphs: a subgraph is represented by a node | ||||
for (const auto &sub_compute_graph : compute_graph->sub_graph_) { | |||||
for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) { | |||||
if (sub_compute_graph != nullptr) { | if (sub_compute_graph != nullptr) { | ||||
auto node_proto = graph_proto->add_node(); | auto node_proto = graph_proto->add_node(); | ||||
if (node_proto == nullptr) { | if (node_proto == nullptr) { | ||||
@@ -679,6 +718,10 @@ bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelPr | |||||
attr->set_name("graph"); | attr->set_name("graph"); | ||||
attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); | attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); | ||||
auto sub_graph_proto = attr->mutable_g(); | auto sub_graph_proto = attr->mutable_g(); | ||||
if (sub_graph_proto == nullptr) { | |||||
GELOGW("Sub graph proto is nullptr"); | |||||
continue; | |||||
} | |||||
if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) { | if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) { | ||||
GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str()); | GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str()); | ||||
continue; | continue; | ||||
@@ -831,56 +874,116 @@ void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t | |||||
value = attr_proto.i(); | value = attr_proto.i(); | ||||
} | } | ||||
void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | |||||
const std::string &attr_name_for_input_output_desc, int32_t index, | |||||
OpDescPtr &op_desc) { | |||||
if (op_desc == nullptr || op_desc->MutableInputDesc(static_cast<uint32_t>(index)) == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "op_desc or op_desc->MutableInputDesc(index) is nullptr"); | |||||
void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, | |||||
const std::string &attr_name_for_input_desc, int32_t index, | |||||
OpDescPtr &op_desc) { | |||||
if (op_desc->MutableInputDesc(static_cast<uint32_t>(index)) == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableInputDesc(static_cast<uint32_t>(index)) is nullptr", | |||||
op_desc->GetName().c_str(), attr_name_for_input_desc.c_str()); | |||||
return; | return; | ||||
} | } | ||||
if (attr_name_for_input_output_desc == "input_desc_dtype") { | |||||
if (attr_name_for_input_desc == "input_desc_dtype") { | |||||
auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); | auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); | ||||
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetDataType(data_type); | op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetDataType(data_type); | ||||
} else if (attr_name_for_input_output_desc == "input_desc_shape") { | |||||
} else if (attr_name_for_input_desc == "input_desc_shape") { | |||||
std::vector<std::int64_t> ints; | std::vector<std::int64_t> ints; | ||||
DecodeAttribute(attr_proto, ints); | DecodeAttribute(attr_proto, ints); | ||||
GeShape ge_shape(ints); | GeShape ge_shape(ints); | ||||
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape); | op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape); | ||||
} else if (attr_name_for_input_output_desc == "input_desc_layout") { | |||||
} else if (attr_name_for_input_desc == "input_desc_layout") { | |||||
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | ||||
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetFormat(data_format); | op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetFormat(data_format); | ||||
} else if (attr_name_for_input_output_desc == "input_desc_origin_shape") { | |||||
} else if (attr_name_for_input_desc == "input_desc_origin_shape") { | |||||
std::vector<std::int64_t> ints; | std::vector<std::int64_t> ints; | ||||
DecodeAttribute(attr_proto, ints); | DecodeAttribute(attr_proto, ints); | ||||
GeShape ge_shape(ints); | GeShape ge_shape(ints); | ||||
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape); | op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape); | ||||
} else if (attr_name_for_input_output_desc == "input_desc_origin_layout") { | |||||
} else if (attr_name_for_input_desc == "input_desc_origin_layout") { | |||||
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | ||||
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format); | op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format); | ||||
} else if (attr_name_for_input_output_desc == "output_desc_dtype") { | |||||
} else if (attr_name_for_input_desc == "input_desc_size") { | |||||
int64_t input_size = 0; | |||||
auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||||
DecodeAttribute(attr_proto, input_size); | |||||
tensor_descriptor->set_size(input_size); | |||||
} else if (attr_name_for_input_desc == "input_desc_data_offset") { | |||||
auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||||
int64_t offset = 0; | |||||
DecodeAttribute(attr_proto, offset); | |||||
tensor_descriptor->set_data_offset(offset); | |||||
} else { | |||||
return; | |||||
} | |||||
} | |||||
void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, | |||||
const std::string &attr_name_for_output_desc, int32_t index, | |||||
OpDescPtr &op_desc) { | |||||
if (op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) is nullptr", | |||||
op_desc->GetName().c_str(), attr_name_for_output_desc.c_str()); | |||||
return; | |||||
} | |||||
if (attr_name_for_output_desc == "output_desc_dtype") { | |||||
auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); | auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); | ||||
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetDataType(data_type); | op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetDataType(data_type); | ||||
} else if (attr_name_for_input_output_desc == "output_desc_shape") { | |||||
} else if (attr_name_for_output_desc == "output_desc_shape") { | |||||
std::vector<std::int64_t> ints; | std::vector<std::int64_t> ints; | ||||
DecodeAttribute(attr_proto, ints); | DecodeAttribute(attr_proto, ints); | ||||
GeShape ge_shape(ints); | GeShape ge_shape(ints); | ||||
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape); | op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape); | ||||
} else if (attr_name_for_input_output_desc == "output_desc_layout") { | |||||
} else if (attr_name_for_output_desc == "output_desc_layout") { | |||||
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | ||||
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetFormat(data_format); | op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetFormat(data_format); | ||||
} else if (attr_name_for_input_output_desc == "output_desc_origin_shape") { | |||||
} else if (attr_name_for_output_desc == "output_desc_origin_shape") { | |||||
std::vector<std::int64_t> ints; | std::vector<std::int64_t> ints; | ||||
DecodeAttribute(attr_proto, ints); | DecodeAttribute(attr_proto, ints); | ||||
GeShape ge_shape(ints); | GeShape ge_shape(ints); | ||||
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape); | op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape); | ||||
} else if (attr_name_for_input_output_desc == "output_desc_origin_layout") { | |||||
} else if (attr_name_for_output_desc == "output_desc_origin_layout") { | |||||
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | ||||
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format); | op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format); | ||||
} else if (attr_name_for_output_desc == "output_desc_size") { | |||||
int64_t output_size = 0; | |||||
auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||||
DecodeAttribute(attr_proto, output_size); | |||||
tensor_descriptor->set_size(output_size); | |||||
} else if (attr_name_for_output_desc == "output_desc_data_offset") { | |||||
auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||||
int64_t offset = 0; | |||||
DecodeAttribute(attr_proto, offset); | |||||
tensor_descriptor->set_data_offset(offset); | |||||
} else { | |||||
return; | |||||
} | |||||
} | |||||
void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | |||||
const std::string &attr_name_for_input_output_desc, int32_t index, | |||||
OpDescPtr &op_desc) { | |||||
if (op_desc == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "op_desc is nullptr"); | |||||
return; | |||||
} | |||||
if (attr_name_for_input_output_desc.substr(0, kInputPrefixLength) == "input") { | |||||
DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); | |||||
} else if (attr_name_for_input_output_desc.substr(0, kOutputPrefixLength) == "output") { | |||||
DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); | |||||
} else { | } else { | ||||
return; | return; | ||||
} | } | ||||
} | } | ||||
void OnnxUtils::DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def) { | |||||
auto attr_map = op_def.mutable_attr(); | |||||
const auto &attr_name = attr_proto.name(); | |||||
ge::proto::AttrDef op_attr; | |||||
int64_t value = 0; | |||||
DecodeAttribute(attr_proto, value); | |||||
op_attr.set_i(value); | |||||
attr_map->insert(AttrDefPair(attr_name, op_attr)); | |||||
} | |||||
void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) { | void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) { | ||||
if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr"); | GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr"); | ||||
@@ -910,6 +1013,16 @@ void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_pr | |||||
std::vector<std::int64_t> ints; | std::vector<std::int64_t> ints; | ||||
DecodeAttribute(attr_proto, ints); | DecodeAttribute(attr_proto, ints); | ||||
op_desc->SetDstIndex(ints); | op_desc->SetDstIndex(ints); | ||||
} else if (attr_name == "fusion_scope") { | |||||
DecodeNodeAttributeForOpDef(attr_proto, *op_desc->op_def_.GetProtoMsg()); | |||||
} else if (attr_name == "input_i") { | |||||
std::vector<std::int64_t> ints; | |||||
DecodeAttribute(attr_proto, ints); | |||||
op_desc->SetInputOffset(ints); | |||||
} else if (attr_name == "output_i") { | |||||
std::vector<std::int64_t> ints; | |||||
DecodeAttribute(attr_proto, ints); | |||||
op_desc->SetOutputOffset(ints); | |||||
} else { | } else { | ||||
return; | return; | ||||
} | } | ||||
@@ -939,20 +1052,14 @@ bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_ | |||||
auto size_in = attr.i(); | auto size_in = attr.i(); | ||||
for (int64_t i = 0; i < size_in; i++) { | for (int64_t i = 0; i < size_in; i++) { | ||||
GeTensorDesc ge_tensor_desc; | GeTensorDesc ge_tensor_desc; | ||||
if (op_desc->AddInputDesc(ge_tensor_desc) != GRAPH_SUCCESS) { | |||||
GELOGW("Add inputdesc failed"); | |||||
continue; | |||||
} | |||||
GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed."); | |||||
} | } | ||||
} | } | ||||
if (attr.name() == "output_desc_nums") { | if (attr.name() == "output_desc_nums") { | ||||
auto size_out = attr.i(); | auto size_out = attr.i(); | ||||
for (int64_t i = 0; i < size_out; i++) { | for (int64_t i = 0; i < size_out; i++) { | ||||
GeTensorDesc ge_tensor_desc; | GeTensorDesc ge_tensor_desc; | ||||
if (op_desc->AddInputDesc(ge_tensor_desc) != GRAPH_SUCCESS) { | |||||
GELOGW("add inputdesc failed"); | |||||
continue; | |||||
} | |||||
GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed."); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -970,10 +1077,7 @@ bool OnnxUtils::DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_p | |||||
} | } | ||||
graph = ComGraphMakeShared<ge::ComputeGraph>(graph_proto.name()); | graph = ComGraphMakeShared<ge::ComputeGraph>(graph_proto.name()); | ||||
if (graph == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); | |||||
return false; | |||||
} | |||||
GE_CHK_BOOL_EXEC(graph != nullptr, return false, "ComputeGraph make shared failed"); | |||||
/// 1. Decode all nodes first, node should include input | /// 1. Decode all nodes first, node should include input | ||||
/// and output nodes and nodes which represent sub graphs | /// and output nodes and nodes which represent sub graphs | ||||
std::map<std::string, NodePtr> node_map; | std::map<std::string, NodePtr> node_map; | ||||
@@ -131,6 +131,10 @@ class OnnxUtils { | |||||
static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc); | static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc); | ||||
static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map<std::string, ge::proto::AttrDef> &attr_map, | |||||
onnx::NodeProto *node_proto, const std::string &prefix = "", | |||||
const std::string &suffix = ""); | |||||
static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto); | static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto); | ||||
static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type); | static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type); | ||||
@@ -172,10 +176,20 @@ class OnnxUtils { | |||||
static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value); | static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value); | ||||
static void DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, | |||||
const std::string &attr_name_for_output_desc, int32_t index, | |||||
OpDescPtr &op_desc); | |||||
static void DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, | |||||
const std::string &attr_name_for_input_desc, int32_t index, | |||||
OpDescPtr &op_desc); | |||||
static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | ||||
const std::string &attr_name_for_input_output_desc, int32_t index, | const std::string &attr_name_for_input_output_desc, int32_t index, | ||||
OpDescPtr &op_desc); | OpDescPtr &op_desc); | ||||
static void DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def); | |||||
static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc); | static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc); | ||||
static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr); | static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr); | ||||
@@ -15,10 +15,12 @@ | |||||
*/ | */ | ||||
#include "utils/node_utils.h" | #include "utils/node_utils.h" | ||||
#include "graph/utils/graph_utils.h" | |||||
#include "debug/ge_op_types.h" | #include "debug/ge_op_types.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/anchor.h" | #include "graph/anchor.h" | ||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "utils/tensor_utils.h" | #include "utils/tensor_utils.h" | ||||
#include "utils/type_utils.h" | #include "utils/type_utils.h" | ||||
@@ -109,6 +111,7 @@ graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_pt | |||||
graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) { | graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) { | ||||
GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED, | GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED, | ||||
"node or in_data_anchor is nullptr"); | "node or in_data_anchor is nullptr"); | ||||
bool find_flag = false; | bool find_flag = false; | ||||
uint32_t index = 0; | uint32_t index = 0; | ||||
vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end(); | vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end(); | ||||
@@ -358,4 +361,45 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const | |||||
input_desc->SetShape(shape); | input_desc->SetShape(shape); | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
std::string NodeUtils::GetNodeType(const Node &node) { | |||||
if (node.GetType() != FRAMEWORKOP) { | |||||
return node.GetType(); | |||||
} | |||||
std::string type; | |||||
(void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | |||||
return type; | |||||
} | |||||
ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { | |||||
auto op_desc = node.GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
return nullptr; | |||||
} | |||||
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||||
if (root_graph == nullptr) { | |||||
return nullptr; | |||||
} | |||||
return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); | |||||
} | |||||
graphStatus NodeUtils::AddSubgraph(Node &node, const ComputeGraphPtr &subgraph) { | |||||
if (subgraph == nullptr) { | |||||
GE_LOGE("Failed to add subgraph to node %s, null subgraph", node.GetName().c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
auto op_desc = node.GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||||
if (root_graph == nullptr) { | |||||
GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
op_desc->AddSubgraphInstanceName(subgraph->GetName()); | |||||
subgraph->SetParentNode(node.shared_from_this()); | |||||
subgraph->SetParentGraph(node.GetOwnerComputeGraph()); | |||||
root_graph->AddSubgraph(subgraph); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -15,9 +15,7 @@ | |||||
*/ | */ | ||||
#include "utils/op_desc_utils.h" | #include "utils/op_desc_utils.h" | ||||
#include <algorithm> | #include <algorithm> | ||||
#include "debug/ge_attr_define.h" | #include "debug/ge_attr_define.h" | ||||
#include "debug/ge_op_types.h" | #include "debug/ge_op_types.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
@@ -209,6 +207,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils:: | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData( | ||||
const vector<ge::NodePtr> &input_nodes) { | const vector<ge::NodePtr> &input_nodes) { | ||||
vector<ConstGeTensorPtr> ret; | vector<ConstGeTensorPtr> ret; | ||||
for (const auto &input_node : input_nodes) { | for (const auto &input_node : input_nodes) { | ||||
auto temp_weight = MutableWeights(input_node->GetOpDesc()); | auto temp_weight = MutableWeights(input_node->GetOpDesc()); | ||||
if (temp_weight == nullptr) { | if (temp_weight == nullptr) { | ||||
@@ -379,7 +378,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUt | |||||
if (NodeUtils::IsAnchorStatusSet(*node)) { | if (NodeUtils::IsAnchorStatusSet(*node)) { | ||||
for (const auto &in_anchor : node->GetAllInDataAnchors()) { | for (const auto &in_anchor : node->GetAllInDataAnchors()) { | ||||
if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) { | if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) { | ||||
(void)ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||||
ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||||
} | } | ||||
} | } | ||||
} else { | } else { | ||||
@@ -389,7 +388,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUt | |||||
continue; | continue; | ||||
} | } | ||||
if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) { | if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) { | ||||
(void)ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||||
ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -572,4 +571,80 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWei | |||||
} | } | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
/// | |||||
/// @brief Add input | |||||
/// @param [in] name | |||||
/// @return OpDescBuilder | |||||
/// | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) { | |||||
inputs_.emplace_back(name); | |||||
return *this; | |||||
} | |||||
/// | |||||
/// @brief Add dynamic input | |||||
/// @param [in] name | |||||
/// @param [in] num | |||||
/// @return OpDescBuilder | |||||
/// | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, | |||||
uint32_t num) { | |||||
for (uint32_t i = 0; i < num; i++) { | |||||
inputs_.emplace_back(name + std::to_string(i)); | |||||
} | |||||
return *this; | |||||
} | |||||
/// | |||||
/// @brief Add output | |||||
/// @param [in] name | |||||
/// @return OpDescBuilder | |||||
/// | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) { | |||||
outputs_.emplace_back(name); | |||||
return *this; | |||||
} | |||||
/// | |||||
/// @brief Add dynamic output | |||||
/// @param [in] name | |||||
/// @param [in] num | |||||
/// @return OpDescBuilder | |||||
/// | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, | |||||
uint32_t num) { | |||||
for (uint32_t i = 0; i < num; i++) { | |||||
outputs_.emplace_back(name + std::to_string(i)); | |||||
} | |||||
return *this; | |||||
} | |||||
/// | |||||
/// @brief Build op_desc | |||||
/// @return OpDescPtr | |||||
/// | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() { | |||||
OpDescPtr op_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name_, type_)); | |||||
if (op_desc == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "OpDesc is nullptr"); | |||||
return nullptr; | |||||
} | |||||
for (auto &input : inputs_) { | |||||
if (op_desc->AddInputDesc(input, GeTensorDesc()) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Add input_desc failed."); | |||||
return nullptr; | |||||
} | |||||
} | |||||
for (auto &output : outputs_) { | |||||
if (op_desc->AddOutputDesc(output, GeTensorDesc()) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Add output_desc failed."); | |||||
return nullptr; | |||||
} | |||||
} | |||||
return op_desc; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -15,7 +15,6 @@ | |||||
*/ | */ | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include <cmath> | #include <cmath> | ||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
@@ -276,6 +275,14 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||||
break; | break; | ||||
case FORMAT_FRACTAL_NZ: | case FORMAT_FRACTAL_NZ: | ||||
case FORMAT_FRACTAL_ZZ: | case FORMAT_FRACTAL_ZZ: | ||||
case FORMAT_NDHWC: | |||||
case FORMAT_NCDHW: | |||||
case FORMAT_DHWCN: | |||||
case FORMAT_DHWNC: | |||||
case FORMAT_FRACTAL_Z_3D: | |||||
case FORMAT_FRACTAL_Z_3D_TRANSPOSE: | |||||
case FORMAT_NDC1HWC0: | |||||
case FORMAT_FRACTAL_Z_C04: | |||||
graph_status = CalcElementCntByDims(dims, element_cnt); | graph_status = CalcElementCntByDims(dims, element_cnt); | ||||
break; | break; | ||||
default: | default: | ||||
@@ -351,21 +358,21 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTens | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | ||||
TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp) { | |||||
TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { | |||||
graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp); | graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp); | ||||
if (graph_status != GRAPH_SUCCESS) { | if (graph_status != GRAPH_SUCCESS) { | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
// 64-byte alignment, if size is 0, align to 32 bytes | // 64-byte alignment, if size is 0, align to 32 bytes | ||||
if (size_temp > (UINT32_MAX - kNum2 * kDataMemAlignSize)) { | |||||
GELOGW("The updated mem size %u is bigger than UINT32_MAX", size_temp); | |||||
if (size_temp > (INT64_MAX - kNum2 * kDataMemAlignSize)) { | |||||
GELOGW("The updated mem size %ld is bigger than INT64_MAX", size_temp); | |||||
} else { | } else { | ||||
size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize; | size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize; | ||||
} | } | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | ||||
TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp) { | |||||
TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { | |||||
GeShape output_shape = desc_temp.GetShape(); | GeShape output_shape = desc_temp.GetShape(); | ||||
Format format = desc_temp.GetFormat(); | Format format = desc_temp.GetFormat(); | ||||
DataType data_type = desc_temp.GetDataType(); | DataType data_type = desc_temp.GetDataType(); | ||||
@@ -376,13 +383,13 @@ TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_ | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
if ((output_mem_size > UINT32_MAX) || (output_mem_size < 0)) { | |||||
GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %u]", | |||||
output_mem_size, UINT32_MAX); | |||||
if (output_mem_size < 0) { | |||||
GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %ld]", | |||||
output_mem_size, INT64_MAX); | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
size_temp = static_cast<uint32_t>(output_mem_size); | |||||
size_temp = output_mem_size; | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -19,43 +19,45 @@ | |||||
namespace ge { | namespace ge { | ||||
static const std::map<Format, std::string> kFormatToStringMap = { | static const std::map<Format, std::string> kFormatToStringMap = { | ||||
{FORMAT_NCHW, "NCHW"}, | |||||
{FORMAT_NHWC, "NHWC"}, | |||||
{FORMAT_ND, "ND"}, | |||||
{FORMAT_NC1HWC0, "NC1HWC0"}, | |||||
{FORMAT_FRACTAL_Z, "FRACTAL_Z"}, | |||||
{FORMAT_NC1C0HWPAD, "NC1C0HWPAD"}, | |||||
{FORMAT_NHWC1C0, "NHWC1C0"}, | |||||
{FORMAT_FSR_NCHW, "FSR_NCHW"}, | |||||
{FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"}, | |||||
{FORMAT_C1HWNC0, "C1HWNC0"}, | |||||
{FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, | |||||
{FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, | |||||
{FORMAT_NC1HWC0_C04, "NC1HWC0_C04"}, | |||||
{FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, | |||||
{FORMAT_CHWN, "CHWN"}, | |||||
{FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, | |||||
{FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"}, | |||||
{FORMAT_BN_WEIGHT, "BN_WEIGHT"}, | |||||
{FORMAT_FILTER_HWCK, "FILTER_HWCK"}, | |||||
{FORMAT_HWCN, "HWCN"}, | |||||
{FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, | |||||
{FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, | |||||
{FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, | |||||
{FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, | |||||
{FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, | |||||
{FORMAT_MD, "MD"}, | |||||
{FORMAT_NDHWC, "NDHWC"}, | |||||
{FORMAT_NCDHW, "NCDHW"}, | |||||
{FORMAT_DHWCK, "DHWCK"}, | |||||
{FORMAT_NDC1HWC0, "NDC1HWC0"}, | |||||
{FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, | |||||
{FORMAT_C1HWNCoC0, "C1HWNCoC0"}, | |||||
{FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, | |||||
{FORMAT_CN, "CN"}, | |||||
{FORMAT_NC, "NC"}, | |||||
{FORMAT_RESERVED, "FORMAT_RESERVED"}, | |||||
{FORMAT_ALL, "ALL"}}; | |||||
{FORMAT_NCHW, "NCHW"}, | |||||
{FORMAT_NHWC, "NHWC"}, | |||||
{FORMAT_ND, "ND"}, | |||||
{FORMAT_NC1HWC0, "NC1HWC0"}, | |||||
{FORMAT_FRACTAL_Z, "FRACTAL_Z"}, | |||||
{FORMAT_NC1C0HWPAD, "NC1C0HWPAD"}, | |||||
{FORMAT_NHWC1C0, "NHWC1C0"}, | |||||
{FORMAT_FSR_NCHW, "FSR_NCHW"}, | |||||
{FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"}, | |||||
{FORMAT_C1HWNC0, "C1HWNC0"}, | |||||
{FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, | |||||
{FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, | |||||
{FORMAT_NC1HWC0_C04, "NC1HWC0_C04"}, | |||||
{FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, | |||||
{FORMAT_CHWN, "CHWN"}, | |||||
{FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, | |||||
{FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"}, | |||||
{FORMAT_BN_WEIGHT, "BN_WEIGHT"}, | |||||
{FORMAT_FILTER_HWCK, "FILTER_HWCK"}, | |||||
{FORMAT_HWCN, "HWCN"}, | |||||
{FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, | |||||
{FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, | |||||
{FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, | |||||
{FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, | |||||
{FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, | |||||
{FORMAT_MD, "MD"}, | |||||
{FORMAT_NDHWC, "NDHWC"}, | |||||
{FORMAT_NCDHW, "NCDHW"}, | |||||
{FORMAT_DHWCN, "DHWCN"}, | |||||
{FORMAT_DHWNC, "DHWNC"}, | |||||
{FORMAT_NDC1HWC0, "NDC1HWC0"}, | |||||
{FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, | |||||
{FORMAT_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"}, | |||||
{FORMAT_C1HWNCoC0, "C1HWNCoC0"}, | |||||
{FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, | |||||
{FORMAT_CN, "CN"}, | |||||
{FORMAT_NC, "NC"}, | |||||
{FORMAT_RESERVED, "FORMAT_RESERVED"}, | |||||
{FORMAT_ALL, "ALL"}}; | |||||
static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | ||||
"FRACTAL_Z", | "FRACTAL_Z", | ||||
@@ -73,137 +75,140 @@ static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | |||||
"FRACTAL_ZZ", | "FRACTAL_ZZ", | ||||
"FRACTAL_NZ", | "FRACTAL_NZ", | ||||
"NDC1HWC0", | "NDC1HWC0", | ||||
"FORMAT_FRACTAL_Z_3D"}; | |||||
"FORMAT_FRACTAL_Z_3D", | |||||
"FORMAT_FRACTAL_Z_3D_TRANSPOSE"}; | |||||
static const std::map<std::string, Format> kDataFormatMap = { | static const std::map<std::string, Format> kDataFormatMap = { | ||||
{"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"ND", FORMAT_ND}}; | |||||
{"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}}; | |||||
static const std::map<std::string, Format> kStringToFormatMap = { | static const std::map<std::string, Format> kStringToFormatMap = { | ||||
{"NCHW", FORMAT_NCHW}, | |||||
{"NHWC", FORMAT_NHWC}, | |||||
{"ND", FORMAT_ND}, | |||||
{"NC1HWC0", FORMAT_NC1HWC0}, | |||||
{"FRACTAL_Z", FORMAT_FRACTAL_Z}, | |||||
{"NC1C0HWPAD", FORMAT_NC1C0HWPAD}, | |||||
{"NHWC1C0", FORMAT_NHWC1C0}, | |||||
{"FSR_NCHW", FORMAT_FSR_NCHW}, | |||||
{"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV}, | |||||
{"C1HWNC0", FORMAT_C1HWNC0}, | |||||
{"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE}, | |||||
{"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, | |||||
{"NC1HWC0_C04", FORMAT_NC1HWC0_C04}, | |||||
{"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04}, | |||||
{"CHWN", FORMAT_CHWN}, | |||||
{"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, | |||||
{"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0}, | |||||
{"BN_WEIGHT", FORMAT_BN_WEIGHT}, | |||||
{"FILTER_HWCK", FORMAT_FILTER_HWCK}, | |||||
{"HWCN", FORMAT_HWCN}, | |||||
{"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, | |||||
{"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS}, | |||||
{"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE}, | |||||
{"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, | |||||
{"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS}, | |||||
{"MD", FORMAT_MD}, | |||||
{"C1HWNCoC0", FORMAT_C1HWNCoC0}, | |||||
{"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, | |||||
{"NDHWC", FORMAT_NDHWC}, | |||||
{"NCDHW", FORMAT_NCDHW}, | |||||
{"DHWCK", FORMAT_DHWCK}, | |||||
{"NDC1HWC0", FORMAT_NDC1HWC0}, | |||||
{"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D}, | |||||
{"CN", FORMAT_CN}, | |||||
{"NC", FORMAT_NC}, | |||||
{"FORMAT_RESERVED", FORMAT_RESERVED}, | |||||
{"ALL", FORMAT_ALL}}; | |||||
{"NCHW", FORMAT_NCHW}, | |||||
{"NHWC", FORMAT_NHWC}, | |||||
{"ND", FORMAT_ND}, | |||||
{"NC1HWC0", FORMAT_NC1HWC0}, | |||||
{"FRACTAL_Z", FORMAT_FRACTAL_Z}, | |||||
{"NC1C0HWPAD", FORMAT_NC1C0HWPAD}, | |||||
{"NHWC1C0", FORMAT_NHWC1C0}, | |||||
{"FSR_NCHW", FORMAT_FSR_NCHW}, | |||||
{"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV}, | |||||
{"C1HWNC0", FORMAT_C1HWNC0}, | |||||
{"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE}, | |||||
{"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, | |||||
{"NC1HWC0_C04", FORMAT_NC1HWC0_C04}, | |||||
{"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04}, | |||||
{"CHWN", FORMAT_CHWN}, | |||||
{"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, | |||||
{"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0}, | |||||
{"BN_WEIGHT", FORMAT_BN_WEIGHT}, | |||||
{"FILTER_HWCK", FORMAT_FILTER_HWCK}, | |||||
{"HWCN", FORMAT_HWCN}, | |||||
{"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, | |||||
{"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS}, | |||||
{"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE}, | |||||
{"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, | |||||
{"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS}, | |||||
{"MD", FORMAT_MD}, | |||||
{"C1HWNCoC0", FORMAT_C1HWNCoC0}, | |||||
{"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, | |||||
{"NDHWC", FORMAT_NDHWC}, | |||||
{"NCDHW", FORMAT_NCDHW}, | |||||
{"DHWCN", FORMAT_DHWCN}, | |||||
{"DHWNC", FORMAT_DHWNC}, | |||||
{"NDC1HWC0", FORMAT_NDC1HWC0}, | |||||
{"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D}, | |||||
{"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, | |||||
{"CN", FORMAT_CN}, | |||||
{"NC", FORMAT_NC}, | |||||
{"FORMAT_RESERVED", FORMAT_RESERVED}, | |||||
{"ALL", FORMAT_ALL}}; | |||||
static const std::map<DataType, std::string> kDataTypeToStringMap = { | static const std::map<DataType, std::string> kDataTypeToStringMap = { | ||||
{DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. | |||||
{DT_FLOAT, "DT_FLOAT"}, // float type | |||||
{DT_FLOAT16, "DT_FLOAT16"}, // fp16 type | |||||
{DT_INT8, "DT_INT8"}, // int8 type | |||||
{DT_INT16, "DT_INT16"}, // int16 type | |||||
{DT_UINT16, "DT_UINT16"}, // uint16 type | |||||
{DT_UINT8, "DT_UINT8"}, // uint8 type | |||||
{DT_INT32, "DT_INT32"}, // uint32 type | |||||
{DT_INT64, "DT_INT64"}, // int64 type | |||||
{DT_UINT32, "DT_UINT32"}, // unsigned int32 | |||||
{DT_UINT64, "DT_UINT64"}, // unsigned int64 | |||||
{DT_BOOL, "DT_BOOL"}, // bool type | |||||
{DT_DOUBLE, "DT_DOUBLE"}, // double type | |||||
{DT_DUAL, "DT_DUAL"}, // dual output type | |||||
{DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type | |||||
{DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type | |||||
{DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type | |||||
{DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type | |||||
{DT_QINT8, "DT_QINT8"}, // qint8 type | |||||
{DT_QINT16, "DT_QINT16"}, // qint16 type | |||||
{DT_QINT32, "DT_QINT32"}, // qint32 type | |||||
{DT_QUINT8, "DT_QUINT8"}, // quint8 type | |||||
{DT_QUINT16, "DT_QUINT16"}, // quint16 type | |||||
{DT_RESOURCE, "DT_RESOURCE"}, // resource type | |||||
{DT_STRING_REF, "DT_STRING_REF"}, // string ref type | |||||
{DT_STRING, "DT_STRING"}, // string type | |||||
{DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. | |||||
{DT_FLOAT, "DT_FLOAT"}, // float type | |||||
{DT_FLOAT16, "DT_FLOAT16"}, // fp16 type | |||||
{DT_INT8, "DT_INT8"}, // int8 type | |||||
{DT_INT16, "DT_INT16"}, // int16 type | |||||
{DT_UINT16, "DT_UINT16"}, // uint16 type | |||||
{DT_UINT8, "DT_UINT8"}, // uint8 type | |||||
{DT_INT32, "DT_INT32"}, // uint32 type | |||||
{DT_INT64, "DT_INT64"}, // int64 type | |||||
{DT_UINT32, "DT_UINT32"}, // unsigned int32 | |||||
{DT_UINT64, "DT_UINT64"}, // unsigned int64 | |||||
{DT_BOOL, "DT_BOOL"}, // bool type | |||||
{DT_DOUBLE, "DT_DOUBLE"}, // double type | |||||
{DT_DUAL, "DT_DUAL"}, // dual output type | |||||
{DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type | |||||
{DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type | |||||
{DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type | |||||
{DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type | |||||
{DT_QINT8, "DT_QINT8"}, // qint8 type | |||||
{DT_QINT16, "DT_QINT16"}, // qint16 type | |||||
{DT_QINT32, "DT_QINT32"}, // qint32 type | |||||
{DT_QUINT8, "DT_QUINT8"}, // quint8 type | |||||
{DT_QUINT16, "DT_QUINT16"}, // quint16 type | |||||
{DT_RESOURCE, "DT_RESOURCE"}, // resource type | |||||
{DT_STRING_REF, "DT_STRING_REF"}, // string ref type | |||||
{DT_STRING, "DT_STRING"}, // string type | |||||
}; | }; | ||||
static const std::map<std::string, DataType> kStringTodataTypeMap = { | static const std::map<std::string, DataType> kStringTodataTypeMap = { | ||||
{"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set. | |||||
{"DT_FLOAT", DT_FLOAT}, // float type | |||||
{ | |||||
"DT_FLOAT16", | |||||
DT_FLOAT16, | |||||
}, // fp16 type | |||||
{"DT_INT8", DT_INT8}, // int8 type | |||||
{"DT_INT16", DT_INT16}, // int16 type | |||||
{"DT_UINT16", DT_UINT16}, // uint16 type | |||||
{"DT_UINT8", DT_UINT8}, // uint8 type | |||||
{"DT_INT32", DT_INT32}, // uint32 type | |||||
{"DT_INT64", DT_INT64}, // int64 type | |||||
{"DT_UINT32", DT_UINT32}, // unsigned int32 | |||||
{"DT_UINT64", DT_UINT64}, // unsigned int64 | |||||
{"DT_BOOL", DT_BOOL}, // bool type | |||||
{"DT_DOUBLE", DT_DOUBLE}, // double type | |||||
{"DT_DUAL", DT_DUAL}, // dual output type | |||||
{"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type | |||||
{"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type | |||||
{"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type | |||||
{"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type | |||||
{"DT_QINT8", DT_QINT8}, // qint8 type | |||||
{"DT_QINT16", DT_QINT16}, // qint16 type | |||||
{"DT_QINT32", DT_QINT32}, // qint32 type | |||||
{"DT_QUINT8", DT_QUINT8}, // quint8 type | |||||
{"DT_QUINT16", DT_QUINT16}, // quint16 type | |||||
{"DT_RESOURCE", DT_RESOURCE}, // resource type | |||||
{"DT_STRING_REF", DT_STRING_REF}, // string ref type | |||||
{"DT_STRING", DT_STRING}, // string type | |||||
{"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set. | |||||
{"DT_FLOAT", DT_FLOAT}, // float type | |||||
{ | |||||
"DT_FLOAT16", | |||||
DT_FLOAT16, | |||||
}, // fp16 type | |||||
{"DT_INT8", DT_INT8}, // int8 type | |||||
{"DT_INT16", DT_INT16}, // int16 type | |||||
{"DT_UINT16", DT_UINT16}, // uint16 type | |||||
{"DT_UINT8", DT_UINT8}, // uint8 type | |||||
{"DT_INT32", DT_INT32}, // uint32 type | |||||
{"DT_INT64", DT_INT64}, // int64 type | |||||
{"DT_UINT32", DT_UINT32}, // unsigned int32 | |||||
{"DT_UINT64", DT_UINT64}, // unsigned int64 | |||||
{"DT_BOOL", DT_BOOL}, // bool type | |||||
{"DT_DOUBLE", DT_DOUBLE}, // double type | |||||
{"DT_DUAL", DT_DUAL}, // dual output type | |||||
{"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type | |||||
{"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type | |||||
{"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type | |||||
{"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type | |||||
{"DT_QINT8", DT_QINT8}, // qint8 type | |||||
{"DT_QINT16", DT_QINT16}, // qint16 type | |||||
{"DT_QINT32", DT_QINT32}, // qint32 type | |||||
{"DT_QUINT8", DT_QUINT8}, // quint8 type | |||||
{"DT_QUINT16", DT_QUINT16}, // quint16 type | |||||
{"DT_RESOURCE", DT_RESOURCE}, // resource type | |||||
{"DT_STRING_REF", DT_STRING_REF}, // string ref type | |||||
{"DT_STRING", DT_STRING}, // string type | |||||
}; | }; | ||||
static const std::map<ge::DataType, uint32_t> kDataTypeToLength = { | static const std::map<ge::DataType, uint32_t> kDataTypeToLength = { | ||||
{DT_BOOL, sizeof(bool)}, | |||||
{DT_INT64, sizeof(int64_t)}, | |||||
{DT_UINT64, sizeof(int64_t)}, | |||||
{DT_FLOAT, sizeof(float)}, | |||||
{DT_INT32, sizeof(int32_t)}, | |||||
{DT_UINT32, sizeof(int32_t)}, | |||||
{DT_INT8, sizeof(char)}, | |||||
{DT_UINT8, sizeof(char)}, | |||||
{DT_INT16, sizeof(int16_t)}, | |||||
{DT_UINT16, sizeof(int16_t)}, | |||||
{DT_FLOAT16, sizeof(int16_t)}, | |||||
{DT_DOUBLE, sizeof(double)}, | |||||
{DT_DUAL, sizeof(float) + sizeof(int8_t)}, | |||||
{DT_DUAL_SUB_INT8, sizeof(int8_t)}, | |||||
{DT_DUAL_SUB_UINT8, sizeof(uint8_t)}, | |||||
{DT_COMPLEX64, sizeof(int64_t)}, | |||||
{DT_COMPLEX128, sizeof(int64_t) * 2}, | |||||
{DT_QINT8, sizeof(int8_t)}, | |||||
{DT_QINT16, sizeof(int16_t)}, | |||||
{DT_QINT32, sizeof(int32_t)}, | |||||
{DT_QUINT8, sizeof(uint8_t)}, | |||||
{DT_QUINT16, sizeof(uint16_t)}, | |||||
{DT_STRING_REF, sizeof(uint64_t) * 2}, | |||||
{DT_STRING, sizeof(uint64_t)}, | |||||
{DT_RESOURCE, sizeof(uint64_t)}, | |||||
{DT_BOOL, sizeof(bool)}, | |||||
{DT_INT64, sizeof(int64_t)}, | |||||
{DT_UINT64, sizeof(int64_t)}, | |||||
{DT_FLOAT, sizeof(float)}, | |||||
{DT_INT32, sizeof(int32_t)}, | |||||
{DT_UINT32, sizeof(int32_t)}, | |||||
{DT_INT8, sizeof(char)}, | |||||
{DT_UINT8, sizeof(char)}, | |||||
{DT_INT16, sizeof(int16_t)}, | |||||
{DT_UINT16, sizeof(int16_t)}, | |||||
{DT_FLOAT16, sizeof(int16_t)}, | |||||
{DT_DOUBLE, sizeof(double)}, | |||||
{DT_DUAL, sizeof(float) + sizeof(int8_t)}, | |||||
{DT_DUAL_SUB_INT8, sizeof(int8_t)}, | |||||
{DT_DUAL_SUB_UINT8, sizeof(uint8_t)}, | |||||
{DT_COMPLEX64, sizeof(int64_t)}, | |||||
{DT_COMPLEX128, sizeof(int64_t) * 2}, | |||||
{DT_QINT8, sizeof(int8_t)}, | |||||
{DT_QINT16, sizeof(int16_t)}, | |||||
{DT_QINT32, sizeof(int32_t)}, | |||||
{DT_QUINT8, sizeof(uint8_t)}, | |||||
{DT_QUINT16, sizeof(uint16_t)}, | |||||
{DT_STRING_REF, sizeof(uint64_t) * 2}, | |||||
{DT_STRING, sizeof(uint64_t)}, | |||||
{DT_RESOURCE, sizeof(uint64_t)}, | |||||
}; | }; | ||||
bool TypeUtils::IsDataTypeValid(DataType dt) { | bool TypeUtils::IsDataTypeValid(DataType dt) { | ||||
@@ -13,7 +13,7 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
# ============================================================================ | # ============================================================================ | ||||
# libge.so & libge_train.so | |||||
# libge_compiler.so & libge_train.so | |||||
# will later be integrated into libgraph_runner.so, works for both training and inference | # will later be integrated into libgraph_runner.so, works for both training and inference | ||||
# compiling proto files generates some warnings, use no-unused-variable to suppress them | # compiling proto files generates some warnings, use no-unused-variable to suppress them | ||||
set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | ||||
@@ -49,7 +49,7 @@ include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||||
######### libge_train.so ############# | ######### libge_train.so ############# | ||||
# need to remove dependencies on pb files later | # need to remove dependencies on pb files later | ||||
file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"common/formats/format_transfers/*.cc" | "common/formats/format_transfers/*.cc" | ||||
"common/formats/formats.cc" | "common/formats/formats.cc" | ||||
"common/formats/utils/formats_trans_utils.cc" | "common/formats/utils/formats_trans_utils.cc" | ||||
@@ -57,20 +57,24 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"common/ge/plugin_manager.cc" | "common/ge/plugin_manager.cc" | ||||
"common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
"engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
"ge_local_engine/engine/host_cpu_engine.cc" | |||||
"generator/ge_generator.cc" | "generator/ge_generator.cc" | ||||
"generator/generator_api.cc" | "generator/generator_api.cc" | ||||
"graph/build/graph_build.cc" | |||||
"graph/build/graph_builder.cc" | |||||
"graph/build/label_allocator.cc" | |||||
"graph/build/logical_stream_allocator.cc" | "graph/build/logical_stream_allocator.cc" | ||||
"graph/build/model_builder.cc" | "graph/build/model_builder.cc" | ||||
"graph/build/optimize_stream_graph.cc" | |||||
"graph/build/run_context.cc" | "graph/build/run_context.cc" | ||||
"graph/build/stream_allocator.cc" | "graph/build/stream_allocator.cc" | ||||
"graph/build/stream_graph_optimizer.cc" | |||||
"graph/build/task_generator.cc" | "graph/build/task_generator.cc" | ||||
"graph/common/bcast.cc" | "graph/common/bcast.cc" | ||||
"graph/common/omg_util.cc" | "graph/common/omg_util.cc" | ||||
"graph/common/transop_util.cc" | "graph/common/transop_util.cc" | ||||
"graph/execute/graph_execute.cc" | "graph/execute/graph_execute.cc" | ||||
"graph/label/*.cc" | |||||
"graph/load/graph_loader.cc" | "graph/load/graph_loader.cc" | ||||
"graph/load/new_model_manager/cpu_queue_schedule.cc" | |||||
"graph/load/new_model_manager/data_dumper.cc" | "graph/load/new_model_manager/data_dumper.cc" | ||||
"graph/load/new_model_manager/data_inputer.cc" | "graph/load/new_model_manager/data_inputer.cc" | ||||
"graph/load/new_model_manager/davinci_model.cc" | "graph/load/new_model_manager/davinci_model.cc" | ||||
@@ -92,10 +96,12 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||||
"graph/load/new_model_manager/task_info/task_info.cc" | "graph/load/new_model_manager/task_info/task_info.cc" | ||||
"graph/load/new_model_manager/tbe_handle_store.cc" | "graph/load/new_model_manager/tbe_handle_store.cc" | ||||
"graph/load/output/output.cc" | "graph/load/output/output.cc" | ||||
"graph/manager/custom/custom_op.cc" | |||||
"graph/manager/graph_context.cc" | "graph/manager/graph_context.cc" | ||||
"graph/manager/graph_manager.cc" | "graph/manager/graph_manager.cc" | ||||
"graph/manager/graph_manager_utils.cc" | "graph/manager/graph_manager_utils.cc" | ||||
@@ -105,12 +111,9 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/manager/trans_var_data_utils.cc" | "graph/manager/trans_var_data_utils.cc" | ||||
"graph/manager/util/debug.cc" | "graph/manager/util/debug.cc" | ||||
"graph/manager/util/hcom_util.cc" | "graph/manager/util/hcom_util.cc" | ||||
"graph/manager/util/node_searcher/need_rebuild_node_searcher.cc" | |||||
"graph/manager/util/rt_context_util.cc" | "graph/manager/util/rt_context_util.cc" | ||||
"graph/manager/util/variable_accelerate_ctrl.cc" | "graph/manager/util/variable_accelerate_ctrl.cc" | ||||
"graph/optimize/graph_functiondef.cc" | |||||
"graph/optimize/graph_optimize.cc" | "graph/optimize/graph_optimize.cc" | ||||
"graph/optimize/graph_optimizer.cc" | |||||
"graph/optimize/optimizer/allreduce_fusion_pass.cc" | "graph/optimize/optimizer/allreduce_fusion_pass.cc" | ||||
"graph/optimize/summary_optimize.cc" | "graph/optimize/summary_optimize.cc" | ||||
"graph/partition/engine_place.cc" | "graph/partition/engine_place.cc" | ||||
@@ -120,7 +123,9 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/passes/assert_pass.cc" | "graph/passes/assert_pass.cc" | ||||
"graph/passes/atomic_addr_clean_pass.cc" | "graph/passes/atomic_addr_clean_pass.cc" | ||||
"graph/passes/base_pass.cc" | "graph/passes/base_pass.cc" | ||||
"graph/passes/cast_remove_pass.cc" | |||||
"graph/passes/cast_translate_pass.cc" | "graph/passes/cast_translate_pass.cc" | ||||
"graph/passes/common_subexpression_elimination_pass.cc" | |||||
"graph/passes/compile_nodes_pass.cc" | "graph/passes/compile_nodes_pass.cc" | ||||
"graph/passes/constant_folding_pass.cc" | "graph/passes/constant_folding_pass.cc" | ||||
"graph/passes/constant_fuse_same_pass.cc" | "graph/passes/constant_fuse_same_pass.cc" | ||||
@@ -159,12 +164,14 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/passes/folding_kernel/shape_kernel.cc" | "graph/passes/folding_kernel/shape_kernel.cc" | ||||
"graph/passes/folding_kernel/shape_n_kernel.cc" | "graph/passes/folding_kernel/shape_n_kernel.cc" | ||||
"graph/passes/folding_kernel/size_kernel.cc" | "graph/passes/folding_kernel/size_kernel.cc" | ||||
"graph/passes/folding_kernel/slice_d_kernel.cc" | |||||
"graph/passes/folding_kernel/slice_kernel.cc" | "graph/passes/folding_kernel/slice_kernel.cc" | ||||
"graph/passes/folding_kernel/squeeze_kernel.cc" | "graph/passes/folding_kernel/squeeze_kernel.cc" | ||||
"graph/passes/folding_kernel/ssd_prior_box_kernel.cc" | "graph/passes/folding_kernel/ssd_prior_box_kernel.cc" | ||||
"graph/passes/folding_kernel/strided_slice_kernel.cc" | "graph/passes/folding_kernel/strided_slice_kernel.cc" | ||||
"graph/passes/folding_kernel/sub_kernel.cc" | "graph/passes/folding_kernel/sub_kernel.cc" | ||||
"graph/passes/folding_kernel/transdata_kernel.cc" | "graph/passes/folding_kernel/transdata_kernel.cc" | ||||
"graph/passes/folding_kernel/unpack_kernel.cc" | |||||
"graph/passes/folding_pass.cc" | "graph/passes/folding_pass.cc" | ||||
"graph/passes/get_original_format_pass.cc" | "graph/passes/get_original_format_pass.cc" | ||||
"graph/passes/guarantee_const_pass.cc" | "graph/passes/guarantee_const_pass.cc" | ||||
@@ -179,7 +186,6 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/passes/multi_batch_pass.cc" | "graph/passes/multi_batch_pass.cc" | ||||
"graph/passes/net_output_pass.cc" | "graph/passes/net_output_pass.cc" | ||||
"graph/passes/next_iteration_pass.cc" | "graph/passes/next_iteration_pass.cc" | ||||
"graph/passes/no_reshape_op_remove_pass.cc" | |||||
"graph/passes/no_use_reshape_remove_pass.cc" | "graph/passes/no_use_reshape_remove_pass.cc" | ||||
"graph/passes/pass_manager.cc" | "graph/passes/pass_manager.cc" | ||||
"graph/passes/pass_utils.cc" | "graph/passes/pass_utils.cc" | ||||
@@ -188,6 +194,7 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/passes/prevent_gradient_pass.cc" | "graph/passes/prevent_gradient_pass.cc" | ||||
"graph/passes/print_op_pass.cc" | "graph/passes/print_op_pass.cc" | ||||
"graph/passes/prune_pass.cc" | "graph/passes/prune_pass.cc" | ||||
"graph/passes/replace_with_empty_const_pass.cc" | |||||
"graph/passes/reshape_remove_pass.cc" | "graph/passes/reshape_remove_pass.cc" | ||||
"graph/passes/resource_pair_add_control_pass.cc" | "graph/passes/resource_pair_add_control_pass.cc" | ||||
"graph/passes/resource_pair_remove_control_pass.cc" | "graph/passes/resource_pair_remove_control_pass.cc" | ||||
@@ -206,14 +213,12 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/passes/transpose_transdata_pass.cc" | "graph/passes/transpose_transdata_pass.cc" | ||||
"graph/passes/unused_const_pass.cc" | "graph/passes/unused_const_pass.cc" | ||||
"graph/passes/unused_op_remove_pass.cc" | "graph/passes/unused_op_remove_pass.cc" | ||||
"graph/passes/update_net_output_pass.cc" | |||||
"graph/passes/var_is_initialized_op_pass.cc" | "graph/passes/var_is_initialized_op_pass.cc" | ||||
"graph/passes/variable_format_pass.cc" | "graph/passes/variable_format_pass.cc" | ||||
"graph/passes/variable_op_pass.cc" | "graph/passes/variable_op_pass.cc" | ||||
"graph/passes/variable_prepare_op_pass.cc" | "graph/passes/variable_prepare_op_pass.cc" | ||||
"graph/passes/variable_ref_delete_op_pass.cc" | "graph/passes/variable_ref_delete_op_pass.cc" | ||||
"graph/preprocess/graph_preprocess.cc" | "graph/preprocess/graph_preprocess.cc" | ||||
"graph/preprocess/insert_op/base_insert_op.cc" | |||||
"graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
"graph/preprocess/insert_op/util_insert_aipp_op.cc" | "graph/preprocess/insert_op/util_insert_aipp_op.cc" | ||||
"graph/preprocess/multi_batch_copy_graph.cc" | "graph/preprocess/multi_batch_copy_graph.cc" | ||||
@@ -223,13 +228,8 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"opskernel_manager/ops_kernel_manager.cc" | "opskernel_manager/ops_kernel_manager.cc" | ||||
"session/inner_session.cc" | "session/inner_session.cc" | ||||
"session/session_manager.cc" | "session/session_manager.cc" | ||||
"single_op/single_op.cc" | |||||
"single_op/single_op_manager.cc" | |||||
"single_op/single_op_model.cc" | |||||
"single_op/stream_resource.cc" | |||||
"single_op/task/build_task_utils.cc" | |||||
"single_op/task/op_task.cc" | |||||
"single_op/task/tbe_task_builder.cc" | |||||
"single_op/*.cc" | |||||
"single_op/task/*.cc" | |||||
) | ) | ||||
@@ -261,9 +261,9 @@ target_link_libraries(ge_train | |||||
rt | rt | ||||
dl) | dl) | ||||
######### libge.so ############# | |||||
######### libge_compiler.so ############# | |||||
# need to remove dependencies on pb files later | # need to remove dependencies on pb files later | ||||
file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"common/formats/format_transfers/*.cc" | "common/formats/format_transfers/*.cc" | ||||
"common/formats/formats.cc" | "common/formats/formats.cc" | ||||
"common/formats/utils/formats_trans_utils.cc" | "common/formats/utils/formats_trans_utils.cc" | ||||
@@ -271,20 +271,24 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"common/ge/plugin_manager.cc" | "common/ge/plugin_manager.cc" | ||||
"common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
"engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
"ge_local_engine/engine/host_cpu_engine.cc" | |||||
"generator/ge_generator.cc" | "generator/ge_generator.cc" | ||||
"generator/generator_api.cc" | "generator/generator_api.cc" | ||||
"graph/build/graph_build.cc" | |||||
"graph/build/graph_builder.cc" | |||||
"graph/build/label_allocator.cc" | |||||
"graph/build/logical_stream_allocator.cc" | "graph/build/logical_stream_allocator.cc" | ||||
"graph/build/model_builder.cc" | "graph/build/model_builder.cc" | ||||
"graph/build/optimize_stream_graph.cc" | |||||
"graph/build/run_context.cc" | "graph/build/run_context.cc" | ||||
"graph/build/stream_allocator.cc" | "graph/build/stream_allocator.cc" | ||||
"graph/build/stream_graph_optimizer.cc" | |||||
"graph/build/task_generator.cc" | "graph/build/task_generator.cc" | ||||
"graph/common/bcast.cc" | "graph/common/bcast.cc" | ||||
"graph/common/omg_util.cc" | "graph/common/omg_util.cc" | ||||
"graph/common/transop_util.cc" | "graph/common/transop_util.cc" | ||||
"graph/execute/graph_execute.cc" | "graph/execute/graph_execute.cc" | ||||
"graph/label/*.cc" | |||||
"graph/load/graph_loader.cc" | "graph/load/graph_loader.cc" | ||||
"graph/load/new_model_manager/cpu_queue_schedule.cc" | |||||
"graph/load/new_model_manager/data_dumper.cc" | "graph/load/new_model_manager/data_dumper.cc" | ||||
"graph/load/new_model_manager/data_inputer.cc" | "graph/load/new_model_manager/data_inputer.cc" | ||||
"graph/load/new_model_manager/davinci_model.cc" | "graph/load/new_model_manager/davinci_model.cc" | ||||
@@ -305,10 +309,12 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||||
"graph/load/new_model_manager/task_info/task_info.cc" | "graph/load/new_model_manager/task_info/task_info.cc" | ||||
"graph/load/new_model_manager/tbe_handle_store.cc" | "graph/load/new_model_manager/tbe_handle_store.cc" | ||||
"graph/load/output/output.cc" | "graph/load/output/output.cc" | ||||
"graph/manager/custom/custom_op.cc" | |||||
"graph/manager/graph_context.cc" | "graph/manager/graph_context.cc" | ||||
"graph/manager/graph_manager.cc" | "graph/manager/graph_manager.cc" | ||||
"graph/manager/graph_manager_utils.cc" | "graph/manager/graph_manager_utils.cc" | ||||
@@ -317,13 +323,9 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/manager/model_manager/event_manager.cc" | "graph/manager/model_manager/event_manager.cc" | ||||
"graph/manager/trans_var_data_utils.cc" | "graph/manager/trans_var_data_utils.cc" | ||||
"graph/manager/util/debug.cc" | "graph/manager/util/debug.cc" | ||||
"graph/manager/util/node_searcher/need_rebuild_node_searcher.cc" | |||||
"graph/manager/util/rt_context_util.cc" | "graph/manager/util/rt_context_util.cc" | ||||
"graph/manager/util/variable_accelerate_ctrl.cc" | "graph/manager/util/variable_accelerate_ctrl.cc" | ||||
"graph/optimize/graph_functiondef.cc" | |||||
"graph/optimize/graph_optimize.cc" | "graph/optimize/graph_optimize.cc" | ||||
"graph/optimize/graph_optimizer.cc" | |||||
"graph/optimize/optimizer/allreduce_fusion_inference_pass.cc" | |||||
"graph/optimize/summary_optimize.cc" | "graph/optimize/summary_optimize.cc" | ||||
"graph/partition/engine_place.cc" | "graph/partition/engine_place.cc" | ||||
"graph/partition/graph_partition.cc" | "graph/partition/graph_partition.cc" | ||||
@@ -332,7 +334,9 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/passes/assert_pass.cc" | "graph/passes/assert_pass.cc" | ||||
"graph/passes/atomic_addr_clean_pass.cc" | "graph/passes/atomic_addr_clean_pass.cc" | ||||
"graph/passes/base_pass.cc" | "graph/passes/base_pass.cc" | ||||
"graph/passes/cast_remove_pass.cc" | |||||
"graph/passes/cast_translate_pass.cc" | "graph/passes/cast_translate_pass.cc" | ||||
"graph/passes/common_subexpression_elimination_pass.cc" | |||||
"graph/passes/compile_nodes_pass.cc" | "graph/passes/compile_nodes_pass.cc" | ||||
"graph/passes/constant_folding_pass.cc" | "graph/passes/constant_folding_pass.cc" | ||||
"graph/passes/constant_fuse_same_pass.cc" | "graph/passes/constant_fuse_same_pass.cc" | ||||
@@ -371,12 +375,14 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/passes/folding_kernel/shape_kernel.cc" | "graph/passes/folding_kernel/shape_kernel.cc" | ||||
"graph/passes/folding_kernel/shape_n_kernel.cc" | "graph/passes/folding_kernel/shape_n_kernel.cc" | ||||
"graph/passes/folding_kernel/size_kernel.cc" | "graph/passes/folding_kernel/size_kernel.cc" | ||||
"graph/passes/folding_kernel/slice_d_kernel.cc" | |||||
"graph/passes/folding_kernel/slice_kernel.cc" | "graph/passes/folding_kernel/slice_kernel.cc" | ||||
"graph/passes/folding_kernel/squeeze_kernel.cc" | "graph/passes/folding_kernel/squeeze_kernel.cc" | ||||
"graph/passes/folding_kernel/ssd_prior_box_kernel.cc" | "graph/passes/folding_kernel/ssd_prior_box_kernel.cc" | ||||
"graph/passes/folding_kernel/strided_slice_kernel.cc" | "graph/passes/folding_kernel/strided_slice_kernel.cc" | ||||
"graph/passes/folding_kernel/sub_kernel.cc" | "graph/passes/folding_kernel/sub_kernel.cc" | ||||
"graph/passes/folding_kernel/transdata_kernel.cc" | "graph/passes/folding_kernel/transdata_kernel.cc" | ||||
"graph/passes/folding_kernel/unpack_kernel.cc" | |||||
"graph/passes/folding_pass.cc" | "graph/passes/folding_pass.cc" | ||||
"graph/passes/get_original_format_pass.cc" | "graph/passes/get_original_format_pass.cc" | ||||
"graph/passes/guarantee_const_pass.cc" | "graph/passes/guarantee_const_pass.cc" | ||||
@@ -391,7 +397,6 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/passes/multi_batch_pass.cc" | "graph/passes/multi_batch_pass.cc" | ||||
"graph/passes/net_output_pass.cc" | "graph/passes/net_output_pass.cc" | ||||
"graph/passes/next_iteration_pass.cc" | "graph/passes/next_iteration_pass.cc" | ||||
"graph/passes/no_reshape_op_remove_pass.cc" | |||||
"graph/passes/no_use_reshape_remove_pass.cc" | "graph/passes/no_use_reshape_remove_pass.cc" | ||||
"graph/passes/pass_manager.cc" | "graph/passes/pass_manager.cc" | ||||
"graph/passes/pass_utils.cc" | "graph/passes/pass_utils.cc" | ||||
@@ -400,6 +405,7 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/passes/prevent_gradient_pass.cc" | "graph/passes/prevent_gradient_pass.cc" | ||||
"graph/passes/print_op_pass.cc" | "graph/passes/print_op_pass.cc" | ||||
"graph/passes/prune_pass.cc" | "graph/passes/prune_pass.cc" | ||||
"graph/passes/replace_with_empty_const_pass.cc" | |||||
"graph/passes/reshape_remove_pass.cc" | "graph/passes/reshape_remove_pass.cc" | ||||
"graph/passes/resource_pair_add_control_pass.cc" | "graph/passes/resource_pair_add_control_pass.cc" | ||||
"graph/passes/resource_pair_remove_control_pass.cc" | "graph/passes/resource_pair_remove_control_pass.cc" | ||||
@@ -418,14 +424,12 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/passes/transpose_transdata_pass.cc" | "graph/passes/transpose_transdata_pass.cc" | ||||
"graph/passes/unused_const_pass.cc" | "graph/passes/unused_const_pass.cc" | ||||
"graph/passes/unused_op_remove_pass.cc" | "graph/passes/unused_op_remove_pass.cc" | ||||
"graph/passes/update_net_output_pass.cc" | |||||
"graph/passes/var_is_initialized_op_pass.cc" | "graph/passes/var_is_initialized_op_pass.cc" | ||||
"graph/passes/variable_format_pass.cc" | "graph/passes/variable_format_pass.cc" | ||||
"graph/passes/variable_op_pass.cc" | "graph/passes/variable_op_pass.cc" | ||||
"graph/passes/variable_prepare_op_pass.cc" | "graph/passes/variable_prepare_op_pass.cc" | ||||
"graph/passes/variable_ref_delete_op_pass.cc" | "graph/passes/variable_ref_delete_op_pass.cc" | ||||
"graph/preprocess/graph_preprocess.cc" | "graph/preprocess/graph_preprocess.cc" | ||||
"graph/preprocess/insert_op/base_insert_op.cc" | |||||
"graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
"graph/preprocess/insert_op/util_insert_aipp_op.cc" | "graph/preprocess/insert_op/util_insert_aipp_op.cc" | ||||
"graph/preprocess/multi_batch_copy_graph.cc" | "graph/preprocess/multi_batch_copy_graph.cc" | ||||
@@ -442,16 +446,19 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"single_op/task/build_task_utils.cc" | "single_op/task/build_task_utils.cc" | ||||
"single_op/task/op_task.cc" | "single_op/task/op_task.cc" | ||||
"single_op/task/tbe_task_builder.cc" | "single_op/task/tbe_task_builder.cc" | ||||
########################################## | |||||
# "ir_build/ge_ir_build.cc" | |||||
# "offline/atc_ir_common.cc" | |||||
) | ) | ||||
add_library(ge SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||||
target_compile_definitions(ge PRIVATE | |||||
add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||||
target_compile_definitions(ge_compiler PRIVATE | |||||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | PROTOBUF_INLINE_NOT_IN_HEADERS=0 | ||||
DAVINCI_SUPPORT_PROFILING | DAVINCI_SUPPORT_PROFILING | ||||
REUSE_MEMORY=1 | REUSE_MEMORY=1 | ||||
FMK_HOST_INFER | FMK_HOST_INFER | ||||
PLATFORM_CLOUD) | PLATFORM_CLOUD) | ||||
target_link_libraries(ge | |||||
target_link_libraries(ge_compiler | |||||
graph | graph | ||||
ge_common | ge_common | ||||
"-Wl,--whole-archive" | "-Wl,--whole-archive" | ||||
@@ -80,7 +80,7 @@ target_compile_definitions(ge_client_train PRIVATE | |||||
PLATFORM_CLOUD) | PLATFORM_CLOUD) | ||||
target_link_libraries(ge_client | target_link_libraries(ge_client | ||||
graph | graph | ||||
ge | |||||
ge_compiler | |||||
ge_common | ge_common | ||||
${PROTOBUF_LIBRARY} | ${PROTOBUF_LIBRARY} | ||||
${register} | ${register} | ||||
@@ -61,14 +61,14 @@ Status CheckDumpAndReuseMemory(const std::map<string, string> &options) { | |||||
const int kDecimal = 10; | const int kDecimal = 10; | ||||
auto dump_op_env = std::getenv("DUMP_OP"); | auto dump_op_env = std::getenv("DUMP_OP"); | ||||
int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0; | int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0; | ||||
auto disable_reuse_memory_iter = options.find("ge.exec.disableReuseMemory"); | |||||
if (disable_reuse_memory_iter != options.end()) { | |||||
if (disable_reuse_memory_iter->second == "0") { | |||||
auto disableReuseMemoryIter = options.find("ge.exec.disableReuseMemory"); | |||||
if (disableReuseMemoryIter != options.end()) { | |||||
if (disableReuseMemoryIter->second == "0") { | |||||
GELOGD("ge.exec.disableReuseMemory=0, reuse memory is open"); | GELOGD("ge.exec.disableReuseMemory=0, reuse memory is open"); | ||||
if (dump_op_flag) { | if (dump_op_flag) { | ||||
GELOGW("Will dump incorrect op data with GE Option ge.exec.disableReuseMemory=0"); | GELOGW("Will dump incorrect op data with GE Option ge.exec.disableReuseMemory=0"); | ||||
} | } | ||||
} else if (disable_reuse_memory_iter->second == "1") { | |||||
} else if (disableReuseMemoryIter->second == "1") { | |||||
GELOGD("ge.exec.disableReuseMemory=1, reuse memory is close"); | GELOGD("ge.exec.disableReuseMemory=1, reuse memory is close"); | ||||
} else { | } else { | ||||
GELOGE(PARAM_INVALID, "CheckDumpAndReuseMemory ge.exec.disableReuseMemory is valid"); | GELOGE(PARAM_INVALID, "CheckDumpAndReuseMemory ge.exec.disableReuseMemory is valid"); | ||||
@@ -128,22 +128,29 @@ Status GEInitialize(const std::map<string, string> &options) { | |||||
OpsProtoManager *manager = OpsProtoManager::Instance(); | OpsProtoManager *manager = OpsProtoManager::Instance(); | ||||
std::map<string, string> option_tmp; | std::map<string, string> option_tmp; | ||||
option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | ||||
GE_TIMESTAMP_START(GEInitialize); | |||||
bool is_proto_init = manager->Initialize(option_tmp); | bool is_proto_init = manager->Initialize(option_tmp); | ||||
GE_TIMESTAMP_END(GEInitialize, "GEInitialize::ManagerInitialize"); | |||||
if (!is_proto_init) { | if (!is_proto_init) { | ||||
GELOGE(GE_CLI_INIT_FAILED, "geInitialize failed, ops proto path is invalid."); | GELOGE(GE_CLI_INIT_FAILED, "geInitialize failed, ops proto path is invalid."); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
// check options is valid | // check options is valid | ||||
GE_TIMESTAMP_START(CheckOptionsValid); | |||||
if (CheckOptionsValid(options) != SUCCESS) { | if (CheckOptionsValid(options) != SUCCESS) { | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
GE_TIMESTAMP_END(CheckOptionsValid, "GEInitialize::CheckOptionsValid"); | |||||
GE_TIMESTAMP_START(InitPreparation); | |||||
SaveDdkVersion(options); | SaveDdkVersion(options); | ||||
GE_TIMESTAMP_END(InitPreparation, "GEInitialize::InitPreparation"); | |||||
// call Initialize | // call Initialize | ||||
GELOGT(TRACE_RUNNING, "Initializing environment"); | GELOGT(TRACE_RUNNING, "Initializing environment"); | ||||
GE_TIMESTAMP_START(GELibInitialize); | |||||
Status ret = ge::GELib::Initialize(options); | Status ret = ge::GELib::Initialize(options); | ||||
GE_TIMESTAMP_END(GELibInitialize, "GEInitialize::GELibInitialize"); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(GE_CLI_INIT_FAILED, "geInitialize failed, error code = %u", ret); | GELOGE(GE_CLI_INIT_FAILED, "geInitialize failed, error code = %u", ret); | ||||
return FAILED; | return FAILED; | ||||
@@ -170,17 +177,20 @@ Status GEFinalize() { | |||||
std::lock_guard<std::mutex> lock(kGeReleaseMutex); | std::lock_guard<std::mutex> lock(kGeReleaseMutex); | ||||
// call Finalize | // call Finalize | ||||
Status ret = SUCCESS; | |||||
Status middle_ret; | |||||
GELOGT(TRACE_RUNNING, "Finalizing environment"); | GELOGT(TRACE_RUNNING, "Finalizing environment"); | ||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GEFinalize Failed: GE not initialized"); | |||||
return GE_CLI_GE_NOT_INITIALIZED; | |||||
} | |||||
Status ret = instance_ptr->Finalize(); | |||||
GELOGI("GEFinalize finalize gelib ret=%u", ret); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "GEFinalize Failed"); | |||||
return FAILED; | |||||
std::shared_ptr<GELib> instancePtr = ge::GELib::GetInstance(); | |||||
if (instancePtr == nullptr || !instancePtr->InitFlag()) { | |||||
GELOGW("GEFinalize Failed: GE not initialized."); | |||||
ret = GE_CLI_GE_NOT_INITIALIZED; | |||||
} | |||||
if (ret != GE_CLI_GE_NOT_INITIALIZED) { | |||||
middle_ret = instancePtr->Finalize(); | |||||
GELOGI("GEFinalize finalize gelib ret=%u", middle_ret); | |||||
if (middle_ret != SUCCESS) { | |||||
ret = middle_ret; | |||||
} | |||||
} | } | ||||
if (kGeInitialized && ret == SUCCESS) { | if (kGeInitialized && ret == SUCCESS) { | ||||
@@ -379,8 +389,6 @@ Status Session::RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, s | |||||
} | } | ||||
Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback) { | Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback) { | ||||
GELOGW( | |||||
"The callback function will not be checked. Please ensure that the implementation of the function is trusted."); | |||||
return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | ||||
} | } | ||||