Merge pull request !22 from yanghaoran/mastertags/v0.3.0-alpha
@@ -18,7 +18,6 @@ | |||
#define INC_COMMON_BLOCKING_QUEUE_H_ | |||
#include <stdint.h> | |||
#include <condition_variable> | |||
#include <list> | |||
#include <mutex> | |||
@@ -87,7 +86,7 @@ class BlockingQueue { | |||
is_stoped_ = false; | |||
} | |||
// if the queue stop , the function to release the unprocessed items will be call | |||
// if the queue is stoped ,need call this function to release the unprocessed items | |||
std::list<T> GetRemainItems() { | |||
std::unique_lock<std::mutex> lock(mutex_); | |||
@@ -19,10 +19,10 @@ | |||
#include <stdint.h> | |||
/// | |||
/// @ingroup dnn | |||
/// @brief struct define of dynamic aipp batch parameter. | |||
/// | |||
/** | |||
* @ingroup dnn | |||
* @brief struct define of dynamic aipp batch parameter. | |||
*/ | |||
typedef struct tagAippDynamicBatchPara { | |||
int8_t cropSwitch; // crop switch | |||
int8_t scfSwitch; // resize switch | |||
@@ -66,10 +66,10 @@ typedef struct tagAippDynamicBatchPara { | |||
int8_t reserve1[16]; // 32B assign, for ub copy | |||
} kAippDynamicBatchPara; | |||
/// | |||
/// @ingroup dnn | |||
/// @brief struct definition of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte | |||
/// | |||
/** | |||
* @ingroup dnn | |||
* @brief struct define of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte | |||
*/ | |||
typedef struct tagAippDynamicPara { | |||
uint8_t inputFormat; // input format:YUV420SP_U8/XRGB8888_U8/RGB888_U8 | |||
int8_t cscSwitch; // csc switch | |||
@@ -61,19 +61,19 @@ typedef enum tagHiAiNpuModuleId { | |||
HIAI_DP = 23, | |||
} HiAiNpuModuleId; | |||
// bit 31-bit30 to be hiai local | |||
/* bit 31-bit30 to be hiai local */ | |||
#define HIAI_NPULOCAL_MASK 0xC0000000 | |||
#define SHIFT_LOCAL_MASK 30 | |||
#define HIAI_NPULOCAL_VAL_MASK 0x3 | |||
// bit 29 -bit28 to be hiai aicpu code type | |||
/* bit 29 -bit28 to be hiai aicpu code type */ | |||
#define HIAI_CODE_TYPE_MASK 0x30000000 | |||
#define SHIFT_CODE_MASK 28 | |||
#define HIAI_CODE_TYPE_VAL_MASK 0x3 | |||
// bit 27 -bit25 to be hiai error level | |||
/* bit 27 -bit25 to be hiai error level */ | |||
#define HIAI_ERROR_LEVEL_MASK 0x0E000000 | |||
#define SHIFT_ERROR_LVL_MASK 25 | |||
#define HIAI_ERROR_LEVEL_VAL_MASK 0x7 | |||
// bit 24 -bit17 to be hiai mod | |||
/* bit 24 -bit17 to be hiai mod */ | |||
#define HIAI_MODE_ID_MASK 0x01FE0000 | |||
#define SHIFT_MODE_MASK 17 | |||
#define HIAI_MODE_ID_VAL_MASK 0xFF | |||
@@ -19,13 +19,12 @@ | |||
#include <runtime/rt.h> | |||
#include <stdint.h> | |||
#include <string> | |||
#include <vector> | |||
using std::string; | |||
namespace ge { | |||
// DAVINCI_TRAIN/DAVINCI_CLOUD is not needed when GETaskKernelHcclInfo needed | |||
// when need to eliminate GETaskKernelHcclInfo, so not need DAVINCI_TRAIN/DAVINCI_CLOUD | |||
struct GETaskKernelHcclInfo { | |||
string hccl_type; | |||
void *inputDataAddr; | |||
@@ -21,7 +21,6 @@ | |||
#include <map> | |||
#include <string> | |||
#include <vector> | |||
#include "./ge_task_info.h" | |||
#include "./ops_kernel_info_types.h" | |||
#include "cce/aicpu_engine_struct.h" | |||
@@ -29,7 +28,6 @@ | |||
#include "common/ge_inner_error_codes.h" | |||
#include "graph/node.h" | |||
#include "proto/task.pb.h" | |||
using std::map; | |||
using std::string; | |||
using std::to_string; | |||
@@ -47,7 +45,7 @@ class OpsKernelInfoStore { | |||
// initialize opsKernelInfoStore | |||
virtual Status Initialize(const map<string, string> &options) = 0; | |||
// finalize opsKernelInfoStore | |||
// close opsKernelInfoStore | |||
virtual Status Finalize() = 0; | |||
virtual Status CreateSession(const std::map<std::string, std::string> &session_options) { return SUCCESS; } | |||
@@ -57,18 +55,20 @@ class OpsKernelInfoStore { | |||
// get all opsKernelInfo | |||
virtual void GetAllOpsKernelInfo(map<string, OpInfo> &infos) const = 0; | |||
// check whether opsKernelInfoStore is supported based on the operator attribute | |||
// whether the opsKernelInfoStore is supported based on the operator attribute | |||
virtual bool CheckSupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason) const = 0; | |||
virtual bool CheckAccuracySupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason, | |||
bool realQuery = false) const { | |||
return CheckSupported(opDescPtr, un_supported_reason); | |||
} | |||
// opsFlag opsFlag[0] indicates constant folding is supported or not | |||
virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag){}; | |||
// requirement of memory allocation | |||
// memory allocation requirement | |||
virtual Status CalcOpRunningParam(Node &node) = 0; | |||
// generate task for op | |||
// generate task for op。 | |||
virtual Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) = 0; | |||
// only call fe engine interface to compile single op | |||
@@ -77,10 +77,10 @@ class OpsKernelInfoStore { | |||
// load task for op | |||
virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; } | |||
// only to call aicpu interface for generating task struct | |||
// only call aicpu interface to generate task struct | |||
virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } | |||
// only to call aicpu interface for generating task struct | |||
// only call aicpu interface to generate task struct | |||
virtual Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } | |||
}; | |||
} // namespace ge | |||
@@ -37,6 +37,7 @@ struct RunContext { | |||
ge::Buffer weightsBuffer; | |||
std::vector<rtStream_t> graphStreamList; // all streams of graph, order by ge stream id(0,1,...) | |||
std::vector<rtEvent_t> graphEventList; // all events of graph, order by ge event id(0,1,...) | |||
std::vector<rtLabel_t> graphLabelList; // all labels of graph, order by ge label id(0,1,...) | |||
}; | |||
struct Task { | |||
@@ -19,7 +19,6 @@ | |||
#include <map> | |||
#include <string> | |||
#include "./graph_optimizer_types.h" | |||
#include "common/ge_inner_error_codes.h" | |||
#include "common/opskernel/ops_kernel_info_types.h" | |||
@@ -39,19 +38,19 @@ class GraphOptimizer { | |||
// close graphOptimizer | |||
virtual Status Finalize() = 0; | |||
// optimize original graph for FE quant optimization | |||
// optimize original graph for FE quant optimize | |||
virtual Status OptimizeGraphPrepare(ComputeGraph &graph) { return SUCCESS; } | |||
// optimize original graph used in the graph preparation stage | |||
// optimize original graph, using in graph preparation stage | |||
virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; | |||
// optimize fused graph | |||
virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; | |||
// optimize the whole graph which will be used after graph merged | |||
// optimize whole graph, using after graph merged stage | |||
virtual Status OptimizeWholeGraph(ComputeGraph &graph) = 0; | |||
// get attributes of graph optimizer | |||
// get attribute of graph optimizer | |||
virtual Status GetAttributes(GraphOptimizerAttribute &attrs) const = 0; | |||
// optimize streamed Graph | |||
@@ -19,8 +19,6 @@ | |||
#include <stdint.h> | |||
#include <string> | |||
using std::string; | |||
namespace ge { | |||
enum OPTIMIZER_SCOPE { | |||
UNIT = 0, | |||
@@ -28,7 +26,7 @@ enum OPTIMIZER_SCOPE { | |||
}; | |||
struct GraphOptimizerAttribute { | |||
string engineName; | |||
std::string engineName; | |||
OPTIMIZER_SCOPE scope; | |||
}; | |||
} // namespace ge | |||
@@ -20,6 +20,7 @@ | |||
#include <cstdint> | |||
#include <string> | |||
#include <vector> | |||
#include <set> | |||
namespace ge { | |||
// Option key: graph run mode | |||
@@ -38,9 +39,11 @@ const char *const GE_AICPU_FLAG = "ge.aicpuFlag"; | |||
const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; | |||
const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | |||
const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | |||
const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | |||
// Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | |||
const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | |||
const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | |||
const char *const OPTION_EXEC_DISABLE_REUSED_MEMORY = "ge.exec.disableReuseMemory"; | |||
// Option key: memory init | |||
const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; | |||
@@ -141,19 +144,43 @@ const std::string STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; | |||
// congigure outputDatatype to setting net output type | |||
const std::string OUTPUT_DATATYPE = "ge.outputDatatype"; | |||
// congigure opSelectImplmode to setting op select implmode | |||
const std::string kOpSelectImplmode = "ge.opSelectImplmode"; | |||
// configure whether to enable hcom parallel by session constructor options param, | |||
// its value should be "0" or "1", default value is "0" | |||
const std::string HCOM_PARALLEL = "ge.hcomParallel"; | |||
// configure whether to use dynamic batch size | |||
const char *const kDynamicBatchSize = "ge.dynamicBatchSize"; | |||
// configure whether to use dynamic image size | |||
const char *const kDynamicImageSize = "ge.dynamicImageSize"; | |||
// Configure auto tune mode, this option only take effect while AUTO_TUNE_FLAG is Y, | |||
// example: GA|RL, support configure multiple, split by | | |||
const std::string AUTO_TUNE_MODE = "ge.autoTuneMode"; | |||
// Configure soc version , example: "Ascend310" | |||
const std::string SOC_VERSION = "ge.socVersion"; | |||
// Configure core type "VectorEngine", default value is "AIcoreEngine" | |||
const std::string CORE_TYPE = "ge.engineType"; | |||
// Configure soc version , example: "Ascend310" | |||
const std::string SOC_VERSION = "ge.socVersion"; | |||
// Configure AICORE NUM | |||
const std::string AICORE_NUM = "ge.aicoreNum"; | |||
// Configure L1FUSION | |||
const std::string L1_FUSION = "ge.l1Fusion"; | |||
// Configure Small Channel flag | |||
const std::string ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; | |||
// Configure Compress Weight flag | |||
const std::string ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight"; | |||
// Configure fusion switch file path | |||
const std::string FUSION_SWITCH_FILE = "ge.fusionSwitchFile"; | |||
// Save original model | |||
const std::string SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; | |||
@@ -194,6 +221,28 @@ struct TensorInfo { | |||
DataDesc data; // tensor data | |||
ShapeDesc shapeInfo; // tensor shape | |||
}; | |||
// for ir build | |||
namespace ir_option { | |||
static const char *const INPUT_FORMAT = "input_format"; | |||
static const char *const INPUT_SHAPE = "input_shape"; | |||
static const char *const OP_NAME_MAP = "op_name_map"; | |||
static const char *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize; | |||
static const char *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize; | |||
static const char *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str(); | |||
static const char *const PRECISION_MODE = ge::PRECISION_MODE.c_str(); | |||
static const char *const EXEC_DISABLE_REUSED_MEMORY = ge::OPTION_EXEC_DISABLE_REUSED_MEMORY; | |||
static const char *const HEAD_STREAM = ge::HEAD_STREAM.c_str(); | |||
static const char *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str(); | |||
static const char *const CORE_TYPE = ge::CORE_TYPE.c_str(); | |||
static const char *const SOC_VERSION = ge::SOC_VERSION.c_str(); | |||
// for interface: aclgrphBuildModel | |||
const std::set<std::string> ir_builder_suppported_options = { | |||
INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, DYNAMIC_BATCH_SIZE, | |||
DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||
AUTO_TUNE_MODE}; | |||
// for interface: aclgrphBuildInitialize | |||
const std::set<std::string> global_options = {HEAD_STREAM, CORE_TYPE, SOC_VERSION}; | |||
} // namespace ir_option | |||
} // namespace ge | |||
#endif // INC_EXTERNAL_GE_GE_API_TYPES_H_ |
@@ -0,0 +1,75 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef INC_EXTERNAL_GE_IR_BUILD_H_ | |||
#define INC_EXTERNAL_GE_IR_BUILD_H_ | |||
#include <string> | |||
#include <map> | |||
#include <memory> | |||
#include "graph/graph.h" | |||
#include "graph/ge_error_codes.h" | |||
namespace ge { | |||
struct ModelBufferData { | |||
std::shared_ptr<uint8_t> data = nullptr; | |||
uint64_t length; | |||
}; | |||
/** | |||
* @ingroup AscendCL | |||
* @brief build model.Notice the model is stored in buffer | |||
* | |||
* @param global_options[IN] global init params for build | |||
* @retval GRAPH_SUCCESS The function is successfully executed. | |||
* @retval OtherValues Failure | |||
*/ | |||
graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options); | |||
/** | |||
* @ingroup AscendCL | |||
* @brief build model.Notice the model is stored in buffer | |||
* | |||
*/ | |||
void aclgrphBuildFinalize(); | |||
/** | |||
* @ingroup AscendCL | |||
* @brief build model.Notice the model is stored in buffer | |||
* | |||
* @param graph[IN] the graph ready to build | |||
* @param options[IN] options used for build | |||
* @param model[OUT] builded model | |||
* @retval GRAPH_SUCCESS The function is successfully executed. | |||
* @retval OtherValues Failure | |||
*/ | |||
graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options, | |||
ModelBufferData &model); | |||
/** | |||
* @ingroup AscendCL | |||
* @brief save model buffer to file | |||
* | |||
* @param output_file[IN] the file path to be saved | |||
* @param model[IN] model buffer data | |||
* @retval GRAPH_SUCCESS The function is successfully executed. | |||
* @retval OtherValues Failure | |||
*/ | |||
graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model); | |||
}; // namespace ge | |||
#endif |
@@ -22,7 +22,7 @@ | |||
#include <string> | |||
#include <vector> | |||
#include "external/graph/ge_error_codes.h" | |||
#include "./ge_error_codes.h" | |||
using std::make_shared; | |||
using std::map; | |||
@@ -22,7 +22,7 @@ | |||
#include <utility> | |||
#include <vector> | |||
#include "external/graph/operator.h" | |||
#include "./operator.h" | |||
namespace ge { | |||
class GraphImpl; | |||
@@ -21,8 +21,8 @@ | |||
#include <string> | |||
#include <vector> | |||
#include "external/graph/tensor.h" | |||
#include "external/graph/types.h" | |||
#include "./tensor.h" | |||
#include "./types.h" | |||
namespace ge { | |||
class InferenceContext; | |||
@@ -69,7 +69,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { | |||
static std::unique_ptr<InferenceContext> Create(); | |||
private: | |||
InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||
explicit InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||
std::shared_ptr<InferenceContextImpl> inference_context_impl_; | |||
}; | |||
} // namespace ge | |||
@@ -23,9 +23,9 @@ | |||
#include <string> | |||
#include <vector> | |||
#include "external/graph/ge_error_codes.h" | |||
#include "external/graph/inference_context.h" | |||
#include "external/graph/tensor.h" | |||
#include "./ge_error_codes.h" | |||
#include "./inference_context.h" | |||
#include "./tensor.h" | |||
#ifndef USER_GE_LOGI | |||
#define USER_GE_LOGI(...) | |||
@@ -22,8 +22,8 @@ | |||
#include <string> | |||
#include <vector> | |||
#include "external/graph//operator.h" | |||
#include "external/graph/ge_error_codes.h" | |||
#include "./operator.h" | |||
#include "./ge_error_codes.h" | |||
namespace ge { | |||
using OpCreator = std::function<Operator(const std::string &)>; | |||
@@ -22,10 +22,10 @@ | |||
#include <string> | |||
#include <vector> | |||
#include "external/graph/operator.h" | |||
#include "external/graph/operator_factory.h" | |||
#include "external/graph/tensor.h" | |||
#include "external/graph/types.h" | |||
#include "./operator.h" | |||
#include "./operator_factory.h" | |||
#include "./tensor.h" | |||
#include "./types.h" | |||
namespace ge { | |||
using std::function; | |||
@@ -60,7 +60,7 @@ class OpReg { | |||
\ | |||
private: \ | |||
void __##x() { \ | |||
OpReg() | |||
OpReg() | |||
#define ATTR(x, Type, ...) \ | |||
N(); \ | |||
@@ -86,7 +86,7 @@ class OpReg { | |||
void __attr_##x() { \ | |||
Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \ | |||
string attr_name(#x); \ | |||
(void)OpReg() | |||
(void)OpReg() | |||
#define REQUIRED_ATTR(x, Type) \ | |||
N(); \ | |||
@@ -112,7 +112,7 @@ class OpReg { | |||
void __required_attr_##x() { \ | |||
Operator::RequiredAttrRegister(#x); \ | |||
string attr_name(#x); \ | |||
(void)OpReg() | |||
(void)OpReg() | |||
#define INPUT(x, t) \ | |||
N(); \ | |||
@@ -137,7 +137,7 @@ class OpReg { | |||
private: \ | |||
void __input_##x() { \ | |||
Operator::InputRegister(#x); \ | |||
(void)OpReg() | |||
(void)OpReg() | |||
#define OPTIONAL_INPUT(x, t) \ | |||
N(); \ | |||
@@ -162,7 +162,7 @@ class OpReg { | |||
private: \ | |||
void __optional_input_##x() { \ | |||
Operator::OptionalInputRegister(#x); \ | |||
(void)OpReg() | |||
(void)OpReg() | |||
#define OUTPUT(x, t) \ | |||
N(); \ | |||
@@ -179,7 +179,7 @@ class OpReg { | |||
private: \ | |||
void __out_##x() { \ | |||
Operator::OutputRegister(#x); \ | |||
(void)OpReg() | |||
(void)OpReg() | |||
#define DYNAMIC_INPUT(x, t) \ | |||
N(); \ | |||
@@ -206,7 +206,7 @@ class OpReg { | |||
\ | |||
private: \ | |||
void __dy_input_##x() { \ | |||
(void)OpReg() | |||
(void)OpReg() | |||
#define DYNAMIC_OUTPUT(x, t) \ | |||
N(); \ | |||
@@ -227,18 +227,18 @@ class OpReg { | |||
\ | |||
private: \ | |||
void __dy_output_##x() { \ | |||
(void)OpReg() | |||
(void)OpReg() | |||
#define PASTE(g_register, y) g_register##y | |||
#define __OP_END_IMPL__(x, y) \ | |||
N(); \ | |||
} \ | |||
static_assert( \ | |||
std::is_same<x, _THIS_TYPE>::value, \ | |||
"The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ | |||
} \ | |||
; \ | |||
static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \ | |||
#define __OP_END_IMPL__(x, y) \ | |||
N(); \ | |||
} \ | |||
static_assert( \ | |||
std::is_same<x, _THIS_TYPE>::value, \ | |||
"The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ | |||
} \ | |||
; \ | |||
static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \ | |||
} | |||
#define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__) | |||
@@ -286,7 +286,7 @@ class OpReg { | |||
// Common shape inferencer | |||
#define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \ | |||
[](Operator op)->graphStatus { \ | |||
[](Operator op) -> graphStatus { \ | |||
auto x_shape = op.GetInputDesc(in_name).GetShape().GetDims(); \ | |||
auto x_type = op.GetInputDesc(in_name).GetDataType(); \ | |||
TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ | |||
@@ -300,7 +300,7 @@ graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape, | |||
const function<void(const vector<int64_t> &y_shape)> &set_out_shape); | |||
#define BROADCAST_INFER(in1_name, in2_name, out_name) \ | |||
[](Operator op)->graphStatus { \ | |||
[](Operator op) -> graphStatus { \ | |||
return BroadCastInfer([&]() { return op.GetInputDesc(in1_name).GetShape().GetDims(); }, \ | |||
[&]() { return op.GetInputDesc(in2_name).GetShape().GetDims(); }, \ | |||
[&](const vector<int64_t> &y_shape) { \ | |||
@@ -22,8 +22,8 @@ | |||
#include <string> | |||
#include <vector> | |||
#include "external/graph/ge_error_codes.h" | |||
#include "external/graph/types.h" | |||
#include "./ge_error_codes.h" | |||
#include "./types.h" | |||
namespace ge { | |||
class ShapeImpl; | |||
@@ -133,11 +133,13 @@ enum Format { | |||
FORMAT_FRACTAL_ZZ, | |||
FORMAT_FRACTAL_NZ, | |||
FORMAT_NCDHW, | |||
FORMAT_DHWCK, // 3D filter input tensor format | |||
FORMAT_DHWCN, // 3D filter input tensor format | |||
FORMAT_NDC1HWC0, | |||
FORMAT_FRACTAL_Z_3D, | |||
FORMAT_CN, | |||
FORMAT_NC, | |||
FORMAT_DHWNC, | |||
FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format | |||
FORMAT_RESERVED, | |||
FORMAT_ALL | |||
}; | |||
@@ -47,6 +47,12 @@ class Tensor; | |||
class TBEPluginManager; | |||
} // namespace ge | |||
namespace google { | |||
namespace protobuf { | |||
class Message; | |||
} | |||
} // namespace google | |||
namespace domi { | |||
Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); | |||
Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | |||
@@ -56,6 +62,8 @@ using google::protobuf::Message; | |||
class OpRegistrationDataImpl; | |||
using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>; | |||
using FusionParseParamFunc = | |||
std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>; | |||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||
public: | |||
@@ -71,15 +79,20 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||
OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn); | |||
OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); | |||
OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | |||
OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); | |||
OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type); | |||
domi::ImplyType GetImplyType() const; | |||
std::string GetOmOptype() const; | |||
std::set<std::string> GetOriginOpTypeSet() const; | |||
domi::FrameworkType GetFrameworkType() const; | |||
ParseParamFunc GetParseParamFn() const; | |||
FusionParseParamFunc GetFusionParseParamFn() const; | |||
private: | |||
std::shared_ptr<OpRegistrationDataImpl> impl_; | |||
@@ -103,5 +116,27 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { | |||
namespace ge { | |||
using OpRegistrationData = domi::OpRegistrationData; | |||
using OpReceiver = domi::OpReceiver; | |||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { | |||
public: | |||
HostCpuOp() = default; | |||
virtual ~HostCpuOp() = default; | |||
virtual graphStatus Compute(Operator &op, const std::map<std::string, const Tensor> &inputs, | |||
std::map<std::string, Tensor> &outputs) = 0; | |||
}; | |||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { | |||
public: | |||
HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); | |||
}; | |||
#define REGISTER_HOST_CPU_OP_BUILDER(name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) | |||
#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) | |||
#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ | |||
static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr __attribute__((unused)) = \ | |||
::ge::HostCpuOpRegistrar(name, []() -> ::ge::HostCpuOp * { return new (std::nothrow) op(); }) | |||
} // namespace ge | |||
#endif // INC_EXTERNAL_REGISTER_REGISTER_H_ |
@@ -22,7 +22,7 @@ | |||
#define DECLARE_ERRORNO(sysid, modid, name, value) \ | |||
const domi::Status name = \ | |||
((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value)); | |||
((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value)); | |||
#define DECLARE_ERRORNO_COMMON(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_COMMON, name, value) | |||
@@ -33,6 +33,7 @@ using Status = uint32_t; | |||
DECLARE_ERRORNO(0, 0, SUCCESS, 0); | |||
DECLARE_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFFFFFF); | |||
DECLARE_ERRORNO_COMMON(PARAM_INVALID, 1); // 50331649 | |||
DECLARE_ERRORNO(SYSID_FWK, 1, SCOPE_NOT_CHANGED, 201); | |||
} // namespace domi | |||
#endif // INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ |
@@ -48,6 +48,10 @@ typedef enum tagDomiTensorFormat { | |||
DOMI_TENSOR_BN_WEIGHT, | |||
DOMI_TENSOR_CHWN, // Android NN Depth CONV | |||
DOMI_TENSOR_FILTER_HWCK, // filter input tensor format | |||
DOMI_TENSOR_NDHWC, | |||
DOMI_TENSOR_NCDHW, | |||
DOMI_TENSOR_DHWCN, // 3D filter input tensor format | |||
DOMI_TENSOR_DHWNC, | |||
DOMI_TENSOR_RESERVED | |||
} domiTensorFormat_t; | |||
} // namespace domi | |||
@@ -18,11 +18,13 @@ | |||
#define INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ | |||
#include <cstdint> | |||
#include <unistd.h> | |||
#include <sys/syscall.h> | |||
#include "framework/common/ge_inner_error_codes.h" | |||
#include "toolchain/slog.h" | |||
#define GE_MODULE_NAME GE | |||
#define GE_MODULE_NAME static_cast<int>(GE) | |||
// trace status of log | |||
enum TraceStatus { TRACE_INIT = 0, TRACE_RUNNING, TRACE_WAITING, TRACE_STOP }; | |||
@@ -35,15 +37,20 @@ enum TraceStatus { TRACE_INIT = 0, TRACE_RUNNING, TRACE_WAITING, TRACE_STOP }; | |||
#define GELOGO(...) GE_LOG_OPLOG(GE_MODULE_NAME, __VA_ARGS__) | |||
#define GELOGT(VALUE, ...) GE_LOG_TRACE(GE_MODULE_NAME, VALUE, __VA_ARGS__) | |||
inline bool IsLogEnable(int module_name, int log_level) noexcept { | |||
int32_t enable_event = 0; | |||
int32_t dlog_level = dlog_getlevel(module_name, &enable_event); | |||
if (dlog_level <= log_level) { | |||
inline bool IsLogEnable(int module_name, int log_level) { | |||
int32_t enable = CheckLogLevel(module_name, log_level); | |||
// 1:enable, 0:disable | |||
if (enable == 1) { | |||
return true; | |||
} | |||
return false; | |||
} | |||
inline pid_t GetTid() { | |||
thread_local static pid_t tid = syscall(__NR_gettid); | |||
return tid; | |||
} | |||
#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() | |||
#define GE_TIMESTAMP_END(stage, stage_name) \ | |||
@@ -68,29 +75,35 @@ inline bool IsLogEnable(int module_name, int log_level) noexcept { | |||
GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ | |||
call_num_of##stage) | |||
#define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ | |||
dlog_error(static_cast<int>(MOD_NAME), "%s: ErrorNo: %d(%s) " fmt, __FUNCTION__, ERROR_CODE, \ | |||
#define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ | |||
dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GetTid(), __FUNCTION__, ERROR_CODE, \ | |||
((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__) | |||
#define GE_LOG_WARN(MOD_NAME, fmt, ...) \ | |||
if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_WARN)) \ | |||
dlog_warn(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||
#define GE_LOG_INFO(MOD_NAME, fmt, ...) \ | |||
if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_INFO)) \ | |||
dlog_info(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||
#define GE_LOG_DEBUG(MOD_NAME, fmt, ...) \ | |||
if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_DEBUG)) \ | |||
dlog_debug(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||
#define GE_LOG_EVENT(MOD_NAME, fmt, ...) dlog_event(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||
#define GE_LOG_WARN(MOD_NAME, fmt, ...) \ | |||
if (IsLogEnable(MOD_NAME, DLOG_WARN)) dlog_warn(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||
#define GE_LOG_INFO(MOD_NAME, fmt, ...) \ | |||
if (IsLogEnable(MOD_NAME, DLOG_INFO)) dlog_info(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||
#define GE_LOG_DEBUG(MOD_NAME, fmt, ...) \ | |||
if (IsLogEnable(MOD_NAME, DLOG_DEBUG)) dlog_debug(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||
#define GE_LOG_EVENT(MOD_NAME, fmt, ...) dlog_event(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||
#define GE_LOG_OPLOG(MOD_NAME, fmt, ...) \ | |||
Dlog(static_cast<int>(MOD_NAME), DLOG_OPLOG, "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||
#define GE_LOG_TRACE(MOD_NAME, value, fmt, ...) \ | |||
do { \ | |||
TraceStatus stat = value; \ | |||
const char *const TraceStatStr[] = {"INIT", "RUNNING", "WAITING", "STOP"}; \ | |||
int idx = static_cast<int>(stat); \ | |||
char *k = const_cast<char *>("status"); \ | |||
char *v = const_cast<char *>(TraceStatStr[idx]); \ | |||
KeyValue kv = {k, v}; \ | |||
DlogWithKV(static_cast<int>(MOD_NAME), DLOG_TRACE, &kv, 1, "%s:" fmt, __FUNCTION__, ##__VA_ARGS__); \ | |||
Dlog(MOD_NAME, DLOG_OPLOG, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||
#define GE_LOG_TRACE(MOD_NAME, value, fmt, ...) \ | |||
do { \ | |||
TraceStatus stat = value; \ | |||
const char *const TraceStatStr[] = {"INIT", "RUNNING", "WAITING", "STOP"}; \ | |||
int idx = static_cast<int>(stat); \ | |||
char *k = const_cast<char *>("status"); \ | |||
char *v = const_cast<char *>(TraceStatStr[idx]); \ | |||
KeyValue kv = {k, v}; \ | |||
DlogWithKV(static_cast<int>(MOD_NAME), DLOG_TRACE, &kv, 1, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__); \ | |||
} while (0) | |||
// print memory when it is greater than 1KB. | |||
#define GE_PRINT_DYNAMIC_MEMORY(FUNC, PURPOSE, SIZE) \ | |||
do { \ | |||
if ((SIZE) > 1024) { \ | |||
GELOGI("MallocMemory, func=%s, size=%zu, purpose=%s", (#FUNC), static_cast<size_t>(SIZE), (PURPOSE)); \ | |||
} \ | |||
} while (0); | |||
#endif // INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ |
@@ -29,7 +29,18 @@ | |||
using cce::CC_STATUS_SUCCESS; | |||
using cce::ccStatus_t; | |||
#define GE_LOGE(...) DAV_LOGE("GE", __VA_ARGS__) | |||
#if !defined(__ANDROID__) && !defined(ANDROID) | |||
#define DOMI_LOGE(...) DAV_LOGE("DOMI", __VA_ARGS__) | |||
#else | |||
#include <android/log.h> | |||
#if defined(BUILD_VERSION_PERF) | |||
#define DOMI_LOGE(fmt, ...) | |||
#else | |||
// The Android system has strict log control. Do not modify the log. | |||
#define DOMI_LOGE(fmt, ...) \ | |||
__android_log_print(ANDROID_LOG_ERROR, "NPU_FMK", "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||
#endif | |||
#endif | |||
// ge marco | |||
#define GE_LOGI_IF(condition, ...) \ | |||
@@ -44,7 +55,7 @@ using cce::ccStatus_t; | |||
#define GE_LOGE_IF(condition, ...) \ | |||
if ((condition)) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
DOMI_LOGE(__VA_ARGS__); \ | |||
} | |||
// If expr is not SUCCESS, print the log and return the same value | |||
@@ -52,7 +63,7 @@ using cce::ccStatus_t; | |||
do { \ | |||
const ge::Status _status = (expr); \ | |||
if (_status != ge::SUCCESS) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
DOMI_LOGE(__VA_ARGS__); \ | |||
return _status; \ | |||
} \ | |||
} while (0); | |||
@@ -62,7 +73,7 @@ using cce::ccStatus_t; | |||
do { \ | |||
const ge::Status _status = (expr); \ | |||
if (_status != ge::SUCCESS) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
DOMI_LOGE(__VA_ARGS__); \ | |||
} \ | |||
} while (0); | |||
@@ -75,6 +86,15 @@ using cce::ccStatus_t; | |||
} \ | |||
} while (0); | |||
// If expr is not GRAPH_SUCCESS, print the log and return FAILED | |||
#define GE_CHK_GRAPH_STATUS_RET(expr, ...) \ | |||
do { \ | |||
if ((expr) != ge::GRAPH_SUCCESS) { \ | |||
DOMI_LOGE(__VA_ARGS__); \ | |||
return FAILED; \ | |||
} \ | |||
} while (0); | |||
// If expr is not SUCCESS, print the log and execute a custom statement | |||
#define GE_CHK_STATUS_EXEC(expr, exec_expr, ...) \ | |||
do { \ | |||
@@ -91,25 +111,11 @@ using cce::ccStatus_t; | |||
(void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ | |||
(void)msg.append( \ | |||
ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||
GE_LOGE("%s", msg.c_str()); \ | |||
DOMI_LOGE("%s", msg.c_str()); \ | |||
return _status; \ | |||
} \ | |||
} while (0); | |||
// If expr is not true, print the Info log and return the specified status | |||
#define GE_CHK_BOOL_RET_STATUS_LOGI(expr, _status, ...) \ | |||
do { \ | |||
bool b = (expr); \ | |||
if (!b) { \ | |||
std::string msg; \ | |||
(void)msg.append(StringUtils::FormatString(__VA_ARGS__)); \ | |||
(void)msg.append( \ | |||
StringUtils::FormatString(" Check result false, status: 0x%X %s", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||
GELOGI("%s", msg.c_str()); \ | |||
return _status; \ | |||
} \ | |||
} while (0); | |||
// If expr is not true, print the log and return the specified status | |||
#define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ | |||
do { \ | |||
@@ -124,7 +130,7 @@ using cce::ccStatus_t; | |||
{ \ | |||
bool b = (expr); \ | |||
if (!b) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
DOMI_LOGE(__VA_ARGS__); \ | |||
exec_expr; \ | |||
} \ | |||
}; | |||
@@ -163,7 +169,7 @@ using cce::ccStatus_t; | |||
{ \ | |||
bool b = (expr); \ | |||
if (b) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
DOMI_LOGE(__VA_ARGS__); \ | |||
exec_expr; \ | |||
} \ | |||
}; | |||
@@ -182,7 +188,7 @@ using cce::ccStatus_t; | |||
{ \ | |||
bool b = (expr); \ | |||
if (b) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
DOMI_LOGE(__VA_ARGS__); \ | |||
exec_expr; \ | |||
return; \ | |||
} \ | |||
@@ -193,7 +199,7 @@ using cce::ccStatus_t; | |||
{ \ | |||
bool b = (expr); \ | |||
if (b) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
DOMI_LOGE(__VA_ARGS__); \ | |||
exec_expr; \ | |||
return _status; \ | |||
} \ | |||
@@ -210,62 +216,42 @@ using cce::ccStatus_t; | |||
// -----------------runtime related macro definitions------------------------------- | |||
// If expr is not RT_ERROR_NONE, print the log | |||
#define GE_CHK_RT(expr) \ | |||
do { \ | |||
rtError_t _rt_ret = (expr); \ | |||
if (_rt_ret != RT_ERROR_NONE) { \ | |||
GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
} \ | |||
#define GE_CHK_RT(expr) \ | |||
do { \ | |||
rtError_t _rt_ret = (expr); \ | |||
if (_rt_ret != RT_ERROR_NONE) { \ | |||
DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
} \ | |||
} while (0); | |||
// If expr is not RT_ERROR_NONE, print the log and execute the exec_expr expression | |||
#define GE_CHK_RT_EXEC(expr, exec_expr) \ | |||
{ \ | |||
rtError_t _rt_ret = (expr); \ | |||
if (_rt_ret != RT_ERROR_NONE) { \ | |||
GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
exec_expr; \ | |||
} \ | |||
#define GE_CHK_RT_EXEC(expr, exec_expr) \ | |||
{ \ | |||
rtError_t _rt_ret = (expr); \ | |||
if (_rt_ret != RT_ERROR_NONE) { \ | |||
DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
exec_expr; \ | |||
} \ | |||
} | |||
// If expr is not RT_ERROR_NONE, print the log and return | |||
#define GE_CHK_RT_RET(expr) \ | |||
do { \ | |||
rtError_t _rt_ret = (expr); \ | |||
if (_rt_ret != RT_ERROR_NONE) { \ | |||
GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
return ge::RT_FAILED; \ | |||
} \ | |||
#define GE_CHK_RT_RET(expr) \ | |||
do { \ | |||
rtError_t _rt_ret = (expr); \ | |||
if (_rt_ret != RT_ERROR_NONE) { \ | |||
DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
return ge::RT_FAILED; \ | |||
} \ | |||
} while (0); | |||
// ------------------------cce related macro definitions---------------------------- | |||
// If expr is not CC_STATUS_SUCCESS, print the log | |||
#define GE_CHK_CCE(expr) \ | |||
do { \ | |||
ccStatus_t _cc_ret = (expr); \ | |||
if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||
GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||
} \ | |||
} while (0); | |||
// If expr is not CC_STATUS_SUCCESS, print the log and execute the exec_expr expression | |||
#define GE_CHK_CCE_EXEC(expr, exec_expr) \ | |||
do { \ | |||
ccStatus_t _cc_ret = (expr); \ | |||
if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||
GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||
exec_expr; \ | |||
} \ | |||
} while (0); | |||
// If expr is not CC_STATUS_SUCCESS, print the log and return | |||
#define GE_CHK_CCE_RET(expr) \ | |||
do { \ | |||
ccStatus_t _cc_ret = (expr); \ | |||
if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||
GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||
return ge::CCE_FAILED; \ | |||
} \ | |||
#define GE_CHK_CCE(expr) \ | |||
do { \ | |||
ccStatus_t _cc_ret = (expr); \ | |||
if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||
DOMI_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||
} \ | |||
} while (0); | |||
// If expr is true, execute exec_expr without printing logs | |||
@@ -281,37 +267,8 @@ using cce::ccStatus_t; | |||
try { \ | |||
exec_expr0; \ | |||
} catch (const std::bad_alloc &) { \ | |||
GE_LOGE("Make shared failed"); \ | |||
DOMI_LOGE("Make shared failed"); \ | |||
exec_expr1; \ | |||
} | |||
#define GE_CHECK_INT32_MUL_OVERFLOW(a, b, ...) \ | |||
do { \ | |||
if ((a) > 0) { \ | |||
if ((b) > 0) { \ | |||
if ((a) > (INT32_MAX / (b))) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
return ge::FAILED; \ | |||
} \ | |||
} else { \ | |||
if ((b) < (INT32_MIN / (a))) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
return ge::FAILED; \ | |||
} \ | |||
} \ | |||
} else { \ | |||
if ((b) > 0) { \ | |||
if ((a) < (INT32_MAX / (b))) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
return ge::FAILED; \ | |||
} \ | |||
} else { \ | |||
if (((a) != 0) && ((b) < (INT32_MAX / (a)))) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
return ge::FAILED; \ | |||
} \ | |||
} \ | |||
} \ | |||
} while (0); | |||
#endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ |
@@ -152,7 +152,6 @@ GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_INVALID, 11, "Get computeGraph by g | |||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 | |||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 | |||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 | |||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_TINY_CAL_CHECK_FAILED, 15, "Check tiny calibration failed."); // 1343242255 | |||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 | |||
GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 | |||
GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 | |||
@@ -204,15 +203,16 @@ GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_GET_GRAPH_REBUILD_FAILED, 60, | |||
GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_SET_GRAPH_FINISH_REBUILD_GRAPH_FAILED, 61, | |||
"Failed set graph finish rebuild in node searcher."); // 1343242301 | |||
GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_OP_PASS_FAILED, 62, "Failed to run variable pass."); // 1343242302 | |||
// Optimize errocode | |||
GE_ERRORNO_GRAPH(TO_BE_DELETED, 200, "The node of the graph to be deleted."); | |||
GE_ERRORNO_GRAPH(NOT_CHANGED, 201, "NThe node of the graph not changed."); | |||
// Engine_manager module error code definition | |||
GE_ERRORNO_ENGINE(GE_ENG_INIT_FAILED, 0, "Failed to initialize engine."); // 1343246336 | |||
GE_ERRORNO_ENGINE(GE_ENG_FINALIZE_FAILED, 1, "Engine finalize failed."); // 1343246337 | |||
GE_ERRORNO_ENGINE(GE_ENG_MEMTYPE_ERROR, 2, "Memory type HBM is necessary when engine is in device"); // 1343246338 | |||
// Optimize errocode | |||
GE_ERRORNO_GRAPH(TO_BE_DELETED, 63, "The node of the graph to be deleted."); // 1343242303 | |||
GE_ERRORNO_GRAPH(NOT_CHANGED, 64, "The node of the graph no changed."); // 1343242304 | |||
// Ops module error code definition | |||
GE_ERRORNO_OPS(GE_OPS_KERNEL_STORE_INIT_FAILED, 0, "Failed to initialize OpsKernelInfoStore."); // 1343250432 | |||
GE_ERRORNO_OPS(GE_OPS_GRAPH_OPTIMIZER_INIT_FAILED, 1, "Failed to initialize GraphOptimizer."); // 1343250433 | |||
@@ -24,8 +24,7 @@ | |||
#include "common/fmk_error_codes.h" | |||
#include "ge/ge_api_error_codes.h" | |||
using std::string; | |||
#include "external/graph/types.h" | |||
namespace ge { | |||
enum RuntimeType { HOST = 0, DEVICE = 1 }; | |||
@@ -56,7 +55,7 @@ struct DataBuffer { | |||
/// | |||
/// @ingroup domi_ome | |||
/// @brief External inputdata | |||
/// @brief External input data | |||
/// | |||
struct InputData { | |||
uint32_t index; // Index of input data | |||
@@ -65,13 +64,14 @@ struct InputData { | |||
uint32_t model_id; // Model ID required for data processing | |||
uint64_t request_id = 0; // Request ID | |||
std::vector<DataBuffer> blobs; // Actual input data, currently only supports one input | |||
bool is_dynamic_batch = false; // Whether is dynamic batch size scene, default:false | |||
std::string batch_label; // Gear used for current inference in dynamic batch scene | |||
}; | |||
// The definition of output result structure | |||
/// Output result structure definition | |||
struct OutputData { | |||
uint32_t index; // Index of input data | |||
uint32_t model_id; // The model ID corresponding to the processing result | |||
/// Output data cache, arranged in sequence of output operators. | |||
/// If the operator has multiple outputs, | |||
/// the data buffer order of the operator is the same as that defined in the | |||
@@ -142,11 +142,31 @@ struct Options { | |||
bool deployMode; | |||
bool isAICPUMode; | |||
bool enable_atomic; | |||
string podName; | |||
std::string podName; | |||
int64_t rankId; | |||
string rankTableFile; | |||
std::string rankTableFile; | |||
int32_t ge_hccl_flag = 0; | |||
int32_t physical_device_id; | |||
}; | |||
// Profiling info of task | |||
struct TaskDescInfo { | |||
std::string op_name; | |||
uint32_t block_dim; | |||
uint32_t task_id; | |||
uint32_t stream_id; | |||
}; | |||
// Profiling info of graph | |||
struct ComputeGraphDescInfo { | |||
std::string op_name; | |||
std::string op_type; | |||
std::vector<Format> input_format; | |||
std::vector<std::vector<int64_t>> input_shape; | |||
std::vector<DataType> input_data_type; | |||
std::vector<Format> output_format; | |||
std::vector<std::vector<int64_t>> output_shape; | |||
std::vector<DataType> output_data_type; | |||
}; | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_COMMON_GE_TYPES_H_ |
@@ -19,7 +19,6 @@ | |||
#include <memory> | |||
#include <string> | |||
#include <memory> | |||
#include "common/fmk_types.h" | |||
#include "common/helper/om_file_helper.h" | |||
@@ -33,36 +32,41 @@ class ModelHelper { | |||
ModelHelper() = default; | |||
~ModelHelper(); | |||
Status SaveToOmModel(const GeModelPtr &ge_model, const SaveParam &save_param, const std::string &output_file); | |||
Status SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::string &output_file); | |||
Status LoadModel(const ge::ModelData &model_data); | |||
Status SaveToOmModel(const GeModelPtr& ge_model, const SaveParam& save_param, const std::string& output_file, | |||
ge::ModelBufferData& model); | |||
Status SaveOriginalGraphToOmModel(const ge::Graph& graph, const std::string& output_file); | |||
Status LoadModel(const ge::ModelData& model_data); | |||
Status GetModelBufferData(ge::ModelBufferData& model); | |||
ModelFileHeader *GetFileHeader() { return file_header_; } | |||
ModelFileHeader* GetFileHeader() { return file_header_; } | |||
GeModelPtr GetGeModel(); | |||
void SetSaveMode(bool val) { is_offline_ = val; } | |||
bool GetSaveMode(void) const { return is_offline_; } | |||
static Status TransModelToGeModel(const ModelPtr &model, GeModelPtr &ge_model); | |||
static Status TransGeModelToModel(const GeModelPtr &geModelPtr, ModelPtr &modelPtr); | |||
static Status TransModelToGeModel(const ModelPtr& model, GeModelPtr& ge_model); | |||
static Status TransGeModelToModel(const GeModelPtr& geModelPtr, ModelPtr& modelPtr); | |||
private: | |||
bool is_assign_model_ = false; | |||
ModelFileHeader *file_header_ = nullptr; | |||
bool is_offline_ = true; | |||
ModelFileHeader* file_header_ = nullptr; | |||
// Encrypted model need delete temp model and unencrypted model need not delete model | |||
uint8_t *model_addr_tmp_ = nullptr; | |||
uint8_t* model_addr_tmp_ = nullptr; | |||
uint32_t model_len_tmp_ = 0; | |||
GeModelPtr model_; | |||
ModelHelper(const ModelHelper &); | |||
ModelHelper &operator=(const ModelHelper &); | |||
Status GenerateGeModel(OmFileLoadHelper &om_load_helper); | |||
Status LoadModelData(OmFileLoadHelper &om_load_helper); | |||
void SetModelToGeModel(ge::Model &model); | |||
Status LoadWeights(OmFileLoadHelper &om_load_helper); | |||
Status LoadTask(OmFileLoadHelper &om_load_helper); | |||
Status LoadTBEKernelStore(OmFileLoadHelper &om_load_helper); | |||
ModelHelper(const ModelHelper&); | |||
ModelHelper& operator=(const ModelHelper&); | |||
Status GenerateGeModel(OmFileLoadHelper& om_load_helper); | |||
Status LoadModelData(OmFileLoadHelper& om_load_helper); | |||
void SetModelToGeModel(ge::Model& model); | |||
Status LoadWeights(OmFileLoadHelper& om_load_helper); | |||
Status LoadTask(OmFileLoadHelper& om_load_helper); | |||
Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); | |||
Status ReleaseLocalModelData() noexcept; | |||
Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, ModelPartitionType type, | |||
const uint8_t *data, size_t size); | |||
Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, ModelPartitionType type, | |||
const uint8_t* data, size_t size); | |||
}; | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ |
@@ -20,10 +20,12 @@ | |||
#include <string> | |||
#include <vector> | |||
#include "external/ge/ge_ir_build.h" | |||
#include "framework/common/fmk_types.h" | |||
#include "framework/common/ge_types.h" | |||
#include "framework/common/types.h" | |||
#include "framework/common/ge_types.h" | |||
using ProcParam = struct PROC_PARAM; | |||
using std::string; | |||
using std::vector; | |||
@@ -80,9 +82,10 @@ class OmFileSaveHelper { | |||
const std::vector<ModelPartition> &GetModelPartitions() const; | |||
Status SaveModel(const SaveParam &save_param, const char *target_file); | |||
Status SaveModel(const SaveParam &save_param, const char *target_file, ge::ModelBufferData &model, | |||
bool is_offline = true); | |||
Status SaveModelToFile(const char *output_file); | |||
Status SaveModelToFile(const char *output_file, ge::ModelBufferData &model, bool is_offline = true); | |||
ModelFileHeader model_header_; | |||
OmFileContext context_; | |||
@@ -120,4 +120,4 @@ class L2CacheOptimize { | |||
}; | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ | |||
#endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ |
@@ -649,6 +649,8 @@ extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_M | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_EVENT_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_LABEL_NUM; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_MEMORY_SIZE; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_WEIGHT_SIZE; | |||
@@ -801,6 +803,8 @@ extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_N | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | |||
// For constant folding | |||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||
} // namespace domi | |||
#endif // INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ |
@@ -17,11 +17,12 @@ | |||
#ifndef INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | |||
#define INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | |||
#include <google/protobuf/map.h> | |||
#include <unordered_map> | |||
#include <string> | |||
#include <google/protobuf/map.h> | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "proto/om.pb.h" | |||
using domi::AttrDef; | |||
@@ -18,7 +18,6 @@ | |||
#define INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ | |||
#include <cce/dnn.h> | |||
#include <memory> | |||
#include <vector> | |||
@@ -56,6 +55,15 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_TR | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_DATA_INPUT; | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_PRED_INPUT; | |||
// FunctionOp | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t IF_COND_INPUT; | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_START_INPUT; | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_LIMIT_INPUT; | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DELTA_INPUT; | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DATA_INPUT; | |||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int NORMAL_TENSOR_SIZE; | |||
class OpUtils { | |||
public: | |||
/// | |||
@@ -164,15 +172,23 @@ class OpUtils { | |||
/// | |||
static Status ConvertAippParams(const GeAttrValue::NamedAttrs &aipp_attr, domi::AippOpParams *aipp_params); | |||
static Status TransferDim(const std::vector<int64_t> &dim, std::vector<int64_t> &dim_vector); | |||
static void SliceData(std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output, int64_t begin, | |||
int64_t out_dim, int64_t stride); | |||
template <typename T> | |||
static void SliceData(const std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output, | |||
int64_t begin, int64_t out_dim, int64_t stride); | |||
template <typename T> | |||
static Status SetDataByDataType(size_t out_size, const std::vector<char *> &chunk_input, | |||
const std::vector<char *> &chunk_output, GeTensor *output); | |||
template <typename T> | |||
static Status SetOutputSliceDataByDataType(void *data, int64_t data_size, const std::vector<int64_t> &input_dims, | |||
const std::vector<int64_t> &begin, const std::vector<int64_t> &output_dims, | |||
ge::GeTensor *output, const std::vector<int64_t> &stride); | |||
static Status SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector<int64_t> &input_dims, | |||
std::vector<int64_t> &begin, std::vector<int64_t> &output_dims, ge::GeTensor *output, | |||
std::vector<int64_t> &stride); | |||
/// | |||
/// @ingroup domi_omg | |||
/// @brief Convert the convolution weight data from [h, w, c, k] to [k, c, h, w] | |||
/// @brief Convert the convolutional weight data from [h, w, c, k] to [k, c, h, w] | |||
/// @param [in] input Weight data in HWCK format | |||
/// @param [in] H value of H dimension | |||
/// @param [in] W value of W dimension | |||
@@ -183,7 +199,7 @@ class OpUtils { | |||
static void TransDataHWCK2KCHW(const void *input, int64_t H, int64_t W, int64_t C, int64_t K, void **output); | |||
/// | |||
/// @ingroup domi_omg | |||
/// @brief Converts the convolution weight data from [k, c, h, w] to [h, w, c, k]. | |||
/// @brief Converts the convolutional weight data from [k, c, h, w] to [h, w, c, k]. | |||
/// @param [in] input Weight data in HWCK format | |||
/// @param [in] K value of K dimension | |||
/// @param [in] C value of C dimension | |||
@@ -222,7 +238,6 @@ using CceTensorDescriptorPtr = std::shared_ptr<CceTensorDescriptor>; | |||
class CceTensorDescriptor { | |||
public: | |||
explicit CceTensorDescriptor(ccTensorDescriptor_t cc_tensor); | |||
CceTensorDescriptor(const CceTensorDescriptor &) = delete; | |||
CceTensorDescriptor &operator=(const CceTensorDescriptor &) = delete; | |||
@@ -22,7 +22,7 @@ | |||
#include <math.h> | |||
#include <stdint.h> | |||
namespace domi { | |||
namespace ge { | |||
// general | |||
const float DEFAULT_ALPHA_VALUE = 1.0; | |||
const float DEFAULT_BETA_VALUE = 0.0; | |||
@@ -421,5 +421,5 @@ const uint32_t MULTI_SHAPE_INPUT_NUM = 2; | |||
// Shufflechannel | |||
const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; | |||
} // namespace domi | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ |
@@ -25,7 +25,7 @@ | |||
/// MAKE_GUARD([&] { Release Resource 1 }) | |||
/// Acquire Resource 2 | |||
// MAKE_GUARD([&] { Release Resource 2 }) | |||
#define GE_MAKE_GUARD(var, callback) ge::ScopeGuard make_guard_##var(callback) | |||
#define GE_MAKE_GUARD(var, callback) ScopeGuard make_guard_##var(callback) | |||
#define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() | |||
namespace ge { | |||
@@ -156,6 +156,7 @@ REGISTER_OPTYPE_DECLARE(GATHER, "Gather"); | |||
REGISTER_OPTYPE_DECLARE(REALDIV, "RealDiv"); | |||
REGISTER_OPTYPE_DECLARE(PACK, "Pack"); | |||
REGISTER_OPTYPE_DECLARE(SLICE, "Slice"); | |||
REGISTER_OPTYPE_DECLARE(SLICED, "SliceD"); | |||
REGISTER_OPTYPE_DECLARE(FLOORDIV, "FloorDiv"); | |||
REGISTER_OPTYPE_DECLARE(SQUEEZE, "Squeeze"); | |||
REGISTER_OPTYPE_DECLARE(STRIDEDSLICE, "StridedSlice"); | |||
@@ -188,6 +189,19 @@ REGISTER_OPTYPE_DECLARE(REFNEXTITERATION, "RefNextIteration"); | |||
REGISTER_OPTYPE_DECLARE(EXIT, "Exit"); | |||
REGISTER_OPTYPE_DECLARE(REFEXIT, "RefExit"); | |||
REGISTER_OPTYPE_DECLARE(CONTROLTRIGGER, "ControlTrigger"); | |||
REGISTER_OPTYPE_DECLARE(SYMBOLICGRADIENT, "SymbolicGradient"); | |||
REGISTER_OPTYPE_DECLARE(REMOTECALL, "RemoteCall"); | |||
REGISTER_OPTYPE_DECLARE(_IF, "_If"); | |||
REGISTER_OPTYPE_DECLARE(STATELESSIF, "StatelessIf"); | |||
REGISTER_OPTYPE_DECLARE(IF, "If"); | |||
REGISTER_OPTYPE_DECLARE(CASE, "Case"); | |||
REGISTER_OPTYPE_DECLARE(_WHILE, "_While"); | |||
REGISTER_OPTYPE_DECLARE(WHILE, "While"); | |||
REGISTER_OPTYPE_DECLARE(STATELESSWHILE, "StatelessWhile"); | |||
REGISTER_OPTYPE_DECLARE(FOR, "For"); | |||
REGISTER_OPTYPE_DECLARE(PARTITIONEDCALL, "PartitionedCall"); | |||
REGISTER_OPTYPE_DECLARE(STATEFULPARTITIONEDCALL, "StatefulPartitionedCall"); | |||
REGISTER_OPTYPE_DECLARE(FAKEPARAM, "FakeParam"); | |||
REGISTER_OPTYPE_DECLARE(TRANSPOSE, "Transpose"); | |||
REGISTER_OPTYPE_DECLARE(TRANSPOSED, "TransposeD"); | |||
REGISTER_OPTYPE_DECLARE(CAST, "Cast"); | |||
@@ -424,6 +438,12 @@ REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | |||
REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | |||
REGISTER_OPTYPE_DECLARE(SEND, "Send"); | |||
REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | |||
REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | |||
REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | |||
REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); | |||
REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | |||
REGISTER_OPTYPE_DECLARE(ATOMICADDRCLEAN, "AtomicAddrClean"); | |||
REGISTER_OPTYPE_DECLARE(ABS_GRAD, "AbsGrad"); | |||
@@ -1032,14 +1052,11 @@ struct BasicInfo { | |||
uint32_t workspace_size; // workspace | |||
uint32_t total_size; // total memory size | |||
}; | |||
#pragma pack() // Cancels single-byte alignment | |||
} // namespace ge | |||
namespace domi { | |||
/// @brief Data structure definition related to task sinking | |||
/// Build model | |||
enum BuildMode { | |||
GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | |||
GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | |||
@@ -30,6 +30,14 @@ | |||
#include "framework/common/ge_inner_error_codes.h" | |||
#include "mmpa/mmpa_api.h" | |||
#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||
do { \ | |||
if (size <= 0) { \ | |||
DOMI_LOGE(param[#size] is not a positive number); \ | |||
return PARAM_INVALID; \ | |||
} \ | |||
} while (0) | |||
#define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ | |||
{ \ | |||
bool b = (expr); \ | |||
@@ -50,21 +58,6 @@ | |||
if (var) GE_CHK_RT(rtStreamDestroy(var)); \ | |||
}); | |||
#define GE_MAKE_GUARD_RTEVENT(var) \ | |||
GE_MAKE_GUARD(var, [&] { \ | |||
if (var) GE_CHK_RT(rtEventDestroy(var)); \ | |||
}); | |||
#define GE_MAKE_GUARD_TENSOR(var) \ | |||
GE_MAKE_GUARD(var, [&] { \ | |||
if (var) GE_CHK_CCE(ccDestroyTensorDescriptor(&var)); \ | |||
}); | |||
#define GE_MAKE_GUARD_FILTER_DESC(var) \ | |||
GE_MAKE_GUARD(var, [&] { \ | |||
if (var) GE_CHK_CCE(ccDestroyFilterDescriptor(&var)); \ | |||
}); | |||
// For propagating errors when calling a function. | |||
#define GE_RETURN_IF_ERROR(expr) \ | |||
do { \ | |||
@@ -76,7 +69,7 @@ | |||
do { \ | |||
const ::ge::Status _status = (expr); \ | |||
if (_status) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
DOMI_LOGE(__VA_ARGS__); \ | |||
return _status; \ | |||
} \ | |||
} while (0) | |||
@@ -85,7 +78,7 @@ | |||
#define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ | |||
do { \ | |||
if (condition) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
DOMI_LOGE(__VA_ARGS__); \ | |||
return ge::FAILED; \ | |||
} \ | |||
} while (0) | |||
@@ -95,7 +88,7 @@ | |||
do { \ | |||
bool _condition = (condition); \ | |||
if (!_condition) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
DOMI_LOGE(__VA_ARGS__); \ | |||
return ge::FAILED; \ | |||
} \ | |||
} while (0) | |||
@@ -104,7 +97,7 @@ | |||
#define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ | |||
do { \ | |||
if (condition) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
DOMI_LOGE(__VA_ARGS__); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
} while (0) | |||
@@ -114,111 +107,90 @@ | |||
do { \ | |||
bool _condition = (condition); \ | |||
if (!_condition) { \ | |||
GE_LOGE(__VA_ARGS__); \ | |||
DOMI_LOGE(__VA_ARGS__); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
} while (0) | |||
// Check if the parameter is null. If yes, return PARAM_INVALID and record the error | |||
#define GE_CHECK_NOTNULL(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
GE_LOGE(param[#val] must not be null.); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
#define GE_CHECK_NOTNULL(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE(param[#val] must not be null.); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
} while (0) | |||
// Check if the parameter is null. If yes, return PARAM_INVALID and record the error | |||
#define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
GE_LOGE(param[#val] must not be null.); \ | |||
return; \ | |||
} \ | |||
// Check if the parameter is null. If yes, just return and record the error | |||
#define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE(param[#val] must not be null.); \ | |||
return; \ | |||
} \ | |||
} while (0) | |||
// Check whether the parameter is null. If so, execute the exec_expr expression and record the error log | |||
#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
GE_LOGE(param[#val] must not be null.); \ | |||
exec_expr; \ | |||
} \ | |||
#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE(param[#val] must not be null.); \ | |||
exec_expr; \ | |||
} \ | |||
} while (0) | |||
// Check whether the parameter is null. If yes, return directly and record the error log | |||
#define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
GE_LOGE(param[#val] must not be null.); \ | |||
return; \ | |||
} \ | |||
#define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE(param[#val] must not be null.); \ | |||
return; \ | |||
} \ | |||
} while (0) | |||
// Check if the parameter is null. If yes, return false and record the error log | |||
#define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
GE_LOGE(param[#val] must not be null.); \ | |||
return false; \ | |||
} \ | |||
#define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||
do { \ | |||
if (val == nullptr) { \ | |||
DOMI_LOGE(param[#val] must not be null.); \ | |||
return false; \ | |||
} \ | |||
} while (0) | |||
// Check if the parameter is out of bounds | |||
#define GE_CHECK_SIZE(size) \ | |||
do { \ | |||
if (size == 0) { \ | |||
GE_LOGE(param[#size] is out of range); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
#define GE_CHECK_SIZE(size) \ | |||
do { \ | |||
if (size == 0) { \ | |||
DOMI_LOGE(param[#size] is out of range); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
} while (0) | |||
// Macros that define the size variable | |||
#define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \ | |||
uint32_t _var_name; \ | |||
do { \ | |||
uint32_t _expr_size = (_expr); \ | |||
uint32_t _sizeof_size = (_sizeof); \ | |||
if (_expr_size > (0xffffffff) / _sizeof_size) { \ | |||
GE_LOGE(byte size : #_var_name is out of range); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
_var_name = _sizeof_size * _expr_size; \ | |||
} while (0); | |||
// Check if the container is empty | |||
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||
do { \ | |||
if (vector.empty()) { \ | |||
GE_LOGE(param[#vector] is empty !); \ | |||
return ge::FAILED; \ | |||
} \ | |||
} while (0) | |||
#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||
do { \ | |||
if (size <= 0) { \ | |||
GE_LOGE(param[#size] is not a positive number); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||
do { \ | |||
if (vector.empty()) { \ | |||
DOMI_LOGE(param[#vector] is empty !); \ | |||
return ge::FAILED; \ | |||
} \ | |||
} while (0) | |||
// Check if the value on the left is greater than or equal to the value on the right | |||
#define GE_CHECK_GE(lhs, rhs) \ | |||
do { \ | |||
if (lhs < rhs) { \ | |||
GE_LOGE(param[#lhs] is less than[#rhs]); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
#define GE_CHECK_GE(lhs, rhs) \ | |||
do { \ | |||
if (lhs < rhs) { \ | |||
DOMI_LOGE(param[#lhs] is less than[#rhs]); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
} while (0) | |||
// Check if the value on the left is less than or equal to the value on the right | |||
#define GE_CHECK_LE(lhs, rhs) \ | |||
do { \ | |||
if (lhs > rhs) { \ | |||
GE_LOGE(param[#lhs] is greater than[#rhs]); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
#define GE_CHECK_LE(lhs, rhs) \ | |||
do { \ | |||
if (lhs > rhs) { \ | |||
DOMI_LOGE(param[#lhs] is greater than[#rhs]); \ | |||
return ge::PARAM_INVALID; \ | |||
} \ | |||
} while (0) | |||
#define GE_DELETE_NEW_SINGLE(var) \ | |||
@@ -52,10 +52,10 @@ | |||
#define DLOG_DECLARE(level) \ | |||
void Log_##level(const char *mod_name, const char *func, const char *file, int line, const char *format, ...) | |||
namespace ge { | |||
namespace domi { | |||
DLOG_DECLARE(INFO); | |||
DLOG_DECLARE(WARNING); | |||
DLOG_DECLARE(ERROR); | |||
} // namespace ge | |||
} // namespace domi | |||
#endif // INC_FRAMEWORK_DLOG_LOG_H_ |
@@ -38,7 +38,7 @@ struct DNNEngineAttribute { | |||
std::vector<std::string> mem_type; | |||
uint32_t compute_cost; | |||
enum RuntimeType runtime_type; // HOST, DEVICE | |||
// set this attribute if the inputformat of engine must be specific, otherwise set FORMAT_RESERVED | |||
// If engine input format must be specific, set this attribute, else set FORMAT_RESERVED | |||
Format engine_input_format; | |||
Format engine_output_format; | |||
}; | |||
@@ -26,6 +26,7 @@ | |||
#include "common/types.h" | |||
#include "graph/tensor.h" | |||
#include "runtime/base.h" | |||
#include "common/dynamic_aipp.h" | |||
namespace ge { | |||
class ModelListenerAdapter; | |||
@@ -33,12 +34,15 @@ class ModelListenerAdapter; | |||
class SingleOp; | |||
struct RunModelData { | |||
uint32_t index; // Data index | |||
uint32_t model_id; // Model id | |||
std::vector<DataBuffer> blobs; // All input/output data buffer | |||
uint32_t timestamp; // Data creation time | |||
uint32_t timeout; // Processing timeout | |||
uint64_t request_id = 0; // Request ID | |||
uint32_t index; // Data index | |||
uint32_t modelId; | |||
std::vector<DataBuffer> blobs; // All input/output data buffer | |||
uint32_t timestamp; // Data creation time | |||
uint32_t timeout; // Processing timeout | |||
uint64_t request_id = 0; // Request ID | |||
uint64_t dynamic_batch_size = 0; // Dynamic batch size scene, set dynamic size, not supported by default:0 | |||
uint64_t dynamic_image_height = 0; // Dynamic image size scene, set image height, not supported by default:0 | |||
uint64_t dynamic_image_width = 0; // Dynamic image size scene, set image width, not supported by default:0 | |||
}; | |||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||
@@ -46,12 +50,13 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||
GeExecutor(); | |||
~GeExecutor() = default; | |||
ge::Status Initialize(); | |||
ge::Status Finalize(); | |||
// Load model | |||
ge::Status LoadModelOffline(uint32_t &model_id, const std::string &path, const std::string &key, int32_t priority, | |||
std::shared_ptr<ge::ModelListener> listener); | |||
ge::Status UnloadModel(uint32_t model_id); | |||
ge::Status UnloadModel(uint32_t modelId); | |||
ge::Status RunModel(const ge::RunModelData &input_data, ge::RunModelData &output_data); | |||
@@ -59,6 +64,52 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||
ge::Status GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | |||
std::vector<ge::TensorDesc> &output_desc); | |||
/// | |||
/// @ingroup ge | |||
/// @brief Set dynamic batch size | |||
/// @param [in] model_id: model id allocate from manager | |||
/// @param [in] dynamic_input_addr: dynamic input addr created by user | |||
/// @param [in] length: length of dynamic input addr | |||
/// @param [in] batch_size: batch size entered by user in dynamic multi-batch scenario | |||
/// @return execute result | |||
/// | |||
ge::Status SetDynamicBatchSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t batch_size); | |||
/// | |||
/// @ingroup ge | |||
/// @brief Set dynamic image info | |||
/// @param [in] model_id: model id allocate from manager | |||
/// @param [in] dynamic_input_addr: dynamic input addr created by user | |||
/// @param [in] length: length of dynamic input addr | |||
/// @param [in] image_height: image height entered by user in dynamic multi-resolution scenario | |||
/// @param [in] image_width: image width entered by user in dynamic multi-resolution scenario | |||
/// @return execute result | |||
/// | |||
ge::Status SetDynamicImageSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t image_height, | |||
uint64_t image_width); | |||
/// | |||
/// @ingroup ge | |||
/// @brief Get dynamic batch_info | |||
/// @param [in] model_id | |||
/// @param [out] batch_info | |||
/// @return execute result | |||
/// | |||
ge::Status GetDynamicBatchInfo(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info); | |||
/// | |||
/// @ingroup ge | |||
/// @brief Set dynamic image info | |||
/// @param [in] model_id: model id allocate from manager | |||
/// @param [in] dynamic_input_addr: dynamic input addr created by user | |||
/// @param [in] length: length of dynamic input addr | |||
/// @param [in] aippBatchPara: kAippDynamicBatchPara vector by user in dynamic aipp | |||
/// @param [in] aippParms: kAippDynamicPara by user in dynamic aipp | |||
/// @return execute result | |||
/// | |||
ge::Status SetDynamicAippData(uint32_t model_id, void *dynamic_input_addr, uint64_t length, | |||
const std::vector<kAippDynamicBatchPara> &aippBatchPara, | |||
const kAippDynamicPara &aippParms); | |||
ge::Status GetModelDescInfoForZeroCopy(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | |||
std::vector<ge::TensorDesc> &output_desc); | |||
@@ -147,7 +198,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||
/// | |||
ge::Status GetMemAndWeightSize(const void *model_data, size_t model_size, size_t &mem_size, size_t &weight_size); | |||
static ge::Status LoadSingleOp(const std::string &model_name, const ge::ModelData &model_data, void *stream, | |||
static ge::Status LoadSingleOp(const std::string &modelName, const ge::ModelData &modelData, void *stream, | |||
SingleOp **single_op); | |||
static ge::Status ExecuteAsync(SingleOp *executor, const std::vector<DataBuffer> &inputs, | |||
@@ -156,8 +207,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||
static ge::Status ReleaseSingleOpResource(void *stream); | |||
private: | |||
static bool is_init_; | |||
std::vector<std::shared_ptr<ModelListenerAdapter>> listener_adapters_; | |||
static bool isInit_; | |||
}; | |||
ge::Status ModelInfoParser(const ge::ModelData &model, ge::ModelInfo &model_info); | |||
@@ -21,7 +21,7 @@ | |||
#include <memory> | |||
#include <string> | |||
#include <vector> | |||
#include "ge/ge_ir_build.h" | |||
#include "common/ge_inner_error_codes.h" | |||
#include "graph/ge_tensor.h" | |||
#include "graph/graph.h" | |||
@@ -45,6 +45,8 @@ class GeGenerator { | |||
Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix, | |||
const std::vector<GeTensor> &inputs = std::vector<GeTensor>()); | |||
Status GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ge::ModelBufferData &model); | |||
/// | |||
/// @ingroup ge | |||
/// @brief: Build single OP in Model. | |||
@@ -58,6 +60,8 @@ class GeGenerator { | |||
const std::vector<GeTensor> &outputs, const std::string &model_file_name); | |||
private: | |||
Status GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | |||
ge::ModelBufferData &model, bool is_offline = true); | |||
class Impl; | |||
std::shared_ptr<Impl> impl_; | |||
@@ -24,7 +24,6 @@ extern "C" { | |||
#endif | |||
typedef uint32_t Status_t; | |||
using Status_t = uint32_t; | |||
typedef void *OpAttr_t; | |||
typedef void *OpTensor_t; | |||
@@ -23,7 +23,7 @@ | |||
#include "graph/node.h" | |||
namespace ge { | |||
const int64_t kMemAlignSize = 512; | |||
const int64_t MEM_ALIGN_SIZE = 512; | |||
class MemoryAssigner { | |||
public: | |||
explicit MemoryAssigner(ge::ComputeGraphPtr compute_graph) : compute_graph_(std::move(compute_graph)) {} | |||
@@ -39,4 +39,4 @@ class MemoryAssigner { | |||
ge::ComputeGraphPtr compute_graph_; | |||
}; | |||
} // namespace ge | |||
#endif // INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_ | |||
#endif // INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_ |
@@ -31,7 +31,6 @@ | |||
using domi::DOMI_TENSOR_ND; | |||
using domi::DOMI_TENSOR_RESERVED; | |||
using domi::domiTensorFormat_t; | |||
using domi::FMK_TYPE_RESERVED; | |||
using domi::FrameworkType; | |||
using std::map; | |||
using std::string; | |||
@@ -44,10 +43,10 @@ namespace ge { | |||
* @brief run model | |||
*/ | |||
enum RunMode { | |||
kGeOmModel = 0, // generate offline model file | |||
kModelToJson = 1, // convert to JSON file | |||
kOnlyPreCheck = 3, // only for pre-check | |||
kPbtxtToJson = 5 // pbtxt to json | |||
GEN_OM_MODEL = 0, // generate offline model file | |||
MODEL_TO_JSON = 1, // convert to JSON file | |||
ONLY_PRE_CHECK = 3, // only for pre-check | |||
PBTXT_TO_JSON = 5 // pbtxt to json | |||
}; | |||
/// | |||
@@ -56,10 +55,10 @@ enum RunMode { | |||
/// | |||
enum HighPrecisionMode { | |||
// the FP16 high-precision function is disabled in common mode | |||
kHighPrecisonDefault = 0, | |||
HIGH_PRECISION_DEFAULT = 0, | |||
// high-precision mode, in which FP16 high-precision mode (Convolution/FullConnect/AvgPooling are involved) is enable | |||
kHighPrecisionFP16 = 1 | |||
// high-precision mode, enabling FP16 high-precision mode (Convolution/FullConnect/AvgPooling are involved) | |||
HIGH_PRECISION_FP16 = 1 | |||
}; | |||
/// | |||
@@ -99,21 +98,23 @@ struct OmgContext { | |||
// preferential format used by the entire network | |||
domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | |||
domi::FrameworkType type = domi::FMK_TYPE_RESERVED; | |||
RunMode run_mode = kOnlyPreCheck; | |||
RunMode run_mode = ONLY_PRE_CHECK; | |||
bool train_flag = false; | |||
// whether to use FP16 high precision | |||
int32_t fp16_high_precision = kHighPrecisonDefault; | |||
int32_t fp16_high_precision = HIGH_PRECISION_DEFAULT; | |||
std::string output_type; | |||
// Save the name of the entire network: Some special operators are used to determine a network. Some operators in the | |||
// network require special processing based on the specific network. | |||
// e.g:faster-rcnn, the FirstStageProcessor module is determined as the Faster-R-CNN network based on the scope | |||
// fusion. Then, the conv+reshape operators in the FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The | |||
// convolution kernel rearrangement reshape operator needs to be deleted for the convolution kernel. | |||
// network require special processing based on the specific network. e.g:faster-rcnn, the FirstStageProcessor module | |||
// is determined as the Faster-R-CNN network based on the scope fusion. Then, the conv+reshape operators in the | |||
// FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The convolution kernel rearrangement reshape | |||
// operator needs to be deleted for the convolution kernel. | |||
std::string net_name; | |||
// whether to enable dynamic batch | |||
bool enable_l2dynamic = false; | |||
// Whether to use dynamic batch size or dynamic image size | |||
bool is_dynamic_input = false; | |||
std::string dynamic_batch_size; | |||
std::string dynamic_image_size; | |||
}; | |||
} // namespace ge | |||
@@ -32,15 +32,7 @@ class PlatformVersionManager { | |||
PlatformVersionManager() = delete; | |||
~PlatformVersionManager() = delete; | |||
static Status GetPlatformVersion(std::string &ver) { | |||
#if defined PLATFORM_PHOENIX | |||
ver = "3.51.z"; | |||
#elif defined PLATFORM_ORLANDO | |||
ver = "3.31.z"; | |||
#elif defined PLATFORM_MINI | |||
ver = "1.11.z"; | |||
#elif defined PLATFORM_CLOUD | |||
ver = "1.61.z"; | |||
#endif | |||
std::vector<std::string> version_splits = StringUtils::Split(ver, '.'); | |||
GE_IF_BOOL_EXEC(version_splits.size() < 3, GELOGW("Read platform version error!"); return FAILED;); | |||
@@ -20,13 +20,17 @@ | |||
#include <memory> | |||
#include <string> | |||
#include <vector> | |||
#include "graph/ge_error_codes.h" | |||
#include "graph/range_vistor.h" | |||
#include "graph/types.h" | |||
namespace ge { | |||
enum AnchorStatus { ANCHOR_SUSPEND = 0, ANCHOR_CONST = 1, ANCHOR_DATA = 2, ANCHOR_RESERVED = 3 }; | |||
enum AnchorStatus { | |||
ANCHOR_SUSPEND = 0, // dat null | |||
ANCHOR_CONST = 1, | |||
ANCHOR_DATA = 2, // Effective | |||
ANCHOR_RESERVED = 3 | |||
}; | |||
using std::string; | |||
using std::vector; | |||
@@ -81,17 +85,19 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable | |||
virtual ~Anchor() = default; | |||
protected: | |||
// Whether the two anchors are equal | |||
// Whether the two anchor is equal | |||
virtual bool Equal(AnchorPtr anchor) const = 0; | |||
virtual bool IsTypeOf(TYPE type) const; | |||
public: | |||
// Get all peer anchors connected to current anchor | |||
Vistor<AnchorPtr> GetPeerAnchors() const; | |||
// Get the first peer anchor | |||
// Get peer anchor size | |||
size_t GetPeerAnchorsSize() const; | |||
// Get first peer anchor | |||
AnchorPtr GetFirstPeerAnchor() const; | |||
// Get the node which is the owner of the anchor | |||
// Get the anchor belong to which node | |||
NodePtr GetOwnerNode() const; | |||
// Remove all links with the anchor | |||
@@ -100,22 +106,22 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable | |||
// Remove link with the given anchor | |||
graphStatus Unlink(const AnchorPtr &peer); | |||
// Replace the peeranchor with the new peeranchor | |||
// Replace peer with new peers | |||
graphStatus ReplacePeer(const AnchorPtr &oldPeer, const AnchorPtr &firstPeer, const AnchorPtr &secondPeer); | |||
// Judge if the anchor is linked with the given anchor | |||
bool IsLinkedWith(const AnchorPtr &peer); | |||
// Get the anchor index of the node | |||
// Get anchor index of the node | |||
int GetIdx() const; | |||
// Set the anchor index of the node | |||
// set anchor index of the node | |||
void SetIdx(int index); | |||
protected: | |||
// All peer anchors connected to current anchor | |||
vector<std::weak_ptr<Anchor>> peer_anchors_; | |||
// The owner nodes of the anchor | |||
// The owner node of anchor | |||
std::weak_ptr<Node> owner_node_; | |||
// The index of current anchor | |||
int idx_; | |||
@@ -167,7 +173,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchor : public DataA | |||
virtual ~InDataAnchor() = default; | |||
// Get source out data anchor | |||
// Get source out data anchor | |||
OutDataAnchorPtr GetPeerOutAnchor() const; | |||
// Build connection from OutDataAnchor to InDataAnchor | |||
@@ -19,10 +19,10 @@ | |||
#include <string> | |||
#include <vector> | |||
#include "graph/ge_attr_value.h" | |||
namespace ge { | |||
class GeAttrValue; | |||
class _GeSerializable { | |||
public: | |||
@@ -107,7 +107,6 @@ class _GeSerializable { | |||
static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) { return GRAPH_SUCCESS; } | |||
}; | |||
#define _GE_FI(a) #a, a | |||
#define _GE_MAP_FIELDS1(a1) _GE_FI(a1) | |||
#define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2) | |||
@@ -130,23 +129,23 @@ class _GeSerializable { | |||
#define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \ | |||
_GE_FI(a1) \ | |||
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||
_GE_FI(a11) | |||
_GE_FI(a11) | |||
#define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \ | |||
_GE_FI(a1) \ | |||
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||
_GE_FI(a11), _GE_FI(a12) | |||
_GE_FI(a11), _GE_FI(a12) | |||
#define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \ | |||
_GE_FI(a1) \ | |||
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13) | |||
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13) | |||
#define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \ | |||
_GE_FI(a1) \ | |||
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14) | |||
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14) | |||
#define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \ | |||
_GE_FI(a1) \ | |||
, _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15) | |||
_GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15) | |||
#define _GE_PRIVATE_ARGS_GLUE(x, y) x y | |||
@@ -17,12 +17,11 @@ | |||
#ifndef INC_GRAPH_BUFFER_H_ | |||
#define INC_GRAPH_BUFFER_H_ | |||
#include <graph/types.h> | |||
#include <memory> | |||
#include <string> | |||
#include <vector> | |||
#include "detail/attributes_holder.h" | |||
#include "graph/types.h" | |||
namespace ge { | |||
#ifdef HOST_VISIBILITY | |||
@@ -72,7 +71,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer { | |||
GeIrProtoHelper<proto::AttrDef> data_; | |||
std::string *buffer_ = nullptr; | |||
// Create buffer from protobuf obj | |||
// Create from protobuf obj | |||
Buffer(const ProtoMsgOwner &protoOnwer, proto::AttrDef *buffer); | |||
Buffer(const ProtoMsgOwner &protoOnwer, std::string *buffer); | |||
@@ -17,7 +17,6 @@ | |||
#ifndef INC_GRAPH_COMPUTE_GRAPH_H_ | |||
#define INC_GRAPH_COMPUTE_GRAPH_H_ | |||
#include <deque> | |||
#include <map> | |||
#include <memory> | |||
#include <string> | |||
@@ -63,7 +62,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
using Vistor = RangeVistor<T, std::shared_ptr<ConstComputeGraph>>; | |||
explicit ComputeGraph(const std::string &name); | |||
virtual ~ComputeGraph(); | |||
~ComputeGraph() override; | |||
std::string GetName() const; | |||
void SetName(const std::string &name); | |||
@@ -81,7 +80,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
Vistor<NodePtr> GetOutputNodes() const; | |||
NodePtr FindNode(const std::string &name) const; | |||
// Add node | |||
// AddNode with NodePtr | |||
NodePtr AddNode(NodePtr node); | |||
NodePtr AddNode(OpDescPtr op); | |||
NodePtr AddNodeFront(NodePtr node); | |||
@@ -94,9 +93,40 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
graphStatus RemoveOutputNode(const NodePtr &node); | |||
graphStatus RemoveConstInput(const NodePtr &node); | |||
/// Add a subgraph to this graph. The subgraph must has a parent graph and parent node, | |||
/// which means the member functions `SetParentGraph` and `SetParentNode` of the subgraph | |||
/// must be called before add it to the root graph. and subgraph->GetParentNode()->GetOwnerGraph() | |||
/// must equal to subgraph->GetOwnerGraph(). | |||
/// The subgraphs can only be added to a *root graph*. A root graph is a graph without any parent graph. | |||
/// The subgraph's name SHOULD(not must) be the same as the parameter `name` | |||
graphStatus AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph); | |||
graphStatus AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph); | |||
void RemoveSubgraph(const std::string &name); | |||
void RemoveSubgraph(const std::shared_ptr<ComputeGraph> &subgraph); | |||
std::shared_ptr<ComputeGraph> GetSubgraph(const std::string &name) const; | |||
std::vector<std::shared_ptr<ComputeGraph>> GetAllSubgraphs() const; | |||
// obsolete | |||
std::shared_ptr<ComputeGraph> AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph); | |||
// obsolete | |||
graphStatus RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph); | |||
/// | |||
/// @brief Update input-mapping | |||
/// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input | |||
/// @return graphStatus | |||
/// | |||
graphStatus UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping); | |||
/// | |||
/// @brief Update output-mapping | |||
/// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output | |||
/// @return graphStatus | |||
/// | |||
graphStatus UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping); | |||
graphStatus TopologicalSorting(); | |||
bool IsValid() const; | |||
void Dump() const; | |||
@@ -127,6 +157,11 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
} | |||
} | |||
shared_ptr<ComputeGraph> GetParentGraph(); | |||
void SetParentGraph(const shared_ptr<ComputeGraph> &parent); | |||
shared_ptr<Node> GetParentNode(); | |||
void SetParentNode(const shared_ptr<Node> &parent); | |||
const std::map<std::string, std::vector<int32_t>> &GetGraphOutNodes() const { return out_nodes_map_; } | |||
void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; } | |||
@@ -138,8 +173,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
uint32_t GetInputSize() const { return input_size_; } | |||
/// | |||
/// Set iteration needed. | |||
/// If set is true, it means this graph need run iteration some | |||
/// Set is need train iteration. | |||
/// If set true, it means this graph need to be run iteration some | |||
/// times(according variant "npu_runconfig/iterations_per_loop"). | |||
/// @param need_iteration is need iteration | |||
/// | |||
@@ -150,7 +185,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
const std::string GetOutput(); | |||
/// | |||
/// Get need_iteration. | |||
/// Get is need train iteration. | |||
/// @return is need iteration | |||
/// | |||
bool GetNeedIteration() const { return need_iteration_; } | |||
@@ -201,6 +236,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
std::deque<NodePtr> &stack); | |||
graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | |||
std::map<string, NodePtr> &breadth_node_map); | |||
graphStatus TopologicalSortingSubgraph(); | |||
graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum); | |||
size_t GetInEdgeSize(const NodePtr &node); | |||
size_t GetOutEdgeSize(const NodePtr &node); | |||
@@ -210,31 +246,38 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
bool VectorInputNodePtrIsEqual(const std::vector<NodePtr> &r_node_ptr_vector, | |||
const std::vector<NodePtr> &l_node_ptr_vector) const; | |||
ProtoAttrMapHelper attrs_; | |||
friend class ModelSerializeImp; | |||
friend class GraphDebugImp; | |||
friend class OnnxUtils; | |||
std::string name_; | |||
uint32_t graph_id_ = 0; | |||
ProtoAttrMapHelper attrs_; | |||
std::vector<NodePtr> nodes_; | |||
std::map<OperatorImplPtr, NodePtr> all_nodes_infos_; | |||
std::vector<NodePtr> target_nodes_info_; | |||
std::vector<NodePtr> input_nodes_; | |||
std::vector<std::string> inputs_order_; | |||
uint32_t input_size_ = 1; | |||
std::map<std::string, std::vector<int32_t>> out_nodes_map_; | |||
uint32_t output_size_ = 1; | |||
std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_; | |||
std::vector<std::shared_ptr<ComputeGraph>> sub_graph_; | |||
std::string name_; | |||
std::map<std::string, std::shared_ptr<ComputeGraph>> names_to_subgraph_; | |||
std::weak_ptr<ComputeGraph> parent_graph_; | |||
std::weak_ptr<Node> parent_node_; | |||
// the members followed should not in the ComputeGraph class | |||
bool is_valid_flag_; | |||
bool is_summary_graph_ = false; | |||
// Indicates whether it is need iteration | |||
bool need_iteration_ = false; | |||
std::map<std::vector<std::string>, std::vector<std::string>> params_share_map_; | |||
std::map<std::string, std::vector<int32_t>> out_nodes_map_; | |||
// TaskIdx -> op_name Map | |||
std::map<uint32_t, std::string> op_name_map_; | |||
std::vector<std::string> inputs_order_; | |||
uint32_t output_size_ = 1; | |||
uint32_t input_size_ = 1; | |||
std::map<OperatorImplPtr, NodePtr> all_nodes_infos_; | |||
std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_; | |||
std::vector<NodePtr> target_nodes_info_; | |||
uint64_t session_id_ = 0; | |||
uint32_t graph_id_ = 0; | |||
ge::Format data_format_ = ge::FORMAT_ND; | |||
}; | |||
} // namespace ge | |||
@@ -18,7 +18,6 @@ | |||
#define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | |||
#include <string> | |||
#include "graph/types.h" | |||
namespace ge { | |||
@@ -59,6 +58,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; | |||
@@ -75,8 +76,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; | |||
// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string | |||
// ATTR_NAME_WEIGHTS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; | |||
@@ -124,6 +124,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; | |||
@@ -141,10 +148,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | |||
@@ -166,15 +178,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; | |||
// _Arg | |||
@@ -263,7 +275,29 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNOR | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; | |||
// Huberloss | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; | |||
// SSDRealDivTileMul | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; | |||
// SSDSumMulRealDivMean | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; | |||
/// ConcatFive2Four | |||
/// ConcatFour2Five | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; | |||
// Scale | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; | |||
@@ -300,7 +334,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_AT | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | |||
// Roipooling | |||
@@ -313,6 +346,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLI | |||
// DetectionOutput | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | |||
@@ -371,6 +405,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ | |||
// Permute | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; | |||
// SSD Normalize | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | |||
@@ -411,9 +446,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_AT | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; | |||
// Log | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; | |||
// Pack | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; | |||
// Dynamic stitch | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||
// Unpack | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; | |||
// Gathernd | |||
@@ -422,8 +463,16 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND | |||
// Argmax | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; | |||
// Upsample | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; | |||
// Relu | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; | |||
@@ -439,6 +488,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_AT | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_MAGIC; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_BLOCKDIM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_METADATA; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_WORKSPACE_TYPE; | |||
// Squeeze | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_AXIS; | |||
@@ -461,6 +511,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_TF; | |||
// Generate_rpn_proposal | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; | |||
@@ -493,6 +544,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_AT | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; | |||
static const std::string NOT_NET_OUTPUT = "not_net_output"; | |||
// ENTER | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; | |||
@@ -518,6 +570,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_B | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; | |||
// RetinaNet | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; | |||
// MatMul | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; | |||
@@ -566,10 +621,30 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; | |||
// Upsample | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; | |||
// PadV2 | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; | |||
// MirrorPad | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; | |||
// Filler | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; | |||
@@ -590,36 +665,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; | |||
@@ -634,6 +679,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; | |||
@@ -642,12 +689,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||
// Public attribute | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; | |||
@@ -685,11 +726,178 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; | |||
// Used for operators that do not generate task | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOTASK; | |||
// Used for operators that output reuse input | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_REUSE_INPUT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||
// L2_normalize | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; | |||
// HCOM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; | |||
// Log time stamp | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; | |||
// SpaceToDepth/DepthToSpace | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; | |||
// SparseSoftmaxCrossEntropyWithLogits | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; | |||
// MaxPoolGradWithArgmax | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; | |||
// AvgPoolGrad | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; | |||
// Varible | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||
// Assign | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; | |||
// ShapeN | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; | |||
// Space2bacth batch2space | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; | |||
// Depth_to_space space_to_depth | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||
// FakeQuantWithMinMaxVars | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; | |||
// Mobilenet_ssd_conv_fusion | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; | |||
// Lsh project | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; | |||
// Control flow | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; | |||
// GatherV2 attr def | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; | |||
// Reshape attr def | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; | |||
// Axis attr def | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; | |||
// The node link with SparseSoftmaxCrossEntropyWithLogits | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | |||
// For constant folding | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||
// Used for mark the active label list to find stream of activated node | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE; | |||
// Multi batch | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM; | |||
@@ -697,7 +905,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
// Control flow | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | |||
@@ -709,6 +916,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEXT_ITERATION; | |||
// Function Op | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; | |||
// Used for mark the active node is for loop, type:bool | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; | |||
@@ -742,6 +952,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INS | |||
// For inserted op | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; | |||
// For compress weight | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMPRESS_WEIGHT; | |||
// For data dump | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP; | |||
@@ -752,6 +965,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; | |||
// used for l1 fusion and other fusion in future | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | |||
// used for label switch | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | |||
// Varible | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | |||
@@ -20,10 +20,8 @@ | |||
#include <atomic> | |||
#include <memory> | |||
#include <vector> | |||
#include "graph/attr_value_serializable.h" | |||
#include "graph/buffer.h" | |||
namespace ge { | |||
#define DEF_TYPE_DEC(type, name) \ | |||
inline void set_##name(const type &value) { name = value; } \ | |||
@@ -49,10 +47,9 @@ namespace ge { | |||
inline void add_##name(type value) { name.push_back(value); } \ | |||
inline std::vector<type> *mutable_##name() { return &name; } | |||
#define DEF_TYPE_BYTES_DEC(name) \ | |||
inline void clear_##name() { name.ClearBuffer(); } \ | |||
inline void set_##name(const void *value, size_t size) { \ | |||
name = Buffer::CopyFrom((const uint8_t *)(value), size); } \ | |||
#define DEF_TYPE_BYTES_DEC(name) \ | |||
inline void clear_##name() { name.ClearBuffer(); } \ | |||
inline void set_##name(const void *value, size_t size) { name = Buffer::CopyFrom((const uint8_t *)(value), size); } \ | |||
inline Buffer *mutable_##name() { return &name; } | |||
struct CompressInfo { | |||
@@ -23,7 +23,6 @@ | |||
#include <unordered_set> | |||
#include <utility> | |||
#include <vector> | |||
#include "graph/detail/any_map.h" | |||
#include "graph/ge_error_codes.h" | |||
#include "graph/types.h" | |||
@@ -96,7 +95,7 @@ class GeIrProtoHelper { | |||
} | |||
} | |||
// protoMsg_ is part of protoOwner_ and they have the same runtime | |||
// protoMsg_ is part of protoOwner_, they have the same runtime | |||
ProtoMsgOwner protoOwner_ = nullptr; | |||
ProtoType *protoMsg_ = nullptr; | |||
friend class GeIrProtoHelper<typename std::conditional< | |||
@@ -21,9 +21,7 @@ | |||
#include <memory> | |||
#include <string> | |||
#include <vector> | |||
#include "graph/anchor.h" | |||
#include "graph/model.h" | |||
#include "detail/attributes_holder.h" | |||
#include "graph/ge_tensor.h" | |||
#include "graph/graph.h" | |||
@@ -48,15 +46,15 @@ struct NodeNameNodeReq { | |||
class ModelSerializeImp { | |||
public: | |||
bool SerializeModel(const Model &model, proto::ModelDef *modeProto); | |||
bool SerializeModel(const Model &model, proto::ModelDef *modeProto, bool is_dump = false); | |||
bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto); | |||
bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto, bool is_dump = false); | |||
bool SerializeEdge(const NodePtr &node, proto::OpDef *opDefProto); | |||
bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto); | |||
bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto, bool is_dump = false); | |||
bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto); | |||
bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto, bool is_dump = false); | |||
bool SerializeTensor(const ConstGeTensorPtr &tensor, proto::TensorDef *tensorProto); | |||
@@ -23,7 +23,6 @@ | |||
#include <string> | |||
#include <utility> | |||
#include <vector> | |||
#include "graph/buffer.h" | |||
#include "detail/attributes_holder.h" | |||
#include "graph/ge_error_codes.h" | |||
@@ -139,15 +138,14 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||
template <typename vector_type> | |||
// To cols | |||
using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE, | |||
int>::type; | |||
using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE, int>::type; | |||
template <typename one_type> | |||
using enable_if_one_type_valid_t = typename std::enable_if<IsAttrTypeEnable<one_type>::VALUE, int>::type; | |||
template <typename val_type> | |||
using enable_if_type_valid_t = | |||
typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type; | |||
typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type; | |||
template <typename seriliable_type> | |||
using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable; | |||
@@ -18,7 +18,6 @@ | |||
#define INC_GRAPH_GE_CONTEXT_H_ | |||
#include <string> | |||
#include "graph/ge_error_codes.h" | |||
namespace ge { | |||
@@ -42,4 +41,4 @@ class GEContext { | |||
GEContext &GetContext(); | |||
} // namespace ge | |||
#endif // INC_GRAPH_GE_CONTEXT_H_ | |||
#endif // INC_GRAPH_GE_CONTEXT_H_ |
@@ -20,7 +20,6 @@ | |||
#include <map> | |||
#include <string> | |||
#include <vector> | |||
#include "graph/ge_error_codes.h" | |||
using std::map; | |||
@@ -42,5 +41,4 @@ class GEThreadLocalContext { | |||
GEThreadLocalContext &GetThreadLocalContext(); | |||
} // namespace ge | |||
#endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_ |
@@ -21,12 +21,10 @@ | |||
#include <memory> | |||
#include <string> | |||
#include <vector> | |||
#include "detail/attributes_holder.h" | |||
#include "graph/buffer.h" | |||
#include "graph/ge_error_codes.h" | |||
#include "graph/types.h" | |||
namespace ge { | |||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | |||
public: | |||
@@ -43,6 +41,18 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | |||
int64_t GetShapeSize() const; | |||
std::string ToString() const; | |||
/// | |||
/// @brief Check is unknown shape | |||
/// @return bool | |||
/// | |||
bool IsUnknownShape() const; | |||
/// | |||
/// @brief Check is a scalar | |||
/// @return bool | |||
/// | |||
bool IsScalar() const; | |||
GeShape(const GeShape &other); | |||
GeShape(GeShape &&other); | |||
GeShape &operator=(const GeShape &other); | |||
@@ -51,7 +61,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | |||
private: | |||
GeIrProtoHelper<proto::ShapeDef> shape_def_; | |||
friend class GeTensorDesc; | |||
// Create geshape from proto obj | |||
// Create from proto obj | |||
GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg); | |||
void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; } | |||
@@ -112,7 +122,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH | |||
void Init(); | |||
// Create getensordesc from proto obj | |||
// Create from proto obj | |||
GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg); | |||
friend class GeTensor; | |||
friend class GeAttrValueImp; | |||
@@ -159,10 +169,10 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { | |||
friend class GeAttrValueImp; | |||
friend class ModelSerializeImp; | |||
friend class OnnxUtils; | |||
// Create getensor from proto obj | |||
// Create from proto obj | |||
GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg); | |||
GeIrProtoHelper<proto::TensorDef> tensor_def_; | |||
// Reference from tensorDef_, cab not use it directly | |||
// Reference from tensorDef_, do not direct use | |||
mutable GeTensorDesc __desc_; | |||
GeTensorDesc &DescReference() const; | |||
}; | |||
@@ -21,7 +21,6 @@ | |||
#include <memory> | |||
#include <string> | |||
#include <vector> | |||
#include "detail/attributes_holder.h" | |||
#include "graph/ge_attr_value.h" | |||
#include "graph/graph.h" | |||
@@ -62,7 +61,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { | |||
using AttrHolder::HasAttr; | |||
using AttrHolder::SetAttr; | |||
graphStatus Save(Buffer &buffer) const; | |||
graphStatus Save(Buffer &buffer, bool is_dump = false) const; | |||
graphStatus SaveToFile(const string &file_name) const; | |||
// Model will be rewrite | |||
@@ -19,7 +19,6 @@ | |||
#include <map> | |||
#include <string> | |||
#include "graph/buffer.h" | |||
#include "graph/compute_graph.h" | |||
#include "graph/model.h" | |||
@@ -27,7 +26,7 @@ | |||
namespace ge { | |||
class ModelSerialize { | |||
public: | |||
Buffer SerializeModel(const Model &model); | |||
Buffer SerializeModel(const Model &model, bool is_dump = false); | |||
Model UnserializeModel(const uint8_t *data, size_t len); | |||
Model UnserializeModel(ge::proto::ModelDef &model_def); | |||
@@ -113,25 +113,25 @@ class Node : public std::enable_shared_from_this<Node> { | |||
bool IsAllInNodesSeen(std::unordered_set<Node *> &nodes_seen) const; | |||
// All inData nodes | |||
// All in Data nodes | |||
Vistor<NodePtr> GetInDataNodes() const; | |||
// All inControl nodes | |||
// All in Control nodes | |||
Vistor<NodePtr> GetInControlNodes() const; | |||
// GetInAllNodes = InDataNodes + InControlNodes | |||
Vistor<NodePtr> GetInAllNodes() const; | |||
// All outData nodes | |||
// All out Data nodes | |||
Vistor<NodePtr> GetOutDataNodes() const; | |||
uint32_t GetOutDataNodesSize() const; | |||
// All outControl nodes | |||
// All out Control nodes | |||
Vistor<NodePtr> GetOutControlNodes() const; | |||
// GetOutAllNodes = OutDataNodes + InControlNodes | |||
Vistor<NodePtr> GetOutAllNodes() const; | |||
// Get all indata nodes and its outanchor | |||
// Get all in data nodes and its out-anchor | |||
Vistor<std::pair<NodePtr, OutDataAnchorPtr>> GetInDataNodesAndAnchors() const; | |||
// Get all outdata nodes and its inanchor | |||
// Get all out data nodes and its in-anchor | |||
Vistor<std::pair<NodePtr, InDataAnchorPtr>> GetOutDataNodesAndAnchors() const; | |||
graphStatus InferShapeAndType() const; | |||
@@ -176,7 +176,7 @@ class Node : public std::enable_shared_from_this<Node> { | |||
void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } | |||
NodePtr GetOrigNode(void) { return orig_node_; } | |||
NodePtr GetOrigNode() { return orig_node_; } | |||
private: | |||
bool NodeMembersAreEqual(const Node &r_node) const; | |||
@@ -23,7 +23,6 @@ | |||
#include <string> | |||
#include <unordered_set> | |||
#include <vector> | |||
#include "detail/attributes_holder.h" | |||
#include "graph/range_vistor.h" | |||
@@ -108,6 +107,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
size_t GetInputsSize() const; | |||
size_t GetAllInputsSize() const; | |||
graphStatus AddOutputDesc(const GeTensorDesc &output_desc); | |||
graphStatus AddOutputDesc(const string &name, const GeTensorDesc &output_desc); | |||
@@ -122,6 +123,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
GeTensorDescPtr MutableOutputDesc(uint32_t index) const; | |||
uint32_t GetAllOutputsDescSize() const; | |||
Vistor<GeTensorDesc> GetAllOutputsDesc() const; | |||
Vistor<GeTensorDescPtr> GetAllOutputsDescPtr() const; | |||
@@ -132,6 +135,10 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
ConstGeTensorDescPtr GetInputDescPtr(uint32_t index) const; | |||
ConstGeTensorDescPtr GetInputDescPtrDfault(uint32_t index) const; | |||
ConstGeTensorDescPtr GetInputDescPtr(const string &name) const; | |||
graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); | |||
graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); | |||
@@ -140,7 +147,11 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
bool IsOptionalInput(uint32_t index) const; | |||
std::map<string, uint32_t> GetAllInputName(); | |||
std::map<string, uint32_t> GetAllInputName() const; | |||
void SetAllInputName(const std::map<string, uint32_t> &input_name_idx); | |||
std::vector<string> GetAllOptionalInputName() const; | |||
std::map<string, uint32_t> GetAllOutputName(); | |||
@@ -225,6 +236,14 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
std::string GetOpEngineName() const; | |||
graphStatus AddSubgraphName(const std::string &name); | |||
const std::map<std::string, uint32_t> &GetSubgraphNameIndexes() const; | |||
std::string GetSubgraphInstanceName(uint32_t index) const; | |||
const std::vector<std::string> &GetSubgraphInstanceNames() const; | |||
void AddSubgraphInstanceName(std::string name); | |||
void RemoveSubgraphInstanceName(const std::string &name); | |||
protected: | |||
ProtoAttrMapHelper MutableAttrMap() override; | |||
ConstProtoAttrMapHelper GetAttrMap() const override; | |||
@@ -236,9 +255,9 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const; | |||
GeIrProtoHelper<ge::proto::OpDef> op_def_; | |||
std::vector<std::string> subgraph_instance_names_; | |||
std::map<std::string, uint32_t> subgraph_names_to_index_; | |||
vector<GeTensorDescPtr> inputs_desc_{}; | |||
map<string, uint32_t> input_name_idx_{}; | |||
std::unordered_set<string> optional_input_names_{}; | |||
vector<GeTensorDescPtr> outputs_desc_{}; | |||
map<string, uint32_t> output_name_idx_{}; | |||
std::function<graphStatus(Operator &)> infer_func_ = nullptr; | |||
@@ -21,7 +21,6 @@ | |||
#include <memory> | |||
#include <string> | |||
#include <vector> | |||
#include "graph/operator_factory.h" | |||
namespace ge { | |||
@@ -47,7 +46,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactoryImpl { | |||
static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func); | |||
private: | |||
static shared_ptr<std::map<string, OpCreator>> operator_creators_; | |||
static shared_ptr<std::map<string, InferShapeFunc>> operator_infershape_funcs_; | |||
static shared_ptr<std::map<string, InferFormatFunc>> operator_inferformat_funcs_; | |||
@@ -18,8 +18,8 @@ | |||
#define INC_GRAPH_SHAPE_REFINER_H_ | |||
#include <string> | |||
#include "external/graph/inference_context.h" | |||
#include "external/graph/ge_error_codes.h" | |||
#include "graph/node.h" | |||
@@ -27,8 +27,10 @@ namespace ge { | |||
// ShapeRefiner performs shape inference for compute graphs | |||
class ShapeRefiner { | |||
public: | |||
static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); | |||
static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph); | |||
static graphStatus InferShapeAndType(const NodePtr &node, bool before_subgraph); | |||
static graphStatus InferShapeAndType(const NodePtr &node); | |||
static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); | |||
private: | |||
static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); | |||
@@ -14,8 +14,8 @@ | |||
* limitations under the License. | |||
*/ | |||
#ifndef INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||
#define INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||
#ifndef INC_GRAPH_USR_TYPES_H_ | |||
#define INC_GRAPH_USR_TYPES_H_ | |||
#include <atomic> | |||
#include <memory> | |||
@@ -130,4 +130,4 @@ struct UsrQuantizeFactorParams { | |||
#undef USR_TYPE_BYTES_DEC | |||
} // namespace ge | |||
#endif // INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||
#endif // INC_GRAPH_USR_TYPES_H_ |
@@ -99,8 +99,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||
static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); | |||
static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer); | |||
// Value will be moved | |||
static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, | |||
vector<Buffer> &listBuffer); | |||
static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer); | |||
static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer); | |||
static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector<vector<int64_t>> &value); | |||
@@ -116,6 +115,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||
static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc); | |||
static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj); | |||
class AttrHolderAdapter { | |||
public: | |||
AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {} | |||
@@ -137,6 +137,18 @@ class GraphUtils { | |||
static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, | |||
const std::vector<OpDescPtr> &vec_op_desc); | |||
/// | |||
/// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst | |||
/// @param [in] src | |||
/// @param [in] dsts | |||
/// @param [in] insert_node | |||
/// @param [in] input_index | |||
/// @param [in] output_index | |||
/// @return graphStatus | |||
/// | |||
static graphStatus InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||
const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); | |||
static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); | |||
static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node); | |||
@@ -145,16 +157,12 @@ class GraphUtils { | |||
static void RecordOriginalNames(std::vector<std::string> names_tmp, const ge::NodePtr &node); | |||
static bool CheckIsTrainGraph(const ge::ComputeGraphPtr &compute_graph); | |||
static bool MatchDumpStr(const std::string &suffix); | |||
static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false); | |||
static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); | |||
static bool CheckGlobalStepNode(const ge::NodePtr &node); | |||
static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos); | |||
static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); | |||
@@ -252,6 +260,315 @@ class GraphUtils { | |||
/// @return success: GRAPH_SUCESS | |||
/// | |||
static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | |||
static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | |||
}; | |||
class ComputeGraphBuilder { | |||
public: | |||
ComputeGraphBuilder() : owner_graph_(nullptr) {} | |||
ComputeGraphBuilder(const ComputeGraphBuilder &) = delete; | |||
ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete; | |||
ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete; | |||
ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete; | |||
~ComputeGraphBuilder() = default; | |||
/// | |||
/// @brief Add node to graph | |||
/// @param [in] op_desc | |||
/// @return ComputeGraphBuilder | |||
/// | |||
virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc); | |||
/// | |||
/// @brief Add data-link among nodes in graph | |||
/// @param [in] src_name | |||
/// @param [in] out_anchor_ind | |||
/// @param [in] dst_name | |||
/// @param [in] in_anchor_ind | |||
/// @return ComputeGraphBuilder | |||
/// | |||
virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, | |||
const std::string &dst_name, uint32_t in_anchor_ind); | |||
/// | |||
/// @brief Add ctrl-link among nodes in graph | |||
/// @param [in] src_name | |||
/// @param [in] dst_name | |||
/// @return ComputeGraphBuilder | |||
/// | |||
virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name); | |||
/// | |||
/// @brief Build graph | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return ComputeGraphPtr | |||
/// | |||
virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0; | |||
/// @brief Get node with name | |||
/// @param [in] name | |||
/// @return NodePtr | |||
/// | |||
NodePtr GetNode(const std::string &name); | |||
protected: | |||
/// | |||
/// @brief Build nodes | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return void | |||
/// | |||
void BuildNodes(graphStatus &error_code, std::string &error_msg); | |||
/// | |||
/// @brief Build data-links | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return void | |||
/// | |||
void BuildDataLinks(graphStatus &error_code, std::string &error_msg); | |||
/// | |||
/// @brief Build ctrl-links | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return void | |||
/// | |||
void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg); | |||
ComputeGraphPtr owner_graph_; | |||
// node_name -> node | |||
std::map<std::string, NodePtr> node_names_; | |||
std::vector<OpDescPtr> nodes_; | |||
// <src_node_name, out_anchor_ind> -> <dst_node_name, in_anchor_ind> | |||
std::vector<std::pair<std::pair<std::string, uint32_t>, std::pair<std::string, uint32_t>>> data_links_; | |||
// src_node_name -> dst_node_name | |||
std::vector<std::pair<std::string, std::string>> ctrl_links_; | |||
}; | |||
class CompleteGraphBuilder : public ComputeGraphBuilder { | |||
public: | |||
explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {} | |||
CompleteGraphBuilder(const CompleteGraphBuilder &) = delete; | |||
CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete; | |||
CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete; | |||
CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete; | |||
~CompleteGraphBuilder() = default; | |||
/// | |||
/// @brief Add node to graph | |||
/// @param [in] op_desc | |||
/// @return CompleteGraphBuilder | |||
/// | |||
CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override; | |||
/// | |||
/// @brief Add data-link among nodes in graph | |||
/// @param [in] src_name | |||
/// @param [in] out_anchor_ind | |||
/// @param [in] dst_name | |||
/// @param [in] in_anchor_ind | |||
/// @return CompleteGraphBuilder | |||
/// | |||
CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, | |||
uint32_t in_anchor_ind) override; | |||
/// | |||
/// @brief Add ctrl-link among nodes in graph | |||
/// @param [in] src_name | |||
/// @param [in] dst_name | |||
/// @return CompleteGraphBuilder | |||
/// | |||
CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; | |||
/// | |||
/// @brief Set index_th input anchor for graph | |||
/// @param [in] index | |||
/// @param [in] node_names | |||
/// @param [in] anchor_inds | |||
/// @return CompleteGraphBuilder | |||
/// | |||
CompleteGraphBuilder &SetInput(uint32_t index, const std::vector<std::string> &node_names, | |||
const std::vector<uint32_t> &anchor_inds); | |||
/// | |||
/// @brief Set index_th input of graph as useless | |||
/// @param [in] index | |||
/// @return CompleteGraphBuilder | |||
/// | |||
CompleteGraphBuilder &SetUselessInput(uint32_t index); | |||
/// | |||
/// @brief Add output anchor for graph | |||
/// @param [in] owner_node_name | |||
/// @param [in] anchor_ind | |||
/// @return CompleteGraphBuilder | |||
/// | |||
CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind); | |||
/// | |||
/// @brief Set parent-node of graph | |||
/// @param [in] parent_node | |||
/// @return CompleteGraphBuilder | |||
/// | |||
CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node); | |||
/// | |||
/// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node | |||
/// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node | |||
/// @return CompleteGraphBuilder | |||
/// | |||
CompleteGraphBuilder &SetInputMapping(const std::map<uint32_t, uint32_t> &input_mapping); | |||
/// | |||
/// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind | |||
/// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node | |||
/// @return CompleteGraphBuilder | |||
/// | |||
CompleteGraphBuilder &SetOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping); | |||
/// | |||
/// @brief Build graph | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return ComputeGraphPtr | |||
/// | |||
ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; | |||
private: | |||
/// | |||
/// @brief Build inputs | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return void | |||
/// | |||
void BuildInputs(graphStatus &error_code, std::string &error_msg); | |||
/// | |||
/// @brief Add data node | |||
/// @param [in] index | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return void | |||
/// | |||
NodePtr AddDateNode(uint32_t index, graphStatus &error_code, std::string &error_msg); | |||
/// | |||
/// @brief Build outputs | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return void | |||
/// | |||
void BuildOutputs(graphStatus &error_code, std::string &error_msg); | |||
/// | |||
/// @brief Add NetOutput node | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return NodePtr | |||
/// | |||
NodePtr AddNetOutputNode(graphStatus &error_code, std::string &error_msg); | |||
/// | |||
/// @brief Add input/output tensor for NetOutput node | |||
/// @param [in] out_nodes_info | |||
/// @param [out] net_output_desc | |||
/// @return graphStatus | |||
/// | |||
graphStatus BuildInOutForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||
OpDescPtr &net_output_desc); | |||
/// | |||
/// @brief Add edge for NetOutput node | |||
/// @param [in] out_nodes_info | |||
/// @param [out] net_output_node | |||
/// @return graphStatus | |||
/// | |||
graphStatus AddEdgeForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||
const NodePtr &net_output_node); | |||
std::string name_; | |||
NodePtr parent_node_; | |||
std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_; | |||
std::vector<std::pair<std::string, uint32_t>> graph_outputs_; | |||
// index_of_graph_input -> in_anchor_index_of_parent_node | |||
std::map<uint32_t, uint32_t> input_mapping_; | |||
// index_of_graph_output -> out_anchor_index_of_parent_node | |||
std::map<uint32_t, uint32_t> output_mapping_; | |||
}; | |||
class PartialGraphBuilder : public ComputeGraphBuilder { | |||
public: | |||
PartialGraphBuilder() = default; | |||
PartialGraphBuilder(const PartialGraphBuilder &) = delete; | |||
PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete; | |||
PartialGraphBuilder(const PartialGraphBuilder &&) = delete; | |||
PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete; | |||
~PartialGraphBuilder() = default; | |||
/// | |||
/// @brief Add node to graph | |||
/// @param [in] op_desc | |||
/// @return PartialGraphBuilder | |||
/// | |||
PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override; | |||
/// | |||
/// @brief Add data-link among nodes in graph | |||
/// @param [in] src_name | |||
/// @param [in] out_anchor_ind | |||
/// @param [in] dst_name | |||
/// @param [in] in_anchor_ind | |||
/// @return PartialGraphBuilder | |||
/// | |||
PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, | |||
uint32_t in_anchor_ind) override; | |||
/// | |||
/// @brief Add ctrl-link among nodes in graph | |||
/// @param [in] src_name | |||
/// @param [in] dst_name | |||
/// @return PartialGraphBuilder | |||
/// | |||
PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; | |||
/// | |||
/// @brief Set owner graph | |||
/// @param [in] graph | |||
/// @return PartialGraphBuilder | |||
/// | |||
PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph); | |||
/// | |||
/// @brief Add exist node | |||
/// @param [in] node | |||
/// @return PartialGraphBuilder | |||
/// | |||
PartialGraphBuilder &AddExistNode(const NodePtr &node); | |||
/// | |||
/// @brief Build multi nodes with links | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return ComputeGraphPtr | |||
/// | |||
ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; | |||
private: | |||
/// | |||
/// @brief Build exist nodes | |||
/// @param [out] error_code | |||
/// @param [out] error_msg | |||
/// @return void | |||
/// | |||
void BuildExistNodes(graphStatus &error_code, std::string &error_msg); | |||
std::vector<NodePtr> exist_nodes_; | |||
}; | |||
} // namespace ge | |||
@@ -56,6 +56,11 @@ class NodeUtils { | |||
static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); | |||
static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); | |||
static std::string GetNodeType(const Node &node); | |||
static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); | |||
static graphStatus AddSubgraph(Node &node, const ComputeGraphPtr &subgraph); | |||
private: | |||
static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | |||
static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_; | |||
@@ -20,7 +20,6 @@ | |||
#include <memory> | |||
#include <string> | |||
#include <vector> | |||
#include "graph/def_types.h" | |||
#include "graph/node.h" | |||
#include "graph/op_desc.h" | |||
@@ -29,7 +28,6 @@ | |||
namespace ge { | |||
class OpDesc; | |||
using OpDescPtr = std::shared_ptr<OpDesc>; | |||
class OpDescUtils { | |||
@@ -39,55 +37,108 @@ class OpDescUtils { | |||
OpDescUtils() = default; | |||
~OpDescUtils() = default; | |||
static bool HasQuantizeFactorParams(const OpDescPtr &op_desc); | |||
static bool HasQuantizeFactorParams(const OpDesc &op_desc); | |||
static graphStatus GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant); | |||
static graphStatus GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant); | |||
static graphStatus SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant); | |||
static graphStatus SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant); | |||
static vector<ge::NodePtr> GetConstInputNode(const ge::Node &node); | |||
static vector<ConstGeTensorPtr> GetInputData(const vector<ge::NodePtr> &input_nodes); | |||
static vector<ConstGeTensorPtr> GetWeights(const ge::Node &node); | |||
static vector<ConstGeTensorPtr> GetWeights(const ge::ConstNodePtr &node); | |||
static vector<GeTensorPtr> MutableWeights(const ge::Node &node); | |||
static bool HasQuantizeFactorParams(const OpDescPtr& op_desc); | |||
static bool HasQuantizeFactorParams(const OpDesc& op_desc); | |||
static graphStatus GetQuantizeFactorParams(const OpDescPtr& op_desc, QuantizeFactorParams& quant); | |||
static graphStatus GetQuantizeFactorParams(const OpDesc& op_desc, QuantizeFactorParams& quant); | |||
static graphStatus SetQuantizeFactorParams(const OpDescPtr& op_desc, const QuantizeFactorParams& quant); | |||
static graphStatus SetQuantizeFactorParams(OpDesc& op_desc, const QuantizeFactorParams& quant); | |||
static vector<ge::NodePtr> GetConstInputNode(const ge::Node& node); | |||
static vector<ConstGeTensorPtr> GetInputData(const vector<ge::NodePtr>& input_nodes); | |||
static vector<ConstGeTensorPtr> GetWeights(const ge::Node& node); | |||
static vector<ConstGeTensorPtr> GetWeights(const ge::ConstNodePtr& node); | |||
static vector<GeTensorPtr> MutableWeights(const ge::Node& node); | |||
static vector<GeTensorPtr> MutableWeights(const ge::NodePtr node); | |||
static graphStatus SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights); | |||
static graphStatus SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr> &weights); | |||
static graphStatus SetWeights(ge::Node& node, const vector<ge::GeTensorPtr>& weights); | |||
static graphStatus SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr>& weights); | |||
static graphStatus ClearWeights(ge::NodePtr node); | |||
static bool ClearInputDesc(ge::OpDescPtr op_desc, uint32_t index); | |||
static bool ClearInputDesc(const ge::NodePtr &node); | |||
static bool ClearOutputDesc(const ge::OpDescPtr &op_desc, uint32_t index); | |||
static bool ClearOutputDesc(const ge::NodePtr &node); | |||
static vector<ge::NodePtr> GetConstInputs(const ge::Node &node); | |||
static vector<ge::NodePtr> GetConstInputs(const ge::ConstNodePtr &node); | |||
static size_t GetNonConstInputsSize(const ge::Node &node); | |||
static bool ClearInputDesc(const ge::NodePtr& node); | |||
static bool ClearOutputDesc(const ge::OpDescPtr& op_desc, uint32_t index); | |||
static bool ClearOutputDesc(const ge::NodePtr& node); | |||
static vector<ge::NodePtr> GetConstInputs(const ge::Node& node); | |||
static vector<ge::NodePtr> GetConstInputs(const ge::ConstNodePtr& node); | |||
static size_t GetNonConstInputsSize(const ge::Node& node); | |||
static size_t GetNonConstInputsSize(ge::ConstNodePtr node); | |||
// Index: Indicate the index of all non const inputs | |||
static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const = 0); | |||
static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const = 0); | |||
static bool GetNonConstInputIndex(const ge::Node &node, size_t index_non_const, size_t &index); | |||
static bool GetNonConstInputIndex(const ge::ConstNodePtr &node, size_t index_non_const, size_t &index); | |||
// Index: Indicate the index of all inputs | |||
static bool IsNonConstInput(const ge::Node &node, size_t index = 0); | |||
static bool IsNonConstInput(const ge::ConstNodePtr &node, size_t index = 0); | |||
static vector<ge::GeTensorDesc> GetNonConstTensorDesc(const ge::ConstNodePtr &node); | |||
static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr); | |||
// Index: Indicates the index of all non const inputs | |||
static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node& node, size_t index_non_const = 0); | |||
static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr& node, size_t index_non_const = 0); | |||
static bool GetNonConstInputIndex(const ge::Node& node, size_t index_non_const, size_t& index); | |||
static bool GetNonConstInputIndex(const ge::ConstNodePtr& node, size_t index_non_const, size_t& index); | |||
// Index: Indicates the index of all inputs | |||
static bool IsNonConstInput(const ge::Node& node, size_t index = 0); | |||
static bool IsNonConstInput(const ge::ConstNodePtr& node, size_t index = 0); | |||
static vector<ge::GeTensorDesc> GetNonConstTensorDesc(const ge::ConstNodePtr& node); | |||
static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr& tensor_ptr); | |||
static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc); | |||
static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr); | |||
static OpDescPtr GetOpDescFromOperator(const Operator &oprt); | |||
static OpDescPtr GetOpDescFromOperator(const Operator& oprt); | |||
static OpDescPtr CreateConstOp(const GeTensorPtr &tensor_ptr); | |||
static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); | |||
private: | |||
static GeTensorPtr MutableWeights(ge::OpDesc &op_desc); | |||
static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); | |||
static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); | |||
static graphStatus SetWeights(ge::OpDesc &op_desc, const GeTensorPtr weight); | |||
static graphStatus SetWeights(ge::OpDesc& op_desc, const GeTensorPtr weight); | |||
static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight); | |||
}; | |||
class OpDescBuilder { | |||
public: | |||
OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {} | |||
OpDescBuilder(const OpDescBuilder&) = delete; | |||
OpDescBuilder& operator=(const OpDescBuilder&) = delete; | |||
OpDescBuilder(const OpDescBuilder&&) = delete; | |||
OpDescBuilder& operator=(const OpDescBuilder&&) = delete; | |||
~OpDescBuilder() = default; | |||
/// | |||
/// @brief Add input | |||
/// @param [in] name | |||
/// @return OpDescBuilder | |||
/// | |||
OpDescBuilder& AddInput(const std::string& name); | |||
/// | |||
/// @brief Add dynamic input | |||
/// @param [in] name | |||
/// @param [in] num | |||
/// @return OpDescBuilder | |||
/// | |||
OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num); | |||
/// | |||
/// @brief Add output | |||
/// @param [in] name | |||
/// @return OpDescBuilder | |||
/// | |||
OpDescBuilder& AddOutput(const std::string& name); | |||
/// | |||
/// @brief Add dynamic output | |||
/// @param [in] name | |||
/// @param [in] num | |||
/// @return OpDescBuilder | |||
/// | |||
OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num); | |||
/// | |||
/// @brief Build op_desc | |||
/// @return OpDescPtr | |||
/// | |||
OpDescPtr Build(); | |||
private: | |||
std::string name_; | |||
std::string type_; | |||
std::vector<std::string> inputs_; | |||
std::vector<std::string> outputs_; | |||
}; | |||
} // namespace ge | |||
#endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_ |
@@ -18,15 +18,14 @@ | |||
#define INC_GRAPH_UTILS_TENSOR_UTILS_H_ | |||
#include <vector> | |||
#include "graph/def_types.h" | |||
#include "graph/ge_error_codes.h" | |||
#include "graph/ge_tensor.h" | |||
namespace ge { | |||
class TensorUtils { | |||
public: | |||
static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, uint32_t &size); | |||
static void SetSize(GeTensorDesc &tensorDesc, uint32_t size); | |||
static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, int64_t &size); | |||
static void SetSize(GeTensorDesc &tensorDesc, int64_t size); | |||
static uint32_t GetWeightSize(const ConstGeTensorPtr &tensorPtr); | |||
static uint32_t GetWeightSize(const GeTensor &tensor); | |||
static uint32_t GetWeightSize(const GeTensorDesc &tensorDesc); | |||
@@ -62,16 +61,16 @@ class TensorUtils { | |||
static void SetRC(GeTensorDesc &tensorDesc, uint32_t rc); | |||
/// | |||
/// calculate mem size of the tensor. | |||
/// calculate tensor mem size. | |||
/// @param shape tensor shape | |||
/// @param format tensor format | |||
/// @param data_type tensor data type | |||
/// @param mem_size -1 means unknown shape,others means mem size | |||
/// @return GRAPH_SUCCESS:success, others:failed | |||
/// @param mem_size -1 means unknown shape,other means mem size | |||
/// @return GRAPH_SUCCESS:success, other:failed | |||
/// | |||
static ge::graphStatus CalcTensorMemSize(const GeShape &shape, Format format, DataType data_type, int64_t &mem_size); | |||
static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp); | |||
static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp); | |||
static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); | |||
static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); | |||
}; | |||
} // namespace ge | |||
#endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_ |
@@ -58,6 +58,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||
include_directories(${GE_SOURCE_DIR}/inc/graph) | |||
include_directories(${GE_SOURCE_DIR}/inc/common) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) | |||
include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | |||
include_directories(${CMAKE_BINARY_DIR}) | |||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||
@@ -26,6 +26,8 @@ Anchor::Anchor(const NodePtr &owner_node, int idx) : owner_node_(owner_node), id | |||
bool Anchor::IsTypeOf(TYPE type) const { return strcmp(Anchor::TypeOf<Anchor>(), type) == 0; } | |||
size_t Anchor::GetPeerAnchorsSize() const { return peer_anchors_.size(); } | |||
Anchor::Vistor<AnchorPtr> Anchor::GetPeerAnchors() const { | |||
vector<AnchorPtr> ret; | |||
for (const auto &anchor : peer_anchors_) { | |||
@@ -32,8 +32,7 @@ Buffer::Buffer(const Buffer &other) { | |||
buffer_ = other.buffer_; | |||
} | |||
// default | |||
Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { | |||
Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { // default | |||
auto proto_msg = data_.GetProtoMsg(); | |||
if (proto_msg != nullptr) { | |||
try { | |||
@@ -15,9 +15,7 @@ | |||
*/ | |||
#include "graph/compute_graph.h" | |||
#include <deque> | |||
#include "./format_refiner.h" | |||
#include "./ge_context.h" | |||
#include "debug/ge_attr_define.h" | |||
@@ -41,7 +39,7 @@ const size_t OUTPUT_PARAM_SIZE = 2; | |||
} // namespace | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name) | |||
: nodes_(), input_nodes_(), sub_graph_(), name_(name), is_valid_flag_(false), need_iteration_(false) { | |||
: name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) { | |||
attrs_.InitDefault(); | |||
} | |||
ComputeGraph::~ComputeGraph() {} | |||
@@ -154,7 +152,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNod | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual( | |||
const ComputeGraph &r_graph) const { | |||
return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.sub_graph_.size()") && | |||
return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.subgraphs_.size()") && | |||
IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") && | |||
VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) && | |||
IsEqual(this->name_, r_graph.name_, "graph.name_") && | |||
@@ -398,6 +396,165 @@ graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr<ComputeGraph> &su | |||
} | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph) { | |||
if (subgraph == nullptr) { | |||
GE_LOGE("Try to add a null subgraph, name %s", name.c_str()); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
auto parent_graph = subgraph->GetParentGraph(); | |||
if (parent_graph == nullptr) { | |||
GE_LOGE("Try to add subgraph without parent graph, name %s", name.c_str()); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
auto parent_node = subgraph->GetParentNode(); | |||
if (parent_node == nullptr) { | |||
GE_LOGE("Try to add a subgraph without parent node, name %s", name.c_str()); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
if (parent_node->GetOwnerComputeGraph() != parent_graph) { | |||
GE_LOGE( | |||
"Try to add a subgraph which parent node's parent graph is not equal to " | |||
"the subgraph's parent graph, subgraph name %s, parent node name %s", | |||
subgraph->GetName().c_str(), parent_graph->GetName().c_str()); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
if (!this->parent_graph_.expired()) { | |||
GE_LOGE("The subgraphs can only be added to the root graph"); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
if (name != subgraph->GetName()) { | |||
GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str()); | |||
} | |||
sub_graph_.push_back(subgraph); | |||
names_to_subgraph_[name] = subgraph; | |||
return GRAPH_SUCCESS; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
ComputeGraph::AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph) { | |||
if (subgraph == nullptr) { | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
return AddSubgraph(subgraph->GetName(), subgraph); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(const std::string &name) { | |||
auto iter = names_to_subgraph_.find(name); | |||
if (iter == names_to_subgraph_.end()) { | |||
return; | |||
} | |||
for (auto vec_iter = sub_graph_.begin(); vec_iter != sub_graph_.end(); ++vec_iter) { | |||
if (*vec_iter == iter->second) { | |||
sub_graph_.erase(vec_iter); | |||
break; | |||
} | |||
} | |||
names_to_subgraph_.erase(iter); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph( | |||
const std::shared_ptr<ComputeGraph> &subgraph) { | |||
if (subgraph != nullptr) { | |||
RemoveSubgraph(subgraph->GetName()); | |||
} | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr<ComputeGraph> ComputeGraph::GetSubgraph( | |||
const std::string &name) const { | |||
auto iter = names_to_subgraph_.find(name); | |||
return iter == names_to_subgraph_.end() ? nullptr : iter->second; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector<std::shared_ptr<ComputeGraph>> | |||
ComputeGraph::GetAllSubgraphs() const { | |||
return sub_graph_; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<ComputeGraph> ComputeGraph::GetParentGraph() { | |||
return parent_graph_.lock(); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentGraph( | |||
const shared_ptr<ComputeGraph> &parent) { | |||
parent_graph_ = parent; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<Node> ComputeGraph::GetParentNode() { | |||
return parent_node_.lock(); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode(const shared_ptr<Node> &parent) { | |||
parent_node_ = parent; | |||
} | |||
/// | |||
/// @brief Update input-mapping | |||
/// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input | |||
/// @return graphStatus | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping) { | |||
for (auto &input : input_nodes_) { | |||
uint32_t cur_index = 0; | |||
if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { | |||
continue; | |||
} | |||
auto iter = input_mapping.find(cur_index); | |||
if (iter == input_mapping.end()) { | |||
continue; | |||
} | |||
if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||
GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
/// | |||
/// @brief Update output-mapping | |||
/// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output | |||
/// @return graphStatus | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping) { | |||
NodePtr net_output = FindNode(kNodeNameNetOutput); | |||
if (net_output == nullptr) { | |||
GE_LOGE("UpdateOutputMapping failed: node %s not exist in graph.", kNodeNameNetOutput); | |||
return GRAPH_FAILED; | |||
} | |||
OpDescPtr op_desc = net_output->GetOpDesc(); | |||
if (op_desc == nullptr) { | |||
GE_LOGE("UpdateOutputMapping failed: op_desc is NULL."); | |||
return GRAPH_FAILED; | |||
} | |||
size_t num = op_desc->GetInputsSize(); | |||
for (size_t i = 0; i < num; i++) { | |||
GeTensorDesc tensor = op_desc->GetInputDesc(i); | |||
uint32_t cur_index = 0; | |||
if (!ge::AttrUtils::GetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { | |||
continue; | |||
} | |||
auto iter = output_mapping.find(cur_index); | |||
if (iter == output_mapping.end()) { | |||
continue; | |||
} | |||
if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||
GE_LOGE("UpdateOutputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||
return GRAPH_FAILED; | |||
} | |||
if (op_desc->UpdateInputDesc(i, tensor) != GRAPH_SUCCESS) { | |||
GE_LOGE("UpdateOutputMapping failed: update %u input_tensor failed.", i); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { | |||
std::vector<NodePtr> node_vec = nodes_; | |||
for (const auto &node : GetAllNodes()) { | |||
@@ -551,6 +708,23 @@ graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<No | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { | |||
auto ret = TopologicalSortingSubgraph(); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Sub graph partition Failed"); | |||
return ret; | |||
} | |||
// partition sub graph | |||
for (const auto &sub_graph : GetAllSubgraphs()) { | |||
ret = sub_graph->TopologicalSortingSubgraph(); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Sub graph topological sort Failed"); | |||
return ret; | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSortingSubgraph() { | |||
std::vector<NodePtr> node_vec; | |||
std::map<NodePtr, uint32_t> map_in_edge_num; | |||
bool use_BFS = false; | |||
@@ -598,6 +772,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog | |||
node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null] | |||
nodes_.push_back(node); | |||
} | |||
is_valid_flag_ = true; | |||
return GRAPH_SUCCESS; | |||
} | |||
@@ -614,7 +789,7 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||
verify_isolated = true; | |||
} | |||
} | |||
for (const auto &node : GetAllNodes()) { | |||
for (const auto &node : GetDirectNode()) { | |||
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | |||
map_in_edge_num[node] = static_cast<uint32_t>(GetInEdgeSize(node)); | |||
if (map_in_edge_num[node] == 0) { | |||
@@ -640,16 +815,16 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||
/// 2. Compare two indices, if not match, swap the positions of two inputs | |||
/// *: Remind: stack is reverse-order | |||
for (size_t i = 0; i < stack.size(); ++i) { | |||
// [stack: should not be null] | |||
// If not found in 'inputs_order_', skip it | |||
auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); | |||
GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue); | |||
auto inx_i = it_i - inputs_order_.begin(); | |||
for (size_t j = i + 1; j < stack.size(); ++j) { | |||
// If not found in 'inputs_order_', skip it | |||
auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); | |||
GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue); | |||
auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName()); | |||
GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue); | |||
// Compare index, swap them if it should be | |||
auto inx_i = it_i - inputs_order_.begin(); | |||
auto inx_j = it_j - inputs_order_.begin(); | |||
GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j])); | |||
} | |||
@@ -663,7 +838,7 @@ size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { | |||
return in_edge_size; | |||
} | |||
for (const auto &anchor : node->GetAllInDataAnchors()) { | |||
in_edge_size = in_edge_size + anchor->GetPeerAnchors().size(); | |||
in_edge_size = in_edge_size + anchor->GetPeerAnchorsSize(); | |||
// Break flow control data loop. | |||
OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor(); | |||
if ((out_anchor != nullptr) && (out_anchor->GetOwnerNode() != nullptr)) { | |||
@@ -680,10 +855,11 @@ size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { | |||
} | |||
} | |||
if (node->GetInControlAnchor() != nullptr) { | |||
in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchors().size(); | |||
in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchorsSize(); | |||
} | |||
return in_edge_size; | |||
} | |||
size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { | |||
size_t out_edge_size = 0; | |||
if (node == nullptr) { | |||
@@ -699,7 +875,7 @@ size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { | |||
} | |||
} | |||
if (node->GetOutControlAnchor() != nullptr) { | |||
if (out_edge_size > (UINT32_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) { | |||
if (out_edge_size > (UINT64_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) { | |||
return 0; | |||
} | |||
out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchors().size(); | |||
@@ -724,17 +900,18 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { | |||
peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||
} | |||
} | |||
GE_IF_BOOL_EXEC(node->GetOutControlAnchor() == nullptr, GELOGE(GRAPH_FAILED, "Out control anchor is null"); | |||
return ); | |||
for (const auto &peer_in_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) { | |||
GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||
GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||
peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||
} | |||
for (const auto &peer_in_anchor : node->GetOutControlAnchor()->GetPeerInDataAnchors()) { | |||
GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||
GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||
peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||
auto out_control_anchor = node->GetOutControlAnchor(); | |||
if (out_control_anchor != nullptr) { | |||
for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) { | |||
GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||
GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||
peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||
} | |||
for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) { | |||
GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||
GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||
peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||
} | |||
} | |||
} | |||
} | |||
@@ -18,21 +18,9 @@ | |||
#define COMMON_GRAPH_DEBUG_GE_LOG_H_ | |||
#include "graph/ge_error_codes.h" | |||
#include "toolchain/slog.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#define GE_MOD_ID GE | |||
#ifdef _MSC_VER | |||
#define FUNC_NAME __FUNCTION__ | |||
#else | |||
#define FUNC_NAME __PRETTY_FUNCTION__ | |||
#endif | |||
#define D_GE_LOGE(fmt, ...) \ | |||
dlog_error(static_cast<int>(GE_MOD_ID), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||
#define GE_LOGE(...) D_GE_LOGE(__VA_ARGS__) | |||
#define GE_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) | |||
#define GE_LOGI_IF(condition, ...) \ | |||
if ((condition)) { \ | |||
@@ -44,15 +32,15 @@ | |||
GELOGW(__VA_ARGS__); \ | |||
} | |||
#define GE_LOGE_IF(condition, ...) \ | |||
if ((condition)) { \ | |||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||
#define GE_LOGE_IF(condition, ...) \ | |||
if ((condition)) { \ | |||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||
} | |||
#define GE_CHK_STATUS_RET_NOLOG(expr) \ | |||
do { \ | |||
const ge::graphStatus _status = (expr); \ | |||
if (_status != ge::GRAPH_SUCCESS) { \ | |||
if (ge::SUCCESS != _status) { \ | |||
return _status; \ | |||
} \ | |||
} while (0) | |||
@@ -61,7 +49,7 @@ | |||
do { \ | |||
bool b = (expr); \ | |||
if (!b) { \ | |||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||
return _status; \ | |||
} \ | |||
} while (0) | |||
@@ -85,7 +73,7 @@ | |||
do { \ | |||
const ge::graphStatus _status = (expr); \ | |||
if (_status) { \ | |||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||
return _status; \ | |||
} \ | |||
} while (0) | |||
@@ -95,7 +83,7 @@ | |||
{ \ | |||
bool b = (expr); \ | |||
if (b) { \ | |||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||
exec_expr; \ | |||
} \ | |||
} | |||
@@ -119,63 +107,41 @@ | |||
} while (0) | |||
// If expr is not true, the log is printed and a custom statement is executed | |||
#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ | |||
{ \ | |||
bool b = (expr); \ | |||
if (!b) { \ | |||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||
exec_expr; \ | |||
} \ | |||
} | |||
// If expr is not true, the log is printed and a custom statement is executed | |||
#define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ | |||
{ \ | |||
bool b = (expr); \ | |||
if (!b) { \ | |||
GELOGI(__VA_ARGS__); \ | |||
exec_expr; \ | |||
} \ | |||
#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ | |||
{ \ | |||
bool b = (expr); \ | |||
if (!b) { \ | |||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||
exec_expr; \ | |||
} \ | |||
} | |||
// If expr is not true, the log is printed and a custom statement is executed | |||
#define GE_CHK_BOOL_EXEC_DEBUG(expr, exec_expr, ...) \ | |||
{ \ | |||
bool b = (expr); \ | |||
if (!b) { \ | |||
GELOGD(__VA_ARGS__); \ | |||
exec_expr; \ | |||
} \ | |||
#define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ | |||
{ \ | |||
bool b = (expr); \ | |||
if (!b) { \ | |||
GELOGI(__VA_ARGS__); \ | |||
exec_expr; \ | |||
} \ | |||
} | |||
// If expr is not GRAPH_SUCCESS, print the log and return the same value | |||
#define GE_CHK_STATUS_RET(expr, ...) \ | |||
do { \ | |||
const ge::graphStatus _status = (expr); \ | |||
if (_status != ge::GRAPH_SUCCESS) { \ | |||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||
return _status; \ | |||
} \ | |||
#define GE_CHK_STATUS_RET(expr, ...) \ | |||
do { \ | |||
const ge::graphStatus _status = (expr); \ | |||
if (ge::SUCCESS != _status) { \ | |||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||
return _status; \ | |||
} \ | |||
} while (0) | |||
#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ | |||
try { \ | |||
exec_expr0; \ | |||
} catch (...) { \ | |||
GELOGE(ge::GRAPH_FAILED, "Make shared failed"); \ | |||
exec_expr1; \ | |||
#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ | |||
try { \ | |||
exec_expr0; \ | |||
} catch (...) { \ | |||
GELOGE(ge::FAILED, "Make shared failed"); \ | |||
exec_expr1; \ | |||
} | |||
/// CCE related macro definition | |||
/// If expr is not CC_STATUS_GRAPH_SUCCESS, print the log and return | |||
#define GE_CHK_CCE_RET(expr) \ | |||
do { \ | |||
ccgraphStatus_t _cc_ret = (expr); \ | |||
if (_cc_ret != CC_STATUS_GRAPH_SUCCESS) { \ | |||
GELOGE(ge::GRAPH_FAILED, "Call cce api failed, ret: 0x%X", _cc_ret); \ | |||
return ge::GRAPH_FAILED; \ | |||
} \ | |||
} while (0) | |||
#endif // COMMON_GRAPH_DEBUG_GE_LOG_H_ | |||
@@ -25,7 +25,6 @@ | |||
#include <string> | |||
#include <utility> | |||
#include <vector> | |||
#include "framework/common/debug/ge_log.h" | |||
#include "graph/debug/ge_log.h" | |||
#include "graph/ge_error_codes.h" | |||
@@ -15,12 +15,10 @@ | |||
*/ | |||
#include "graph/debug/graph_debug.h" | |||
#include <algorithm> | |||
#include <unordered_set> | |||
#include <vector> | |||
#include "debug/ge_util.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#define TAB " " | |||
@@ -16,13 +16,11 @@ | |||
#ifndef COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | |||
#define COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | |||
#include <cstdint> | |||
#include <fstream> | |||
#include <iostream> | |||
#include <sstream> | |||
#include <string> | |||
#include "external/graph/graph.h" | |||
#include "./ge_error_codes.h" | |||
#include "graph/compute_graph.h" | |||
@@ -15,9 +15,7 @@ | |||
*/ | |||
#include "detail/attributes_holder.h" | |||
#include <map> | |||
#include "debug/ge_log.h" | |||
#include "debug/ge_util.h" | |||
#include "framework/common/debug/ge_log.h" | |||
@@ -14,14 +14,12 @@ | |||
* limitations under the License. | |||
*/ | |||
#include "graph/format_refiner.h" | |||
#include "format_refiner.h" | |||
#include <deque> | |||
#include <iostream> | |||
#include <set> | |||
#include <unordered_map> | |||
#include <unordered_set> | |||
#include "./compute_graph.h" | |||
#include "./ge_error_codes.h" | |||
#include "./graph/ge_tensor.h" | |||
@@ -57,6 +55,7 @@ graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points, | |||
std::vector<ge::NodePtr> &data_nodes, | |||
std::unordered_map<ge::NodePtr, bool> &node_status) { | |||
@@ -82,10 +81,10 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
// consider special node save process | |||
// get all input desc format | |||
bool node_is_all_nd = false; | |||
for (uint32_t i = 0; i < static_cast<uint32_t>(op_desc->GetInputsSize()); i++) { | |||
auto input_desc = op_desc->GetInputDesc(i); | |||
auto input_size = static_cast<uint32_t>(op_desc->GetInputsSize()); | |||
for (uint32_t i = 0; i < input_size; i++) { | |||
// Operator pre-set format but not origin format | |||
auto input_format = input_desc.GetFormat(); | |||
auto input_format = op_desc->MutableInputDesc(i)->GetFormat(); | |||
// Pre-save data node and default infer fail | |||
if (node_ptr->GetType() == DATA) { | |||
data_nodes.push_back(node_ptr); | |||
@@ -95,9 +94,9 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
} | |||
} | |||
// Get all output desc format | |||
for (uint32_t i = 0; i < static_cast<uint32_t>(op_desc->GetOutputsSize()); i++) { | |||
GeTensorDesc output_desc = op_desc->GetOutputDesc(i); | |||
auto output_format = output_desc.GetFormat(); | |||
auto output_size = static_cast<uint32_t>(op_desc->GetOutputsSize()); | |||
for (uint32_t i = 0; i < output_size; i++) { | |||
auto output_format = op_desc->MutableOutputDesc(i)->GetFormat(); | |||
if (output_format != FORMAT_ND && output_format != FORMAT_RESERVED) { | |||
node_is_all_nd = true; | |||
} | |||
@@ -145,7 +144,8 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||
for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||
GELOGD("Node is [%s] [B]", (node->GetName()).c_str()); | |||
auto in_data_anchor_idx = in_anchor->GetIdx(); | |||
auto to_be_set_format = (node->GetOpDesc()->GetInputDesc(in_data_anchor_idx)).GetOriginFormat(); | |||
auto to_be_set_format = | |||
node->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx))->GetOriginFormat(); | |||
if (to_be_set_format == FORMAT_ND) { | |||
GELOGD("Node [%s] [B], format is ND", (node->GetName()).c_str()); | |||
continue; | |||
@@ -162,7 +162,7 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||
} | |||
// Check format whether have been set | |||
int idx = peer_out_data_anchor->GetIdx(); | |||
auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(idx); | |||
auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(idx)); | |||
if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | |||
auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | |||
if (dim_num == 0) { | |||
@@ -182,7 +182,7 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||
ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||
ge_tensor_desc.SetFormat(to_be_set_format); | |||
(void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(idx, ge_tensor_desc); | |||
(void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||
// Call operator infer format api (forward) to get out format | |||
GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); | |||
@@ -205,7 +205,8 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||
GELOGD("Node is [%s] [F]", (node->GetName()).c_str()); | |||
GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); | |||
auto out_data_anchor_idx = out_data_anchor->GetIdx(); | |||
auto to_be_set_format = (node->GetOpDesc()->GetOutputDesc(out_data_anchor_idx)).GetOriginFormat(); | |||
auto to_be_set_format = | |||
node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(out_data_anchor_idx))->GetOriginFormat(); | |||
if (to_be_set_format == FORMAT_ND) { | |||
GELOGD("Node [%s] format is ND.[F]", (node->GetName()).c_str()); | |||
continue; | |||
@@ -222,7 +223,7 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||
} | |||
// Check format whether have been set | |||
int idx = peer_in_data_anchor->GetIdx(); | |||
auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(idx); | |||
auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(idx)); | |||
if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | |||
auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | |||
if (dim_num == 0) { | |||
@@ -285,9 +286,9 @@ void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer = | |||
graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format, | |||
std::unordered_map<ge::NodePtr, bool> &node_status) { | |||
bool is_internal_format = TypeUtils::IsInternalFormat(data_format); | |||
bool need_process = ((!is_first_infer) && (is_internal_format == false) && (data_format != FORMAT_ND)); | |||
bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND); | |||
if (!need_process) { | |||
GELOGI("no necessary to do DataNodeFormatProcess.IsFirstInfer: %d, data_format:%s", is_first_infer, | |||
GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer, | |||
TypeUtils::FormatToSerialString(data_format).c_str()); | |||
return GRAPH_SUCCESS; | |||
} | |||
@@ -378,9 +379,9 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) | |||
/// Notice: ignore 5D formats | |||
auto data_format = graph->GetDataFormat(); | |||
status = DataNodeFormatProcess(data_nodes, data_format, node_status); | |||
// Set infer flag to false | |||
SetInferOrigineFormatFlag(false); | |||
return status; | |||
} | |||
} // namespace ge |
@@ -42,6 +42,8 @@ const std::string ATTR_NAME_BIAS = "bias"; | |||
const std::string ATTR_NAME_BIAS_TERM = "bias_term"; | |||
const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; | |||
const std::string ATTR_NAME_PAD = "pad"; | |||
const std::string ATTR_NAME_PADS = "pad"; | |||
@@ -83,6 +85,7 @@ const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; | |||
const std::string ATTR_NAME_AXIS = "axis"; | |||
const std::string ATTR_NAME_BROADCAST = "broadcast"; | |||
const std::string ATTR_NAME_OUTPUT = "output"; | |||
const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; | |||
const std::string ATTR_NAME_TIDX = "t_idx"; | |||
@@ -103,6 +106,13 @@ const std::string ATTR_NAME_TSHAPE = "Tshape"; | |||
const std::string ATTR_NAME_NAN_OPT = "nan_opt"; | |||
const std::string ATTR_NAME_AIPP = "aipp"; | |||
const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; | |||
const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||
const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; | |||
const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; | |||
const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; | |||
const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; | |||
const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; | |||
@@ -111,6 +121,7 @@ const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; | |||
const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; | |||
const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; | |||
const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; | |||
const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||
const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; | |||
const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | |||
@@ -122,15 +133,11 @@ const std::string ATTR_NAME_WEIGHTS = "value"; | |||
const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; | |||
const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; | |||
const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; | |||
const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||
const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||
const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; | |||
const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; | |||
const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; | |||
const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; | |||
const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; | |||
const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; | |||
const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; | |||
const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; | |||
// To be deleted | |||
const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; | |||
@@ -144,15 +151,13 @@ const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; | |||
const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | |||
const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||
const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||
// Refinedet | |||
const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; | |||
const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||
const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; | |||
const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | |||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||
const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||
const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||
// _Arg | |||
const std::string ATTR_NAME_INDEX = "index"; | |||
@@ -242,6 +247,30 @@ const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; | |||
const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; | |||
const std::string BATCHNORM_ATTR_SCALE = "scale"; | |||
const std::string BATCHNORM_ATTR_BIAS = "bias"; | |||
const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; | |||
const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; | |||
const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; | |||
// huberloss | |||
const std::string HUBER_LOSS_ATTR_DELTA = "delta"; | |||
// SSDRealDivTileMul | |||
const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; | |||
// SSDSumMulRealDivMean | |||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; | |||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; | |||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; | |||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; | |||
// ConcatFive2Four | |||
// ConcatFour2Five | |||
const std::string SSD_BOX_TYPE_NUM = "box_type_num"; | |||
const std::string SSD_CLASS_NUM = "class_num"; | |||
const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; | |||
const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; | |||
const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; | |||
const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; | |||
// Scale | |||
const std::string SCALE_ATTR_SCALE = "scale"; | |||
@@ -346,6 +375,7 @@ const std::string SOFTMAX_ATTR_AXIS = "axis"; | |||
// Permute | |||
const std::string PERMUTE_ATTR_ORDER = "order"; | |||
const std::string PERMUTE_ATTR_PERM = "perm"; | |||
// SSD Normalize | |||
const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; | |||
@@ -373,6 +403,10 @@ const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; | |||
const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||
const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||
// RefinedetDetectionOutput | |||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||
// PRelu | |||
const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; | |||
@@ -386,11 +420,16 @@ const std::string POWER_ATTR_NAME_POWER = "power"; | |||
const std::string POWER_ATTR_NAME_SCALE = "scale"; | |||
const std::string POWER_ATTR_NAME_SHIFT = "shift"; | |||
// log | |||
const std::string LOG_ATTR_NAME_SCALE = "scale"; | |||
const std::string LOG_ATTR_NAME_SHIFT = "shift"; | |||
const std::string LOG_ATTR_NAME_BASE = "base"; | |||
// Pack | |||
const std::string PACK_ATTR_NAME_NUM = "N"; | |||
// Unpack | |||
const std::string UNPACK_ATTR_NAME_NUM = "num"; | |||
const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; | |||
// Gathernd | |||
const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; | |||
const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; | |||
@@ -400,6 +439,13 @@ const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; | |||
const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; | |||
const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; | |||
const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; | |||
const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; | |||
const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; | |||
const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; | |||
// upsample | |||
const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; | |||
const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; | |||
// Relu | |||
const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; | |||
@@ -416,6 +462,7 @@ const std::string SPLIT_ATTR_NAME_NUM_SPLIT = "num_split"; | |||
const std::string TVM_ATTR_NAME_MAGIC = "tvm_magic"; | |||
const std::string TVM_ATTR_NAME_BLOCKDIM = "tvm_blockdim"; | |||
const std::string TVM_ATTR_NAME_METADATA = "tvm_metadata"; | |||
const std::string TVM_ATTR_NAME_WORKSPACE_TYPE = "tvm_workspace_type"; | |||
// Squeeze | |||
const std::string SQUEEZE_ATTR_AXIS = "axis"; | |||
@@ -438,6 +485,7 @@ const std::string ROIALIGN_ATTR_SPATIAL_SCALE = "spatial_scale"; | |||
const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; | |||
const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; | |||
const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; | |||
const std::string ROIALIGN_ATTR_NAME_TF = "roialign_tf"; | |||
// Generate_rpn_proposal | |||
const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; | |||
@@ -536,19 +584,42 @@ const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape | |||
const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; | |||
// Rnn | |||
const std::string RNN_MODE_ = "rnn_"; | |||
const std::string CNN_RNN = "cnn_rnn"; | |||
const std::string RNN_TENSORFLOW = "rnn_tensorflow"; | |||
const std::string RNN_MODE_STATIC = "rnn_static"; | |||
const std::string MUTI_RNN = "multi_rnn"; | |||
const std::string CNN_RNN = "cnn_rnn"; | |||
const std::string RNN_MODE_ = "rnn_"; | |||
const std::string CELL_MODE = "mode"; | |||
const std::string LSTM_CELL = "lstm_cell"; | |||
const std::string GRU_CELL = "gru_cell"; | |||
const std::string RNN_HT = "ht"; | |||
const std::string RNN_XT_HT = "xt_ht"; | |||
const std::string RNN_BATCH_SIZE = "batch_size"; | |||
const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; | |||
const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; | |||
const std::string LSTM_ACTIVATE = "lstm_activate"; | |||
const std::string LSTM_OUT_MAP = "lstm_out_map"; | |||
const std::string LSTM_OUT_MODE = "lstm_out_mode"; | |||
const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; | |||
const std::string LSTM_TIME_MAJOR = "lstm_time_major"; | |||
const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; | |||
// Upsample | |||
const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; | |||
// PadV2 | |||
const std::string PADV2_ATTR_NAME_MODE = "mode"; | |||
const std::string PADV2_ATTR_NAME_PADS = "paddings"; | |||
const std::string PADV2_ATTR_NAME_T = "T"; | |||
const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||
const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; | |||
// MirrorPad | |||
const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; | |||
const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; | |||
const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||
const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; | |||
// Filler | |||
const std::string FILLER_TYPE = "filler_type"; | |||
const std::string FILLER_VALUE = "filler_value"; | |||
@@ -559,9 +630,6 @@ const std::string SHUFFLE_CHANNEL_GROUP = "group"; | |||
// TopKV2 | |||
const std::string TOPKV2_ATTR_K = "k"; | |||
const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||
const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||
// Calibaration | |||
const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; | |||
const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; | |||
@@ -616,6 +684,8 @@ const std::string ATTR_MODEL_STREAM_NUM = "stream_num"; | |||
const std::string ATTR_MODEL_EVENT_NUM = "event_num"; | |||
const std::string ATTR_MODEL_LABEL_NUM = "label_num"; | |||
const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; | |||
const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; | |||
@@ -630,6 +700,8 @@ const std::string ATTR_MODEL_VAR_SIZE = "variable_size"; | |||
const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; | |||
const std::string ATTR_MODEL_CORE_TYPE = "core_type"; | |||
// Public attribute | |||
const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; | |||
@@ -661,17 +733,145 @@ const std::string TARGET_TYPE_TINY = "TINY"; | |||
const std::string TARGET_TYPE_LITE = "LITE"; | |||
// l2_normalize | |||
const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; | |||
const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||
const std::string POOL_PARAMA_ATTR_WINDOW = "window"; | |||
const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; | |||
const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; | |||
const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; | |||
const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; | |||
const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; | |||
// HCOM | |||
const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; | |||
const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; | |||
const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; | |||
const std::string HCOM_ATTR_GROUP = "group"; | |||
const std::string HCOM_ATTR_SR_TAG = "sr_tag"; | |||
const std::string HCOM_ATTR_SRC_RANK = "src_rank"; | |||
const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; | |||
const std::string HCOM_ATTR_FUSION = "fusion"; | |||
const std::string HCOM_ATTR_SHAPE = "shape"; | |||
const std::string HCOM_ATTR_DATA_TYPE = "dtype"; | |||
// SpaceToDepth/DepthToSpace | |||
const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; | |||
// SparseSoftmaxCrossEntropyWithLogits | |||
const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; | |||
// MaxPoolGradWithArgmax | |||
const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; | |||
// AvgPoolGrad | |||
const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; | |||
// Pad | |||
const std::string ATTR_PAD_FORMAT = "attr_pad_format"; | |||
// Varible | |||
const std::string VAR_ATTR_FORMAT = "_var_format"; | |||
const std::string VAR_ATTR_NAME = "var_name"; | |||
const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; | |||
const std::string VAR_ATTR_4D_FORMAT = "4D"; | |||
const std::string VAR_ATTR_5D_FORMAT = "5D"; | |||
const std::string VAR_ATTR_DATA_TYPE = "data_format"; | |||
const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; | |||
const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; | |||
const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; | |||
const std::string VAR_ATTR_SHAPE = "shape"; | |||
const std::string HALF_VAR_NAME_END = "_fp16"; | |||
const std::string VAR_ATTR_INITED = "var_is_inited"; | |||
const std::string VAR_ATTR_CONTAINER = "container"; | |||
const std::string VAR_ATTR_SHARED_NAME = "shared_name"; | |||
const std::string VAR_ATTR_DTYPE = "dtype"; | |||
const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; | |||
const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; | |||
const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; | |||
const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; | |||
const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; | |||
const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; | |||
// Assign | |||
const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; | |||
// space2bacth batch2space | |||
const std::string BATCH_SPACE_ATTR_BLOCK = "block"; | |||
const std::string BATCH_SPACE_ATTR_PADDING = "padding"; | |||
// depth_to_space space_to_depth | |||
const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||
// FakeQuantWithMinMaxVars | |||
const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; | |||
const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; | |||
// mobilenet_ssd_conv_fusion | |||
const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; | |||
const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; | |||
const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; | |||
// lsh project | |||
const std::string LSH_PROJ_TYPE = "lsh_project_type"; | |||
// log time stamp | |||
const std::string LOG_TIME_STAMP_LOGID = "logid"; | |||
const std::string LOG_TIME_STAMP_NOTIFY = "notify"; | |||
// ShapeN | |||
const std::string SHAPEN_ATTR_N = "N"; | |||
const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; | |||
const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; | |||
// GatherV2 attr def | |||
const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; | |||
const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; | |||
const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; | |||
// Reshape attr def | |||
const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; | |||
const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; | |||
// axis attr def | |||
const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; | |||
const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; | |||
const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; | |||
const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; | |||
// For constant folding | |||
const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; | |||
const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; | |||
const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; | |||
const std::string ATTR_NAME_REFERENCE = "reference"; | |||
const std::string ATTR_NAME_NOTASK = "_no_task"; | |||
const std::string ATTR_NAME_OUTPUT_REUSE_INPUT = "_output_reuse_input"; | |||
const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX = "_reuse_input_on_dim_index"; | |||
const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT = "_no_padding_continuous_input"; | |||
const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT = "_no_padding_continuous_output"; | |||
const std::string ATTR_NAME_ATOMIC_INDEX = "atomic_index"; | |||
// Used for mark the active label list stream of activated node | |||
const std::string ATTR_NAME_ACTIVE_LABEL_LIST = "_active_label_list"; | |||
// Used for l2cache, true: the memory of all inputs is used for the last time. | |||
const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE = "is_end_of_inputmem_lifecycle"; | |||
// Multi batch | |||
const std::string ATTR_NAME_PRED_VALUE = "_pred_value"; | |||
const std::string ATTR_NAME_BATCH_NUM = "_batch_num"; | |||
@@ -682,6 +882,8 @@ const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; | |||
const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; | |||
const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; | |||
const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | |||
const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; | |||
const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; | |||
const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | |||
const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | |||
@@ -691,6 +893,9 @@ const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG = "_cyclic_dependence_flag"; | |||
const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; | |||
// Function Op | |||
const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; | |||
// Used for mark the active node is for loop, type:bool | |||
const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; | |||
@@ -702,6 +907,20 @@ const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; | |||
const std::string MODEL_ATTR_SESSION_ID = "session_id"; | |||
// l1 fusion and other fusion in future | |||
const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; | |||
const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; | |||
const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; | |||
const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; | |||
const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST = "_input_memory_type"; | |||
const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST = "_output_memory_type"; | |||
const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR = "_l1_fusion_extend_content"; | |||
const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE = "_tensor_actual_size"; | |||
const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1_fuison"; | |||
const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; | |||
const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; | |||
const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; | |||
// Atomic addr clean attrs | |||
const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; | |||
const std::string ATOMIC_ATTR_OUTPUT_INDEX = "atomic_output_index"; | |||
@@ -722,6 +941,9 @@ const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; | |||
// For inserted op | |||
const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; | |||
// For compress weight | |||
const std::string ATTR_NAME_COMPRESS_WEIGHT = "_is_compress_weight"; | |||
// For data dump | |||
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES = "_datadump_original_op_names"; | |||
const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP = "_datadump_is_multiop"; | |||
@@ -732,24 +954,17 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_ou | |||
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | |||
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | |||
// Variable | |||
const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; | |||
const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; | |||
const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; | |||
const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; | |||
const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; | |||
// HCOM | |||
const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; | |||
const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; | |||
const std::string HCOM_ATTR_SHAPE = "shape"; | |||
const std::string HCOM_ATTR_DATA_TYPE = "dtype"; | |||
// functional ops attr | |||
const std::string ATTR_NAME_TCOND = "Tcond"; | |||
const std::string ATTR_NAME_TIN = "Tin"; | |||
const std::string ATTR_NAME_TOUT = "Tout"; | |||
const std::string ATTR_NAME_THEN_BRANCH = "then_branch"; | |||
const std::string ATTR_NAME_ELSE_BRANCH = "else_branch"; | |||
const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; | |||
// used for label switch | |||
const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | |||
const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; | |||
const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | |||
const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | |||
// Dynamic stitch | |||
const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; | |||
} // namespace ge |
@@ -22,7 +22,7 @@ | |||
#include "graph/model_serialize.h" | |||
#include "proto/ge_ir.pb.h" | |||
#include "detail/model_serialize_imp.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "debug/ge_attr_define.h" | |||
#include "debug/ge_log.h" | |||
#include "debug/ge_util.h" | |||
@@ -53,7 +53,7 @@ string GeAttrValue::NamedAttrs::GetName() const { | |||
GeAttrValue GeAttrValue::NamedAttrs::GetItem(const string &key) const { | |||
GeAttrValue value; | |||
(void)GetAttr(key, value); | |||
GetAttr(key, value); | |||
return value; | |||
} | |||
@@ -1081,6 +1081,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstA | |||
if (!GetListInt(std::move(obj), name, int64_list)) { | |||
return false; | |||
} | |||
for (size_t i = 0; i < int64_list.size(); ++i) { | |||
if (int64_list[i] > INT32_MAX) { | |||
GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]); | |||
@@ -1098,6 +1099,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstA | |||
if (!GetListInt(std::move(obj), name, int64_list)) { | |||
return false; | |||
} | |||
for (size_t i = 0; i < int64_list.size(); ++i) { | |||
if (int64_list[i] > UINT32_MAX) { | |||
GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]); | |||
@@ -1215,6 +1217,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc( | |||
GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); | |||
op_desc->extAttrs_ = org_op_desc->extAttrs_; | |||
if (op_desc->HasAttr("_input_name_idx_key")) { | |||
if (op_desc->DelAttr("_input_name_idx_key") != SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_key failed."); | |||
} | |||
} | |||
if (op_desc->HasAttr("_input_name_idx_value")) { | |||
if (op_desc->DelAttr("_input_name_idx_value") != SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_value failed."); | |||
} | |||
} | |||
if (op_desc->HasAttr("_opt_input")) { | |||
if (op_desc->DelAttr("_opt_input") != SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "DelAttr _opt_input failed."); | |||
} | |||
} | |||
return op_desc; | |||
} | |||
@@ -1237,11 +1256,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c | |||
op_desc->extAttrs_ = org_op_desc->extAttrs_; | |||
op_desc->input_name_idx_.insert(org_op_desc->input_name_idx_.begin(), org_op_desc->input_name_idx_.end()); | |||
op_desc->optional_input_names_.insert(org_op_desc->optional_input_names_.begin(), | |||
org_op_desc->optional_input_names_.end()); | |||
op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); | |||
op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); | |||
op_desc->infer_func_ = org_op_desc->infer_func_; | |||
@@ -1250,4 +1264,25 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c | |||
return op_desc; | |||
} | |||
std::string AttrUtils::GetAllAttrsStr(AttrUtils::ConstAttrHolderAdapter &&obj) { | |||
auto holder = obj.get(); | |||
if (holder == nullptr) { | |||
return ""; | |||
} | |||
auto attrs_map = holder->GetAttrMap(); | |||
if (attrs_map.GetProtoMsg() == nullptr) { | |||
return ""; | |||
} | |||
std::map<std::string, std::string> ordered_attrs; | |||
for (auto &attr : *(attrs_map.GetProtoMsg())) { | |||
ordered_attrs[attr.first] = attr.second.SerializeAsString(); | |||
} | |||
std::stringstream ss; | |||
for (auto &attr : ordered_attrs) { | |||
ss << attr.first << ":" << attr.second << ";"; | |||
} | |||
return ss.str(); | |||
} | |||
} // namespace ge |
@@ -163,6 +163,34 @@ int64_t GeShape::GetShapeSize() const { | |||
return res; | |||
} | |||
/// | |||
/// @brief Check is unknown shape | |||
/// @return bool | |||
/// /// | |||
bool GeShape::IsUnknownShape() const { | |||
auto proto_msg = shape_def_.GetProtoMsg(); | |||
if (proto_msg != nullptr) { | |||
for (auto i : proto_msg->dim()) { | |||
if (i < 0) { | |||
return true; | |||
} | |||
} | |||
} | |||
return false; | |||
} | |||
/// | |||
/// @brief Check is a scalar | |||
/// @return bool | |||
/// | |||
bool GeShape::IsScalar() const { | |||
auto proto_msg = shape_def_.GetProtoMsg(); | |||
if (proto_msg != nullptr) { | |||
return proto_msg->dim().empty(); | |||
} | |||
return false; | |||
} | |||
const string TENSOR_UTILS_SIZE = "size"; | |||
const string TENSOR_UTILS_WEIGHT_SIZE = "weight_size"; | |||
const string TENSOR_UTILS_REUSE_INPUT = "reuse_input"; | |||
@@ -639,14 +667,14 @@ GeTensor &GeTensor::operator=(const GeTensor &other) { | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc, | |||
uint32_t &size) { | |||
int64_t &size) { | |||
auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); | |||
GE_CHECK_NOTNULL(tensor_descriptor_msg); | |||
size = static_cast<uint32_t>(tensor_descriptor_msg->size()); | |||
size = static_cast<int64_t>(tensor_descriptor_msg->size()); | |||
return GRAPH_SUCCESS; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, uint32_t size) { | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, int64_t size) { | |||
auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); | |||
if (tensor_descriptor_msg != nullptr) { | |||
tensor_descriptor_msg->set_size(size); | |||
@@ -49,6 +49,7 @@ void Model::Init() { | |||
(void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0); | |||
(void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0); | |||
(void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0); | |||
(void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0); | |||
(void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0); | |||
(void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI); | |||
version_ = 0; | |||
@@ -77,9 +78,9 @@ void Model::SetGraph(const ge::Graph &graph) { graph_ = graph; } | |||
Graph Model::GetGraph() const { return graph_; } | |||
graphStatus Model::Save(Buffer &buffer) const { | |||
graphStatus Model::Save(Buffer &buffer, bool is_dump) const { | |||
ModelSerialize serialize; | |||
buffer = serialize.SerializeModel(*this); | |||
buffer = serialize.SerializeModel(*this, is_dump); | |||
return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED; | |||
} | |||
@@ -113,7 +114,7 @@ graphStatus Model::SaveToFile(const string &file_name) const { | |||
} | |||
int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS); | |||
if (fd < 0) { | |||
GELOGE(GRAPH_FAILED, "open file failed, file path [%s] ", real_path); | |||
GELOGE(GRAPH_FAILED, "open file failed, file path [%s], %s ", real_path, strerror(errno)); | |||
return GRAPH_FAILED; | |||
} | |||
bool ret = ge_proto.SerializeToFileDescriptor(fd); | |||
@@ -129,6 +130,10 @@ graphStatus Model::SaveToFile(const string &file_name) const { | |||
GELOGE(GRAPH_FAILED, "close file descriptor fail."); | |||
return GRAPH_FAILED; | |||
} | |||
if (!ret) { | |||
GELOGE(GRAPH_FAILED, "function [SerializeToFileDescriptor] failed"); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
@@ -152,7 +157,7 @@ graphStatus Model::LoadFromFile(const string &file_name) { | |||
} | |||
int fd = open(real_path, O_RDONLY); | |||
if (fd < 0) { | |||
GELOGE(GRAPH_FAILED, "open file failed"); | |||
GELOGE(GRAPH_FAILED, "open file failed, %s", strerror(errno)); | |||
return GRAPH_FAILED; | |||
} | |||
@@ -170,6 +175,10 @@ graphStatus Model::LoadFromFile(const string &file_name) { | |||
GELOGE(GRAPH_FAILED, "close file descriptor fail."); | |||
return GRAPH_FAILED; | |||
} | |||
if (!ret) { | |||
GELOGE(GRAPH_FAILED, "function [ParseFromFileDescriptor] failed"); | |||
return GRAPH_FAILED; | |||
} | |||
return Load(model_def); | |||
} | |||
@@ -15,10 +15,8 @@ | |||
*/ | |||
#include "graph/model_serialize.h" | |||
#include <google/protobuf/text_format.h> | |||
#include <iostream> | |||
#include "debug/ge_attr_define.h" | |||
#include "debug/ge_log.h" | |||
#include "debug/ge_util.h" | |||
@@ -26,6 +24,7 @@ | |||
#include "graph/detail/model_serialize_imp.h" | |||
#include "proto/ge_ir.pb.h" | |||
#include "utils/graph_utils.h" | |||
#include "debug/ge_op_types.h" | |||
using std::string; | |||
@@ -84,20 +83,29 @@ bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_ | |||
return true; | |||
} | |||
bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) { | |||
bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { | |||
if (op_desc == nullptr || op_def_proto == nullptr) { | |||
GELOGE(GRAPH_FAILED, "Input Para Invalid"); | |||
return false; | |||
} | |||
if (op_desc->op_def_.GetProtoMsg() != nullptr) { | |||
*op_def_proto = *op_desc->op_def_.GetProtoMsg(); | |||
// Delete unnecessary attr | |||
if (is_dump) { | |||
auto attr = op_def_proto->mutable_attr(); | |||
attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF); | |||
attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF); | |||
attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF); | |||
GE_IF_BOOL_EXEC((op_def_proto->type() == CONSTANT || op_def_proto->type() == CONSTANTOP), | |||
attr->erase(ATTR_NAME_WEIGHTS)); | |||
} | |||
op_def_proto->clear_input_desc(); | |||
op_def_proto->clear_output_desc(); | |||
// Input descs | |||
if (op_desc->GetInputsSize() > 0) { | |||
auto size = static_cast<uint32_t>(op_desc->GetInputsSize()); | |||
if (op_desc->GetAllInputsSize() > 0) { | |||
auto size = static_cast<uint32_t>(op_desc->GetAllInputsSize()); | |||
for (uint32_t i = 0; i < size; i++) { | |||
auto tensor_desc = op_desc->GetInputDescPtr(i); | |||
auto tensor_desc = op_desc->GetInputDescPtrDfault(i); | |||
if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { | |||
*op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); | |||
} | |||
@@ -117,12 +125,12 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op | |||
return true; | |||
} | |||
bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto) { | |||
bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) { | |||
if (node == nullptr || op_def_proto == nullptr) { | |||
GELOGE(GRAPH_FAILED, "Input Para Node Invalid"); | |||
return false; | |||
} | |||
if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto)) { | |||
if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, is_dump)) { | |||
GELOGE(GRAPH_FAILED, "Serialize OpDesc failed"); | |||
return false; | |||
} | |||
@@ -134,7 +142,8 @@ bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_ | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph, | |||
proto::GraphDef *graph_proto) { | |||
proto::GraphDef *graph_proto, | |||
bool is_dump) { | |||
if (graph == nullptr || graph_proto == nullptr) { | |||
GELOGE(GRAPH_FAILED, "Input para Invalid"); | |||
return false; | |||
@@ -156,7 +165,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::Serialize | |||
*graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg(); | |||
} | |||
for (const auto &node : graph->GetDirectNode()) { | |||
if (!SerializeNode(node, graph_proto->add_op())) { | |||
if (!SerializeNode(node, graph_proto->add_op(), is_dump)) { | |||
if (node->GetOpDesc() != nullptr) { | |||
GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str()); | |||
} | |||
@@ -166,7 +175,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::Serialize | |||
return true; | |||
} | |||
bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto) { | |||
bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto, bool is_dump) { | |||
if (model_proto == nullptr) { | |||
GELOGE(GRAPH_FAILED, "model_proto para Invalid"); | |||
return false; | |||
@@ -183,7 +192,7 @@ bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *mode | |||
GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr"); | |||
return false; | |||
} | |||
if (!SerializeGraph(compute_graph, model_proto->add_graph())) { | |||
if (!SerializeGraph(compute_graph, model_proto->add_graph(), is_dump)) { | |||
GELOGE(GRAPH_FAILED, "SerializeGraph fail"); | |||
return false; | |||
} | |||
@@ -390,10 +399,10 @@ bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf:: | |||
return true; | |||
} | |||
Buffer ModelSerialize::SerializeModel(const Model &model) { | |||
Buffer ModelSerialize::SerializeModel(const Model &model, bool is_dump) { | |||
proto::ModelDef model_def; | |||
ModelSerializeImp imp; | |||
if (!imp.SerializeModel(model, &model_def)) { | |||
if (!imp.SerializeModel(model, &model_def, is_dump)) { | |||
return Buffer(); | |||
} | |||
#if !defined(__ANDROID__) && !defined(ANDROID) | |||
@@ -401,7 +401,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::Get | |||
vec.push_back(in_anchor); | |||
} | |||
} | |||
// Push back in_control_anchor_ | |||
// Push back in_control_anchor_ | |||
if ((in_control_anchor_->GetPeerOutControlAnchors().size() > 0) || | |||
(in_control_anchor_->GetPeerOutDataAnchors().size() > 0)) { | |||
auto in_anchor = Anchor::DynamicAnchorCast<Anchor>(in_control_anchor_); | |||
@@ -512,7 +512,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetIn | |||
auto peer_out_anchors = in_control_anchor_->GetPeerOutDataAnchors(); | |||
for (const auto &out_anchor : peer_out_anchors) { | |||
GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, " in_control_anchor_ peer out data anchors is nullptr"); | |||
GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "in_control_anchor_ peer out data anchors is nullptr"); | |||
auto node = out_anchor->GetOwnerNode(); | |||
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||
vec.push_back(node); | |||
@@ -521,7 +521,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetIn | |||
auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); | |||
for (const auto &out_control_anchor : peer_out_control_anchors) { | |||
GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, | |||
" in_control_anchor_ peer out control anchors is nullptr"); | |||
"in_control_anchor_ peer out control anchors is nullptr"); | |||
auto node = out_control_anchor->GetOwnerNode(); | |||
GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||
vec.push_back(node); | |||
@@ -785,6 +785,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::UpdateOpDesc(co | |||
GE_CHK_BOOL_EXEC(op_->GetInputsSize() == op_desc->GetInputsSize(), return GRAPH_PARAM_INVALID, | |||
"Inputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetInputsSize(), | |||
op_desc->GetInputsSize()); | |||
GE_CHK_BOOL_EXEC(op_->GetOutputsSize() == op_desc->GetOutputsSize(), return GRAPH_PARAM_INVALID, | |||
"Outputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetOutputsSize(), | |||
op_desc->GetOutputsSize()); | |||
@@ -61,6 +61,12 @@ const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes"; | |||
const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; | |||
const std::string ATTR_NAME_OPT_INPUT = "_opt_input"; | |||
const std::string ATTR_NAME_INPUT_NAME_IDX_KEY = "_input_name_idx_key"; | |||
const std::string ATTR_NAME_INPUT_NAME_IDX_VALUE = "_input_name_idx_value"; | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() { | |||
op_def_.InitDefault(); | |||
if (op_def_.GetProtoMsg() != nullptr) { | |||
@@ -202,7 +208,8 @@ graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_d | |||
} | |||
graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { | |||
if (input_name_idx_.find(name) != input_name_idx_.end()) { | |||
auto input_name_idx = GetAllInputName(); | |||
if (input_name_idx.find(name) != input_name_idx.end()) { | |||
GELOGI("input %s is exist, update it", name.c_str()); | |||
graphStatus ret = UpdateInputDesc(name, input_desc); | |||
return ret; | |||
@@ -214,15 +221,17 @@ graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &inp | |||
return GRAPH_FAILED; | |||
} | |||
inputs_desc_.push_back(in_desc); | |||
(void)input_name_idx_.insert(make_pair(name, index)); | |||
(void)input_name_idx.insert(make_pair(name, index)); | |||
SetAllInputName(input_name_idx); | |||
return GRAPH_SUCCESS; | |||
} | |||
} | |||
graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { | |||
auto input_name_idx = GetAllInputName(); | |||
for (unsigned int i = 0; i < num; i++) { | |||
string input_name = name + std::to_string(i); | |||
GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, | |||
GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED, | |||
"Add input tensor_desc is existed. name[%s]", input_name.c_str()); | |||
std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc()); | |||
@@ -234,12 +243,13 @@ graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int n | |||
(void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); | |||
// Update index in input_name_idx | |||
for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { | |||
for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) { | |||
it->second += 1; | |||
} | |||
(void)input_name_idx_.insert(make_pair(input_name, 0)); | |||
(void)input_name_idx.insert(make_pair(input_name, 0)); | |||
} | |||
SetAllInputName(input_name_idx); | |||
return GRAPH_SUCCESS; | |||
} | |||
@@ -270,10 +280,19 @@ graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int | |||
graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { | |||
if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED; | |||
(void)optional_input_names_.insert(name); | |||
vector<string> optional_input_names; | |||
(void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||
optional_input_names.push_back(name); | |||
(void)AttrUtils::SetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||
return GRAPH_SUCCESS; | |||
} | |||
std::vector<string> OpDesc::GetAllOptionalInputName() const { | |||
vector<string> optional_input_names; | |||
(void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||
return optional_input_names; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { | |||
GE_CHK_BOOL_RET_STATUS((index < inputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); | |||
@@ -288,11 +307,12 @@ OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const { | |||
return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") && | |||
IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && | |||
IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") && | |||
IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && | |||
IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); | |||
return ( | |||
IsEqual(this->GetAllInputName(), r_op_desc.GetAllInputName(), "OpDesc.GetAllInputName()") && | |||
IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && | |||
IsEqual(this->GetAllOptionalInputName(), r_op_desc.GetAllOptionalInputName(), "OpDesc.GetAllOptionalInputName()") && | |||
IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && | |||
IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const { | |||
@@ -366,8 +386,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpD | |||
} | |||
graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { | |||
auto it = input_name_idx_.find(name); | |||
if (it == input_name_idx_.end()) { | |||
auto input_name_idx = GetAllInputName(); | |||
auto it = input_name_idx.find(name); | |||
if (it == input_name_idx.end()) { | |||
GELOGW("Cann't find the input desc. name[%s]", name.c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
@@ -387,8 +408,9 @@ graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc & | |||
} | |||
bool OpDesc::InputIsSet(const string &name) const { | |||
auto it = input_name_idx_.find(name); | |||
if (it != input_name_idx_.end()) { | |||
auto input_name_idx = GetAllInputName(); | |||
auto it = input_name_idx.find(name); | |||
if (it != input_name_idx.end()) { | |||
GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false); | |||
auto tensor_desc = inputs_desc_[it->second]; | |||
GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false); | |||
@@ -406,18 +428,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc | |||
} | |||
GeTensorDesc OpDesc::GetInputDesc(const string &name) const { | |||
auto it = input_name_idx_.find(name); | |||
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc()); | |||
auto input_name_idx = GetAllInputName(); | |||
auto it = input_name_idx.find(name); | |||
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), GeTensorDesc()); | |||
GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc()); | |||
return *(inputs_desc_[it->second].get()); | |||
} | |||
GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const { | |||
auto input_name_idx = GetAllInputName(); | |||
vector<string> names; | |||
if (input_name_idx_.empty()) { | |||
if (input_name_idx.empty()) { | |||
return OpDesc::Vistor<string>(shared_from_this(), names); | |||
} | |||
for (std::pair<string, uint32_t> input : input_name_idx_) { | |||
for (std::pair<string, uint32_t> input : input_name_idx) { | |||
names.push_back(input.first); | |||
} | |||
return OpDesc::Vistor<string>(shared_from_this(), names); | |||
@@ -483,6 +507,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetInputsSize() co | |||
return size; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetAllInputsSize() const { return inputs_desc_.size(); } | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) { | |||
int index = static_cast<int>(outputs_desc_.size()); | |||
return AddOutputDesc("__output" + std::to_string(index), output_desc); | |||
@@ -548,6 +574,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOu | |||
return outputs_desc_[index]; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { | |||
return static_cast<uint32_t>(outputs_desc_.size()); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllOutputsDesc() const { | |||
vector<GeTensorDesc> temp{}; | |||
for (const auto &it : outputs_desc_) { | |||
@@ -580,6 +610,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetI | |||
} | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr | |||
OpDesc::GetInputDescPtrDfault(uint32_t index) const { | |||
GE_CHK_BOOL_RET_STATUS_NOLOG((index) < (uint32_t)(inputs_desc_.size()), nullptr); | |||
return inputs_desc_[(int32_t)index]; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const { | |||
auto input_name_idx = GetAllInputName(); | |||
auto it = input_name_idx.find(name); | |||
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), shared_ptr<const GeTensorDesc>()); | |||
return inputs_desc_[it->second]; | |||
} | |||
graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) { | |||
if (is_push_back) { | |||
for (unsigned int i = 0; i < num; i++) { | |||
@@ -603,12 +646,45 @@ graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int | |||
} | |||
bool OpDesc::IsOptionalInput(const string &name) const { | |||
return optional_input_names_.find(name) != optional_input_names_.end(); | |||
vector<string> optional_input_names; | |||
(void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||
for (auto &item : optional_input_names) { | |||
if (item == name) { | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } | |||
std::map<string, uint32_t> OpDesc::GetAllInputName() { return input_name_idx_; } | |||
std::map<string, uint32_t> OpDesc::GetAllInputName() const { | |||
std::map<string, uint32_t> input_name_idx; | |||
std::vector<string> key; | |||
std::vector<uint32_t> value; | |||
(void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key); | |||
(void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value); | |||
if (key.size() != value.size()) { | |||
GE_LOGE("twe vector size is different. key_size: %zu, value_size: %zu.", key.size(), value.size()); | |||
} else { | |||
for (uint32_t i = 0; i < key.size(); ++i) { | |||
input_name_idx.insert(std::pair<string, uint32_t>(key.at(i), value.at(i))); | |||
} | |||
} | |||
return input_name_idx; | |||
} | |||
void OpDesc::SetAllInputName(const std::map<string, uint32_t> &input_name_idx) { | |||
std::vector<string> key; | |||
std::vector<uint32_t> value; | |||
for (auto &item : input_name_idx) { | |||
key.emplace_back(item.first); | |||
value.emplace_back(item.second); | |||
} | |||
(void)AttrUtils::SetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key); | |||
(void)AttrUtils::SetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value); | |||
} | |||
std::map<string, uint32_t> OpDesc::GetAllOutputName() { return output_name_idx_; } | |||
@@ -619,6 +695,7 @@ bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) { | |||
auto factory_map_size = input_name_idx.size(); | |||
// It indicates that some inputs have no optionalname. | |||
// The redundant optionalname of factory needs to be deleted and then assigned | |||
auto all_input_name_idx = GetAllInputName(); | |||
if (input_map_size < factory_map_size) { | |||
GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size, | |||
factory_map_size); | |||
@@ -631,22 +708,23 @@ bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) { | |||
} | |||
if (input_name_idx.size() == input_map_size) { | |||
GELOGI("UpdateInputName"); | |||
input_name_idx_ = input_name_idx; | |||
all_input_name_idx = input_name_idx; | |||
} else { | |||
ret = false; | |||
GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size()); | |||
} | |||
} else if (input_map_size == factory_map_size) { | |||
input_name_idx_ = input_name_idx; | |||
all_input_name_idx = input_name_idx; | |||
} else { | |||
ret = false; | |||
GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); | |||
} | |||
SetAllInputName(all_input_name_idx); | |||
return ret; | |||
} | |||
bool OpDesc::UpdateOutputName(std::map<string, uint32_t> output_name_idx) { | |||
size_t output_map_size = GetAllOutputsDesc().size(); | |||
size_t output_map_size = GetAllOutputsDescSize(); | |||
size_t factory_map_size = output_name_idx.size(); | |||
if (output_map_size < factory_map_size) { | |||
GELOGI("UpdateOutputName org outputname map size: %zu, factory outputname map size: %zu", output_map_size, | |||
@@ -754,17 +832,17 @@ graphStatus OpDesc::OpVerify() { | |||
} | |||
graphStatus OpDesc::CommonVerify() const { | |||
for (string iname : GetAllInputNames()) { | |||
for (const string &iname : GetAllInputNames()) { | |||
// Checking shape of all inputs | |||
vector<int64_t> ishape = GetInputDesc(iname).GetShape().GetDims(); | |||
vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims(); | |||
for (int64_t dim : ishape) { | |||
GE_CHK_BOOL_RET_STATUS(dim >= -1, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", | |||
iname.c_str()); | |||
} | |||
} | |||
// Check all attributes defined | |||
const auto all_attributes = GetAllAttrs(); | |||
for (const auto name : GetAllAttrNames()) { | |||
const auto &all_attributes = GetAllAttrs(); | |||
for (const auto &name : GetAllAttrNames()) { | |||
GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, | |||
"operator attribute %s is empty.", name.c_str()); | |||
} | |||
@@ -773,19 +851,21 @@ graphStatus OpDesc::CommonVerify() const { | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const { | |||
auto it = input_name_idx_.begin(); | |||
for (; it != input_name_idx_.end(); ++it) { | |||
auto input_name_idx = GetAllInputName(); | |||
auto it = input_name_idx.begin(); | |||
for (; it != input_name_idx.end(); ++it) { | |||
if (it->second == index) { | |||
break; | |||
} | |||
} | |||
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), ""); | |||
GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), ""); | |||
return it->first; | |||
} | |||
int OpDesc::GetInputIndexByName(const string &name) const { | |||
auto it_find = input_name_idx_.find(name); | |||
GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1); | |||
auto input_name_idx = GetAllInputName(); | |||
auto it_find = input_name_idx.find(name); | |||
GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx.end(), -1); | |||
return static_cast<int>(it_find->second); | |||
} | |||
@@ -1065,10 +1145,12 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<bool> OpDesc::GetIsInputCo | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name, | |||
const int &index) { | |||
if (input_name_idx_.find(name) != input_name_idx_.end()) { | |||
auto input_name_idx = GetAllInputName(); | |||
if (input_name_idx.find(name) != input_name_idx.end()) { | |||
GELOGI("Restore input name index is existed. name[%s]", name.c_str()); | |||
} | |||
(void)input_name_idx_.insert(make_pair(name, index)); | |||
(void)input_name_idx.insert(make_pair(name, index)); | |||
SetAllInputName(input_name_idx); | |||
return GRAPH_SUCCESS; | |||
} | |||
@@ -1104,4 +1186,45 @@ graphStatus OpDesc::CallInferFormatFunc(Operator &op) { | |||
} | |||
return (graphStatus)infer_format_func_(op); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetSubgraphInstanceName(uint32_t index) const { | |||
if (static_cast<size_t>(index) >= subgraph_instance_names_.size()) { | |||
return ""; | |||
} | |||
return subgraph_instance_names_.at(index); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector<std::string> &OpDesc::GetSubgraphInstanceNames() | |||
const { | |||
return subgraph_instance_names_; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::AddSubgraphInstanceName(std::string name) { | |||
subgraph_instance_names_.emplace_back(std::move(name)); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { | |||
for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { | |||
if (*iter == name) { | |||
subgraph_instance_names_.erase(iter); | |||
return; | |||
} | |||
} | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { | |||
auto iter = subgraph_names_to_index_.find(name); | |||
if (iter != subgraph_names_to_index_.end()) { | |||
GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second); | |||
return GRAPH_FAILED; | |||
} | |||
auto size = subgraph_names_to_index_.size(); | |||
subgraph_names_to_index_[name] = size; | |||
return GRAPH_SUCCESS; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map<std::string, uint32_t> &OpDesc::GetSubgraphNameIndexes() | |||
const { | |||
return subgraph_names_to_index_; | |||
} | |||
} // namespace ge |
@@ -20,8 +20,7 @@ | |||
#include "debug/ge_log.h" | |||
#include "debug/ge_util.h" | |||
using std::function; | |||
using std::vector; | |||
using namespace std; | |||
namespace ge { | |||
@@ -15,13 +15,12 @@ | |||
*/ | |||
#include "external/graph/operator.h" | |||
#include <stdint.h> | |||
#include <algorithm> | |||
#include <mutex> | |||
#include <queue> | |||
#include <set> | |||
#include "array_ops.h" | |||
#include "debug/ge_log.h" | |||
#include "debug/ge_op_types.h" | |||
#include "debug/ge_util.h" | |||
@@ -33,7 +32,6 @@ | |||
#include "graph/ge_tensor.h" | |||
#include "graph/node.h" | |||
#include "graph/op_desc.h" | |||
#include "graph/operator_factory.h" | |||
#include "graph/usr_types.h" | |||
#include "utils/graph_utils.h" | |||
#include "utils/op_desc_utils.h" | |||
@@ -48,10 +46,6 @@ using std::string; | |||
using std::to_string; | |||
using std::vector; | |||
namespace { | |||
const char *const kValue = "value"; | |||
} // namespace | |||
namespace ge { | |||
class OpIO { | |||
public: | |||
@@ -148,6 +142,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||
for (int i = static_cast<int>(is_input_const.size()); i <= dst_index; ++i) { | |||
is_input_const.push_back(false); | |||
} | |||
is_input_const[dst_index] = is_const; | |||
op_desc_->SetIsInputConst(is_input_const); | |||
@@ -179,8 +174,8 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||
GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), | |||
op_desc_->GetName().c_str()); | |||
auto out_op_impl = out_handler->GetOwner(); | |||
GE_CHK_BOOL_EXEC(out_op_impl && out_op_impl->GetOpDescImpl(), return, "out_handler invalid. name[%s]", | |||
dst_name.c_str()); | |||
GE_CHK_BOOL_EXEC(out_op_impl != nullptr && out_op_impl->GetOpDescImpl() != nullptr, return, | |||
"out_handler invalid. name[%s]", dst_name.c_str()); | |||
bool is_const = false; | |||
if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) { | |||
is_const = true; | |||
@@ -193,7 +188,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||
op_desc_->SetIsInputConst(is_input_const); | |||
OpIO in_handler(dst_name, dst_index, shared_from_this()); | |||
GE_CHK_BOOL_EXEC(!!out_op_impl, return, "Get out_handler's impl failed."); | |||
GE_CHK_BOOL_EXEC(out_op_impl != nullptr, return, "Get out_handler's impl failed."); | |||
out_op_impl->UpdateLinkMapImpl(src_name, in_handler); | |||
auto src_output_desc = out_op_impl->GetOutputDesc(src_name); | |||
@@ -210,7 +205,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||
void AddControlInputImp(const ge::Operator &src_oprt) { | |||
if (src_oprt.operator_impl_ == nullptr) { | |||
GELOGE(GRAPH_FAILED, "Src operator impl is nullptr"); | |||
GELOGE(FAILED, "Src operator impl is nullptr"); | |||
return; | |||
} | |||
for (auto &input : control_input_link_) { | |||
@@ -520,9 +515,9 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co | |||
if (peer_node_ptr->GetOpDesc() != nullptr) { | |||
const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType(); | |||
if (op_descType == CONSTANTOP) { | |||
return const_op.GetAttr(kValue, data); | |||
return const_op.GetAttr(op::Constant::name_attr_value(), data); | |||
} else if (op_descType == CONSTANT) { | |||
return const_op.GetAttr(kValue, data); | |||
return const_op.GetAttr(op::Const::name_attr_value(), data); | |||
} | |||
} | |||
} else { | |||
@@ -542,9 +537,9 @@ graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) | |||
Operator const_op(out_handle.GetOwner()); | |||
const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType(); | |||
if (op_desc_impl_type == CONSTANTOP) { | |||
return const_op.GetAttr(kValue, data); | |||
return const_op.GetAttr(op::Constant::name_attr_value(), data); | |||
} else if (op_desc_impl_type == CONSTANT) { | |||
return const_op.GetAttr(kValue, data); | |||
return const_op.GetAttr(op::Const::name_attr_value(), data); | |||
} | |||
} | |||
return GRAPH_FAILED; | |||
@@ -709,6 +704,7 @@ void Operator::InputRegister(const string &name) { | |||
void Operator::OptionalInputRegister(const string &name) { | |||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | |||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | |||
// [No need to verify return value] | |||
(void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name, | |||
GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED)); | |||
} | |||
@@ -716,24 +712,28 @@ void Operator::OptionalInputRegister(const string &name) { | |||
void Operator::InferFuncRegister(const std::function<graphStatus(Operator &)> &func) { | |||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | |||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | |||
// [No need to verify return value] | |||
(void)operator_impl_->GetOpDescImpl()->AddInferFunc(func); | |||
} | |||
void Operator::InferFormatFuncRegister(const std::function<graphStatus(Operator &)> &func) { | |||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | |||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | |||
// [No need to verify return value] | |||
(void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func); | |||
} | |||
void Operator::VerifierFuncRegister(const std::function<graphStatus(Operator &)> &func) { | |||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | |||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | |||
// [No need to verify return value] | |||
(void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func); | |||
} | |||
void Operator::OutputRegister(const string &name) { | |||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | |||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | |||
// [No need to verify return value] | |||
(void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc()); | |||
} | |||
@@ -757,7 +757,8 @@ int Operator::GetDynamicInputNum(const string &name) const { | |||
void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) { | |||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | |||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | |||
(void)AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num); | |||
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return, | |||
"Set %s int failed", name.c_str()); | |||
(void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back); | |||
} | |||
@@ -765,7 +766,8 @@ int Operator::GetDynamicOutputNum(const string &name) const { | |||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); | |||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); | |||
int num = 0; | |||
(void)AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num); | |||
GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return num, | |||
"Get %s int failed", name.c_str()); | |||
return num; | |||
} | |||
@@ -1141,7 +1143,9 @@ class GraphBuilderImpl { | |||
GELOGW("Input operator should be Data, Variable operator or operator that has output but no input."); | |||
} | |||
} | |||
GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr, | |||
"User Input do not include operator such as \ | |||
Data, Variable operator or operator that has output but no input."); | |||
auto ret = WalkAllOperators(vec_inputs); | |||
GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed."); | |||
@@ -1163,7 +1167,8 @@ class GraphBuilderImpl { | |||
que.pop(); | |||
for (const auto &op_impl : vec_tem) { | |||
GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.") | |||
GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue) | |||
GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue, | |||
"This node %s has created.", op_impl->GetName().c_str()) | |||
auto node_ptr = graph_->AddNode(op_impl->op_desc_); | |||
GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed."); | |||
all_nodes_info_.insert(std::make_pair(op_impl, node_ptr)); | |||
@@ -1202,10 +1207,13 @@ class GraphBuilderImpl { | |||
for (const auto &node_info : all_nodes_info_) { | |||
auto src_op_impl_ptr = node_info.first; | |||
auto src_node_ptr = node_info.second; | |||
GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue); | |||
auto out_links = src_op_impl_ptr->output_links_; | |||
GE_CHK_BOOL_EXEC(src_op_impl_ptr->op_desc_ != nullptr, return GRAPH_FAILED, | |||
"Src operator impl's op_desc is null."); | |||
auto &op_desc = src_op_impl_ptr->op_desc_; | |||
GE_IF_BOOL_EXEC(op_desc == nullptr, continue); | |||
for (const auto &out : out_links) { | |||
auto src_idx = op_desc->GetOutputIndexByName(out.first); | |||
GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed"); | |||
@@ -1216,7 +1224,9 @@ class GraphBuilderImpl { | |||
for (const auto &dst_opio : out.second) { | |||
auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner()); | |||
GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed."); | |||
GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); | |||
auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex()); | |||
GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed."); | |||
@@ -1260,8 +1270,7 @@ inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) { | |||
ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector<ge::Operator> &inputs) { | |||
auto graph_builder_impl = GraphBuilderImpl(name); | |||
ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs); | |||
GE_IF_BOOL_EXEC(compute_graph == nullptr, return compute_graph); | |||
GE_CHK_BOOL_EXEC(compute_graph != nullptr, return compute_graph, "Computer graph is nullptr"); | |||
compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo()); | |||
if (HasSameNameNode(compute_graph)) { | |||
GELOGW("Compute do not allow has same name nodes."); | |||
@@ -15,13 +15,11 @@ | |||
*/ | |||
#include "graph/opsproto_manager.h" | |||
#include <algorithm> | |||
#include <cstdlib> | |||
#include <algorithm> | |||
#include <functional> | |||
#include <iostream> | |||
#include <sstream> | |||
#include "debug/ge_util.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "graph/debug/ge_log.h" | |||
@@ -155,7 +153,7 @@ void OpsProtoManager::LoadOpsProtoPluginSo(std::string &path) { | |||
// Load .so file | |||
for (auto elem : file_list) { | |||
void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL | RTLD_NODELETE); | |||
void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||
if (handle == nullptr) { | |||
GELOGW("OpsProtoManager dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); | |||
continue; | |||
@@ -15,7 +15,6 @@ | |||
*/ | |||
#include "./ge_context.h" | |||
#include "./ge_global_options.h" | |||
#include "./ge_local_context.h" | |||
#include "framework/common/debug/ge_log.h" | |||
@@ -87,4 +86,5 @@ uint32_t GEContext::DeviceId() { return device_id_; } | |||
uint64_t GEContext::TraceId() { return trace_id_; } | |||
void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } | |||
} // namespace ge |
@@ -22,6 +22,7 @@ | |||
#include <utility> | |||
#include <vector> | |||
#include "graph/utils/graph_utils.h" | |||
#include "debug/ge_log.h" | |||
#include "debug/ge_op_types.h" | |||
#include "external/graph/operator.h" | |||
@@ -34,6 +35,122 @@ | |||
#include "utils/type_utils.h" | |||
namespace ge { | |||
namespace { | |||
constexpr const char *kRefIndex = "parent_node_index"; | |||
graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
auto op_desc = node->GetOpDesc(); | |||
auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
if (sub_graph_names.empty()) { | |||
return GRAPH_SUCCESS; | |||
} | |||
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
for (const auto &name : sub_graph_names) { | |||
auto sub_graph = root_graph->GetSubgraph(name); | |||
if (sub_graph == nullptr) { | |||
GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
for (const auto &node_sub : sub_graph->GetDirectNode()) { | |||
if (node_sub->GetType() != DATA) { | |||
continue; | |||
} | |||
int ref_i; | |||
auto data_opdesc = node_sub->GetOpDesc(); | |||
if (data_opdesc == nullptr) { | |||
GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), | |||
node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
if (!AttrUtils::GetInt(node_sub->GetOpDesc(), kRefIndex, ref_i)) { | |||
GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), | |||
node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
auto input_desc = op_desc->MutableInputDesc(ref_i); | |||
if (input_desc == nullptr) { | |||
GE_LOGE( | |||
"The ref index(%d) on the data %s on the sub graph %s " | |||
"parent node %s are incompatible, inputs num %u", | |||
ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); | |||
return GRAPH_FAILED; | |||
} | |||
auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); | |||
if (ret != GRAPH_SUCCESS) { | |||
GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s", | |||
node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||
return ret; | |||
} | |||
ret = data_opdesc->UpdateOutputDesc(0, *input_desc); | |||
if (ret != GRAPH_SUCCESS) { | |||
GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s", | |||
node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||
return ret; | |||
} | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
auto op_desc = node->GetOpDesc(); | |||
auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
if (sub_graph_names.empty()) { | |||
return GRAPH_SUCCESS; | |||
} | |||
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
for (const auto &name : sub_graph_names) { | |||
auto sub_graph = root_graph->GetSubgraph(name); | |||
if (sub_graph == nullptr) { | |||
GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
NodePtr netoutput = nullptr; | |||
auto sub_nodes = sub_graph->GetDirectNode(); | |||
for (size_t i = sub_nodes.size(); i > 0; --i) { | |||
auto sub_node = sub_nodes.at(i - 1); | |||
if (sub_node->GetType() == NETOUTPUT) { | |||
netoutput = sub_node; | |||
break; | |||
} | |||
} | |||
if (netoutput == nullptr) { | |||
GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
auto netoutput_opdesc = netoutput->GetOpDesc(); | |||
if (netoutput_opdesc == nullptr) { | |||
GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(), | |||
node->GetName().c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) { | |||
auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx()); | |||
if (edge_desc == nullptr) { | |||
GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", name.c_str(), | |||
node->GetName().c_str(), edge_anchor->GetIdx()); | |||
return GRAPH_FAILED; | |||
} | |||
int ref_i; | |||
if (!AttrUtils::GetInt(edge_desc, kRefIndex, ref_i)) { | |||
// if there is no ref index on the TensorDesc, it means the output data will be ignored outer. | |||
continue; | |||
} | |||
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i)); | |||
if (output_desc == nullptr) { | |||
GE_LOGE( | |||
"The ref index(%d) on the input %d of netoutput %s on the sub graph %s " | |||
"parent node %s are incompatible, outputs num %u", | |||
ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(), | |||
node->GetAllOutDataAnchorsSize()); | |||
return GRAPH_FAILED; | |||
} | |||
op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc); | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
} // namespace | |||
void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { | |||
if (node == nullptr) { | |||
GELOGE(GRAPH_FAILED, "node is null"); | |||
@@ -42,7 +159,7 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||
ge::OpDescPtr op_desc = node->GetOpDesc(); | |||
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); | |||
std::string str; | |||
if (!op_desc->GetAllInputsDescPtr().empty()) { | |||
if (op_desc->GetInputsSize() != 0) { | |||
std::string input_desc_str = "input shape: "; | |||
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | |||
input_desc_str += "["; | |||
@@ -56,7 +173,7 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||
str += input_desc_str; | |||
} | |||
if (!op_desc->GetAllOutputsDescPtr().empty()) { | |||
if (op_desc->GetAllOutputsDescSize() != 0) { | |||
std::string output_desc_str = "output shape: "; | |||
for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { | |||
if (output_desc == nullptr) { | |||
@@ -76,13 +193,24 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||
} | |||
graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) { | |||
return InferShapeAndType(node, op, true); | |||
} | |||
graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) { | |||
GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | |||
auto op_desc = node->GetOpDesc(); | |||
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); | |||
const auto &op_type = op_desc->GetType(); | |||
graphStatus ret; | |||
if (before_subgraph) { | |||
ret = UpdateSubGraphDataNodes(node); | |||
if (ret != GRAPH_SUCCESS) { | |||
return ret; | |||
} | |||
} | |||
// Get infer func and execute | |||
graphStatus ret = op_desc->CallInferFunc(op); | |||
ret = op_desc->CallInferFunc(op); | |||
if (ret == GRAPH_PARAM_INVALID) { | |||
// Op ir no infer func, try to get infer func from operator factory | |||
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); | |||
@@ -113,7 +241,14 @@ graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator & | |||
ret = op_desc->CallInferFunc(op); | |||
GELOGI("op CallInferFunc second. ret: %u", ret); | |||
} | |||
return ret; | |||
if (ret != GRAPH_SUCCESS) { | |||
return ret; | |||
} | |||
if (!before_subgraph) { | |||
return UpdateParentNodeOutTensor(node); | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map, | |||
@@ -179,8 +314,11 @@ InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, Inf | |||
namespace { | |||
std::unordered_map<NodePtr, InferenceContextPtr> context_map; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) { | |||
return InferShapeAndType(node, true); | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node, | |||
bool before_subgraph) { | |||
GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | |||
if (node->Verify() != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str()); | |||
@@ -199,7 +337,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh | |||
Operator op = OpDescUtils::CreateOperatorFromNode(node); | |||
op.SetInferenceContext(inference_context); | |||
graphStatus status = InferShapeAndType(node, op); | |||
graphStatus status = InferShapeAndType(node, op, before_subgraph); | |||
if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { | |||
(void)ge::NodeUtils::UpdatePeerNodeInputDesc(node); | |||
} else { | |||
@@ -353,6 +353,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) | |||
} | |||
} | |||
} | |||
impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size); | |||
} | |||
@@ -516,13 +517,14 @@ graphStatus Tensor::IsValid() { | |||
GELOGW("mul overflow: %lu, %u", shape_size, type_length); | |||
} else { | |||
if (shape_size * type_length != data_size) { | |||
// [Just log] Constructor | |||
GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, | |||
data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
} | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
@@ -539,7 +541,7 @@ GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_des | |||
tensor_desc.GetDataType()); | |||
ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); | |||
ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); | |||
auto size = static_cast<uint32_t>(tensor_desc.GetSize()); | |||
auto size = tensor_desc.GetSize(); | |||
TensorUtils::SetSize(ge_tensor_desc, size); | |||
auto real_dim_cnt = static_cast<uint32_t>(tensor_desc.GetRealDimCnt()); | |||
@@ -552,7 +554,7 @@ TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_ | |||
ge_tensor_desc.GetDataType()); | |||
tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); | |||
tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); | |||
uint32_t size = 0; | |||
int64_t size = 0; | |||
(void)TensorUtils::GetSize(ge_tensor_desc, size); | |||
tensor_desc.SetSize(size); | |||
@@ -15,18 +15,21 @@ | |||
*/ | |||
#include "graph/utils/ge_ir_utils.h" | |||
#include <utility> | |||
#include "framework/common/debug/ge_log.h" | |||
namespace { | |||
const char *const kControlAnchorIndex = ":-1"; | |||
const char *const kNodeTypeForSubgraph = "subgraph"; | |||
const char *const kPrefixForInputDesc = "input_desc_attr_"; | |||
const char *const kPrefixForOutputDesc = "output_desc_attr_"; | |||
const char *const kDumpGEGraph = "DUMP_GE_GRAPH"; | |||
const int8_t kMaxRecursionDepth = 10; | |||
const char *const kDumpGeGraph = std::getenv(kDumpGEGraph); | |||
const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, nullptr, 10) : ge::OnnxUtils::NO_DUMP; | |||
const int64_t kInputPrefixLength = 5; | |||
const int64_t kOutputPrefixLength = 6; | |||
using AttrDefPair = ::google::protobuf::MapPair<std::string, ge::proto::AttrDef>; | |||
} // namespace | |||
namespace ge { | |||
@@ -198,7 +201,7 @@ void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_A | |||
void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, | |||
::google::protobuf::RepeatedField<bool> data) { | |||
if (node_proto == nullptr) { | |||
GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); | |||
GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str()); | |||
return; | |||
} | |||
if (!data.empty()) { | |||
@@ -320,7 +323,16 @@ void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const | |||
auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset(); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, | |||
"input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset); | |||
const auto &tensor_desc_map = tensor_descriptor->attr(); | |||
std::string suffix = ":" + std::to_string(i); | |||
AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForInputDesc, suffix); | |||
} else { | |||
GELOGW("Tensor descriptor is nullptr"); | |||
continue; | |||
} | |||
} else { | |||
GELOGW("Input desc is nullptr"); | |||
continue; | |||
} | |||
} | |||
} | |||
@@ -360,16 +372,25 @@ void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const | |||
auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, | |||
"output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); | |||
const auto &tensor_desc_map = tensor_descriptor->attr(); | |||
std::string suffix = ":" + std::to_string(i); | |||
AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForOutputDesc, suffix); | |||
} else { | |||
GELOGW("Tensor descriptor is nullptr"); | |||
continue; | |||
} | |||
} else { | |||
GELOGW("Output desc is nullptr"); | |||
continue; | |||
} | |||
} | |||
} | |||
} | |||
void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto) { | |||
GE_CHK_BOOL_EXEC(op_def != nullptr, return, "Opdef is nullptr"); | |||
const auto &op_def_attr_map = op_def->attr(); | |||
for (const auto &item : op_def_attr_map) { | |||
void OnnxUtils::AddAttrProtoForAttrsFromAttrMap( | |||
const ::google::protobuf::Map<std::string, ::ge::proto::AttrDef> &attr_map, onnx::NodeProto *node_proto, | |||
const std::string &prefix, const std::string &suffix) { | |||
for (const auto &item : attr_map) { | |||
auto attr_name = item.first; | |||
auto attr_def = item.second; | |||
auto attr_type = attr_def.value_case(); | |||
@@ -377,36 +398,40 @@ void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, on | |||
const auto &tensor_def = attr_def.t(); | |||
const auto &tensor_desc = tensor_def.desc(); | |||
auto data_type = ge::proto::DataType_Name(tensor_desc.dtype()); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_desc_dtype:", &data_type); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_dtype" + suffix, | |||
&data_type); | |||
auto dims = tensor_desc.shape().dim(); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name + "_desc_shape:", dims); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + "_desc_shape" + suffix, | |||
dims); | |||
auto layout = tensor_desc.layout(); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_desc_layout:", &layout); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_layout" + suffix, | |||
&layout); | |||
auto device_type = tensor_desc.device_type(); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, | |||
attr_name + "_desc_device_type:", &device_type); | |||
prefix + attr_name + "_desc_device_type" + suffix, &device_type); | |||
if (kDumpLevel == DUMP_ALL) { | |||
auto data = tensor_def.data(); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_data", &data); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_data" + suffix, | |||
&data); | |||
} | |||
} | |||
if (attr_type == ge::proto::AttrDef::kS) { | |||
if (kDumpLevel == DUMP_ALL) { | |||
auto str_value = attr_def.s(); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name, &str_value); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value); | |||
} | |||
} | |||
if (attr_type == ge::proto::AttrDef::kI) { | |||
auto int_value = attr_def.i(); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, attr_name, &int_value); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); | |||
} | |||
if (attr_type == ge::proto::AttrDef::kF) { | |||
auto float_value = attr_def.f(); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, attr_name, &float_value); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value); | |||
} | |||
if (attr_type == ge::proto::AttrDef::kB) { | |||
auto int_value = static_cast<int64_t>(attr_def.b()); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, attr_name, &int_value); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); | |||
} | |||
if (attr_type == ge::proto::AttrDef::kList) { | |||
const auto &list_value = attr_def.list(); | |||
@@ -415,21 +440,21 @@ void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, on | |||
ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) { | |||
if (kDumpLevel == DUMP_ALL) { | |||
const auto &strings = list_value.s(); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, attr_name, strings); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings); | |||
} | |||
} | |||
if (list_value_type == | |||
ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) { | |||
const auto &floats = list_value.f(); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, attr_name, floats); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats); | |||
} | |||
if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) { | |||
const auto &ints = list_value.i(); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name, ints); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints); | |||
} | |||
if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) { | |||
const auto &bools = list_value.b(); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name, bools); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools); | |||
} | |||
} | |||
} | |||
@@ -481,8 +506,15 @@ void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes); | |||
const auto &is_input_const = op_def->is_input_const(); | |||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const); | |||
AddAttrProtoForAttrsFromOpDef(op_def, node_proto); | |||
const auto &op_def_attr_map = op_def->attr(); | |||
AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto); | |||
} else { | |||
GELOGE(FAILED, "Opdef is nullptr"); | |||
return; | |||
} | |||
} else { | |||
GELOGE(FAILED, "Opdesc is nullptr"); | |||
return; | |||
} | |||
} | |||
@@ -526,15 +558,13 @@ bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) | |||
node_proto->clear_input(); | |||
// 1. Add input by in data edge | |||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
if (in_data_anchor != nullptr) { | |||
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) { | |||
node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + | |||
std::to_string(peer_out_anchor->GetIdx())); | |||
} else { | |||
// Add "" input | |||
node_proto->add_input(""); | |||
} | |||
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) { | |||
node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + | |||
std::to_string(peer_out_anchor->GetIdx())); | |||
} else { | |||
// Add "" input | |||
node_proto->add_input(""); | |||
} | |||
} | |||
@@ -547,6 +577,9 @@ bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) | |||
node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex); | |||
} | |||
} | |||
} else { | |||
GELOGE(FAILED, "Incontrol anchor is nullptr"); | |||
return false; | |||
} | |||
// 3. Add output for Netron visual support | |||
@@ -584,7 +617,7 @@ void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_T | |||
} | |||
const auto &op_desc = node->GetOpDesc(); | |||
if (op_desc != nullptr) { | |||
auto size_out = op_desc->GetOutputsSize(); | |||
uint32_t size_out = static_cast<uint32_t>(op_desc->GetOutputsSize()); | |||
if (size_out > 0) { | |||
for (uint32_t i = 0; i < size_out; i++) { | |||
const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i); | |||
@@ -598,7 +631,13 @@ void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_T | |||
auto dim = shape->add_dim(); | |||
dim->set_dim_value(d); | |||
} | |||
} else { | |||
GELOGW("Shape is nullptr"); | |||
continue; | |||
} | |||
} else { | |||
GELOGW("Ge tensor is nullptr"); | |||
continue; | |||
} | |||
} | |||
} | |||
@@ -666,7 +705,7 @@ bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelPr | |||
} | |||
// For subgraphs: a subgraph is represented by a node | |||
for (const auto &sub_compute_graph : compute_graph->sub_graph_) { | |||
for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) { | |||
if (sub_compute_graph != nullptr) { | |||
auto node_proto = graph_proto->add_node(); | |||
if (node_proto == nullptr) { | |||
@@ -679,6 +718,10 @@ bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelPr | |||
attr->set_name("graph"); | |||
attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); | |||
auto sub_graph_proto = attr->mutable_g(); | |||
if (sub_graph_proto == nullptr) { | |||
GELOGW("Sub graph proto is nullptr"); | |||
continue; | |||
} | |||
if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) { | |||
GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str()); | |||
continue; | |||
@@ -831,56 +874,116 @@ void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t | |||
value = attr_proto.i(); | |||
} | |||
void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | |||
const std::string &attr_name_for_input_output_desc, int32_t index, | |||
OpDescPtr &op_desc) { | |||
if (op_desc == nullptr || op_desc->MutableInputDesc(static_cast<uint32_t>(index)) == nullptr) { | |||
GELOGE(GRAPH_FAILED, "op_desc or op_desc->MutableInputDesc(index) is nullptr"); | |||
void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, | |||
const std::string &attr_name_for_input_desc, int32_t index, | |||
OpDescPtr &op_desc) { | |||
if (op_desc->MutableInputDesc(static_cast<uint32_t>(index)) == nullptr) { | |||
GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableInputDesc(static_cast<uint32_t>(index)) is nullptr", | |||
op_desc->GetName().c_str(), attr_name_for_input_desc.c_str()); | |||
return; | |||
} | |||
if (attr_name_for_input_output_desc == "input_desc_dtype") { | |||
if (attr_name_for_input_desc == "input_desc_dtype") { | |||
auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); | |||
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetDataType(data_type); | |||
} else if (attr_name_for_input_output_desc == "input_desc_shape") { | |||
} else if (attr_name_for_input_desc == "input_desc_shape") { | |||
std::vector<std::int64_t> ints; | |||
DecodeAttribute(attr_proto, ints); | |||
GeShape ge_shape(ints); | |||
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape); | |||
} else if (attr_name_for_input_output_desc == "input_desc_layout") { | |||
} else if (attr_name_for_input_desc == "input_desc_layout") { | |||
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | |||
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetFormat(data_format); | |||
} else if (attr_name_for_input_output_desc == "input_desc_origin_shape") { | |||
} else if (attr_name_for_input_desc == "input_desc_origin_shape") { | |||
std::vector<std::int64_t> ints; | |||
DecodeAttribute(attr_proto, ints); | |||
GeShape ge_shape(ints); | |||
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape); | |||
} else if (attr_name_for_input_output_desc == "input_desc_origin_layout") { | |||
} else if (attr_name_for_input_desc == "input_desc_origin_layout") { | |||
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | |||
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format); | |||
} else if (attr_name_for_input_output_desc == "output_desc_dtype") { | |||
} else if (attr_name_for_input_desc == "input_desc_size") { | |||
int64_t input_size = 0; | |||
auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||
DecodeAttribute(attr_proto, input_size); | |||
tensor_descriptor->set_size(input_size); | |||
} else if (attr_name_for_input_desc == "input_desc_data_offset") { | |||
auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||
int64_t offset = 0; | |||
DecodeAttribute(attr_proto, offset); | |||
tensor_descriptor->set_data_offset(offset); | |||
} else { | |||
return; | |||
} | |||
} | |||
void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, | |||
const std::string &attr_name_for_output_desc, int32_t index, | |||
OpDescPtr &op_desc) { | |||
if (op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) == nullptr) { | |||
GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) is nullptr", | |||
op_desc->GetName().c_str(), attr_name_for_output_desc.c_str()); | |||
return; | |||
} | |||
if (attr_name_for_output_desc == "output_desc_dtype") { | |||
auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); | |||
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetDataType(data_type); | |||
} else if (attr_name_for_input_output_desc == "output_desc_shape") { | |||
} else if (attr_name_for_output_desc == "output_desc_shape") { | |||
std::vector<std::int64_t> ints; | |||
DecodeAttribute(attr_proto, ints); | |||
GeShape ge_shape(ints); | |||
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape); | |||
} else if (attr_name_for_input_output_desc == "output_desc_layout") { | |||
} else if (attr_name_for_output_desc == "output_desc_layout") { | |||
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | |||
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetFormat(data_format); | |||
} else if (attr_name_for_input_output_desc == "output_desc_origin_shape") { | |||
} else if (attr_name_for_output_desc == "output_desc_origin_shape") { | |||
std::vector<std::int64_t> ints; | |||
DecodeAttribute(attr_proto, ints); | |||
GeShape ge_shape(ints); | |||
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape); | |||
} else if (attr_name_for_input_output_desc == "output_desc_origin_layout") { | |||
} else if (attr_name_for_output_desc == "output_desc_origin_layout") { | |||
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | |||
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format); | |||
} else if (attr_name_for_output_desc == "output_desc_size") { | |||
int64_t output_size = 0; | |||
auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||
DecodeAttribute(attr_proto, output_size); | |||
tensor_descriptor->set_size(output_size); | |||
} else if (attr_name_for_output_desc == "output_desc_data_offset") { | |||
auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||
int64_t offset = 0; | |||
DecodeAttribute(attr_proto, offset); | |||
tensor_descriptor->set_data_offset(offset); | |||
} else { | |||
return; | |||
} | |||
} | |||
void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | |||
const std::string &attr_name_for_input_output_desc, int32_t index, | |||
OpDescPtr &op_desc) { | |||
if (op_desc == nullptr) { | |||
GELOGE(GRAPH_FAILED, "op_desc is nullptr"); | |||
return; | |||
} | |||
if (attr_name_for_input_output_desc.substr(0, kInputPrefixLength) == "input") { | |||
DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); | |||
} else if (attr_name_for_input_output_desc.substr(0, kOutputPrefixLength) == "output") { | |||
DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); | |||
} else { | |||
return; | |||
} | |||
} | |||
void OnnxUtils::DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def) { | |||
auto attr_map = op_def.mutable_attr(); | |||
const auto &attr_name = attr_proto.name(); | |||
ge::proto::AttrDef op_attr; | |||
int64_t value = 0; | |||
DecodeAttribute(attr_proto, value); | |||
op_attr.set_i(value); | |||
attr_map->insert(AttrDefPair(attr_name, op_attr)); | |||
} | |||
void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) { | |||
if (op_desc == nullptr) { | |||
GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr"); | |||
@@ -910,6 +1013,16 @@ void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_pr | |||
std::vector<std::int64_t> ints; | |||
DecodeAttribute(attr_proto, ints); | |||
op_desc->SetDstIndex(ints); | |||
} else if (attr_name == "fusion_scope") { | |||
DecodeNodeAttributeForOpDef(attr_proto, *op_desc->op_def_.GetProtoMsg()); | |||
} else if (attr_name == "input_i") { | |||
std::vector<std::int64_t> ints; | |||
DecodeAttribute(attr_proto, ints); | |||
op_desc->SetInputOffset(ints); | |||
} else if (attr_name == "output_i") { | |||
std::vector<std::int64_t> ints; | |||
DecodeAttribute(attr_proto, ints); | |||
op_desc->SetOutputOffset(ints); | |||
} else { | |||
return; | |||
} | |||
@@ -939,20 +1052,14 @@ bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_ | |||
auto size_in = attr.i(); | |||
for (int64_t i = 0; i < size_in; i++) { | |||
GeTensorDesc ge_tensor_desc; | |||
if (op_desc->AddInputDesc(ge_tensor_desc) != GRAPH_SUCCESS) { | |||
GELOGW("Add inputdesc failed"); | |||
continue; | |||
} | |||
GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed."); | |||
} | |||
} | |||
if (attr.name() == "output_desc_nums") { | |||
auto size_out = attr.i(); | |||
for (int64_t i = 0; i < size_out; i++) { | |||
GeTensorDesc ge_tensor_desc; | |||
if (op_desc->AddInputDesc(ge_tensor_desc) != GRAPH_SUCCESS) { | |||
GELOGW("add inputdesc failed"); | |||
continue; | |||
} | |||
GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed."); | |||
} | |||
} | |||
} | |||
@@ -970,10 +1077,7 @@ bool OnnxUtils::DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_p | |||
} | |||
graph = ComGraphMakeShared<ge::ComputeGraph>(graph_proto.name()); | |||
if (graph == nullptr) { | |||
GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); | |||
return false; | |||
} | |||
GE_CHK_BOOL_EXEC(graph != nullptr, return false, "ComputeGraph make shared failed"); | |||
/// 1. Decode all nodes first, node should include input | |||
/// and output nodes and nodes which represent sub graphs | |||
std::map<std::string, NodePtr> node_map; | |||
@@ -131,6 +131,10 @@ class OnnxUtils { | |||
static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc); | |||
static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map<std::string, ge::proto::AttrDef> &attr_map, | |||
onnx::NodeProto *node_proto, const std::string &prefix = "", | |||
const std::string &suffix = ""); | |||
static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto); | |||
static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type); | |||
@@ -172,10 +176,20 @@ class OnnxUtils { | |||
static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value); | |||
static void DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, | |||
const std::string &attr_name_for_output_desc, int32_t index, | |||
OpDescPtr &op_desc); | |||
static void DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, | |||
const std::string &attr_name_for_input_desc, int32_t index, | |||
OpDescPtr &op_desc); | |||
static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | |||
const std::string &attr_name_for_input_output_desc, int32_t index, | |||
OpDescPtr &op_desc); | |||
static void DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def); | |||
static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc); | |||
static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr); | |||
@@ -15,10 +15,12 @@ | |||
*/ | |||
#include "utils/node_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "debug/ge_op_types.h" | |||
#include "debug/ge_util.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "graph/anchor.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "utils/tensor_utils.h" | |||
#include "utils/type_utils.h" | |||
@@ -109,6 +111,7 @@ graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_pt | |||
graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) { | |||
GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED, | |||
"node or in_data_anchor is nullptr"); | |||
bool find_flag = false; | |||
uint32_t index = 0; | |||
vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end(); | |||
@@ -358,4 +361,45 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const | |||
input_desc->SetShape(shape); | |||
return GRAPH_SUCCESS; | |||
} | |||
std::string NodeUtils::GetNodeType(const Node &node) { | |||
if (node.GetType() != FRAMEWORKOP) { | |||
return node.GetType(); | |||
} | |||
std::string type; | |||
(void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | |||
return type; | |||
} | |||
ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { | |||
auto op_desc = node.GetOpDesc(); | |||
if (op_desc == nullptr) { | |||
return nullptr; | |||
} | |||
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||
if (root_graph == nullptr) { | |||
return nullptr; | |||
} | |||
return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); | |||
} | |||
graphStatus NodeUtils::AddSubgraph(Node &node, const ComputeGraphPtr &subgraph) { | |||
if (subgraph == nullptr) { | |||
GE_LOGE("Failed to add subgraph to node %s, null subgraph", node.GetName().c_str()); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
auto op_desc = node.GetOpDesc(); | |||
if (op_desc == nullptr) { | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||
if (root_graph == nullptr) { | |||
GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); | |||
return GRAPH_PARAM_INVALID; | |||
} | |||
op_desc->AddSubgraphInstanceName(subgraph->GetName()); | |||
subgraph->SetParentNode(node.shared_from_this()); | |||
subgraph->SetParentGraph(node.GetOwnerComputeGraph()); | |||
root_graph->AddSubgraph(subgraph); | |||
return GRAPH_SUCCESS; | |||
} | |||
} // namespace ge |
@@ -15,9 +15,7 @@ | |||
*/ | |||
#include "utils/op_desc_utils.h" | |||
#include <algorithm> | |||
#include "debug/ge_attr_define.h" | |||
#include "debug/ge_op_types.h" | |||
#include "debug/ge_util.h" | |||
@@ -209,6 +207,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils:: | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData( | |||
const vector<ge::NodePtr> &input_nodes) { | |||
vector<ConstGeTensorPtr> ret; | |||
for (const auto &input_node : input_nodes) { | |||
auto temp_weight = MutableWeights(input_node->GetOpDesc()); | |||
if (temp_weight == nullptr) { | |||
@@ -379,7 +378,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUt | |||
if (NodeUtils::IsAnchorStatusSet(*node)) { | |||
for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||
if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) { | |||
(void)ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||
ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||
} | |||
} | |||
} else { | |||
@@ -389,7 +388,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUt | |||
continue; | |||
} | |||
if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) { | |||
(void)ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||
ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||
} | |||
} | |||
} | |||
@@ -572,4 +571,80 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWei | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
/// | |||
/// @brief Add input | |||
/// @param [in] name | |||
/// @return OpDescBuilder | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) { | |||
inputs_.emplace_back(name); | |||
return *this; | |||
} | |||
/// | |||
/// @brief Add dynamic input | |||
/// @param [in] name | |||
/// @param [in] num | |||
/// @return OpDescBuilder | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, | |||
uint32_t num) { | |||
for (uint32_t i = 0; i < num; i++) { | |||
inputs_.emplace_back(name + std::to_string(i)); | |||
} | |||
return *this; | |||
} | |||
/// | |||
/// @brief Add output | |||
/// @param [in] name | |||
/// @return OpDescBuilder | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) { | |||
outputs_.emplace_back(name); | |||
return *this; | |||
} | |||
/// | |||
/// @brief Add dynamic output | |||
/// @param [in] name | |||
/// @param [in] num | |||
/// @return OpDescBuilder | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, | |||
uint32_t num) { | |||
for (uint32_t i = 0; i < num; i++) { | |||
outputs_.emplace_back(name + std::to_string(i)); | |||
} | |||
return *this; | |||
} | |||
/// | |||
/// @brief Build op_desc | |||
/// @return OpDescPtr | |||
/// | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() { | |||
OpDescPtr op_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name_, type_)); | |||
if (op_desc == nullptr) { | |||
GELOGE(GRAPH_FAILED, "OpDesc is nullptr"); | |||
return nullptr; | |||
} | |||
for (auto &input : inputs_) { | |||
if (op_desc->AddInputDesc(input, GeTensorDesc()) != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Add input_desc failed."); | |||
return nullptr; | |||
} | |||
} | |||
for (auto &output : outputs_) { | |||
if (op_desc->AddOutputDesc(output, GeTensorDesc()) != GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "Add output_desc failed."); | |||
return nullptr; | |||
} | |||
} | |||
return op_desc; | |||
} | |||
} // namespace ge |
@@ -15,7 +15,6 @@ | |||
*/ | |||
#include "graph/utils/tensor_utils.h" | |||
#include <cmath> | |||
#include "debug/ge_log.h" | |||
@@ -276,6 +275,14 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||
break; | |||
case FORMAT_FRACTAL_NZ: | |||
case FORMAT_FRACTAL_ZZ: | |||
case FORMAT_NDHWC: | |||
case FORMAT_NCDHW: | |||
case FORMAT_DHWCN: | |||
case FORMAT_DHWNC: | |||
case FORMAT_FRACTAL_Z_3D: | |||
case FORMAT_FRACTAL_Z_3D_TRANSPOSE: | |||
case FORMAT_NDC1HWC0: | |||
case FORMAT_FRACTAL_Z_C04: | |||
graph_status = CalcElementCntByDims(dims, element_cnt); | |||
break; | |||
default: | |||
@@ -351,21 +358,21 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTens | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp) { | |||
TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { | |||
graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp); | |||
if (graph_status != GRAPH_SUCCESS) { | |||
return GRAPH_FAILED; | |||
} | |||
// 64-byte alignment, if size is 0, align to 32 bytes | |||
if (size_temp > (UINT32_MAX - kNum2 * kDataMemAlignSize)) { | |||
GELOGW("The updated mem size %u is bigger than UINT32_MAX", size_temp); | |||
if (size_temp > (INT64_MAX - kNum2 * kDataMemAlignSize)) { | |||
GELOGW("The updated mem size %ld is bigger than INT64_MAX", size_temp); | |||
} else { | |||
size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize; | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp) { | |||
TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { | |||
GeShape output_shape = desc_temp.GetShape(); | |||
Format format = desc_temp.GetFormat(); | |||
DataType data_type = desc_temp.GetDataType(); | |||
@@ -376,13 +383,13 @@ TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_ | |||
return GRAPH_FAILED; | |||
} | |||
if ((output_mem_size > UINT32_MAX) || (output_mem_size < 0)) { | |||
GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %u]", | |||
output_mem_size, UINT32_MAX); | |||
if (output_mem_size < 0) { | |||
GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %ld]", | |||
output_mem_size, INT64_MAX); | |||
return GRAPH_FAILED; | |||
} | |||
size_temp = static_cast<uint32_t>(output_mem_size); | |||
size_temp = output_mem_size; | |||
return GRAPH_SUCCESS; | |||
} | |||
} // namespace ge |
@@ -19,43 +19,45 @@ | |||
namespace ge { | |||
static const std::map<Format, std::string> kFormatToStringMap = { | |||
{FORMAT_NCHW, "NCHW"}, | |||
{FORMAT_NHWC, "NHWC"}, | |||
{FORMAT_ND, "ND"}, | |||
{FORMAT_NC1HWC0, "NC1HWC0"}, | |||
{FORMAT_FRACTAL_Z, "FRACTAL_Z"}, | |||
{FORMAT_NC1C0HWPAD, "NC1C0HWPAD"}, | |||
{FORMAT_NHWC1C0, "NHWC1C0"}, | |||
{FORMAT_FSR_NCHW, "FSR_NCHW"}, | |||
{FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"}, | |||
{FORMAT_C1HWNC0, "C1HWNC0"}, | |||
{FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, | |||
{FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, | |||
{FORMAT_NC1HWC0_C04, "NC1HWC0_C04"}, | |||
{FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, | |||
{FORMAT_CHWN, "CHWN"}, | |||
{FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, | |||
{FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"}, | |||
{FORMAT_BN_WEIGHT, "BN_WEIGHT"}, | |||
{FORMAT_FILTER_HWCK, "FILTER_HWCK"}, | |||
{FORMAT_HWCN, "HWCN"}, | |||
{FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, | |||
{FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, | |||
{FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, | |||
{FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, | |||
{FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, | |||
{FORMAT_MD, "MD"}, | |||
{FORMAT_NDHWC, "NDHWC"}, | |||
{FORMAT_NCDHW, "NCDHW"}, | |||
{FORMAT_DHWCK, "DHWCK"}, | |||
{FORMAT_NDC1HWC0, "NDC1HWC0"}, | |||
{FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, | |||
{FORMAT_C1HWNCoC0, "C1HWNCoC0"}, | |||
{FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, | |||
{FORMAT_CN, "CN"}, | |||
{FORMAT_NC, "NC"}, | |||
{FORMAT_RESERVED, "FORMAT_RESERVED"}, | |||
{FORMAT_ALL, "ALL"}}; | |||
{FORMAT_NCHW, "NCHW"}, | |||
{FORMAT_NHWC, "NHWC"}, | |||
{FORMAT_ND, "ND"}, | |||
{FORMAT_NC1HWC0, "NC1HWC0"}, | |||
{FORMAT_FRACTAL_Z, "FRACTAL_Z"}, | |||
{FORMAT_NC1C0HWPAD, "NC1C0HWPAD"}, | |||
{FORMAT_NHWC1C0, "NHWC1C0"}, | |||
{FORMAT_FSR_NCHW, "FSR_NCHW"}, | |||
{FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"}, | |||
{FORMAT_C1HWNC0, "C1HWNC0"}, | |||
{FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, | |||
{FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, | |||
{FORMAT_NC1HWC0_C04, "NC1HWC0_C04"}, | |||
{FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, | |||
{FORMAT_CHWN, "CHWN"}, | |||
{FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, | |||
{FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"}, | |||
{FORMAT_BN_WEIGHT, "BN_WEIGHT"}, | |||
{FORMAT_FILTER_HWCK, "FILTER_HWCK"}, | |||
{FORMAT_HWCN, "HWCN"}, | |||
{FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, | |||
{FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, | |||
{FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, | |||
{FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, | |||
{FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, | |||
{FORMAT_MD, "MD"}, | |||
{FORMAT_NDHWC, "NDHWC"}, | |||
{FORMAT_NCDHW, "NCDHW"}, | |||
{FORMAT_DHWCN, "DHWCN"}, | |||
{FORMAT_DHWNC, "DHWNC"}, | |||
{FORMAT_NDC1HWC0, "NDC1HWC0"}, | |||
{FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, | |||
{FORMAT_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"}, | |||
{FORMAT_C1HWNCoC0, "C1HWNCoC0"}, | |||
{FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, | |||
{FORMAT_CN, "CN"}, | |||
{FORMAT_NC, "NC"}, | |||
{FORMAT_RESERVED, "FORMAT_RESERVED"}, | |||
{FORMAT_ALL, "ALL"}}; | |||
static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | |||
"FRACTAL_Z", | |||
@@ -73,137 +75,140 @@ static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | |||
"FRACTAL_ZZ", | |||
"FRACTAL_NZ", | |||
"NDC1HWC0", | |||
"FORMAT_FRACTAL_Z_3D"}; | |||
"FORMAT_FRACTAL_Z_3D", | |||
"FORMAT_FRACTAL_Z_3D_TRANSPOSE"}; | |||
static const std::map<std::string, Format> kDataFormatMap = { | |||
{"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"ND", FORMAT_ND}}; | |||
{"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}}; | |||
static const std::map<std::string, Format> kStringToFormatMap = { | |||
{"NCHW", FORMAT_NCHW}, | |||
{"NHWC", FORMAT_NHWC}, | |||
{"ND", FORMAT_ND}, | |||
{"NC1HWC0", FORMAT_NC1HWC0}, | |||
{"FRACTAL_Z", FORMAT_FRACTAL_Z}, | |||
{"NC1C0HWPAD", FORMAT_NC1C0HWPAD}, | |||
{"NHWC1C0", FORMAT_NHWC1C0}, | |||
{"FSR_NCHW", FORMAT_FSR_NCHW}, | |||
{"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV}, | |||
{"C1HWNC0", FORMAT_C1HWNC0}, | |||
{"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE}, | |||
{"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, | |||
{"NC1HWC0_C04", FORMAT_NC1HWC0_C04}, | |||
{"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04}, | |||
{"CHWN", FORMAT_CHWN}, | |||
{"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, | |||
{"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0}, | |||
{"BN_WEIGHT", FORMAT_BN_WEIGHT}, | |||
{"FILTER_HWCK", FORMAT_FILTER_HWCK}, | |||
{"HWCN", FORMAT_HWCN}, | |||
{"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, | |||
{"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS}, | |||
{"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE}, | |||
{"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, | |||
{"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS}, | |||
{"MD", FORMAT_MD}, | |||
{"C1HWNCoC0", FORMAT_C1HWNCoC0}, | |||
{"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, | |||
{"NDHWC", FORMAT_NDHWC}, | |||
{"NCDHW", FORMAT_NCDHW}, | |||
{"DHWCK", FORMAT_DHWCK}, | |||
{"NDC1HWC0", FORMAT_NDC1HWC0}, | |||
{"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D}, | |||
{"CN", FORMAT_CN}, | |||
{"NC", FORMAT_NC}, | |||
{"FORMAT_RESERVED", FORMAT_RESERVED}, | |||
{"ALL", FORMAT_ALL}}; | |||
{"NCHW", FORMAT_NCHW}, | |||
{"NHWC", FORMAT_NHWC}, | |||
{"ND", FORMAT_ND}, | |||
{"NC1HWC0", FORMAT_NC1HWC0}, | |||
{"FRACTAL_Z", FORMAT_FRACTAL_Z}, | |||
{"NC1C0HWPAD", FORMAT_NC1C0HWPAD}, | |||
{"NHWC1C0", FORMAT_NHWC1C0}, | |||
{"FSR_NCHW", FORMAT_FSR_NCHW}, | |||
{"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV}, | |||
{"C1HWNC0", FORMAT_C1HWNC0}, | |||
{"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE}, | |||
{"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, | |||
{"NC1HWC0_C04", FORMAT_NC1HWC0_C04}, | |||
{"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04}, | |||
{"CHWN", FORMAT_CHWN}, | |||
{"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, | |||
{"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0}, | |||
{"BN_WEIGHT", FORMAT_BN_WEIGHT}, | |||
{"FILTER_HWCK", FORMAT_FILTER_HWCK}, | |||
{"HWCN", FORMAT_HWCN}, | |||
{"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, | |||
{"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS}, | |||
{"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE}, | |||
{"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, | |||
{"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS}, | |||
{"MD", FORMAT_MD}, | |||
{"C1HWNCoC0", FORMAT_C1HWNCoC0}, | |||
{"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, | |||
{"NDHWC", FORMAT_NDHWC}, | |||
{"NCDHW", FORMAT_NCDHW}, | |||
{"DHWCN", FORMAT_DHWCN}, | |||
{"DHWNC", FORMAT_DHWNC}, | |||
{"NDC1HWC0", FORMAT_NDC1HWC0}, | |||
{"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D}, | |||
{"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, | |||
{"CN", FORMAT_CN}, | |||
{"NC", FORMAT_NC}, | |||
{"FORMAT_RESERVED", FORMAT_RESERVED}, | |||
{"ALL", FORMAT_ALL}}; | |||
static const std::map<DataType, std::string> kDataTypeToStringMap = { | |||
{DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. | |||
{DT_FLOAT, "DT_FLOAT"}, // float type | |||
{DT_FLOAT16, "DT_FLOAT16"}, // fp16 type | |||
{DT_INT8, "DT_INT8"}, // int8 type | |||
{DT_INT16, "DT_INT16"}, // int16 type | |||
{DT_UINT16, "DT_UINT16"}, // uint16 type | |||
{DT_UINT8, "DT_UINT8"}, // uint8 type | |||
{DT_INT32, "DT_INT32"}, // uint32 type | |||
{DT_INT64, "DT_INT64"}, // int64 type | |||
{DT_UINT32, "DT_UINT32"}, // unsigned int32 | |||
{DT_UINT64, "DT_UINT64"}, // unsigned int64 | |||
{DT_BOOL, "DT_BOOL"}, // bool type | |||
{DT_DOUBLE, "DT_DOUBLE"}, // double type | |||
{DT_DUAL, "DT_DUAL"}, // dual output type | |||
{DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type | |||
{DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type | |||
{DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type | |||
{DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type | |||
{DT_QINT8, "DT_QINT8"}, // qint8 type | |||
{DT_QINT16, "DT_QINT16"}, // qint16 type | |||
{DT_QINT32, "DT_QINT32"}, // qint32 type | |||
{DT_QUINT8, "DT_QUINT8"}, // quint8 type | |||
{DT_QUINT16, "DT_QUINT16"}, // quint16 type | |||
{DT_RESOURCE, "DT_RESOURCE"}, // resource type | |||
{DT_STRING_REF, "DT_STRING_REF"}, // string ref type | |||
{DT_STRING, "DT_STRING"}, // string type | |||
{DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. | |||
{DT_FLOAT, "DT_FLOAT"}, // float type | |||
{DT_FLOAT16, "DT_FLOAT16"}, // fp16 type | |||
{DT_INT8, "DT_INT8"}, // int8 type | |||
{DT_INT16, "DT_INT16"}, // int16 type | |||
{DT_UINT16, "DT_UINT16"}, // uint16 type | |||
{DT_UINT8, "DT_UINT8"}, // uint8 type | |||
{DT_INT32, "DT_INT32"}, // uint32 type | |||
{DT_INT64, "DT_INT64"}, // int64 type | |||
{DT_UINT32, "DT_UINT32"}, // unsigned int32 | |||
{DT_UINT64, "DT_UINT64"}, // unsigned int64 | |||
{DT_BOOL, "DT_BOOL"}, // bool type | |||
{DT_DOUBLE, "DT_DOUBLE"}, // double type | |||
{DT_DUAL, "DT_DUAL"}, // dual output type | |||
{DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type | |||
{DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type | |||
{DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type | |||
{DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type | |||
{DT_QINT8, "DT_QINT8"}, // qint8 type | |||
{DT_QINT16, "DT_QINT16"}, // qint16 type | |||
{DT_QINT32, "DT_QINT32"}, // qint32 type | |||
{DT_QUINT8, "DT_QUINT8"}, // quint8 type | |||
{DT_QUINT16, "DT_QUINT16"}, // quint16 type | |||
{DT_RESOURCE, "DT_RESOURCE"}, // resource type | |||
{DT_STRING_REF, "DT_STRING_REF"}, // string ref type | |||
{DT_STRING, "DT_STRING"}, // string type | |||
}; | |||
static const std::map<std::string, DataType> kStringTodataTypeMap = { | |||
{"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set. | |||
{"DT_FLOAT", DT_FLOAT}, // float type | |||
{ | |||
"DT_FLOAT16", | |||
DT_FLOAT16, | |||
}, // fp16 type | |||
{"DT_INT8", DT_INT8}, // int8 type | |||
{"DT_INT16", DT_INT16}, // int16 type | |||
{"DT_UINT16", DT_UINT16}, // uint16 type | |||
{"DT_UINT8", DT_UINT8}, // uint8 type | |||
{"DT_INT32", DT_INT32}, // uint32 type | |||
{"DT_INT64", DT_INT64}, // int64 type | |||
{"DT_UINT32", DT_UINT32}, // unsigned int32 | |||
{"DT_UINT64", DT_UINT64}, // unsigned int64 | |||
{"DT_BOOL", DT_BOOL}, // bool type | |||
{"DT_DOUBLE", DT_DOUBLE}, // double type | |||
{"DT_DUAL", DT_DUAL}, // dual output type | |||
{"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type | |||
{"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type | |||
{"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type | |||
{"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type | |||
{"DT_QINT8", DT_QINT8}, // qint8 type | |||
{"DT_QINT16", DT_QINT16}, // qint16 type | |||
{"DT_QINT32", DT_QINT32}, // qint32 type | |||
{"DT_QUINT8", DT_QUINT8}, // quint8 type | |||
{"DT_QUINT16", DT_QUINT16}, // quint16 type | |||
{"DT_RESOURCE", DT_RESOURCE}, // resource type | |||
{"DT_STRING_REF", DT_STRING_REF}, // string ref type | |||
{"DT_STRING", DT_STRING}, // string type | |||
{"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set. | |||
{"DT_FLOAT", DT_FLOAT}, // float type | |||
{ | |||
"DT_FLOAT16", | |||
DT_FLOAT16, | |||
}, // fp16 type | |||
{"DT_INT8", DT_INT8}, // int8 type | |||
{"DT_INT16", DT_INT16}, // int16 type | |||
{"DT_UINT16", DT_UINT16}, // uint16 type | |||
{"DT_UINT8", DT_UINT8}, // uint8 type | |||
{"DT_INT32", DT_INT32}, // uint32 type | |||
{"DT_INT64", DT_INT64}, // int64 type | |||
{"DT_UINT32", DT_UINT32}, // unsigned int32 | |||
{"DT_UINT64", DT_UINT64}, // unsigned int64 | |||
{"DT_BOOL", DT_BOOL}, // bool type | |||
{"DT_DOUBLE", DT_DOUBLE}, // double type | |||
{"DT_DUAL", DT_DUAL}, // dual output type | |||
{"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type | |||
{"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type | |||
{"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type | |||
{"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type | |||
{"DT_QINT8", DT_QINT8}, // qint8 type | |||
{"DT_QINT16", DT_QINT16}, // qint16 type | |||
{"DT_QINT32", DT_QINT32}, // qint32 type | |||
{"DT_QUINT8", DT_QUINT8}, // quint8 type | |||
{"DT_QUINT16", DT_QUINT16}, // quint16 type | |||
{"DT_RESOURCE", DT_RESOURCE}, // resource type | |||
{"DT_STRING_REF", DT_STRING_REF}, // string ref type | |||
{"DT_STRING", DT_STRING}, // string type | |||
}; | |||
static const std::map<ge::DataType, uint32_t> kDataTypeToLength = { | |||
{DT_BOOL, sizeof(bool)}, | |||
{DT_INT64, sizeof(int64_t)}, | |||
{DT_UINT64, sizeof(int64_t)}, | |||
{DT_FLOAT, sizeof(float)}, | |||
{DT_INT32, sizeof(int32_t)}, | |||
{DT_UINT32, sizeof(int32_t)}, | |||
{DT_INT8, sizeof(char)}, | |||
{DT_UINT8, sizeof(char)}, | |||
{DT_INT16, sizeof(int16_t)}, | |||
{DT_UINT16, sizeof(int16_t)}, | |||
{DT_FLOAT16, sizeof(int16_t)}, | |||
{DT_DOUBLE, sizeof(double)}, | |||
{DT_DUAL, sizeof(float) + sizeof(int8_t)}, | |||
{DT_DUAL_SUB_INT8, sizeof(int8_t)}, | |||
{DT_DUAL_SUB_UINT8, sizeof(uint8_t)}, | |||
{DT_COMPLEX64, sizeof(int64_t)}, | |||
{DT_COMPLEX128, sizeof(int64_t) * 2}, | |||
{DT_QINT8, sizeof(int8_t)}, | |||
{DT_QINT16, sizeof(int16_t)}, | |||
{DT_QINT32, sizeof(int32_t)}, | |||
{DT_QUINT8, sizeof(uint8_t)}, | |||
{DT_QUINT16, sizeof(uint16_t)}, | |||
{DT_STRING_REF, sizeof(uint64_t) * 2}, | |||
{DT_STRING, sizeof(uint64_t)}, | |||
{DT_RESOURCE, sizeof(uint64_t)}, | |||
{DT_BOOL, sizeof(bool)}, | |||
{DT_INT64, sizeof(int64_t)}, | |||
{DT_UINT64, sizeof(int64_t)}, | |||
{DT_FLOAT, sizeof(float)}, | |||
{DT_INT32, sizeof(int32_t)}, | |||
{DT_UINT32, sizeof(int32_t)}, | |||
{DT_INT8, sizeof(char)}, | |||
{DT_UINT8, sizeof(char)}, | |||
{DT_INT16, sizeof(int16_t)}, | |||
{DT_UINT16, sizeof(int16_t)}, | |||
{DT_FLOAT16, sizeof(int16_t)}, | |||
{DT_DOUBLE, sizeof(double)}, | |||
{DT_DUAL, sizeof(float) + sizeof(int8_t)}, | |||
{DT_DUAL_SUB_INT8, sizeof(int8_t)}, | |||
{DT_DUAL_SUB_UINT8, sizeof(uint8_t)}, | |||
{DT_COMPLEX64, sizeof(int64_t)}, | |||
{DT_COMPLEX128, sizeof(int64_t) * 2}, | |||
{DT_QINT8, sizeof(int8_t)}, | |||
{DT_QINT16, sizeof(int16_t)}, | |||
{DT_QINT32, sizeof(int32_t)}, | |||
{DT_QUINT8, sizeof(uint8_t)}, | |||
{DT_QUINT16, sizeof(uint16_t)}, | |||
{DT_STRING_REF, sizeof(uint64_t) * 2}, | |||
{DT_STRING, sizeof(uint64_t)}, | |||
{DT_RESOURCE, sizeof(uint64_t)}, | |||
}; | |||
bool TypeUtils::IsDataTypeValid(DataType dt) { | |||
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
# ============================================================================ | |||
# libge.so & libge_train.so | |||
# libge_compiler.so & libge_train.so | |||
# will later be integrated into libgraph_runner.so, works for both training and inference | |||
# compiling proto files generates some warnings, use no-unused-variable to suppress them | |||
set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | |||
@@ -49,7 +49,7 @@ include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||
######### libge_train.so ############# | |||
# need to remove dependencies on pb files later | |||
file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"common/formats/format_transfers/*.cc" | |||
"common/formats/formats.cc" | |||
"common/formats/utils/formats_trans_utils.cc" | |||
@@ -57,20 +57,24 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"common/ge/plugin_manager.cc" | |||
"common/profiling/profiling_manager.cc" | |||
"engine_manager/dnnengine_manager.cc" | |||
"ge_local_engine/engine/host_cpu_engine.cc" | |||
"generator/ge_generator.cc" | |||
"generator/generator_api.cc" | |||
"graph/build/graph_build.cc" | |||
"graph/build/graph_builder.cc" | |||
"graph/build/label_allocator.cc" | |||
"graph/build/logical_stream_allocator.cc" | |||
"graph/build/model_builder.cc" | |||
"graph/build/optimize_stream_graph.cc" | |||
"graph/build/run_context.cc" | |||
"graph/build/stream_allocator.cc" | |||
"graph/build/stream_graph_optimizer.cc" | |||
"graph/build/task_generator.cc" | |||
"graph/common/bcast.cc" | |||
"graph/common/omg_util.cc" | |||
"graph/common/transop_util.cc" | |||
"graph/execute/graph_execute.cc" | |||
"graph/label/*.cc" | |||
"graph/load/graph_loader.cc" | |||
"graph/load/new_model_manager/cpu_queue_schedule.cc" | |||
"graph/load/new_model_manager/data_dumper.cc" | |||
"graph/load/new_model_manager/data_inputer.cc" | |||
"graph/load/new_model_manager/davinci_model.cc" | |||
@@ -92,10 +96,12 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | |||
"graph/load/new_model_manager/task_info/stream_active_task_info.cc" | |||
"graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | |||
"graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" | |||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
"graph/load/new_model_manager/task_info/task_info.cc" | |||
"graph/load/new_model_manager/tbe_handle_store.cc" | |||
"graph/load/output/output.cc" | |||
"graph/manager/custom/custom_op.cc" | |||
"graph/manager/graph_context.cc" | |||
"graph/manager/graph_manager.cc" | |||
"graph/manager/graph_manager_utils.cc" | |||
@@ -105,12 +111,9 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/manager/trans_var_data_utils.cc" | |||
"graph/manager/util/debug.cc" | |||
"graph/manager/util/hcom_util.cc" | |||
"graph/manager/util/node_searcher/need_rebuild_node_searcher.cc" | |||
"graph/manager/util/rt_context_util.cc" | |||
"graph/manager/util/variable_accelerate_ctrl.cc" | |||
"graph/optimize/graph_functiondef.cc" | |||
"graph/optimize/graph_optimize.cc" | |||
"graph/optimize/graph_optimizer.cc" | |||
"graph/optimize/optimizer/allreduce_fusion_pass.cc" | |||
"graph/optimize/summary_optimize.cc" | |||
"graph/partition/engine_place.cc" | |||
@@ -120,7 +123,9 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/passes/assert_pass.cc" | |||
"graph/passes/atomic_addr_clean_pass.cc" | |||
"graph/passes/base_pass.cc" | |||
"graph/passes/cast_remove_pass.cc" | |||
"graph/passes/cast_translate_pass.cc" | |||
"graph/passes/common_subexpression_elimination_pass.cc" | |||
"graph/passes/compile_nodes_pass.cc" | |||
"graph/passes/constant_folding_pass.cc" | |||
"graph/passes/constant_fuse_same_pass.cc" | |||
@@ -159,12 +164,14 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/passes/folding_kernel/shape_kernel.cc" | |||
"graph/passes/folding_kernel/shape_n_kernel.cc" | |||
"graph/passes/folding_kernel/size_kernel.cc" | |||
"graph/passes/folding_kernel/slice_d_kernel.cc" | |||
"graph/passes/folding_kernel/slice_kernel.cc" | |||
"graph/passes/folding_kernel/squeeze_kernel.cc" | |||
"graph/passes/folding_kernel/ssd_prior_box_kernel.cc" | |||
"graph/passes/folding_kernel/strided_slice_kernel.cc" | |||
"graph/passes/folding_kernel/sub_kernel.cc" | |||
"graph/passes/folding_kernel/transdata_kernel.cc" | |||
"graph/passes/folding_kernel/unpack_kernel.cc" | |||
"graph/passes/folding_pass.cc" | |||
"graph/passes/get_original_format_pass.cc" | |||
"graph/passes/guarantee_const_pass.cc" | |||
@@ -179,7 +186,6 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/passes/multi_batch_pass.cc" | |||
"graph/passes/net_output_pass.cc" | |||
"graph/passes/next_iteration_pass.cc" | |||
"graph/passes/no_reshape_op_remove_pass.cc" | |||
"graph/passes/no_use_reshape_remove_pass.cc" | |||
"graph/passes/pass_manager.cc" | |||
"graph/passes/pass_utils.cc" | |||
@@ -188,6 +194,7 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/passes/prevent_gradient_pass.cc" | |||
"graph/passes/print_op_pass.cc" | |||
"graph/passes/prune_pass.cc" | |||
"graph/passes/replace_with_empty_const_pass.cc" | |||
"graph/passes/reshape_remove_pass.cc" | |||
"graph/passes/resource_pair_add_control_pass.cc" | |||
"graph/passes/resource_pair_remove_control_pass.cc" | |||
@@ -206,14 +213,12 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/passes/transpose_transdata_pass.cc" | |||
"graph/passes/unused_const_pass.cc" | |||
"graph/passes/unused_op_remove_pass.cc" | |||
"graph/passes/update_net_output_pass.cc" | |||
"graph/passes/var_is_initialized_op_pass.cc" | |||
"graph/passes/variable_format_pass.cc" | |||
"graph/passes/variable_op_pass.cc" | |||
"graph/passes/variable_prepare_op_pass.cc" | |||
"graph/passes/variable_ref_delete_op_pass.cc" | |||
"graph/preprocess/graph_preprocess.cc" | |||
"graph/preprocess/insert_op/base_insert_op.cc" | |||
"graph/preprocess/insert_op/ge_aipp_op.cc" | |||
"graph/preprocess/insert_op/util_insert_aipp_op.cc" | |||
"graph/preprocess/multi_batch_copy_graph.cc" | |||
@@ -223,13 +228,8 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"opskernel_manager/ops_kernel_manager.cc" | |||
"session/inner_session.cc" | |||
"session/session_manager.cc" | |||
"single_op/single_op.cc" | |||
"single_op/single_op_manager.cc" | |||
"single_op/single_op_model.cc" | |||
"single_op/stream_resource.cc" | |||
"single_op/task/build_task_utils.cc" | |||
"single_op/task/op_task.cc" | |||
"single_op/task/tbe_task_builder.cc" | |||
"single_op/*.cc" | |||
"single_op/task/*.cc" | |||
) | |||
@@ -261,9 +261,9 @@ target_link_libraries(ge_train | |||
rt | |||
dl) | |||
######### libge.so ############# | |||
######### libge_compiler.so ############# | |||
# need to remove dependencies on pb files later | |||
file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"common/formats/format_transfers/*.cc" | |||
"common/formats/formats.cc" | |||
"common/formats/utils/formats_trans_utils.cc" | |||
@@ -271,20 +271,24 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"common/ge/plugin_manager.cc" | |||
"common/profiling/profiling_manager.cc" | |||
"engine_manager/dnnengine_manager.cc" | |||
"ge_local_engine/engine/host_cpu_engine.cc" | |||
"generator/ge_generator.cc" | |||
"generator/generator_api.cc" | |||
"graph/build/graph_build.cc" | |||
"graph/build/graph_builder.cc" | |||
"graph/build/label_allocator.cc" | |||
"graph/build/logical_stream_allocator.cc" | |||
"graph/build/model_builder.cc" | |||
"graph/build/optimize_stream_graph.cc" | |||
"graph/build/run_context.cc" | |||
"graph/build/stream_allocator.cc" | |||
"graph/build/stream_graph_optimizer.cc" | |||
"graph/build/task_generator.cc" | |||
"graph/common/bcast.cc" | |||
"graph/common/omg_util.cc" | |||
"graph/common/transop_util.cc" | |||
"graph/execute/graph_execute.cc" | |||
"graph/label/*.cc" | |||
"graph/load/graph_loader.cc" | |||
"graph/load/new_model_manager/cpu_queue_schedule.cc" | |||
"graph/load/new_model_manager/data_dumper.cc" | |||
"graph/load/new_model_manager/data_inputer.cc" | |||
"graph/load/new_model_manager/davinci_model.cc" | |||
@@ -305,10 +309,12 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | |||
"graph/load/new_model_manager/task_info/stream_active_task_info.cc" | |||
"graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | |||
"graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" | |||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
"graph/load/new_model_manager/task_info/task_info.cc" | |||
"graph/load/new_model_manager/tbe_handle_store.cc" | |||
"graph/load/output/output.cc" | |||
"graph/manager/custom/custom_op.cc" | |||
"graph/manager/graph_context.cc" | |||
"graph/manager/graph_manager.cc" | |||
"graph/manager/graph_manager_utils.cc" | |||
@@ -317,13 +323,9 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/manager/model_manager/event_manager.cc" | |||
"graph/manager/trans_var_data_utils.cc" | |||
"graph/manager/util/debug.cc" | |||
"graph/manager/util/node_searcher/need_rebuild_node_searcher.cc" | |||
"graph/manager/util/rt_context_util.cc" | |||
"graph/manager/util/variable_accelerate_ctrl.cc" | |||
"graph/optimize/graph_functiondef.cc" | |||
"graph/optimize/graph_optimize.cc" | |||
"graph/optimize/graph_optimizer.cc" | |||
"graph/optimize/optimizer/allreduce_fusion_inference_pass.cc" | |||
"graph/optimize/summary_optimize.cc" | |||
"graph/partition/engine_place.cc" | |||
"graph/partition/graph_partition.cc" | |||
@@ -332,7 +334,9 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/passes/assert_pass.cc" | |||
"graph/passes/atomic_addr_clean_pass.cc" | |||
"graph/passes/base_pass.cc" | |||
"graph/passes/cast_remove_pass.cc" | |||
"graph/passes/cast_translate_pass.cc" | |||
"graph/passes/common_subexpression_elimination_pass.cc" | |||
"graph/passes/compile_nodes_pass.cc" | |||
"graph/passes/constant_folding_pass.cc" | |||
"graph/passes/constant_fuse_same_pass.cc" | |||
@@ -371,12 +375,14 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/passes/folding_kernel/shape_kernel.cc" | |||
"graph/passes/folding_kernel/shape_n_kernel.cc" | |||
"graph/passes/folding_kernel/size_kernel.cc" | |||
"graph/passes/folding_kernel/slice_d_kernel.cc" | |||
"graph/passes/folding_kernel/slice_kernel.cc" | |||
"graph/passes/folding_kernel/squeeze_kernel.cc" | |||
"graph/passes/folding_kernel/ssd_prior_box_kernel.cc" | |||
"graph/passes/folding_kernel/strided_slice_kernel.cc" | |||
"graph/passes/folding_kernel/sub_kernel.cc" | |||
"graph/passes/folding_kernel/transdata_kernel.cc" | |||
"graph/passes/folding_kernel/unpack_kernel.cc" | |||
"graph/passes/folding_pass.cc" | |||
"graph/passes/get_original_format_pass.cc" | |||
"graph/passes/guarantee_const_pass.cc" | |||
@@ -391,7 +397,6 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/passes/multi_batch_pass.cc" | |||
"graph/passes/net_output_pass.cc" | |||
"graph/passes/next_iteration_pass.cc" | |||
"graph/passes/no_reshape_op_remove_pass.cc" | |||
"graph/passes/no_use_reshape_remove_pass.cc" | |||
"graph/passes/pass_manager.cc" | |||
"graph/passes/pass_utils.cc" | |||
@@ -400,6 +405,7 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/passes/prevent_gradient_pass.cc" | |||
"graph/passes/print_op_pass.cc" | |||
"graph/passes/prune_pass.cc" | |||
"graph/passes/replace_with_empty_const_pass.cc" | |||
"graph/passes/reshape_remove_pass.cc" | |||
"graph/passes/resource_pair_add_control_pass.cc" | |||
"graph/passes/resource_pair_remove_control_pass.cc" | |||
@@ -418,14 +424,12 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"graph/passes/transpose_transdata_pass.cc" | |||
"graph/passes/unused_const_pass.cc" | |||
"graph/passes/unused_op_remove_pass.cc" | |||
"graph/passes/update_net_output_pass.cc" | |||
"graph/passes/var_is_initialized_op_pass.cc" | |||
"graph/passes/variable_format_pass.cc" | |||
"graph/passes/variable_op_pass.cc" | |||
"graph/passes/variable_prepare_op_pass.cc" | |||
"graph/passes/variable_ref_delete_op_pass.cc" | |||
"graph/preprocess/graph_preprocess.cc" | |||
"graph/preprocess/insert_op/base_insert_op.cc" | |||
"graph/preprocess/insert_op/ge_aipp_op.cc" | |||
"graph/preprocess/insert_op/util_insert_aipp_op.cc" | |||
"graph/preprocess/multi_batch_copy_graph.cc" | |||
@@ -442,16 +446,19 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"single_op/task/build_task_utils.cc" | |||
"single_op/task/op_task.cc" | |||
"single_op/task/tbe_task_builder.cc" | |||
########################################## | |||
# "ir_build/ge_ir_build.cc" | |||
# "offline/atc_ir_common.cc" | |||
) | |||
add_library(ge SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||
target_compile_definitions(ge PRIVATE | |||
add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||
target_compile_definitions(ge_compiler PRIVATE | |||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
DAVINCI_SUPPORT_PROFILING | |||
REUSE_MEMORY=1 | |||
FMK_HOST_INFER | |||
PLATFORM_CLOUD) | |||
target_link_libraries(ge | |||
target_link_libraries(ge_compiler | |||
graph | |||
ge_common | |||
"-Wl,--whole-archive" | |||
@@ -80,7 +80,7 @@ target_compile_definitions(ge_client_train PRIVATE | |||
PLATFORM_CLOUD) | |||
target_link_libraries(ge_client | |||
graph | |||
ge | |||
ge_compiler | |||
ge_common | |||
${PROTOBUF_LIBRARY} | |||
${register} | |||
@@ -61,14 +61,14 @@ Status CheckDumpAndReuseMemory(const std::map<string, string> &options) { | |||
const int kDecimal = 10; | |||
auto dump_op_env = std::getenv("DUMP_OP"); | |||
int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0; | |||
auto disable_reuse_memory_iter = options.find("ge.exec.disableReuseMemory"); | |||
if (disable_reuse_memory_iter != options.end()) { | |||
if (disable_reuse_memory_iter->second == "0") { | |||
auto disableReuseMemoryIter = options.find("ge.exec.disableReuseMemory"); | |||
if (disableReuseMemoryIter != options.end()) { | |||
if (disableReuseMemoryIter->second == "0") { | |||
GELOGD("ge.exec.disableReuseMemory=0, reuse memory is open"); | |||
if (dump_op_flag) { | |||
GELOGW("Will dump incorrect op data with GE Option ge.exec.disableReuseMemory=0"); | |||
} | |||
} else if (disable_reuse_memory_iter->second == "1") { | |||
} else if (disableReuseMemoryIter->second == "1") { | |||
GELOGD("ge.exec.disableReuseMemory=1, reuse memory is close"); | |||
} else { | |||
GELOGE(PARAM_INVALID, "CheckDumpAndReuseMemory ge.exec.disableReuseMemory is valid"); | |||
@@ -128,22 +128,29 @@ Status GEInitialize(const std::map<string, string> &options) { | |||
OpsProtoManager *manager = OpsProtoManager::Instance(); | |||
std::map<string, string> option_tmp; | |||
option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | |||
GE_TIMESTAMP_START(GEInitialize); | |||
bool is_proto_init = manager->Initialize(option_tmp); | |||
GE_TIMESTAMP_END(GEInitialize, "GEInitialize::ManagerInitialize"); | |||
if (!is_proto_init) { | |||
GELOGE(GE_CLI_INIT_FAILED, "geInitialize failed, ops proto path is invalid."); | |||
return FAILED; | |||
} | |||
// check options is valid | |||
GE_TIMESTAMP_START(CheckOptionsValid); | |||
if (CheckOptionsValid(options) != SUCCESS) { | |||
return FAILED; | |||
} | |||
GE_TIMESTAMP_END(CheckOptionsValid, "GEInitialize::CheckOptionsValid"); | |||
GE_TIMESTAMP_START(InitPreparation); | |||
SaveDdkVersion(options); | |||
GE_TIMESTAMP_END(InitPreparation, "GEInitialize::InitPreparation"); | |||
// call Initialize | |||
GELOGT(TRACE_RUNNING, "Initializing environment"); | |||
GE_TIMESTAMP_START(GELibInitialize); | |||
Status ret = ge::GELib::Initialize(options); | |||
GE_TIMESTAMP_END(GELibInitialize, "GEInitialize::GELibInitialize"); | |||
if (ret != SUCCESS) { | |||
GELOGE(GE_CLI_INIT_FAILED, "geInitialize failed, error code = %u", ret); | |||
return FAILED; | |||
@@ -170,17 +177,20 @@ Status GEFinalize() { | |||
std::lock_guard<std::mutex> lock(kGeReleaseMutex); | |||
// call Finalize | |||
Status ret = SUCCESS; | |||
Status middle_ret; | |||
GELOGT(TRACE_RUNNING, "Finalizing environment"); | |||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GEFinalize Failed: GE not initialized"); | |||
return GE_CLI_GE_NOT_INITIALIZED; | |||
} | |||
Status ret = instance_ptr->Finalize(); | |||
GELOGI("GEFinalize finalize gelib ret=%u", ret); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "GEFinalize Failed"); | |||
return FAILED; | |||
std::shared_ptr<GELib> instancePtr = ge::GELib::GetInstance(); | |||
if (instancePtr == nullptr || !instancePtr->InitFlag()) { | |||
GELOGW("GEFinalize Failed: GE not initialized."); | |||
ret = GE_CLI_GE_NOT_INITIALIZED; | |||
} | |||
if (ret != GE_CLI_GE_NOT_INITIALIZED) { | |||
middle_ret = instancePtr->Finalize(); | |||
GELOGI("GEFinalize finalize gelib ret=%u", middle_ret); | |||
if (middle_ret != SUCCESS) { | |||
ret = middle_ret; | |||
} | |||
} | |||
if (kGeInitialized && ret == SUCCESS) { | |||
@@ -379,8 +389,6 @@ Status Session::RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, s | |||
} | |||
Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback) { | |||
GELOGW( | |||
"The callback function will not be checked. Please ensure that the implementation of the function is trusted."); | |||
return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | |||
} | |||