@@ -16,6 +16,7 @@ | |||||
cmake_minimum_required(VERSION 3.14) | cmake_minimum_required(VERSION 3.14) | ||||
project (GraphEngine[CXX]) | project (GraphEngine[CXX]) | ||||
set(CMAKE_CXX_STANDARD 14) | set(CMAKE_CXX_STANDARD 14) | ||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) | |||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) | ||||
set(GE_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) | set(GE_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) | ||||
@@ -71,6 +72,7 @@ elseif(DEFINED ENV{D_LINK_PATH}) | |||||
find_library(register libregister.so ${GE_LIB_PATH}) | find_library(register libregister.so ${GE_LIB_PATH}) | ||||
find_library(hccl libhccl.so ${GE_LIB_PATH}) | find_library(hccl libhccl.so ${GE_LIB_PATH}) | ||||
find_library(resource libresource.so ${GE_LIB_PATH}) | find_library(resource libresource.so ${GE_LIB_PATH}) | ||||
find_library(error_manager liberror_manager.so ${GE_LIB_PATH}) | |||||
else() | else() | ||||
# Ascend mode | # Ascend mode | ||||
if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | ||||
@@ -88,6 +90,7 @@ else() | |||||
find_library(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | find_library(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | ||||
find_library(register libregister.so ${ASCEND_RUNTIME_DIR}) | find_library(register libregister.so ${ASCEND_RUNTIME_DIR}) | ||||
find_library(resource libresource.so ${ASCEND_RUNTIME_DIR}) | find_library(resource libresource.so ${ASCEND_RUNTIME_DIR}) | ||||
find_library(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) | |||||
endif() | endif() | ||||
# add compile flags | # add compile flags | ||||
@@ -44,6 +44,9 @@ class GraphOptimizer { | |||||
// optimize original graph, using in graph preparation stage | // optimize original graph, using in graph preparation stage | ||||
virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; | virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; | ||||
// optimize original graph, using for conversion operator insert in graph preparation stage | |||||
virtual Status OptimizeOriginalGraphJudgeInsert(ComputeGraph &graph) { return SUCCESS; } | |||||
// optimize fused graph | // optimize fused graph | ||||
virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; | virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; | ||||
@@ -0,0 +1,36 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef COMPRESS_H | |||||
#define COMPRESS_H | |||||
#include <uchar.h> | |||||
enum CmpStatus { RET_SUCCESS = 0, RET_ERROR = -1 }; | |||||
struct CompressConfig { | |||||
size_t inputSize; // length of data to compress | |||||
size_t engineNum; // how many decompress engines | |||||
size_t maxRatio; // how much size of a basic compression block, only 64 supported now (8x: 64 4x: 32) | |||||
size_t channel; // channels of L2 or DDR. For load balance | |||||
size_t fractalSize; // size of compressing block | |||||
bool isTight; // whether compose compressed data tightly | |||||
}; | |||||
CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, | |||||
size_t& compressedLength); | |||||
#endif // COMPRESS_H |
@@ -0,0 +1,83 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef ERROR_MANAGER_H_ | |||||
#define ERROR_MANAGER_H_ | |||||
#include <map> | |||||
#include <string> | |||||
#include <vector> | |||||
class ErrorManager { | |||||
public: | |||||
/// | |||||
/// @brief Obtain ErrorManager instance | |||||
/// @return ErrorManager instance | |||||
/// | |||||
static ErrorManager &GetInstance(); | |||||
/// | |||||
/// @brief init | |||||
/// @param [in] path current so path | |||||
/// @return int 0(success) -1(fail) | |||||
/// | |||||
int Init(std::string path); | |||||
/// | |||||
/// @brief Report error message | |||||
/// @param [in] errCode error code | |||||
/// @param [in] mapArgs parameter map | |||||
/// @return int 0(success) -1(fail) | |||||
/// | |||||
int ReportErrMessage(std::string error_code, const std::map<std::string, std::string> &args_map); | |||||
/// @brief output error message | |||||
/// @param [in] handle print handle | |||||
/// @return int 0(success) -1(fail) | |||||
/// | |||||
int OutputErrMessage(int handle); | |||||
/// @brief Report error message | |||||
/// @param [in] vector parameter key, vector parameter value | |||||
/// | |||||
void ATCReportErrMessage(std::string error_code, const std::vector<std::string> &key = {}, | |||||
const std::vector<std::string> &value = {}); | |||||
private: | |||||
struct ErrorInfo { | |||||
std::string error_id; | |||||
std::string error_message; | |||||
std::vector<std::string> arglist; | |||||
}; | |||||
ErrorManager() {} | |||||
~ErrorManager() {} | |||||
ErrorManager(const ErrorManager &) = delete; | |||||
ErrorManager(ErrorManager &&) = delete; | |||||
ErrorManager &operator=(const ErrorManager &) = delete; | |||||
ErrorManager &operator=(ErrorManager &&) = delete; | |||||
int ParseJsonFile(std::string path); | |||||
int ReadJsonFile(const std::string &file_path, void *handle); | |||||
bool is_init_ = false; | |||||
std::map<std::string, ErrorInfo> error_map_; | |||||
std::vector<std::string> error_message_evc_; | |||||
}; | |||||
#endif // ERROR_MANAGER_H_ |
@@ -65,6 +65,8 @@ class PlatformInfoManager { | |||||
void ParseUBOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | void ParseUBOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | ||||
void ParseUnzipOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
void ParseAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | void ParseAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | ||||
void ParseBufferOfAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | void ParseBufferOfAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | ||||
@@ -65,6 +65,10 @@ typedef struct tagAiCoreSpec { | |||||
uint64_t ubbankNum; | uint64_t ubbankNum; | ||||
uint64_t ubburstInOneBlock; | uint64_t ubburstInOneBlock; | ||||
uint64_t ubbankGroupNum; | uint64_t ubbankGroupNum; | ||||
uint32_t unzipEngines; | |||||
uint32_t unzipMaxRatios; | |||||
uint32_t unzipChannels; | |||||
uint8_t unzipIsTight; | |||||
} AiCoreSpec; | } AiCoreSpec; | ||||
typedef struct tagAiCoreMemoryRates { | typedef struct tagAiCoreMemoryRates { | ||||
@@ -82,14 +82,12 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||||
/// @brief run graph in the session with specific session id asynchronously | /// @brief run graph in the session with specific session id asynchronously | ||||
/// @param [in] graphId: graph id | /// @param [in] graphId: graph id | ||||
/// @param [in] inputs: input data | /// @param [in] inputs: input data | ||||
/// @param [out] outputs: output data | |||||
/// @param [out] callback: callback while runing graph has been finished. | /// @param [out] callback: callback while runing graph has been finished. | ||||
/// The callback function will not be checked. | /// The callback function will not be checked. | ||||
/// Please ensure that the implementation of the function is trusted. | /// Please ensure that the implementation of the function is trusted. | ||||
/// @return Status result of function | /// @return Status result of function | ||||
/// | /// | ||||
Status RunGraphAsync(uint32_t graphId, const std::vector<ge::TensorInfo> &inputs, | |||||
std::vector<ge::TensorInfo> &outputs, std::function<void(Status)> callback); | |||||
Status RunGraphAsync(uint32_t graphId, const std::vector<ge::InputTensorInfo> &inputs, RunAsyncCallback callback); | |||||
/// | /// | ||||
/// @ingroup ge_graph | /// @ingroup ge_graph | ||||
@@ -21,6 +21,8 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include <set> | #include <set> | ||||
#include <functional> | |||||
#include <memory> | |||||
namespace ge { | namespace ge { | ||||
// Option key: graph run mode | // Option key: graph run mode | ||||
@@ -40,6 +42,12 @@ const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; | |||||
const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | ||||
const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | ||||
const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | ||||
const char *const OPTION_EXEC_DUMP_MODE = "ge.exec.dumpMode"; | |||||
const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; | |||||
const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; | |||||
// profiling flag | |||||
const char *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode"; | |||||
const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; | |||||
// Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | // Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | ||||
const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | ||||
const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | ||||
@@ -173,6 +181,9 @@ const std::string AICORE_NUM = "ge.aicoreNum"; | |||||
// Configure L1FUSION | // Configure L1FUSION | ||||
const std::string L1_FUSION = "ge.l1Fusion"; | const std::string L1_FUSION = "ge.l1Fusion"; | ||||
// Configure l1,l2,and others optimize option | |||||
const std::string BUFFER_OPTIMIZE = "ge.bufferOptimize"; | |||||
// Configure Small Channel flag | // Configure Small Channel flag | ||||
const std::string ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; | const std::string ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; | ||||
@@ -188,6 +199,9 @@ const std::string SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; | |||||
// Save original model file name | // Save original model file name | ||||
const std::string ORIGINAL_MODEL_FILE = "ge.originalModelFile"; | const std::string ORIGINAL_MODEL_FILE = "ge.originalModelFile"; | ||||
// FE enable quant optimize | |||||
const std::string QUANT_OPTIMIZE = "ge.quantOptimize"; | |||||
const char *const OPTION_GE_MAX_DUMP_FILE_NUM = "ge.maxDumpFileNum"; | const char *const OPTION_GE_MAX_DUMP_FILE_NUM = "ge.maxDumpFileNum"; | ||||
const char *const OPTION_GE_MAX_DUMP_FILE_SIZE = "ge.maxDumpFileSize"; | const char *const OPTION_GE_MAX_DUMP_FILE_SIZE = "ge.maxDumpFileSize"; | ||||
const char *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; | const char *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; | ||||
@@ -196,36 +210,49 @@ const char *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; | |||||
// Its value should be "0" or "1", default value is "1" | // Its value should be "0" or "1", default value is "1" | ||||
const char *const ENABLE_PRINT_OP_PASS = "ge.enablePrintOpPass"; | const char *const ENABLE_PRINT_OP_PASS = "ge.enablePrintOpPass"; | ||||
// Configure whether to use single stream. | |||||
// Its value should be "true" or "false", default value is "false" | |||||
const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; | |||||
// Graph run mode | // Graph run mode | ||||
enum GraphRunMode { PREDICTION = 0, TRAIN }; | enum GraphRunMode { PREDICTION = 0, TRAIN }; | ||||
// Data description | |||||
struct DataDesc { | |||||
void *data = nullptr; // data address | |||||
uint32_t length = 0; // data size | |||||
bool isDataSupportMemShare = false; | |||||
// Input/Output tensor info | |||||
struct InputTensorInfo { | |||||
uint32_t data_type; // data type | |||||
std::vector<int64_t> dims; // shape description | |||||
void *data; // tensor data | |||||
int64_t length; // tensor length | |||||
}; | }; | ||||
// Input/Output shape description | |||||
struct ShapeDesc { | |||||
int64_t num = 0; | |||||
int64_t channel = 0; | |||||
int64_t height = 0; | |||||
int64_t width = 0; | |||||
std::vector<int64_t> dims; | |||||
struct OutputTensorInfo { | |||||
uint32_t data_type; // data type | |||||
std::vector<int64_t> dims; // shape description | |||||
std::unique_ptr<uint8_t[]> data; // tensor data | |||||
int64_t length; // tensor length | |||||
OutputTensorInfo() : data_type(0), dims({}), data(nullptr), length(0) {} | |||||
OutputTensorInfo(OutputTensorInfo &&out) | |||||
: data_type(out.data_type), dims(out.dims), data(std::move(out.data)), length(out.length) {} | |||||
OutputTensorInfo &operator=(OutputTensorInfo &&out) { | |||||
if (this != &out) { | |||||
data_type = out.data_type; | |||||
dims = out.dims; | |||||
data = std::move(out.data); | |||||
length = out.length; | |||||
} | |||||
return *this; | |||||
} | |||||
OutputTensorInfo(const OutputTensorInfo &) = delete; | |||||
OutputTensorInfo &operator=(const OutputTensorInfo &) = delete; | |||||
}; | }; | ||||
// Input/Output tensor info | |||||
struct TensorInfo { | |||||
uint32_t dataType; // data type | |||||
DataDesc data; // tensor data | |||||
ShapeDesc shapeInfo; // tensor shape | |||||
}; | |||||
using Status = uint32_t; | |||||
using RunAsyncCallback = std::function<void(Status, std::vector<ge::OutputTensorInfo> &)>; | |||||
// for ir build | // for ir build | ||||
namespace ir_option { | namespace ir_option { | ||||
static const char *const INPUT_FORMAT = "input_format"; | static const char *const INPUT_FORMAT = "input_format"; | ||||
static const char *const INPUT_SHAPE = "input_shape"; | static const char *const INPUT_SHAPE = "input_shape"; | ||||
static const char *const OP_NAME_MAP = "op_name_map"; | |||||
static const char *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize; | static const char *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize; | ||||
static const char *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize; | static const char *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize; | ||||
static const char *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str(); | static const char *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str(); | ||||
@@ -235,13 +262,15 @@ static const char *const HEAD_STREAM = ge::HEAD_STREAM.c_str(); | |||||
static const char *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str(); | static const char *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str(); | ||||
static const char *const CORE_TYPE = ge::CORE_TYPE.c_str(); | static const char *const CORE_TYPE = ge::CORE_TYPE.c_str(); | ||||
static const char *const SOC_VERSION = ge::SOC_VERSION.c_str(); | static const char *const SOC_VERSION = ge::SOC_VERSION.c_str(); | ||||
static const char *const ENABLE_SINGLE_STREAM = ge::ENABLE_SINGLE_STREAM; | |||||
// for interface: aclgrphBuildModel | // for interface: aclgrphBuildModel | ||||
const std::set<std::string> ir_builder_suppported_options = { | |||||
INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, DYNAMIC_BATCH_SIZE, | |||||
DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||||
AUTO_TUNE_MODE}; | |||||
const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT, INPUT_SHAPE, DYNAMIC_BATCH_SIZE, | |||||
DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE}; | |||||
// for interface: aclgrphBuildInitialize | // for interface: aclgrphBuildInitialize | ||||
const std::set<std::string> global_options = {HEAD_STREAM, CORE_TYPE, SOC_VERSION}; | |||||
const std::set<std::string> global_options = { | |||||
HEAD_STREAM, CORE_TYPE, SOC_VERSION, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||||
AUTO_TUNE_MODE, ENABLE_SINGLE_STREAM}; | |||||
} // namespace ir_option | } // namespace ir_option | ||||
} // namespace ge | } // namespace ge | ||||
@@ -55,12 +55,16 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { | |||||
graphStatus FindOpByName(const string &name, ge::Operator &op) const; | graphStatus FindOpByName(const string &name, ge::Operator &op) const; | ||||
graphStatus FindOpByType(const string &type, std::vector<ge::Operator> &ops) const; | |||||
graphStatus GetAllOpName(std::vector<string> &op_name) const; | graphStatus GetAllOpName(std::vector<string> &op_name) const; | ||||
graphStatus SaveToFile(const string &file_name) const; | graphStatus SaveToFile(const string &file_name) const; | ||||
graphStatus LoadFromFile(const string &file_name); | graphStatus LoadFromFile(const string &file_name); | ||||
const std::string &GetName() const; | |||||
/// | /// | ||||
/// Set is need train iteration. | /// Set is need train iteration. | ||||
/// If set true, it means this graph need to be run iteration some | /// If set true, it means this graph need to be run iteration some | ||||
@@ -69,7 +69,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { | |||||
static std::unique_ptr<InferenceContext> Create(); | static std::unique_ptr<InferenceContext> Create(); | ||||
private: | private: | ||||
InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||||
explicit InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||||
std::shared_ptr<InferenceContextImpl> inference_context_impl_; | std::shared_ptr<InferenceContextImpl> inference_context_impl_; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -44,11 +44,16 @@ | |||||
namespace ge { | namespace ge { | ||||
class OperatorImpl; | class OperatorImpl; | ||||
class NamedAttrs; | |||||
class Graph; | |||||
class AttrValue; | class AttrValue; | ||||
using SubgraphBuilder = std::function<Graph(const std::string &name)>; | |||||
using OperatorImplPtr = std::shared_ptr<OperatorImpl>; | using OperatorImplPtr = std::shared_ptr<OperatorImpl>; | ||||
class Graph; | |||||
using GraphBuilderCallback = std::function<Graph()>; | |||||
class OpIO; | class OpIO; | ||||
using OutHandler = std::shared_ptr<OpIO>; | using OutHandler = std::shared_ptr<OpIO>; | ||||
using InHandler = std::shared_ptr<OpIO>; | using InHandler = std::shared_ptr<OpIO>; | ||||
@@ -69,6 +74,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
using OpBool = bool; | using OpBool = bool; | ||||
using OpTensor = Tensor; | using OpTensor = Tensor; | ||||
using OpType = ge::DataType; | using OpType = ge::DataType; | ||||
using OpNamedAttrs = ge::NamedAttrs; | |||||
using OpListInt = std::vector<int64_t>; | using OpListInt = std::vector<int64_t>; | ||||
using OpListFloat = std::vector<float>; | using OpListFloat = std::vector<float>; | ||||
using OpListString = std::vector<string>; | using OpListString = std::vector<string>; | ||||
@@ -77,6 +83,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
using OpBytes = std::vector<uint8_t>; | using OpBytes = std::vector<uint8_t>; | ||||
using OpListListInt = std::vector<std::vector<int64_t>>; | using OpListListInt = std::vector<std::vector<int64_t>>; | ||||
using OpListType = std::vector<ge::DataType>; | using OpListType = std::vector<ge::DataType>; | ||||
using OpListNamedAttrs = std::vector<ge::NamedAttrs>; | |||||
Operator() {} | Operator() {} | ||||
@@ -132,6 +139,12 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
void SetInferenceContext(const InferenceContextPtr &inference_context); | void SetInferenceContext(const InferenceContextPtr &inference_context); | ||||
InferenceContextPtr GetInferenceContext() const; | InferenceContextPtr GetInferenceContext() const; | ||||
void SetGraphBuilder(const GraphBuilderCallback &builder); | |||||
graphStatus GetGraphBuilder(GraphBuilderCallback &builder) const; | |||||
void AddSubgraphName(const string &name); | |||||
string GetSubgraphName(int index) const; | |||||
graphStatus VerifyAllAttr(bool disable_common_verifier = false); | graphStatus VerifyAllAttr(bool disable_common_verifier = false); | ||||
size_t GetInputsSize() const; | size_t GetInputsSize() const; | ||||
@@ -190,8 +203,21 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
Operator &SetAttr(const string &name, const ge::DataType &attr_value); | Operator &SetAttr(const string &name, const ge::DataType &attr_value); | ||||
graphStatus GetAttr(const string &name, ge::DataType &attr_value) const; | graphStatus GetAttr(const string &name, ge::DataType &attr_value) const; | ||||
// func type | |||||
Operator &SetAttr(const string &name, const ge::NamedAttrs &attr_value); | |||||
graphStatus GetAttr(const string &name, ge::NamedAttrs &attr_value) const; | |||||
Operator &SetAttr(const string &name, const std::vector<ge::NamedAttrs> &attr_value); | |||||
graphStatus GetAttr(const string &name, std::vector<ge::NamedAttrs> &attr_value) const; | |||||
void BreakConnect() const; | void BreakConnect() const; | ||||
size_t GetSubgraphNamesCount() const; | |||||
std::vector<std::string> GetSubgraphNames() const; | |||||
SubgraphBuilder GetSubgraphBuilder(const string &name) const; | |||||
Graph GetSubgraph(const string &name) const; | |||||
SubgraphBuilder GetDynamicSubgraphBuilder(const string &name, uint32_t index) const; | |||||
Graph GetDynamicSubgraph(const string &name, uint32_t index) const; | |||||
protected: | protected: | ||||
void AttrRegister(const string &name, float attr_value); | void AttrRegister(const string &name, float attr_value); | ||||
void AttrRegister(const string &name, const std::vector<float> &attr_value); | void AttrRegister(const string &name, const std::vector<float> &attr_value); | ||||
@@ -207,6 +233,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
void AttrRegister(const string &name, const std::vector<std::vector<int64_t>> &attr_value); | void AttrRegister(const string &name, const std::vector<std::vector<int64_t>> &attr_value); | ||||
void AttrRegister(const string &name, const std::vector<ge::DataType> &attr_value); | void AttrRegister(const string &name, const std::vector<ge::DataType> &attr_value); | ||||
void AttrRegister(const string &name, const ge::DataType &attr_value); | void AttrRegister(const string &name, const ge::DataType &attr_value); | ||||
void AttrRegister(const string &name, const ge::NamedAttrs &attr_value); | |||||
void AttrRegister(const string &name, const std::vector<ge::NamedAttrs> &attr_value); | |||||
explicit Operator(OperatorImplPtr &&op_impl); | explicit Operator(OperatorImplPtr &&op_impl); | ||||
@@ -224,6 +252,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
void DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back = true); | void DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back = true); | ||||
void DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index); | |||||
void DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back = true); | void DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back = true); | ||||
void RequiredAttrRegister(const string &name); | void RequiredAttrRegister(const string &name); | ||||
@@ -235,6 +265,10 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name); | Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name); | ||||
void SubgraphRegister(const std::string &name, bool dynamic); | |||||
void SubgraphCountRegister(const std::string &name, uint32_t count); | |||||
void SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder); | |||||
private: | private: | ||||
Operator &SetInput(const string &dst_name, const OutHandler &out_handler); | Operator &SetInput(const string &dst_name, const OutHandler &out_handler); | ||||
@@ -22,10 +22,11 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "./operator.h" | |||||
#include "./operator_factory.h" | |||||
#include "./tensor.h" | |||||
#include "./types.h" | |||||
#include "graph/operator.h" | |||||
#include "graph/operator_factory.h" | |||||
#include "graph/tensor.h" | |||||
#include "graph/types.h" | |||||
#include "graph/graph.h" | |||||
namespace ge { | namespace ge { | ||||
using std::function; | using std::function; | ||||
@@ -46,6 +47,10 @@ class OpReg { | |||||
OpReg &OUTPUT() { return *this; } | OpReg &OUTPUT() { return *this; } | ||||
OpReg &GRAPH() { return *this; } | |||||
OpReg &DYNAMIC_GRAPH() { return *this; } | |||||
OpReg &INFER_SHAPE_AND_TYPE() { return *this; } | OpReg &INFER_SHAPE_AND_TYPE() { return *this; } | ||||
}; | }; | ||||
@@ -191,6 +196,10 @@ class OpReg { | |||||
Operator::DynamicInputRegister(#x, num, isPushBack); \ | Operator::DynamicInputRegister(#x, num, isPushBack); \ | ||||
return *this; \ | return *this; \ | ||||
} \ | } \ | ||||
_THIS_TYPE &create_dynamic_input_byindex_##x(unsigned int num, size_t index) { \ | |||||
Operator::DynamicInputRegisterByIndex(#x, num, index); \ | |||||
return *this; \ | |||||
} \ | |||||
TensorDesc get_dynamic_input_desc_##x(unsigned int index) const { return Operator::GetDynamicInputDesc(#x, index); } \ | TensorDesc get_dynamic_input_desc_##x(unsigned int index) const { return Operator::GetDynamicInputDesc(#x, index); } \ | ||||
graphStatus update_dynamic_input_desc_##x(unsigned int index, const TensorDesc &tensorDesc) { \ | graphStatus update_dynamic_input_desc_##x(unsigned int index, const TensorDesc &tensorDesc) { \ | ||||
return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ | return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ | ||||
@@ -229,6 +238,51 @@ class OpReg { | |||||
void __dy_output_##x() { \ | void __dy_output_##x() { \ | ||||
(void)OpReg() | (void)OpReg() | ||||
#define GRAPH(x) \ | |||||
N(); \ | |||||
__graph_##x(); \ | |||||
} \ | |||||
\ | |||||
public: \ | |||||
static const string name_graph_##x() { return #x; } \ | |||||
SubgraphBuilder get_subgraph_builder_##x() const { return Operator::GetSubgraphBuilder(#x); } \ | |||||
_THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \ | |||||
Operator::SetSubgraphBuilder(#x, 0, v); \ | |||||
return *this; \ | |||||
} \ | |||||
Graph get_subgraph_##x() const { return Operator::GetSubgraph(#x); } \ | |||||
\ | |||||
private: \ | |||||
void __graph_##x() { \ | |||||
Operator::SubgraphRegister(#x, false); \ | |||||
Operator::SubgraphCountRegister(#x, 1); \ | |||||
(void)OpReg() | |||||
#define DYNAMIC_GRAPH(x) \ | |||||
N(); \ | |||||
__graph_##x(); \ | |||||
} \ | |||||
\ | |||||
public: \ | |||||
static const string name_graph_##x() { return #x; } \ | |||||
_THIS_TYPE &create_dynamic_subgraph_##x(unsigned int num) { \ | |||||
Operator::SubgraphCountRegister(#x, num); \ | |||||
return *this; \ | |||||
} \ | |||||
SubgraphBuilder get_dynamic_subgraph_builder_##x(unsigned int index) const { \ | |||||
return Operator::GetDynamicSubgraphBuilder(#x, index); \ | |||||
} \ | |||||
Graph get_dynamic_subgraph_##x(unsigned int index) const { return Operator::GetDynamicSubgraph(#x, index); } \ | |||||
_THIS_TYPE &set_dynamic_subgraph_builder_##x(unsigned int index, const SubgraphBuilder &v) { \ | |||||
Operator::SetSubgraphBuilder(#x, index, v); \ | |||||
return *this; \ | |||||
} \ | |||||
\ | |||||
private: \ | |||||
void __graph_##x() { \ | |||||
Operator::SubgraphRegister(#x, true); \ | |||||
(void)OpReg() | |||||
#define PASTE(g_register, y) g_register##y | #define PASTE(g_register, y) g_register##y | ||||
#define __OP_END_IMPL__(x, y) \ | #define __OP_END_IMPL__(x, y) \ | ||||
N(); \ | N(); \ | ||||
@@ -21,6 +21,7 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include <utility> | |||||
#include "./ge_error_codes.h" | #include "./ge_error_codes.h" | ||||
#include "./types.h" | #include "./types.h" | ||||
@@ -62,6 +63,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc { | |||||
void Update(const Shape &shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); | void Update(const Shape &shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); | ||||
Shape GetShape() const; | Shape GetShape() const; | ||||
void SetShape(const Shape &shape); | void SetShape(const Shape &shape); | ||||
// set shape with -2, it stand for unknown shape | |||||
graphStatus SetUnknownDimNumShape(); | |||||
// for unknown shape | |||||
graphStatus SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range); | |||||
graphStatus GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const; | |||||
Format GetFormat() const; | Format GetFormat() const; | ||||
void SetFormat(Format format); | void SetFormat(Format format); | ||||
@@ -23,7 +23,9 @@ | |||||
namespace ge { | namespace ge { | ||||
static const int64_t UNKNOWN_DIM = -1; | static const int64_t UNKNOWN_DIM = -1; | ||||
static const int64_t UNKNOWN_DIM_NUM = -2; | |||||
static const std::vector<int64_t> UNKNOWN_SHAPE = {0}; | static const std::vector<int64_t> UNKNOWN_SHAPE = {0}; | ||||
static const std::vector<int64_t> UNKNOWN_RANK = {-2}; | |||||
#ifdef HOST_VISIBILITY | #ifdef HOST_VISIBILITY | ||||
#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) | #define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) | ||||
@@ -140,10 +142,19 @@ enum Format { | |||||
FORMAT_NC, | FORMAT_NC, | ||||
FORMAT_DHWNC, | FORMAT_DHWNC, | ||||
FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format | FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format | ||||
FORMAT_FRACTAL_ZN_LSTM, | |||||
FORMAT_RESERVED, | FORMAT_RESERVED, | ||||
FORMAT_ALL | FORMAT_ALL | ||||
}; | }; | ||||
// for unknown shape op type | |||||
enum UnknowShapeOpType { | |||||
DEPEND_IN_SHAPE = 1, // op out shape get by input shape | |||||
DEPEND_CONST_VALUE = 2, // op out shape get by const op value | |||||
DEPEND_SHAPE_RANGE = 3, // op out shape get by range | |||||
DEPEND_COMPUTE = 4 // op out shape get by totally computing | |||||
}; | |||||
struct TensorDescInfo { | struct TensorDescInfo { | ||||
Format format_ = FORMAT_RESERVED; // tbe op register support format | Format format_ = FORMAT_RESERVED; // tbe op register support format | ||||
DataType dataType_ = DT_UNDEFINED; // tbe op register support datatype | DataType dataType_ = DT_UNDEFINED; // tbe op register support datatype | ||||
@@ -58,12 +58,18 @@ Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); | |||||
Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | ||||
std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value, | std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value, | ||||
int in_pos = -1, int out_pos = -1); | int in_pos = -1, int out_pos = -1); | ||||
Status AutoMappingSubgraphIndex(const ge::Graph &graph, const std::function<int(int data_index)> &input, | |||||
const std::function<int(int netoutput_index)> &output); | |||||
Status AutoMappingSubgraphIndex(const ge::Graph &graph, | |||||
const std::function<Status(int data_index, int &parent_input_index)> &input, | |||||
const std::function<Status(int netoutput_index, int &parent_output_index)> &output); | |||||
using google::protobuf::Message; | using google::protobuf::Message; | ||||
class OpRegistrationDataImpl; | class OpRegistrationDataImpl; | ||||
using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>; | using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>; | ||||
using FusionParseParamFunc = | using FusionParseParamFunc = | ||||
std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>; | std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>; | ||||
using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>; | |||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | ||||
public: | public: | ||||
@@ -81,6 +87,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||||
OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); | OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); | ||||
OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn); | |||||
OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | ||||
OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); | OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); | ||||
@@ -93,6 +101,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||||
domi::FrameworkType GetFrameworkType() const; | domi::FrameworkType GetFrameworkType() const; | ||||
ParseParamFunc GetParseParamFn() const; | ParseParamFunc GetParseParamFn() const; | ||||
FusionParseParamFunc GetFusionParseParamFn() const; | FusionParseParamFunc GetFusionParseParamFn() const; | ||||
ParseSubgraphFunc GetParseSubgraphPostFn() const; | |||||
private: | private: | ||||
std::shared_ptr<OpRegistrationDataImpl> impl_; | std::shared_ptr<OpRegistrationDataImpl> impl_; | ||||
@@ -116,27 +125,5 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { | |||||
namespace ge { | namespace ge { | ||||
using OpRegistrationData = domi::OpRegistrationData; | using OpRegistrationData = domi::OpRegistrationData; | ||||
using OpReceiver = domi::OpReceiver; | using OpReceiver = domi::OpReceiver; | ||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { | |||||
public: | |||||
HostCpuOp() = default; | |||||
virtual ~HostCpuOp() = default; | |||||
virtual graphStatus Compute(Operator &op, const std::map<std::string, const Tensor> &inputs, | |||||
std::map<std::string, Tensor> &outputs) = 0; | |||||
}; | |||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { | |||||
public: | |||||
HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); | |||||
}; | |||||
#define REGISTER_HOST_CPU_OP_BUILDER(name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) | |||||
#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) | |||||
#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ | |||||
static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr __attribute__((unused)) = \ | |||||
::ge::HostCpuOpRegistrar(name, []() -> ::ge::HostCpuOp * { return new (std::nothrow) op(); }) | |||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ |
@@ -51,24 +51,24 @@ inline pid_t GetTid() { | |||||
return tid; | return tid; | ||||
} | } | ||||
#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = domi::GetCurrentTimestap() | |||||
#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() | |||||
#define GE_TIMESTAMP_END(stage, stage_name) \ | #define GE_TIMESTAMP_END(stage, stage_name) \ | ||||
do { \ | do { \ | ||||
uint64_t endUsec_##stage = domi::GetCurrentTimestap(); \ | |||||
uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ | |||||
GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ | GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ | ||||
(endUsec_##stage - startUsec_##stage)); \ | (endUsec_##stage - startUsec_##stage)); \ | ||||
} while (0); | } while (0); | ||||
#define GE_TIMESTAMP_CALLNUM_START(stage) \ | |||||
uint64_t startUsec_##stage = domi::GetCurrentTimestap(); \ | |||||
uint64_t call_num_of##stage = 0; \ | |||||
#define GE_TIMESTAMP_CALLNUM_START(stage) \ | |||||
uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \ | |||||
uint64_t call_num_of##stage = 0; \ | |||||
uint64_t time_of##stage = 0 | uint64_t time_of##stage = 0 | ||||
#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = domi::GetCurrentTimestap()) | |||||
#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap()) | |||||
#define GE_TIMESTAMP_ADD(stage) \ | |||||
time_of##stage += domi::GetCurrentTimestap() - startUsec_##stage; \ | |||||
#define GE_TIMESTAMP_ADD(stage) \ | |||||
time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \ | |||||
call_num_of##stage++ | call_num_of##stage++ | ||||
#define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ | #define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ | ||||
@@ -22,7 +22,6 @@ | |||||
#include "cce/cce_def.hpp" | #include "cce/cce_def.hpp" | ||||
#include "common/string_util.h" | #include "common/string_util.h" | ||||
#include "common/util.h" | #include "common/util.h" | ||||
#include "dlog/log.h" | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "ge/ge_api_error_codes.h" | #include "ge/ge_api_error_codes.h" | ||||
@@ -30,7 +29,7 @@ using cce::CC_STATUS_SUCCESS; | |||||
using cce::ccStatus_t; | using cce::ccStatus_t; | ||||
#if !defined(__ANDROID__) && !defined(ANDROID) | #if !defined(__ANDROID__) && !defined(ANDROID) | ||||
#define DOMI_LOGE(...) DAV_LOGE("DOMI", __VA_ARGS__) | |||||
#define DOMI_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) | |||||
#else | #else | ||||
#include <android/log.h> | #include <android/log.h> | ||||
#if defined(BUILD_VERSION_PERF) | #if defined(BUILD_VERSION_PERF) | ||||
@@ -103,17 +102,17 @@ using cce::ccStatus_t; | |||||
} while (0); | } while (0); | ||||
// If expr is not true, print the log and return the specified status | // If expr is not true, print the log and return the specified status | ||||
#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ | |||||
do { \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
std::string msg; \ | |||||
(void)msg.append(domi::StringUtils::FormatString(__VA_ARGS__)); \ | |||||
(void)msg.append( \ | |||||
domi::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||||
DOMI_LOGE("%s", msg.c_str()); \ | |||||
return _status; \ | |||||
} \ | |||||
#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ | |||||
do { \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
std::string msg; \ | |||||
(void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ | |||||
(void)msg.append( \ | |||||
ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||||
DOMI_LOGE("%s", msg.c_str()); \ | |||||
return _status; \ | |||||
} \ | |||||
} while (0); | } while (0); | ||||
// If expr is not true, print the log and return the specified status | // If expr is not true, print the log and return the specified status | ||||
@@ -152,7 +152,6 @@ GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_INVALID, 11, "Get computeGraph by g | |||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_TINY_CAL_CHECK_FAILED, 15, "Check tiny calibration failed."); // 1343242255 | |||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 | GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 | GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 | ||||
@@ -25,6 +25,7 @@ | |||||
#include "common/fmk_error_codes.h" | #include "common/fmk_error_codes.h" | ||||
#include "ge/ge_api_error_codes.h" | #include "ge/ge_api_error_codes.h" | ||||
#include "external/graph/types.h" | #include "external/graph/types.h" | ||||
#include "external/ge/ge_api_types.h" | |||||
namespace ge { | namespace ge { | ||||
enum RuntimeType { HOST = 0, DEVICE = 1 }; | enum RuntimeType { HOST = 0, DEVICE = 1 }; | ||||
@@ -130,7 +131,8 @@ class ModelListener { | |||||
/// @param [in] data_index Index of the input_data | /// @param [in] data_index Index of the input_data | ||||
/// @param [in] resultCode Execution results | /// @param [in] resultCode Execution results | ||||
/// | /// | ||||
virtual Status OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t result_code) = 0; | |||||
virtual Status OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t result_code, | |||||
std::vector<ge::OutputTensorInfo> &outputs) = 0; | |||||
}; | }; | ||||
// OMM configuration item | // OMM configuration item | ||||
@@ -147,6 +149,8 @@ struct Options { | |||||
std::string rankTableFile; | std::string rankTableFile; | ||||
int32_t ge_hccl_flag = 0; | int32_t ge_hccl_flag = 0; | ||||
int32_t physical_device_id; | int32_t physical_device_id; | ||||
std::string profiling_mode; | |||||
std::string profiling_options; | |||||
}; | }; | ||||
// Profiling info of task | // Profiling info of task | ||||
@@ -20,7 +20,7 @@ | |||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include <string> | #include <string> | ||||
namespace domi { | |||||
namespace ge { | |||||
class GflagsUtils { | class GflagsUtils { | ||||
public: | public: | ||||
static bool IsSetCommandTrue(const char *name) { | static bool IsSetCommandTrue(const char *name) { | ||||
@@ -66,6 +66,6 @@ class GflagsUtils { | |||||
} | } | ||||
} | } | ||||
}; | }; | ||||
} // namespace domi | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_COMMON_GFLAGS_UTIL_H_ | #endif // INC_FRAMEWORK_COMMON_GFLAGS_UTIL_H_ |
@@ -26,7 +26,7 @@ | |||||
#include "graph/model.h" | #include "graph/model.h" | ||||
#include "model/ge_model.h" | #include "model/ge_model.h" | ||||
namespace domi { | |||||
namespace ge { | |||||
class ModelHelper { | class ModelHelper { | ||||
public: | public: | ||||
ModelHelper() = default; | ModelHelper() = default; | ||||
@@ -38,7 +38,7 @@ class ModelHelper { | |||||
Status LoadModel(const ge::ModelData& model_data); | Status LoadModel(const ge::ModelData& model_data); | ||||
Status GetModelBufferData(ge::ModelBufferData& model); | Status GetModelBufferData(ge::ModelBufferData& model); | ||||
ModelFileHeader* GetFileHeader() { return file_header_; } | |||||
const ModelFileHeader* GetFileHeader() const { return file_header_; } | |||||
GeModelPtr GetGeModel(); | GeModelPtr GetGeModel(); | ||||
void SetSaveMode(bool val) { is_offline_ = val; } | void SetSaveMode(bool val) { is_offline_ = val; } | ||||
@@ -65,9 +65,8 @@ class ModelHelper { | |||||
Status LoadTask(OmFileLoadHelper& om_load_helper); | Status LoadTask(OmFileLoadHelper& om_load_helper); | ||||
Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); | Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); | ||||
Status ReleaseLocalModelData() noexcept; | Status ReleaseLocalModelData() noexcept; | ||||
Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, ModelPartitionType type, | Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, ModelPartitionType type, | ||||
const uint8_t* data, size_t size); | const uint8_t* data, size_t size); | ||||
}; | }; | ||||
} // namespace domi | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ | #endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ |
@@ -26,8 +26,10 @@ | |||||
#include "framework/common/ge_types.h" | #include "framework/common/ge_types.h" | ||||
using ProcParam = struct PROC_PARAM; | using ProcParam = struct PROC_PARAM; | ||||
using std::string; | |||||
using std::vector; | |||||
namespace domi { | |||||
namespace ge { | |||||
struct ModelPartition { | struct ModelPartition { | ||||
ModelPartitionType type; | ModelPartitionType type; | ||||
uint8_t *data = 0; | uint8_t *data = 0; | ||||
@@ -88,5 +90,5 @@ class OmFileSaveHelper { | |||||
ModelFileHeader model_header_; | ModelFileHeader model_header_; | ||||
OmFileContext context_; | OmFileContext context_; | ||||
}; | }; | ||||
} // namespace domi | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ | #endif // INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ |
@@ -30,7 +30,7 @@ | |||||
using std::vector; | using std::vector; | ||||
namespace domi { | |||||
namespace ge { | |||||
// Size of RC memory alignment, 2M | // Size of RC memory alignment, 2M | ||||
constexpr size_t ALIGN_SIZE = 2097152; | constexpr size_t ALIGN_SIZE = 2097152; | ||||
@@ -118,6 +118,6 @@ class L2CacheOptimize { | |||||
bool Cross(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); | bool Cross(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); | ||||
bool Connect(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); | bool Connect(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); | ||||
}; | }; | ||||
} // namespace domi | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ | #endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ |
@@ -1,810 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ | |||||
#define INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ | |||||
#include <string> | |||||
#include "framework/common/fmk_types.h" | |||||
namespace domi { | |||||
// Public Attribute | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NAME; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WEIGHT_NAME; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IS_QUANTIZE_FACTOR; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ALPHA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BETA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADMODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADMODES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FILTER; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BIAS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BIAS_TERM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_HAS_BIAS_VALUE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WINDOWS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_GLOBAL_POOLING; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CEIL_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDE_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RELU_FLAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ALGO; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FILTER_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_K; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_NORM_REGION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_LOCAL_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_ALPHA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_BETA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BROADCAST; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TIDX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TPADDINGS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_IMG_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_IMG_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NET_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NET_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TMULTIPLES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTIPLES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_T; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_N; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TSHAPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NAN_OPT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AIPP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string NEW_AIPP_CONV_OP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||||
static const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | |||||
static const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_BATCH_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INPUT_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_NODE_DEF; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_OP_DEF; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INPUT_TENSOR_DESC; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_TENSOR_DESC; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INFERRED_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WEIGHTS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DIM_ALIGN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||||
// To be deleted | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_TO_BE_DELETED; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_LOC_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_CONF_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_OCR_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | |||||
// Refinedet | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_LOC_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_CONF_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIORBOX_CONCAT; | |||||
// _Arg | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INDEX; | |||||
// _RetVal | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETVAL_ATTR_NAME_INDEX; | |||||
// Data | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DATA_ATTR_NAME_DATA_TYPE; | |||||
// Send | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SEND_ATTR_EVENT_ID; | |||||
// Recv | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RECV_ATTR_EVENT_ID; | |||||
// convolution | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_COEF; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATIONS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_ALGO; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_GROUP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_PAD_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_PAD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_STRIDE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_DILATION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_NUM_OUTPUT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_KERNEL; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_FILTER; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_BIAS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_RELU_FLAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_ADJ; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_TARGET_SHAPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_BEFORE_PAD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_HAS_BIAS; | |||||
// Pooling | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_NAN_OPT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_PAD_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_GLOBAL_POOLING; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_WINDOW; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_PAD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_STRIDE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_CEIL_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_DATA_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_BEFORE_PAD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_NAME_ALGO; | |||||
// Eltwise | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_COEFF; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_WEIGHT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_RELU_FLAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_ALPHA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_BETA; | |||||
// BatchNorm | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_EPSILON; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_ESTIMATED_MEAN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_BIAS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_DATA_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_IS_TRAINING; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; | |||||
// Huberloss | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HUBER_LOSS_ATTR_DELTA; | |||||
// SSDRealDivTileMul | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; | |||||
// SSDSumMulRealDivMean | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||||
SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; | |||||
/// ConcatFive2Four | |||||
/// ConcatFour2Five | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_CLASS_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TRANS_FOR_LOSS_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOX_TYPE_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_HIGH; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_WIDTH; | |||||
// Scale | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SCALE_ATTR_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SCALE_ATTR_BIAS; | |||||
// FullConnection | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_FILTER; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_BIAS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_RELU_FLAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_ATTR_NAME_ALGO; | |||||
// SoftmaxOpParams | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_ALGO; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_MODE; | |||||
// SparseSoftmaxCrossEntropy | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING; | |||||
// Activation | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ACTIVATION_ATTR_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ACTIVATION_ATTR_COEF; | |||||
// Concat | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_ATTR_NAME_AXIS; | |||||
// Const | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_DATA_TRANSTYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | |||||
// Roipooling | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLED_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLED_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO; | |||||
// DetectionOutput | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_TOP_K; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IMG_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IMG_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE; | |||||
// Ssd DetectionOutput | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_ETA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||||
DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K; | |||||
// Refinedet DetectionOutput | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE; | |||||
// yolo DetectionOutput | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_ClASSES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BIASES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_RELATIVE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION; | |||||
// DetectionPostprocess | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_CLS_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_POST_NMS_TOPN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT; | |||||
// Spatialtransfrom | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_OUTPUT_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_OUTPUT_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM; | |||||
// Proposal | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_BASE_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_MIN_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_RATIO; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_NMS_THRESH; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_TOP_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_IMG_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_IMG_W; | |||||
// Softmax | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_AXIS; | |||||
// Permute | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_ATTR_ORDER; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_ATTR_PERM; | |||||
// SSD Normalize | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_EPS; | |||||
// Flatten | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_ATTR_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_ATTR_END_AXIS; | |||||
// SsdPRIORBOX | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_FLIP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_CLIP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_IMG_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_IMG_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_STEP_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_STEP_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_OFFSET; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_VARIANCE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM; | |||||
// RefinedetPRIORBOX | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | |||||
// PRelu | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PRELU_ATTR_CHANNEL_SHARED; | |||||
// Psroi pooling | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_OUTPUT_DIM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_GROUP_SIZE; | |||||
// Power | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_POWER; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_SHIFT; | |||||
// Log | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_SHIFT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_BASE; | |||||
// Pack | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PACK_ATTR_NAME_NUM; | |||||
// Dynamic stitch | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||||
// Unpack | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UNPACK_ATTR_NAME_NUM; | |||||
// Gathernd | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERND_ATTR_NAME_TINDICES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERND_ATTR_NAME_TPARAMS; | |||||
// Argmax | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_TOPK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_REDUCESIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_OUTMAX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_AXISTYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_KEEPDIMS; | |||||
// Upsample | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE_W; | |||||
// Relu | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NEGATIVE_SLOPE; | |||||
// FreeSpaceExtract | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT; | |||||
// split | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_SLICE_POINT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_SIZE_SPLIT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_NUM_SPLIT; | |||||
// Tvm | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_MAGIC; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_BLOCKDIM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_METADATA; | |||||
// Squeeze | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_ATTR_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_ATTR_DIMS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_OP_NAME; | |||||
// Stride slice | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_BEGIN_MASK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_END_MASK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK; | |||||
// Slice | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SLICE_ATTR_NAME_BEGINS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SLICE_ATTR_NAME_SIZES; | |||||
// Roialign | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_SPATIAL_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_SAMPLING_RATIO; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_NAME_POOLED_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_NAME_POOLED_W; | |||||
// Generate_rpn_proposal | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||||
GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||||
GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH; | |||||
// Decode_bbox | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DECODE_BBOX_ATTR_DECODECLIP; | |||||
// Cast | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CAST_ATTR_DSTT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CAST_ATTR_SRCT; | |||||
// Fastrcnnn predications | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES; | |||||
// REORG | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REORG_ATTR_STRIDE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REORG_ATTR_REVERSE; | |||||
// MERGE | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MERGE_DEAD_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MERGE_PRENODE_FLAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TO_BE_OUTPUT; | |||||
static const std::string NOT_NET_OUTPUT = "not_net_output"; | |||||
// Concatv2 | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_V2_ATTR_TIDX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_V2_ATTR_N; | |||||
// SUM | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_TIDX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_KEEP_DIMS; | |||||
// ResizeBilinear | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_HEIGHT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_WIDTH; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_PAD_END; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ALPHA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_BETA; | |||||
// RetinaNet | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETINANET_FILTER_BACKGROUND_TRUE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETINANET_ANCHOR_FUSION; | |||||
// MatMul | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_TRANSPOSE_X; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_TRANSPOSE_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_HAS_BIAS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_ATTR_IS_TRAINING; | |||||
// Flatten | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_START_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_END_AXIS; | |||||
// Reshape | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NUM_AXES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_SHAPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_ALPHA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_BETA; | |||||
// Frameoworkop | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string T_IN_DATATYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string T_OUT_DATATYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_N; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_C; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_PAD_DEPTH_CONV; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_PAD_CONV; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BEFORE_PAD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ANN_MEAN_KEEPDIMS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_ATTR_PADDINGDS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_ATTR_CONSTANT_VALUE; | |||||
// ConvGradFilter | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE; | |||||
// ConvGradInput | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE; | |||||
// Rnn | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_MODE_STATIC; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MUTI_RNN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CELL_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CNN_RNN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_CELL; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GRU_CELL; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_HT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_XT_HT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_BATCH_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_CELL_CLIP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_PROJ_CLIP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_ACTIVATE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_OUT_MAP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_OUT_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_STATE_OUT_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_TIME_MAJOR; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_IS_INPUT_PRE_PROCESS; | |||||
// Upsample | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE; | |||||
// PadV2 | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_PADS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_T; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_PAD_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_CONST_VALUE; | |||||
// MirrorPad | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_PADS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; | |||||
// Filler | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FILLER_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FILLER_VALUE; | |||||
// Shufflechannel | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHUFFLE_CHANNEL_GROUP; | |||||
// TopKV2 | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TOPKV2_ATTR_K; | |||||
// Calibaration | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_H_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_W_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_TOP_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_BOTTOM_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_RIGHT_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_LEFT_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IS_CONST; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_GROUP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATION_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_EPSILON; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_POOLING_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CLASS_NUM; | |||||
// model | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TARGET_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_STREAM_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_EVENT_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_LABEL_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_MEMORY_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_WEIGHT_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | |||||
// Public Attribute | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IMPLY_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BYTE_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_INFERENCE_ID; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_OPDEF; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_SCOPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OPATTR; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RELUFLAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SEQLEN_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_X_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CONT_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_XSTATIC_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_MINI; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_TINY; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_LITE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STREAM_LABEL; | |||||
// L2_normalize | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string L2_NORMALIZE_ATTR_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string L2_NORMALIZE_ATTR_EPS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_WINDOW; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_CEIL_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_DATA_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_NAN_OP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_PAD_MOD; | |||||
// HCOM | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_ROOT_RANK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_REDUCE_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_RANK_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_REDUCTION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_GROUP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SR_TAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SRC_RANK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_DEST_RANK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SHAPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_DATA_TYPE; | |||||
// Log time stamp | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_TIME_STAMP_LOGID; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_TIME_STAMP_NOTIFY; | |||||
// SpaceToDepth/DepthToSpace | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BLOCK_SIZE; | |||||
// SparseSoftmaxCrossEntropyWithLogits | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; | |||||
// MaxPoolGradWithArgmax | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; | |||||
// AvgPoolGrad | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; | |||||
// Pad | |||||
extern const std::string ATTR_PAD_FORMAT; | |||||
// Varible | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_NAME; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_FRACTALZ_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_4D_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_5D_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_DATA_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_NAME; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_OUT_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SHAPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HALF_VAR_NAME_END; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_CONTAINER; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SHARED_NAME; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_DTYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_ADDR_OFFSET; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_INDEX_KEY; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SRC_VAR_NAME; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_SAVE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_RESTORE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REF_VAR_SRC_VAR_NAME; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||||
// Assign | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ASSIGN_VALIDATE_SHAPE; | |||||
// ShapeN | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_N; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_IN_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_OUT_TYPE; | |||||
// Space2bacth batch2space | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCH_SPACE_ATTR_BLOCK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCH_SPACE_ATTR_PADDING; | |||||
// Depth_to_space space_to_depth | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||||
// FakeQuantWithMinMaxVars | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FakeQuantWithMinMaxVars_ATTR_MAX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FakeQuantWithMinMaxVars_ATTR_MIN; | |||||
// Mobilenet_ssd_conv_fusion | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_BOXES_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_SCORES_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; | |||||
// Lsh project | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSH_PROJ_TYPE; | |||||
// Control flow | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ITERATORS_PER_LOOP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; | |||||
// GatherV2 attr def | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TAXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TINDICES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TPARAMS; | |||||
// Reshape attr def | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NAME_INPUT_DESC; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; | |||||
// Axis attr def | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AXIS_ORG_OP; | |||||
// The node link with SparseSoftmaxCrossEntropyWithLogits | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LINK_WITH_SPARE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | |||||
// For constant folding | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ |
@@ -21,11 +21,17 @@ | |||||
#include <unordered_map> | #include <unordered_map> | ||||
#include <string> | #include <string> | ||||
#include "common/op/attr_define.h" | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "proto/om.pb.h" | #include "proto/om.pb.h" | ||||
namespace domi { | |||||
using domi::AttrDef; | |||||
using domi::AttrDef_ListValue; | |||||
using domi::ModelDef; | |||||
using domi::NamedAttrs; | |||||
using domi::OpDef; | |||||
namespace ge { | |||||
using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; | using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; | ||||
using AttrDefPair = ::google::protobuf::MapPair<std::string, domi::AttrDef>; | using AttrDefPair = ::google::protobuf::MapPair<std::string, domi::AttrDef>; | ||||
@@ -150,6 +156,6 @@ bool GetAttrDefListValue(const std::string &key, int idx, int32_t *value, const | |||||
bool GetAttrDefListValue(const std::string &key, int idx, uint32_t *value, const AttrDefMap &attr); | bool GetAttrDefListValue(const std::string &key, int idx, uint32_t *value, const AttrDefMap &attr); | ||||
bool GetAttrDefListValue(const std::string &key, int idx, float *value, const AttrDefMap &attr); | bool GetAttrDefListValue(const std::string &key, int idx, float *value, const AttrDefMap &attr); | ||||
bool GetAttrDefListValue(const std::string &key, int idx, double *value, const AttrDefMap &attr); | bool GetAttrDefListValue(const std::string &key, int idx, double *value, const AttrDefMap &attr); | ||||
} // namespace domi | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | #endif // INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ |
@@ -62,6 +62,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_LIMIT | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DELTA_INPUT; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DELTA_INPUT; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DATA_INPUT; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DATA_INPUT; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int NORMAL_TENSOR_SIZE; | |||||
class OpUtils { | class OpUtils { | ||||
public: | public: | ||||
/// | /// | ||||
@@ -22,7 +22,7 @@ | |||||
#include <math.h> | #include <math.h> | ||||
#include <stdint.h> | #include <stdint.h> | ||||
namespace domi { | |||||
namespace ge { | |||||
// general | // general | ||||
const float DEFAULT_ALPHA_VALUE = 1.0; | const float DEFAULT_ALPHA_VALUE = 1.0; | ||||
const float DEFAULT_BETA_VALUE = 0.0; | const float DEFAULT_BETA_VALUE = 0.0; | ||||
@@ -421,5 +421,5 @@ const uint32_t MULTI_SHAPE_INPUT_NUM = 2; | |||||
// Shufflechannel | // Shufflechannel | ||||
const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; | const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; | ||||
} // namespace domi | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ | #endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ |
@@ -20,7 +20,7 @@ | |||||
#include <set> | #include <set> | ||||
#include <string> | #include <string> | ||||
namespace domi { | |||||
namespace ge { | |||||
class OpTypeContainer { | class OpTypeContainer { | ||||
public: | public: | ||||
static OpTypeContainer *Instance() { | static OpTypeContainer *Instance() { | ||||
@@ -57,6 +57,6 @@ class OpTypeRegistrar { | |||||
const OpTypeRegistrar g_##var_name##_reg(str_name); | const OpTypeRegistrar g_##var_name##_reg(str_name); | ||||
#define IS_OPTYPE_EXISTING(str_name) (OpTypeContainer::Instance()->IsExisting(str_name)) | #define IS_OPTYPE_EXISTING(str_name) (OpTypeContainer::Instance()->IsExisting(str_name)) | ||||
} // namespace domi | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_COMMON_OP_TYPES_H_ | #endif // INC_FRAMEWORK_COMMON_OP_TYPES_H_ |
@@ -25,10 +25,10 @@ | |||||
/// MAKE_GUARD([&] { Release Resource 1 }) | /// MAKE_GUARD([&] { Release Resource 1 }) | ||||
/// Acquire Resource 2 | /// Acquire Resource 2 | ||||
// MAKE_GUARD([&] { Release Resource 2 }) | // MAKE_GUARD([&] { Release Resource 2 }) | ||||
#define GE_MAKE_GUARD(var, callback) domi::ScopeGuard make_guard_##var(callback) | |||||
#define GE_MAKE_GUARD(var, callback) ScopeGuard make_guard_##var(callback) | |||||
#define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() | #define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() | ||||
namespace domi { | |||||
namespace ge { | |||||
class ScopeGuard { | class ScopeGuard { | ||||
public: | public: | ||||
// Noncopyable | // Noncopyable | ||||
@@ -55,6 +55,6 @@ class ScopeGuard { | |||||
std::function<void()> on_exit_scope_; | std::function<void()> on_exit_scope_; | ||||
bool dismissed_; | bool dismissed_; | ||||
}; | }; | ||||
} // namespace domi | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_COMMON_SCOPE_GUARD_H_ | #endif // INC_FRAMEWORK_COMMON_SCOPE_GUARD_H_ |
@@ -25,7 +25,7 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
namespace domi { | |||||
namespace ge { | |||||
class StringUtils { | class StringUtils { | ||||
public: | public: | ||||
static std::string &Ltrim(std::string &s) { | static std::string &Ltrim(std::string &s) { | ||||
@@ -151,6 +151,6 @@ class StringUtils { | |||||
return ret > 0 ? buffer : ""; | return ret > 0 ? buffer : ""; | ||||
} | } | ||||
}; | }; | ||||
} // namespace domi | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_COMMON_STRING_UTIL_H_ | #endif // INC_FRAMEWORK_COMMON_STRING_UTIL_H_ |
@@ -26,6 +26,7 @@ | |||||
#include <string> | #include <string> | ||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include "framework/common/fmk_error_codes.h" | #include "framework/common/fmk_error_codes.h" | ||||
#include "framework/common/fmk_types.h" | #include "framework/common/fmk_types.h" | ||||
#include "framework/common/op_types.h" | #include "framework/common/op_types.h" | ||||
@@ -46,9 +47,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_A | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_STATUS; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_STATUS; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; | ||||
} // namespace ge | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODE; | |||||
namespace domi { | |||||
// Supported public properties name | // Supported public properties name | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_DUMP_PATH; // Dump path | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_DUMP_PATH; // Dump path | ||||
@@ -68,14 +68,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFIL | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::map<std::string, std::string> PROFILE_COMPONENT_MAP; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::map<std::string, std::string> PROFILE_COMPONENT_MAP; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_CONFIG; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_CONFIG; | ||||
/// @brief Data structure definition related to task sinking | |||||
/// Build model | |||||
enum BuildMode { | |||||
GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | |||||
GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | |||||
GEN_TASK_WITH_FUSION = 5 // Carrying task data (with UB/L1/L2 enabled for all convergence functions) | |||||
}; | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASKS; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASKS; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR; | ||||
@@ -341,8 +333,9 @@ REGISTER_OPTYPE_DECLARE(END, "End"); | |||||
REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell"); | REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell"); | ||||
REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); | REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); | ||||
REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); | REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); | ||||
REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape") | |||||
/***************ANN dedicated operator *************************/ | |||||
// ANN dedicated operator | |||||
REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); | REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); | ||||
REGISTER_OPTYPE_DECLARE(ANN_CONVOLUTION, "AnnConvolution"); | REGISTER_OPTYPE_DECLARE(ANN_CONVOLUTION, "AnnConvolution"); | ||||
REGISTER_OPTYPE_DECLARE(ANN_DEPCONVOLUTION, "AnnDepthConv"); | REGISTER_OPTYPE_DECLARE(ANN_DEPCONVOLUTION, "AnnDepthConv"); | ||||
@@ -359,7 +352,7 @@ REGISTER_OPTYPE_DECLARE(ANN_QUANTIZE, "AnnQuant"); | |||||
REGISTER_OPTYPE_DECLARE(ANN_PAD, "AnnPad"); | REGISTER_OPTYPE_DECLARE(ANN_PAD, "AnnPad"); | ||||
REGISTER_OPTYPE_DECLARE(ANN_RESIZE_BILINEAR, "AnnResizeBilinear"); | REGISTER_OPTYPE_DECLARE(ANN_RESIZE_BILINEAR, "AnnResizeBilinear"); | ||||
/********************Training operator ***********************/ | |||||
// Training operator | |||||
REGISTER_OPTYPE_DECLARE(GATHERV2, "GatherV2"); | REGISTER_OPTYPE_DECLARE(GATHERV2, "GatherV2"); | ||||
REGISTER_OPTYPE_DECLARE(CONVGRADFILTER, "Conv2DBackpropFilter"); | REGISTER_OPTYPE_DECLARE(CONVGRADFILTER, "Conv2DBackpropFilter"); | ||||
REGISTER_OPTYPE_DECLARE(CONV2D, "Conv2D"); | REGISTER_OPTYPE_DECLARE(CONV2D, "Conv2D"); | ||||
@@ -438,11 +431,13 @@ REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive"); | |||||
REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); | REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); | ||||
REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | ||||
REGISTER_OPTYPE_DECLARE(LogTimeStamp, "LogTimeStamp"); | REGISTER_OPTYPE_DECLARE(LogTimeStamp, "LogTimeStamp"); | ||||
REGISTER_OPTYPE_DECLARE(PARALLELCONCATSTART, "_ParallelConcatStart"); | |||||
REGISTER_OPTYPE_DECLARE(CONSTANTOP, "Constant"); | REGISTER_OPTYPE_DECLARE(CONSTANTOP, "Constant"); | ||||
REGISTER_OPTYPE_DECLARE(STREAMSWITCH, "StreamSwitch"); | REGISTER_OPTYPE_DECLARE(STREAMSWITCH, "StreamSwitch"); | ||||
REGISTER_OPTYPE_DECLARE(STREAMSWITCHN, "StreamSwitchN"); | REGISTER_OPTYPE_DECLARE(STREAMSWITCHN, "StreamSwitchN"); | ||||
REGISTER_OPTYPE_DECLARE(STREAMACTIVE, "StreamActive"); | REGISTER_OPTYPE_DECLARE(STREAMACTIVE, "StreamActive"); | ||||
REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); | REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); | ||||
REGISTER_OPTYPE_DECLARE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | |||||
REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | ||||
REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | ||||
REGISTER_OPTYPE_DECLARE(SEND, "Send"); | REGISTER_OPTYPE_DECLARE(SEND, "Send"); | ||||
@@ -450,6 +445,7 @@ REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | |||||
REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | ||||
REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | ||||
REGISTER_OPTYPE_DECLARE(LABELGOTOEX, "LabelGotoEx"); | |||||
REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); | REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); | ||||
REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | ||||
@@ -828,9 +824,6 @@ static constexpr int32_t PARTITION_TYPE_TASK_INFO = 2; | |||||
// number of partitions in the current model | // number of partitions in the current model | ||||
static constexpr uint32_t PARTITION_SIZE = 4; | static constexpr uint32_t PARTITION_SIZE = 4; | ||||
#define SIZE_OF_MODEL_PARTITION_TABLE(table) \ | |||||
(sizeof(domi::ModelPartitionTable) + sizeof(domi::ModelPartitionMemInfo) * (table).num) | |||||
enum ModelPartitionType { MODEL_DEF = 0, WEIGHTS_DATA, TASK_INFO, TBE_KERNELS }; | enum ModelPartitionType { MODEL_DEF = 0, WEIGHTS_DATA, TASK_INFO, TBE_KERNELS }; | ||||
struct ModelPartitionMemInfo { | struct ModelPartitionMemInfo { | ||||
@@ -844,6 +837,8 @@ struct ModelPartitionTable { | |||||
ModelPartitionMemInfo partition[0]; | ModelPartitionMemInfo partition[0]; | ||||
}; | }; | ||||
#define SIZE_OF_MODEL_PARTITION_TABLE(table) (sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * (table).num) | |||||
static constexpr int32_t PTHREAD_CREAT_SUCCESS = 0; // pthread_creat success | static constexpr int32_t PTHREAD_CREAT_SUCCESS = 0; // pthread_creat success | ||||
// Filter format | // Filter format | ||||
@@ -975,8 +970,8 @@ typedef enum tagDomiNanPropagation { | |||||
// mode of cropandresize | // mode of cropandresize | ||||
typedef enum tagDomiCropAndResizeMode { | typedef enum tagDomiCropAndResizeMode { | ||||
DOMI_RESIZE_METHOD_BILINEAR = 0, /**< resize bilinear */ | |||||
DOMI_RESIZE_METHOD_NEAREST, /**< resize nearest */ | |||||
DOMI_RESIZE_METHOD_BILINEAR = 0, // resize bilinear | |||||
DOMI_RESIZE_METHOD_NEAREST, // resize nearest | |||||
DOMI_RESIZE_RESERVED | DOMI_RESIZE_RESERVED | ||||
} domiCropAndResizeMode_t; | } domiCropAndResizeMode_t; | ||||
@@ -1063,6 +1058,15 @@ struct BasicInfo { | |||||
uint32_t total_size; // total memory size | uint32_t total_size; // total memory size | ||||
}; | }; | ||||
#pragma pack() // Cancels single-byte alignment | #pragma pack() // Cancels single-byte alignment | ||||
} // namespace ge | |||||
namespace domi { | |||||
/// @brief Data structure definition related to task sinking | |||||
enum BuildMode { | |||||
GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | |||||
GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | |||||
GEN_TASK_WITH_FUSION = 5 // Carrying task data (with UB/L1/L2 enabled for all convergence functions) | |||||
}; | |||||
} // namespace domi | } // namespace domi | ||||
#endif // INC_FRAMEWORK_COMMON_TYPES_H_ | #endif // INC_FRAMEWORK_COMMON_TYPES_H_ |
@@ -30,12 +30,12 @@ | |||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "mmpa/mmpa_api.h" | #include "mmpa/mmpa_api.h" | ||||
#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||||
do { \ | |||||
if (size <= 0) { \ | |||||
DOMI_LOGE(param[#size] is not a positive number); \ | |||||
return PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||||
do { \ | |||||
if (size <= 0) { \ | |||||
DOMI_LOGE("param[%s] is not a positive number", #size); \ | |||||
return PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
#define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ | #define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ | ||||
@@ -44,7 +44,7 @@ | |||||
if (!b) { \ | if (!b) { \ | ||||
exec_expr; \ | exec_expr; \ | ||||
} \ | } \ | ||||
}; | |||||
} | |||||
// new ge marco | // new ge marco | ||||
// Encapsulate common resource releases | // Encapsulate common resource releases | ||||
@@ -113,101 +113,101 @@ | |||||
} while (0) | } while (0) | ||||
// Check if the parameter is null. If yes, return PARAM_INVALID and record the error | // Check if the parameter is null. If yes, return PARAM_INVALID and record the error | ||||
#define GE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE(param[#val] must not be null.); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE("param[%s] must not be null.", #val); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the parameter is null. If yes, just return and record the error | // Check if the parameter is null. If yes, just return and record the error | ||||
#define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE(param[#val] must not be null.); \ | |||||
return; \ | |||||
} \ | |||||
#define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE("param[%s] must not be null.", #val); \ | |||||
return; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check whether the parameter is null. If so, execute the exec_expr expression and record the error log | // Check whether the parameter is null. If so, execute the exec_expr expression and record the error log | ||||
#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE(param[#val] must not be null.); \ | |||||
exec_expr; \ | |||||
} \ | |||||
#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE("param[%s] must not be null.", #val); \ | |||||
exec_expr; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check whether the parameter is null. If yes, return directly and record the error log | // Check whether the parameter is null. If yes, return directly and record the error log | ||||
#define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE(param[#val] must not be null.); \ | |||||
return; \ | |||||
} \ | |||||
#define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE("param[%s] must not be null.", #val); \ | |||||
return; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the parameter is null. If yes, return false and record the error log | // Check if the parameter is null. If yes, return false and record the error log | ||||
#define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE(param[#val] must not be null.); \ | |||||
return false; \ | |||||
} \ | |||||
#define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
DOMI_LOGE("param[%s] must not be null.", #val); \ | |||||
return false; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the parameter is out of bounds | // Check if the parameter is out of bounds | ||||
#define GE_CHECK_SIZE(size) \ | |||||
do { \ | |||||
if (size == 0) { \ | |||||
DOMI_LOGE(param[#size] is out of range); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_SIZE(size) \ | |||||
do { \ | |||||
if (size == 0) { \ | |||||
DOMI_LOGE("param[%s] is out of range", #size); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the container is empty | // Check if the container is empty | ||||
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||||
do { \ | |||||
if (vector.empty()) { \ | |||||
DOMI_LOGE(param[#vector] is empty !); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||||
do { \ | |||||
if (vector.empty()) { \ | |||||
DOMI_LOGE("param[%s] is empty!", #vector); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the value on the left is greater than or equal to the value on the right | // Check if the value on the left is greater than or equal to the value on the right | ||||
#define GE_CHECK_GE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs < rhs) { \ | |||||
DOMI_LOGE(param[#lhs] is less than[#rhs]); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_GE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs < rhs) { \ | |||||
DOMI_LOGE("param[%s] is less than[%s]", #lhs, #rhs); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the value on the left is less than or equal to the value on the right | // Check if the value on the left is less than or equal to the value on the right | ||||
#define GE_CHECK_LE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs > rhs) { \ | |||||
DOMI_LOGE(param[#lhs] is greater than[#rhs]); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_LE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs > rhs) { \ | |||||
DOMI_LOGE("param[%s] is greater than[%s]", #lhs, #rhs); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
#define GE_DELETE_NEW_SINGLE(var) \ | #define GE_DELETE_NEW_SINGLE(var) \ | ||||
{ \ | |||||
do { \ | |||||
if (var != nullptr) { \ | if (var != nullptr) { \ | ||||
delete var; \ | delete var; \ | ||||
var = nullptr; \ | var = nullptr; \ | ||||
} \ | } \ | ||||
}; | |||||
} while (0) | |||||
#define GE_DELETE_NEW_ARRAY(var) \ | #define GE_DELETE_NEW_ARRAY(var) \ | ||||
{ \ | |||||
do { \ | |||||
if (var != nullptr) { \ | if (var != nullptr) { \ | ||||
delete[] var; \ | delete[] var; \ | ||||
var = nullptr; \ | var = nullptr; \ | ||||
} \ | } \ | ||||
}; | |||||
} while (0) | |||||
/** | /** | ||||
* @ingroup domi_common | * @ingroup domi_common | ||||
@@ -220,7 +220,7 @@ static constexpr int32_t OM_PROTO_VERSION = 2; | |||||
*/ | */ | ||||
#define CEIL(N, n) (((N) + (n)-1) / (n)) | #define CEIL(N, n) (((N) + (n)-1) / (n)) | ||||
namespace domi { | |||||
namespace ge { | |||||
using google::protobuf::Message; | using google::protobuf::Message; | ||||
/// | /// | ||||
@@ -373,7 +373,7 @@ std::string RealPath(const char *path); | |||||
/// @param [in] file_path path of input file | /// @param [in] file_path path of input file | ||||
/// @param [out] result | /// @param [out] result | ||||
/// | /// | ||||
bool CheckInputPathValid(const std::string &file_path); | |||||
bool CheckInputPathValid(const std::string &file_path, const std::string &atc_param = ""); | |||||
/// | /// | ||||
/// @ingroup domi_common | /// @ingroup domi_common | ||||
@@ -381,7 +381,7 @@ bool CheckInputPathValid(const std::string &file_path); | |||||
/// @param [in] file_path path of output file | /// @param [in] file_path path of output file | ||||
/// @param [out] result | /// @param [out] result | ||||
/// | /// | ||||
bool CheckOutputPathValid(const std::string &file_path); | |||||
bool CheckOutputPathValid(const std::string &file_path, const std::string &atc_param = ""); | |||||
/// | /// | ||||
/// @ingroup domi_common | /// @ingroup domi_common | ||||
@@ -390,6 +390,6 @@ bool CheckOutputPathValid(const std::string &file_path); | |||||
/// @param [out] result | /// @param [out] result | ||||
/// | /// | ||||
bool ValidateStr(const std::string &filePath, const std::string &mode); | bool ValidateStr(const std::string &filePath, const std::string &mode); | ||||
} // namespace domi | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_COMMON_UTIL_H_ | #endif // INC_FRAMEWORK_COMMON_UTIL_H_ |
@@ -47,6 +47,8 @@ class GeGenerator { | |||||
Status GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ge::ModelBufferData &model); | Status GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ge::ModelBufferData &model); | ||||
Status GenerateInfershapeGraph(const Graph &graph); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief: Build single OP in Model. | /// @brief: Build single OP in Model. | ||||
@@ -33,7 +33,7 @@ class MemoryAssigner { | |||||
MemoryAssigner &operator=(const MemoryAssigner &) = delete; | MemoryAssigner &operator=(const MemoryAssigner &) = delete; | ||||
Status AssignMemory(bool is_loop_graph, size_t &mem_offset); | |||||
Status AssignMemory(bool is_loop_graph, size_t &mem_offset, size_t &zero_copy_mem_size); | |||||
private: | private: | ||||
ge::ComputeGraphPtr compute_graph_; | ge::ComputeGraphPtr compute_graph_; | ||||
@@ -28,21 +28,27 @@ | |||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "register/register_fmk_types.h" | #include "register/register_fmk_types.h" | ||||
using domi::DOMI_TENSOR_ND; | |||||
using domi::DOMI_TENSOR_RESERVED; | |||||
using domi::domiTensorFormat_t; | |||||
using domi::FMK_TYPE_RESERVED; | |||||
using domi::FrameworkType; | |||||
using std::map; | using std::map; | ||||
using std::string; | using std::string; | ||||
using std::unordered_map; | using std::unordered_map; | ||||
using std::vector; | using std::vector; | ||||
namespace domi { | |||||
namespace ge { | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
* @brief run model | * @brief run model | ||||
*/ | */ | ||||
enum RunMode { | enum RunMode { | ||||
GEN_OM_MODEL = 0, // generate offline model file | |||||
MODEL_TO_JSON = 1, // convert to JSON file | |||||
ONLY_PRE_CHECK = 3, // only for pre-check | |||||
PBTXT_TO_JSON = 5 // pbtxt to json | |||||
GEN_OM_MODEL = 0, // generate offline model file | |||||
MODEL_TO_JSON = 1, // convert to JSON file | |||||
MODEL_TO_JSON_WITH_SHAPE = 2, // convert to json file with shape | |||||
ONLY_PRE_CHECK = 3, // only for pre-check | |||||
PBTXT_TO_JSON = 5 // pbtxt to json | |||||
}; | }; | ||||
/// | /// | ||||
@@ -93,7 +99,7 @@ struct OmgContext { | |||||
std::string ddk_version; | std::string ddk_version; | ||||
// preferential format used by the entire network | // preferential format used by the entire network | ||||
domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | ||||
FrameworkType type = FMK_TYPE_RESERVED; | |||||
domi::FrameworkType type = domi::FMK_TYPE_RESERVED; | |||||
RunMode run_mode = ONLY_PRE_CHECK; | RunMode run_mode = ONLY_PRE_CHECK; | ||||
bool train_flag = false; | bool train_flag = false; | ||||
// whether to use FP16 high precision | // whether to use FP16 high precision | ||||
@@ -102,23 +108,25 @@ struct OmgContext { | |||||
std::string output_type; | std::string output_type; | ||||
// Save the name of the entire network: Some special operators are used to determine a network. Some operators in the | // Save the name of the entire network: Some special operators are used to determine a network. Some operators in the | ||||
// network require special processing based on the specific network. | |||||
// e.g:faster-rcnn, the FirstStageProcessor module is determined as the Faster-R-CNN network based on the scope | |||||
// fusion. Then, the conv+reshape operators in the FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The | |||||
// convolution kernel rearrangement reshape operator needs to be deleted for the convolution kernel. | |||||
// network require special processing based on the specific network. e.g:faster-rcnn, the FirstStageProcessor module | |||||
// is determined as the Faster-R-CNN network based on the scope fusion. Then, the conv+reshape operators in the | |||||
// FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The convolution kernel rearrangement reshape | |||||
// operator needs to be deleted for the convolution kernel. | |||||
std::string net_name; | std::string net_name; | ||||
// Whether to use dynamic batch size or dynamic image size | // Whether to use dynamic batch size or dynamic image size | ||||
bool is_dynamic_input = false; | bool is_dynamic_input = false; | ||||
std::string dynamic_batch_size; | std::string dynamic_batch_size; | ||||
std::string dynamic_image_size; | std::string dynamic_image_size; | ||||
}; | }; | ||||
} // namespace ge | |||||
namespace domi { | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
* @brief get OMG context | * @brief get OMG context | ||||
* @return OmgContext context | * @return OmgContext context | ||||
*/ | */ | ||||
OmgContext &GetContext(); | |||||
ge::OmgContext &GetContext(); | |||||
struct TEBinInfo { | struct TEBinInfo { | ||||
// It is obsolete. It will be automatically obtained from the binfilename field of the JSON file later. | // It is obsolete. It will be automatically obtained from the binfilename field of the JSON file later. | ||||
@@ -26,7 +26,7 @@ | |||||
#include "common/string_util.h" | #include "common/string_util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
namespace domi { | |||||
namespace ge { | |||||
class PlatformVersionManager { | class PlatformVersionManager { | ||||
public: | public: | ||||
PlatformVersionManager() = delete; | PlatformVersionManager() = delete; | ||||
@@ -40,6 +40,6 @@ class PlatformVersionManager { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
}; // class PlatformManager | }; // class PlatformManager | ||||
} // namespace domi | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_OMG_VERSION_H_ | #endif // INC_FRAMEWORK_OMG_VERSION_H_ |
@@ -86,16 +86,16 @@ class _GeSerializable { | |||||
} | } | ||||
template <class T, class... Args> | template <class T, class... Args> | ||||
static void SaveItem(GeAttrValue::NamedAttrs &namedAttrs, string itemName, T &item, Args &... args) { | |||||
static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { | |||||
GeAttrValue itemVal = SaveItemAsAttrValue(item); | GeAttrValue itemVal = SaveItemAsAttrValue(item); | ||||
(void)namedAttrs.SetAttr(itemName, itemVal); | (void)namedAttrs.SetAttr(itemName, itemVal); | ||||
SaveItem(namedAttrs, args...); | SaveItem(namedAttrs, args...); | ||||
} | } | ||||
static void SaveItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) {} | |||||
static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) {} | |||||
template <class T, class... Args> | template <class T, class... Args> | ||||
static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs, string itemName, T &item, Args &... args) { | |||||
static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { | |||||
auto itemVal = namedAttrs.GetItem(itemName); | auto itemVal = namedAttrs.GetItem(itemName); | ||||
auto status = LoadItemFromAttrValue(item, itemVal); | auto status = LoadItemFromAttrValue(item, itemVal); | ||||
if (status != GRAPH_SUCCESS) { | if (status != GRAPH_SUCCESS) { | ||||
@@ -104,7 +104,9 @@ class _GeSerializable { | |||||
return LoadItem(namedAttrs, args...); | return LoadItem(namedAttrs, args...); | ||||
} | } | ||||
static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) { return GRAPH_SUCCESS; } | |||||
static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) { | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
}; | }; | ||||
#define _GE_FI(a) #a, a | #define _GE_FI(a) #a, a | ||||
@@ -171,13 +173,13 @@ class _GeSerializable { | |||||
\ | \ | ||||
private: \ | private: \ | ||||
ge::graphStatus Save(GeAttrValue &ar) const { \ | ge::graphStatus Save(GeAttrValue &ar) const { \ | ||||
GeAttrValue::NamedAttrs named_attrs; \ | |||||
GeAttrValue::NAMED_ATTRS named_attrs; \ | |||||
_GeSerializable::SaveItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ | _GeSerializable::SaveItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ | ||||
return ar.SetValue<GeAttrValue::NamedAttrs>(named_attrs); \ | |||||
return ar.SetValue<GeAttrValue::NAMED_ATTRS>(named_attrs); \ | |||||
} \ | } \ | ||||
ge::graphStatus Load(const GeAttrValue &ar) { \ | ge::graphStatus Load(const GeAttrValue &ar) { \ | ||||
GeAttrValue::NamedAttrs named_attrs; \ | |||||
ge::graphStatus status = ar.GetValue<GeAttrValue::NamedAttrs>(named_attrs); \ | |||||
GeAttrValue::NAMED_ATTRS named_attrs; \ | |||||
ge::graphStatus status = ar.GetValue<GeAttrValue::NAMED_ATTRS>(named_attrs); \ | |||||
if (status != GRAPH_SUCCESS) { \ | if (status != GRAPH_SUCCESS) { \ | ||||
return status; \ | return status; \ | ||||
} \ | } \ | ||||
@@ -83,6 +83,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
// AddNode with NodePtr | // AddNode with NodePtr | ||||
NodePtr AddNode(NodePtr node); | NodePtr AddNode(NodePtr node); | ||||
NodePtr AddNode(OpDescPtr op); | NodePtr AddNode(OpDescPtr op); | ||||
NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize. | |||||
NodePtr AddNodeFront(NodePtr node); | NodePtr AddNodeFront(NodePtr node); | ||||
NodePtr AddNodeFront(const OpDescPtr &op); | NodePtr AddNodeFront(const OpDescPtr &op); | ||||
NodePtr AddInputNode(NodePtr node); | NodePtr AddInputNode(NodePtr node); | ||||
@@ -236,8 +237,9 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
std::deque<NodePtr> &stack); | std::deque<NodePtr> &stack); | ||||
graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | ||||
std::map<string, NodePtr> &breadth_node_map); | std::map<string, NodePtr> &breadth_node_map); | ||||
graphStatus TopologicalSortingSubgraph(); | |||||
graphStatus TopologicalSortingGraph(); | |||||
graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum); | graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum); | ||||
Vistor<NodePtr> AllGraphNodes(std::vector<std::shared_ptr<ComputeGraph>> &subgraphs) const; | |||||
size_t GetInEdgeSize(const NodePtr &node); | size_t GetInEdgeSize(const NodePtr &node); | ||||
size_t GetOutEdgeSize(const NodePtr &node); | size_t GetOutEdgeSize(const NodePtr &node); | ||||
graphStatus RemoveExtraOutEdge(const NodePtr &node); | graphStatus RemoveExtraOutEdge(const NodePtr &node); | ||||
@@ -32,6 +32,12 @@ namespace ge { | |||||
#define GE_FUNC_DEV_VISIBILITY | #define GE_FUNC_DEV_VISIBILITY | ||||
#endif | #endif | ||||
// Public attribute | // Public attribute | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_UNKNOWN_SHAPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAME; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TYPE; | ||||
@@ -58,6 +64,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; | ||||
@@ -74,8 +82,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; | ||||
// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string | |||||
// ATTR_NAME_WEIGHTS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; | ||||
@@ -123,6 +130,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; | ||||
@@ -140,12 +154,24 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||||
// to be deleted | // to be deleted | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; | ||||
@@ -158,15 +184,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; | ||||
// _Arg | // _Arg | ||||
@@ -255,7 +281,29 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNOR | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; | |||||
// Huberloss | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; | |||||
// SSDRealDivTileMul | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; | |||||
// SSDSumMulRealDivMean | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; | |||||
/// ConcatFive2Four | |||||
/// ConcatFour2Five | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; | |||||
// Scale | // Scale | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; | ||||
@@ -292,7 +340,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_AT | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | ||||
// Roipooling | // Roipooling | ||||
@@ -305,6 +352,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLI | |||||
// DetectionOutput | // DetectionOutput | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | ||||
@@ -363,6 +411,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ | |||||
// Permute | // Permute | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; | |||||
// SSD Normalize | // SSD Normalize | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | ||||
@@ -403,9 +452,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_AT | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; | ||||
// Log | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; | |||||
// Pack | // Pack | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; | ||||
// Dynamic stitch | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||||
// Unpack | // Unpack | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; | ||||
// Gathernd | // Gathernd | ||||
@@ -414,8 +469,16 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND | |||||
// Argmax | // Argmax | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; | |||||
// Upsample | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; | |||||
// Relu | // Relu | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; | ||||
@@ -486,6 +549,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_AT | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; | ||||
static const std::string NOT_NET_OUTPUT = "not_net_output"; | |||||
// ENTER | // ENTER | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; | ||||
@@ -511,6 +575,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_B | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; | ||||
// RetinaNet | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; | |||||
// MatMul | // MatMul | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; | ||||
@@ -559,10 +626,30 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; | |||||
// Upsample | // Upsample | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; | ||||
// PadV2 | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; | |||||
// MirrorPad | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; | |||||
// Filler | // Filler | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; | ||||
@@ -583,36 +670,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; | ||||
@@ -627,24 +684,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_HUGE_STREAM_LIST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; | |||||
// Public attribute | // Public attribute | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; | ||||
@@ -678,6 +731,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_T | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_OUTPUT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_OUTPUT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; | ||||
@@ -696,6 +751,161 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||||
// L2_normalize | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; | |||||
// HCOM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; | |||||
// Log time stamp | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; | |||||
// SpaceToDepth/DepthToSpace | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; | |||||
// SparseSoftmaxCrossEntropyWithLogits | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; | |||||
// MaxPoolGradWithArgmax | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; | |||||
// AvgPoolGrad | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; | |||||
// Varible | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||||
// Assign | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; | |||||
// ShapeN | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; | |||||
// Space2bacth batch2space | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; | |||||
// Depth_to_space space_to_depth | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||||
// FakeQuantWithMinMaxVars | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; | |||||
// Mobilenet_ssd_conv_fusion | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; | |||||
// Lsh project | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; | |||||
// Control flow | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; | |||||
// GatherV2 attr def | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; | |||||
// Reshape attr def | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; | |||||
// Axis attr def | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; | |||||
// The node link with SparseSoftmaxCrossEntropyWithLogits | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | |||||
// For constant folding | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||||
// Used for mark the active label list to find stream of activated node | // Used for mark the active label list to find stream of activated node | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; | ||||
@@ -708,7 +918,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
// Control flow | // Control flow | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | ||||
@@ -722,6 +931,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
// Function Op | // Function Op | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_CONST_TYPE; | |||||
// Used for mark the active node is for loop, type:bool | // Used for mark the active node is for loop, type:bool | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; | ||||
@@ -752,6 +962,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NEE | |||||
// For mutil-batch | // For mutil-batch | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERT_BY_MBATCH; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERT_BY_MBATCH; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS; | |||||
// For inserted op | // For inserted op | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; | ||||
@@ -772,6 +983,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
// used for l1 fusion and other fusion in future | // used for l1 fusion and other fusion in future | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; | ||||
@@ -782,10 +994,44 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; | |||||
// functional ops attr | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; | |||||
// used for label switch | // used for label switch | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | ||||
// Varible | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; | |||||
// HCOM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DATATYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_DATATYPE; | |||||
// used for LX tiling | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_L1_SPACE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_TYPE_LIST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST; | |||||
// Dynamic stitch | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ |
@@ -22,7 +22,7 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/anchor.h" | #include "graph/anchor.h" | ||||
#include "detail/attributes_holder.h" | |||||
#include "graph/detail/attributes_holder.h" | |||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
@@ -77,6 +77,8 @@ class ModelSerializeImp { | |||||
void SetProtobufOwner(const ProtoMsgOwner &bufferProtobufOnwer) { protobuf_owner_ = bufferProtobufOnwer; } | void SetProtobufOwner(const ProtoMsgOwner &bufferProtobufOnwer) { protobuf_owner_ = bufferProtobufOnwer; } | ||||
private: | private: | ||||
bool RebuildOwnership(ComputeGraphPtr &compute_graph, std::map<std::string, ComputeGraphPtr> &subgraphs); | |||||
std::vector<NodeNameGraphReq> graph_input_node_names_; | std::vector<NodeNameGraphReq> graph_input_node_names_; | ||||
std::vector<NodeNameGraphReq> graph_output_node_names_; | std::vector<NodeNameGraphReq> graph_output_node_names_; | ||||
std::vector<NodeNameNodeReq> node_input_node_names_; | std::vector<NodeNameNodeReq> node_input_node_names_; | ||||
@@ -43,30 +43,31 @@ using ComputeGraphPtr = std::shared_ptr<ComputeGraph>; | |||||
using ConstComputeGraphPtr = std::shared_ptr<const ComputeGraph>; | using ConstComputeGraphPtr = std::shared_ptr<const ComputeGraph>; | ||||
class GeTensorDesc; | class GeTensorDesc; | ||||
class GeAttrValue; | |||||
class GeAttrValueImp; | class GeAttrValueImp; | ||||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NamedAttrs : public AttrHolder { | |||||
public: | public: | ||||
class NamedAttrs : public AttrHolder { | |||||
public: | |||||
NamedAttrs(); | |||||
virtual ~NamedAttrs() = default; | |||||
void SetName(const std::string &name); | |||||
string GetName() const; | |||||
GeAttrValue GetItem(const string &key) const; | |||||
protected: | |||||
ProtoAttrMapHelper MutableAttrMap() override; | |||||
ConstProtoAttrMapHelper GetAttrMap() const override; | |||||
private: | |||||
// Create namedAttrs from protobuf obj | |||||
NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg); | |||||
GeIrProtoHelper<proto::NamedAttrs> named_attrs_; | |||||
friend class GeAttrValueImp; | |||||
}; | |||||
NamedAttrs(); | |||||
virtual ~NamedAttrs() = default; | |||||
void SetName(const std::string &name); | |||||
string GetName() const; | |||||
GeAttrValue GetItem(const string &key) const; | |||||
protected: | |||||
ProtoAttrMapHelper MutableAttrMap() override; | |||||
ConstProtoAttrMapHelper GetAttrMap() const override; | |||||
private: | |||||
// Create namedAttrs from protobuf obj | |||||
NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg); | |||||
GeIrProtoHelper<proto::NamedAttrs> named_attrs_; | |||||
friend class GeAttrValueImp; | |||||
friend class GeAttrValue; | |||||
}; | |||||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||||
public: | |||||
using INT = int64_t; | using INT = int64_t; | ||||
using FLOAT = float; | using FLOAT = float; | ||||
using BOOL = bool; | using BOOL = bool; | ||||
@@ -75,7 +76,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||||
using TENSOR_DESC = GeTensorDesc; | using TENSOR_DESC = GeTensorDesc; | ||||
using GRAPH = ComputeGraphPtr; | using GRAPH = ComputeGraphPtr; | ||||
using BYTES = Buffer; | using BYTES = Buffer; | ||||
using NAMED_ATTRS = NamedAttrs; | |||||
using NAMED_ATTRS = ge::NamedAttrs; | |||||
using DATA_TYPE = ge::DataType; | using DATA_TYPE = ge::DataType; | ||||
using LIST_INT = vector<INT>; | using LIST_INT = vector<INT>; | ||||
@@ -90,6 +91,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||||
using LIST_LIST_INT = vector<vector<int64_t>>; | using LIST_LIST_INT = vector<vector<int64_t>>; | ||||
using LIST_DATA_TYPE = vector<ge::DataType>; | using LIST_DATA_TYPE = vector<ge::DataType>; | ||||
using NamedAttrs = ge::NamedAttrs; // for cce use (ge::GeAttrValue::NamedAttrs). | |||||
enum ValueType { | enum ValueType { | ||||
VT_NONE = 0, | VT_NONE = 0, | ||||
VT_STRING, | VT_STRING, | ||||
@@ -87,6 +87,12 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH | |||||
GeShape &MutableShape(); | GeShape &MutableShape(); | ||||
void SetShape(GeShape shape); | void SetShape(GeShape shape); | ||||
// set shape with -2, it stand for unknown shape | |||||
void SetUnknownDimNumShape(); | |||||
// for unknown shape | |||||
graphStatus SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range); | |||||
graphStatus GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const; | |||||
GeShape GetOriginShape() const; | GeShape GetOriginShape() const; | ||||
void SetOriginShape(const GeShape &originShape); | void SetOriginShape(const GeShape &originShape); | ||||
@@ -25,11 +25,7 @@ | |||||
#include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
namespace domi { | |||||
class ModelHelper; | |||||
} | |||||
namespace ge { | namespace ge { | ||||
using domi::ModelHelper; | |||||
using std::map; | using std::map; | ||||
using std::string; | using std::string; | ||||
using std::vector; | using std::vector; | ||||
@@ -50,6 +50,8 @@ class GeAttrValue; | |||||
using ConstOpDesc = const OpDesc; | using ConstOpDesc = const OpDesc; | ||||
enum SubgraphType { kStatic, kDynamic, kSubgraphTypeEnd }; | |||||
class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | ||||
public: | public: | ||||
template <class T> | template <class T> | ||||
@@ -83,6 +85,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
graphStatus AddInputDescForward(const string &name, const unsigned int num); | graphStatus AddInputDescForward(const string &name, const unsigned int num); | ||||
graphStatus AddInputDescMiddle(const string &name, const unsigned int num, size_t index); | |||||
graphStatus AddOutputDescForward(const string &name, const unsigned int num); | graphStatus AddOutputDescForward(const string &name, const unsigned int num); | ||||
graphStatus AddOptionalInputDesc(const string &name, const GeTensorDesc &input_desc); | graphStatus AddOptionalInputDesc(const string &name, const GeTensorDesc &input_desc); | ||||
@@ -141,6 +145,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); | graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); | ||||
graphStatus AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index); | |||||
graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); | graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); | ||||
bool IsOptionalInput(const string &name) const; | bool IsOptionalInput(const string &name) const; | ||||
@@ -214,6 +220,9 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
void SetIsInputConst(const vector<bool> &is_input_const); | void SetIsInputConst(const vector<bool> &is_input_const); | ||||
vector<bool> GetIsInputConst() const; | vector<bool> GetIsInputConst() const; | ||||
void SetOpInferDepends(const vector<string> &depend_names); | |||||
vector<string> GetOpInferDepends() const; | |||||
string GetInputNameByIndex(uint32_t index) const; | string GetInputNameByIndex(uint32_t index) const; | ||||
int GetInputIndexByName(const string &name) const; | int GetInputIndexByName(const string &name) const; | ||||
@@ -236,12 +245,23 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
std::string GetOpEngineName() const; | std::string GetOpEngineName() const; | ||||
void RegisterSubgraphIrName(const std::string &name, SubgraphType type); | |||||
const std::map<std::string, SubgraphType> &GetSubgraphIrNames() const; | |||||
SubgraphType GetSubgraphTypeByIrName(const std::string &name) const; | |||||
graphStatus AddSubgraphName(const std::string &name); | graphStatus AddSubgraphName(const std::string &name); | ||||
const std::map<std::string, uint32_t> &GetSubgraphNameIndexes() const; | const std::map<std::string, uint32_t> &GetSubgraphNameIndexes() const; | ||||
std::string GetSubgraphInstanceName(uint32_t index) const; | std::string GetSubgraphInstanceName(uint32_t index) const; | ||||
const std::vector<std::string> &GetSubgraphInstanceNames() const; | const std::vector<std::string> &GetSubgraphInstanceNames() const; | ||||
void AddSubgraphInstanceName(std::string name); | |||||
/// Does not provide functions `AddSubgraphInstance` or `AppendSubgraphInstance`, | |||||
/// because this kind of functions will only append a new subgraph instance name | |||||
/// at the tail of `subgraph_instance_names_` and ignore the synchronous change of `subgraph_names_to_index_`. | |||||
/// If we want to append a new subgraph instance name, the function `AddSubgraphName` should be called first. | |||||
/// \param index | |||||
/// \param name | |||||
/// \return | |||||
graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name); | |||||
void RemoveSubgraphInstanceName(const std::string &name); | void RemoveSubgraphInstanceName(const std::string &name); | ||||
protected: | protected: | ||||
@@ -256,7 +276,23 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
GeIrProtoHelper<ge::proto::OpDef> op_def_; | GeIrProtoHelper<ge::proto::OpDef> op_def_; | ||||
std::vector<std::string> subgraph_instance_names_; | std::vector<std::string> subgraph_instance_names_; | ||||
// subgraph names to index, for a `if` operator: | |||||
// then_branch: 0 | |||||
// else_branch: 1 | |||||
// or for a `case` node: | |||||
// branches0: 0 | |||||
// branches1: 1 | |||||
// branches2: 2 | |||||
std::map<std::string, uint32_t> subgraph_names_to_index_; | std::map<std::string, uint32_t> subgraph_names_to_index_; | ||||
// subgraph ir names to type, for a `if` operator: | |||||
// then_branch: static | |||||
// else_branch: dynamic | |||||
// or for a `case` op: | |||||
// branches: dynamic | |||||
std::map<std::string, SubgraphType> subgraph_ir_names_to_type_; | |||||
vector<GeTensorDescPtr> inputs_desc_{}; | vector<GeTensorDescPtr> inputs_desc_{}; | ||||
vector<GeTensorDescPtr> outputs_desc_{}; | vector<GeTensorDescPtr> outputs_desc_{}; | ||||
map<string, uint32_t> output_name_idx_{}; | map<string, uint32_t> output_name_idx_{}; | ||||
@@ -0,0 +1,79 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef COMMON_GRAPH_REF_RELATION_H_ | |||||
#define COMMON_GRAPH_REF_RELATION_H_ | |||||
#include <deque> | |||||
#include <string> | |||||
#include <unordered_map> | |||||
#include <vector> | |||||
#include "graph/compute_graph.h" | |||||
#include "graph/types.h" | |||||
#include "graph/ge_error_codes.h" | |||||
#include "node.h" | |||||
namespace ge { | |||||
enum InOutFlag { | |||||
NODE_IN = 0, // input flag | |||||
NODE_OUT = 1, // output flag | |||||
}; | |||||
struct RefCell { | |||||
std::string node_name; | |||||
ge::NodePtr node = nullptr; | |||||
InOutFlag in_out = NODE_IN; | |||||
int in_out_idx = 0; | |||||
bool operator==(const RefCell &c) const { | |||||
return node_name == c.node_name && node == c.node && in_out == c.in_out && in_out_idx == c.in_out_idx; | |||||
} | |||||
RefCell() = default; | |||||
RefCell(std::string name, ge::NodePtr node_ptr, InOutFlag in_out_flag, int idx) { | |||||
node_name = name; | |||||
node = node_ptr; | |||||
in_out = in_out_flag; | |||||
in_out_idx = idx; | |||||
}; | |||||
~RefCell() = default; | |||||
}; | |||||
struct RefCellHash { | |||||
size_t operator()(const RefCell &c) const { | |||||
unsigned long number = reinterpret_cast<unsigned long>(reinterpret_cast<uintptr_t>(c.node.get())); | |||||
string tmp = c.node_name + std::to_string(c.in_out) + std::to_string(c.in_out_idx) + std::to_string(number); | |||||
return std::hash<string>()(tmp); | |||||
} | |||||
}; | |||||
class RefRelations { | |||||
public: | |||||
graphStatus LookUpRefRelations(const RefCell &key, std::unordered_set<RefCell, RefCellHash> &result); | |||||
graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); | |||||
graphStatus Clear(); | |||||
RefRelations(); | |||||
~RefRelations() = default; | |||||
public: | |||||
class Impl; | |||||
std::shared_ptr<Impl> impl_ = nullptr; | |||||
}; | |||||
} // namespace ge | |||||
#endif // COMMON_GRAPH_REF_RELATION_H_ |
@@ -14,8 +14,8 @@ | |||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#ifndef INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||||
#define INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||||
#ifndef INC_GRAPH_USR_TYPES_H_ | |||||
#define INC_GRAPH_USR_TYPES_H_ | |||||
#include <atomic> | #include <atomic> | ||||
#include <memory> | #include <memory> | ||||
@@ -130,4 +130,4 @@ struct UsrQuantizeFactorParams { | |||||
#undef USR_TYPE_BYTES_DEC | #undef USR_TYPE_BYTES_DEC | ||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||||
#endif // INC_GRAPH_USR_TYPES_H_ |
@@ -62,9 +62,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||||
static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector<ComputeGraphPtr> &value); | static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector<ComputeGraphPtr> &value); | ||||
static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value); | static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value); | ||||
static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector<GeAttrValue::BYTES> &value); | static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector<GeAttrValue::BYTES> &value); | ||||
static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NamedAttrs &value); | |||||
static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NAMED_ATTRS &value); | |||||
static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name, | static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name, | ||||
const vector<GeAttrValue::NamedAttrs> &value); | |||||
const vector<GeAttrValue::NAMED_ATTRS> &value); | |||||
static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<ConstOpDescPtr> &value); | static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<ConstOpDescPtr> &value); | ||||
static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<OpDescPtr> &value); | static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<OpDescPtr> &value); | ||||
@@ -91,9 +91,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||||
static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector<ComputeGraphPtr> &value); | static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector<ComputeGraphPtr> &value); | ||||
static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value); | static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value); | ||||
static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<GeAttrValue::BYTES> &value); | static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<GeAttrValue::BYTES> &value); | ||||
static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NamedAttrs &value); | |||||
static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NAMED_ATTRS &value); | |||||
static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, | static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, | ||||
vector<GeAttrValue::NamedAttrs> &value); | |||||
vector<GeAttrValue::NAMED_ATTRS> &value); | |||||
static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector<OpDescPtr> &value); | static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector<OpDescPtr> &value); | ||||
// Value will be moved | // Value will be moved | ||||
static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); | static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); | ||||
@@ -95,12 +95,35 @@ | |||||
}; | }; | ||||
namespace ge { | namespace ge { | ||||
enum IOType { kIn, kOut }; | |||||
struct NodeIndexIO { | |||||
NodeIndexIO(ge::NodePtr node, uint32_t index, IOType io_type) | |||||
: node(std::move(node)), index(index), io_type(io_type) {} | |||||
NodeIndexIO(ge::NodePtr node, int index, IOType io_type) | |||||
: node(std::move(node)), index(static_cast<uint32_t>(index)), io_type(io_type) {} | |||||
~NodeIndexIO() {} | |||||
NodePtr node = nullptr; | |||||
uint32_t index = 0; | |||||
IOType io_type = kOut; | |||||
std::string ToString() const { | |||||
if ((node == nullptr) || (node->GetOwnerComputeGraph() == nullptr)) { | |||||
return ""; | |||||
} | |||||
return node->GetName() + (io_type == kOut ? "_out_" : "_in_") + std::to_string(index); | |||||
} | |||||
}; | |||||
class GraphUtils { | class GraphUtils { | ||||
public: | public: | ||||
static ComputeGraphPtr GetComputeGraph(const Graph &graph); | static ComputeGraphPtr GetComputeGraph(const Graph &graph); | ||||
static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); | static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); | ||||
static graphStatus RecoverGraphOperators(const Graph &graph); | |||||
static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector<Operator> &inputs); | static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector<Operator> &inputs); | ||||
static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); | static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); | ||||
@@ -262,6 +285,108 @@ class GraphUtils { | |||||
static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | ||||
static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | ||||
static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec); | |||||
/// | |||||
/// Get reference-mapping of all data_anchors in graph | |||||
/// @param [in] graph | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
static graphStatus GetRefMapping(const ComputeGraphPtr &graph, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol); | |||||
private: | |||||
/// | |||||
/// Get reference-mapping for in_data_anchors of node | |||||
/// @param [in] node | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
static graphStatus HandleInAnchorMapping(const NodePtr &node, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol); | |||||
/// | |||||
/// Get reference-mapping for out_data_anchors of node | |||||
/// @param [in] node | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
static graphStatus HandleOutAnchorMapping(const NodePtr &node, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol); | |||||
/// | |||||
/// Handle input of subgraph | |||||
/// @param [in] node | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
static graphStatus HandleSubgraphInput(const NodePtr &node, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol); | |||||
/// | |||||
/// Handle input of Merge op | |||||
/// @param [in] node | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
static graphStatus HandleMergeInput(const NodePtr &node, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol); | |||||
/// | |||||
/// Handle output of subgraph | |||||
/// @param [in] node | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
static graphStatus HandleSubgraphOutput(const NodePtr &node, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol); | |||||
/// | |||||
/// Union ref-mapping | |||||
/// @param [in] exist_node_info1 | |||||
/// @param [in] exist_node_info2 | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @param [out] symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol); | |||||
/// | |||||
/// Update symbol mapping with a new reference pair | |||||
/// @param [in] cur_node_info | |||||
/// @param [in] exist_node_info | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol); | |||||
/// | |||||
/// Check if out_data_anchor is reference of input | |||||
/// @param [in] out_data_anchor | |||||
/// @param [out] reuse_in_index | |||||
/// @return bool | |||||
/// | |||||
static bool IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index); | |||||
}; | }; | ||||
class ComputeGraphBuilder { | class ComputeGraphBuilder { | ||||
@@ -441,12 +566,12 @@ class CompleteGraphBuilder : public ComputeGraphBuilder { | |||||
private: | private: | ||||
/// | /// | ||||
/// @brief Build inputs | |||||
/// @brief Add data nodes | |||||
/// @param [out] error_code | /// @param [out] error_code | ||||
/// @param [out] error_msg | /// @param [out] error_msg | ||||
/// @return void | /// @return void | ||||
/// | /// | ||||
void BuildInputs(graphStatus &error_code, std::string &error_msg); | |||||
void AddDataNodes(graphStatus &error_code, std::string &error_msg); | |||||
/// | /// | ||||
/// @brief Add data node | /// @brief Add data node | ||||
@@ -455,41 +580,15 @@ class CompleteGraphBuilder : public ComputeGraphBuilder { | |||||
/// @param [out] error_msg | /// @param [out] error_msg | ||||
/// @return void | /// @return void | ||||
/// | /// | ||||
NodePtr AddDateNode(uint32_t index, graphStatus &error_code, std::string &error_msg); | |||||
NodePtr AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg); | |||||
/// | /// | ||||
/// @brief Build outputs | |||||
/// @brief Add RetVal nodes | |||||
/// @param [out] error_code | /// @param [out] error_code | ||||
/// @param [out] error_msg | /// @param [out] error_msg | ||||
/// @return void | /// @return void | ||||
/// | /// | ||||
void BuildOutputs(graphStatus &error_code, std::string &error_msg); | |||||
/// | |||||
/// @brief Add NetOutput node | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return NodePtr | |||||
/// | |||||
NodePtr AddNetOutputNode(graphStatus &error_code, std::string &error_msg); | |||||
/// | |||||
/// @brief Add input/output tensor for NetOutput node | |||||
/// @param [in] out_nodes_info | |||||
/// @param [out] net_output_desc | |||||
/// @return graphStatus | |||||
/// | |||||
graphStatus BuildInOutForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||||
OpDescPtr &net_output_desc); | |||||
/// | |||||
/// @brief Add edge for NetOutput node | |||||
/// @param [in] out_nodes_info | |||||
/// @param [out] net_output_node | |||||
/// @return graphStatus | |||||
/// | |||||
graphStatus AddEdgeForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||||
const NodePtr &net_output_node); | |||||
void AddRetValNodes(graphStatus &error_code, std::string &error_msg); | |||||
std::string name_; | std::string name_; | ||||
NodePtr parent_node_; | NodePtr parent_node_; | ||||
@@ -55,11 +55,44 @@ class NodeUtils { | |||||
static GeTensorDesc GetInputDesc(const Node &node, uint32_t index); | static GeTensorDesc GetInputDesc(const Node &node, uint32_t index); | ||||
static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); | static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); | ||||
static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); | static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); | ||||
// check node whether unknown shape.If node shape contain -1 or -2,out param "is_unknow" will be true; | |||||
// for func op, it will check subgraph yet, if some node shape of subgraph contain -1 or -2, | |||||
// the out param "is_unknow" will be true too | |||||
static graphStatus GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow); | |||||
static std::string GetNodeType(const Node &node); | static std::string GetNodeType(const Node &node); | ||||
static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); | static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); | ||||
static graphStatus AddSubgraph(Node &node, const ComputeGraphPtr &subgraph); | |||||
static graphStatus SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph); | |||||
/// | |||||
/// Check if node is input of subgraph | |||||
/// @param [in] node | |||||
/// @return bool | |||||
/// | |||||
static bool IsSubgraphInput(const NodePtr &node); | |||||
/// | |||||
/// Check if node is output of subgraph | |||||
/// @param [in] node | |||||
/// @return bool | |||||
/// | |||||
static bool IsSubgraphOutput(const NodePtr &node); | |||||
/// | |||||
/// @brief Get subgraph original input node. | |||||
/// @param [in] node | |||||
/// @return Node | |||||
/// | |||||
static NodePtr GetParentInput(const NodePtr &node); | |||||
/// | |||||
/// @brief Get subgraph input is constant. | |||||
/// @param [in] node | |||||
/// @param [out] string | |||||
/// @return bool | |||||
/// | |||||
static bool GetConstOpType(const NodePtr &in_node, std::string &op_type); | |||||
private: | private: | ||||
static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | ||||
@@ -81,6 +81,9 @@ class OpDescUtils { | |||||
static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); | static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); | ||||
static graphStatus SetSubgraphInstanceName(const std::string& subgraph_name, | |||||
const std::string& subgraph_instance_name, OpDescPtr& op_desc); | |||||
private: | private: | ||||
static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); | static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); | ||||
static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); | static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); | ||||
@@ -105,6 +108,14 @@ class OpDescBuilder { | |||||
OpDescBuilder& AddInput(const std::string& name); | OpDescBuilder& AddInput(const std::string& name); | ||||
/// | /// | ||||
/// @brief Add input | |||||
/// @param [in] name | |||||
/// @param [in] tensor | |||||
/// @return OpDescBuilder | |||||
/// | |||||
OpDescBuilder& AddInput(const std::string& name, const GeTensorDesc& tensor); | |||||
/// | |||||
/// @brief Add dynamic input | /// @brief Add dynamic input | ||||
/// @param [in] name | /// @param [in] name | ||||
/// @param [in] num | /// @param [in] num | ||||
@@ -113,6 +124,15 @@ class OpDescBuilder { | |||||
OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num); | OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num); | ||||
/// | /// | ||||
/// @brief Add dynamic input | |||||
/// @param [in] name | |||||
/// @param [in] num | |||||
/// @param [in] tensor | |||||
/// @return OpDescBuilder | |||||
/// | |||||
OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); | |||||
/// | |||||
/// @brief Add output | /// @brief Add output | ||||
/// @param [in] name | /// @param [in] name | ||||
/// @return OpDescBuilder | /// @return OpDescBuilder | ||||
@@ -120,6 +140,14 @@ class OpDescBuilder { | |||||
OpDescBuilder& AddOutput(const std::string& name); | OpDescBuilder& AddOutput(const std::string& name); | ||||
/// | /// | ||||
/// @brief Add output | |||||
/// @param [in] name | |||||
/// @param [in] tensor | |||||
/// @return OpDescBuilder | |||||
/// | |||||
OpDescBuilder& AddOutput(const std::string& name, const GeTensorDesc& tensor); | |||||
/// | |||||
/// @brief Add dynamic output | /// @brief Add dynamic output | ||||
/// @param [in] name | /// @param [in] name | ||||
/// @param [in] num | /// @param [in] num | ||||
@@ -128,6 +156,15 @@ class OpDescBuilder { | |||||
OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num); | OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num); | ||||
/// | /// | ||||
/// @brief Add dynamic output | |||||
/// @param [in] name | |||||
/// @param [in] num | |||||
/// @param [in] tensor | |||||
/// @return OpDescBuilder | |||||
/// | |||||
OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); | |||||
/// | |||||
/// @brief Build op_desc | /// @brief Build op_desc | ||||
/// @return OpDescPtr | /// @return OpDescPtr | ||||
/// | /// | ||||
@@ -136,8 +173,8 @@ class OpDescBuilder { | |||||
private: | private: | ||||
std::string name_; | std::string name_; | ||||
std::string type_; | std::string type_; | ||||
std::vector<std::string> inputs_; | |||||
std::vector<std::string> outputs_; | |||||
std::vector<std::pair<std::string, GeTensorDesc>> inputs_; | |||||
std::vector<std::pair<std::string, GeTensorDesc>> outputs_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -34,13 +34,12 @@ ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
ge_protobuf_generate(ge PROTO_ONNX_SRCS PROTO_ONNX_HDRS ${ONNX_PROTO_LIST}) | ge_protobuf_generate(ge PROTO_ONNX_SRCS PROTO_ONNX_HDRS ${ONNX_PROTO_LIST}) | ||||
# need to remove dependencies on pb files later | # need to remove dependencies on pb files later | ||||
file(GLOB_RECURSE SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"*.cc" | "*.cc" | ||||
"utils/*.cc" | "utils/*.cc" | ||||
"opsproto/*.cc" | "opsproto/*.cc" | ||||
"detail/*.cc" | "detail/*.cc" | ||||
"debug/*.cc" | "debug/*.cc" | ||||
"op_imp.cc" | |||||
"option/*.cc" | "option/*.cc" | ||||
) | ) | ||||
@@ -53,7 +53,6 @@ void Anchor::UnlinkAll() noexcept { | |||||
if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { | if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { | ||||
GELOGW("unlink peer_anchor_ptr failed."); | GELOGW("unlink peer_anchor_ptr failed."); | ||||
} | } | ||||
} while (!peer_anchors_.empty()); | } while (!peer_anchors_.empty()); | ||||
} | } | ||||
} | } | ||||
@@ -42,8 +42,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const | |||||
: name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) { | : name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) { | ||||
attrs_.InitDefault(); | attrs_.InitDefault(); | ||||
} | } | ||||
ComputeGraph::~ComputeGraph() {} | ComputeGraph::~ComputeGraph() {} | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() const { return name_; } | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() const { return name_; } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; } | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const { | ||||
@@ -53,24 +56,50 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesS | |||||
} | } | ||||
return s; | return s; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const { | ||||
vector<NodePtr> all_nodes(nodes_.size()); | |||||
(void)std::copy(nodes_.begin(), nodes_.end(), all_nodes.begin()); | |||||
for (const auto &sub_graph : sub_graph_) { | |||||
if (sub_graph == nullptr) { | |||||
GELOGW("sub graph is nullptr"); | |||||
if (sub_graph_.empty()) { | |||||
return Vistor<NodePtr>(shared_from_this(), nodes_); | |||||
} | |||||
std::vector<std::shared_ptr<ComputeGraph>> subgraphs; | |||||
return AllGraphNodes(subgraphs); | |||||
} | |||||
ComputeGraph::Vistor<NodePtr> ComputeGraph::AllGraphNodes(std::vector<std::shared_ptr<ComputeGraph>> &subgraphs) const { | |||||
std::vector<NodePtr> all_nodes; | |||||
std::deque<NodePtr> candidates; | |||||
candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end()); | |||||
while (!candidates.empty()) { | |||||
NodePtr node = candidates.front(); | |||||
all_nodes.emplace_back(node); | |||||
candidates.pop_front(); | |||||
OpDescPtr op_desc = node->GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
continue; | continue; | ||||
} | } | ||||
for (const auto &node : sub_graph->GetAllNodes()) { | |||||
all_nodes.push_back(node); | |||||
const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||||
for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { | |||||
auto subgraph = GetSubgraph(*name_iter); | |||||
if (subgraph != nullptr) { | |||||
subgraphs.emplace_back(subgraph); | |||||
candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); | |||||
} | |||||
} | } | ||||
} | } | ||||
return Vistor<NodePtr>(shared_from_this(), all_nodes); | return Vistor<NodePtr>(shared_from_this(), all_nodes); | ||||
} | } | ||||
size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } | size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetDirectNode() const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetDirectNode() const { | ||||
return Vistor<NodePtr>(shared_from_this(), nodes_); | return Vistor<NodePtr>(shared_from_this(), nodes_); | ||||
} | } | ||||
ComputeGraph::Vistor<NodePtr> ComputeGraph::GetInputNodes() const { | ComputeGraph::Vistor<NodePtr> ComputeGraph::GetInputNodes() const { | ||||
return Vistor<NodePtr>(shared_from_this(), input_nodes_); | return Vistor<NodePtr>(shared_from_this(), input_nodes_); | ||||
} | } | ||||
@@ -82,6 +111,7 @@ ComputeGraph::Vistor<NodePtr> ComputeGraph::GetOutputNodes() const { | |||||
} | } | ||||
return Vistor<NodePtr>(shared_from_this(), result); | return Vistor<NodePtr>(shared_from_this(), result); | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(const std::string &name) const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(const std::string &name) const { | ||||
for (const auto &node : nodes_) { | for (const auto &node : nodes_) { | ||||
if (node == nullptr) { | if (node == nullptr) { | ||||
@@ -203,10 +233,6 @@ NodePtr ComputeGraph::AddNodeFront(NodePtr node) { | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
node->GetOpDesc()->SetId(nodes_.size()); | node->GetOpDesc()->SetId(nodes_.size()); | ||||
if (nodes_[0] == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "nodes_ size or nodes_[0] is nullptr"); | |||||
return nullptr; | |||||
} | |||||
if (nodes_.size() > 0 && nodes_[0]->GetType() == DATA) { | if (nodes_.size() > 0 && nodes_[0]->GetType() == DATA) { | ||||
(void)nodes_.insert(nodes_.begin() + 1, node); | (void)nodes_.insert(nodes_.begin() + 1, node); | ||||
} else { | } else { | ||||
@@ -248,6 +274,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpD | |||||
GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); | GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); | ||||
return AddNode(node_ptr); | return AddNode(node_ptr); | ||||
} | } | ||||
NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize. | |||||
if (op == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null."); | |||||
return nullptr; | |||||
} | |||||
op->SetId(id); | |||||
NodePtr node = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this())); | |||||
GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); | |||||
GE_IF_BOOL_EXEC(node->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); | |||||
nodes_.push_back(node); | |||||
return node; | |||||
} | |||||
NodePtr ComputeGraph::AddInputNode(NodePtr node) { | NodePtr ComputeGraph::AddInputNode(NodePtr node) { | ||||
if (node == nullptr) { | if (node == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "The node ptr should be not null."); | GELOGE(GRAPH_FAILED, "The node ptr should be not null."); | ||||
@@ -259,6 +299,7 @@ NodePtr ComputeGraph::AddInputNode(NodePtr node) { | |||||
} | } | ||||
return node; | return node; | ||||
} | } | ||||
NodePtr ComputeGraph::AddOutputNode(NodePtr node) { | NodePtr ComputeGraph::AddOutputNode(NodePtr node) { | ||||
if (node == nullptr || node->GetOpDesc() == nullptr) { | if (node == nullptr || node->GetOpDesc() == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "The node ptr or opdesc should be not null."); | GELOGE(GRAPH_FAILED, "The node ptr or opdesc should be not null."); | ||||
@@ -336,6 +377,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveN | |||||
} | } | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
// Used in sub_graph scenes | // Used in sub_graph scenes | ||||
graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) { | graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) { | ||||
if (node == nullptr) { | if (node == nullptr) { | ||||
@@ -372,20 +414,24 @@ graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) { | |||||
GE_IF_BOOL_EXEC(find_node == false, return GRAPH_FAILED); | GE_IF_BOOL_EXEC(find_node == false, return GRAPH_FAILED); | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
std::shared_ptr<ComputeGraph> ComputeGraph::AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph) { | std::shared_ptr<ComputeGraph> ComputeGraph::AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph) { | ||||
if (sub_graph == nullptr) { | if (sub_graph == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "The graph ptr should be not null."); | GELOGE(GRAPH_FAILED, "The graph ptr should be not null."); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
sub_graph_.push_back(sub_graph); | sub_graph_.push_back(sub_graph); | ||||
names_to_subgraph_[sub_graph->GetName()] = sub_graph; | |||||
return sub_graph; | return sub_graph; | ||||
} | } | ||||
graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph) { | graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph) { | ||||
if (sub_graph == nullptr) { | if (sub_graph == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "The graph ptr should be not null."); | GELOGE(GRAPH_FAILED, "The graph ptr should be not null."); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
names_to_subgraph_.erase(sub_graph->GetName()); | |||||
auto iter = find(sub_graph_.begin(), sub_graph_.end(), sub_graph); | auto iter = find(sub_graph_.begin(), sub_graph_.end(), sub_graph); | ||||
if (iter != sub_graph_.end()) { | if (iter != sub_graph_.end()) { | ||||
(void)sub_graph_.erase(iter); | (void)sub_graph_.erase(iter); | ||||
@@ -462,8 +508,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr<ComputeGraph> ComputeGraph::GetSubgraph( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr<ComputeGraph> ComputeGraph::GetSubgraph( | ||||
const std::string &name) const { | const std::string &name) const { | ||||
auto iter = names_to_subgraph_.find(name); | |||||
return iter == names_to_subgraph_.end() ? nullptr : iter->second; | |||||
std::shared_ptr<ComputeGraph> parent = parent_graph_.lock(); | |||||
if (parent == nullptr) { | |||||
auto iter = names_to_subgraph_.find(name); | |||||
return iter == names_to_subgraph_.end() ? nullptr : iter->second; | |||||
} else { | |||||
return parent->GetSubgraph(name); | |||||
} | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector<std::shared_ptr<ComputeGraph>> | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector<std::shared_ptr<ComputeGraph>> | ||||
@@ -495,7 +546,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode( | |||||
/// | /// | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | ||||
ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping) { | ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping) { | ||||
for (auto &input : input_nodes_) { | |||||
size_t update_num = 0; | |||||
for (auto &input : nodes_) { | |||||
if (update_num >= input_mapping.size()) { | |||||
break; | |||||
} | |||||
uint32_t cur_index = 0; | uint32_t cur_index = 0; | ||||
if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { | if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { | ||||
continue; | continue; | ||||
@@ -508,6 +563,7 @@ ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mappi | |||||
GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
update_num++; | |||||
} | } | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
@@ -520,9 +576,9 @@ ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mappi | |||||
/// | /// | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | ||||
ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping) { | ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping) { | ||||
NodePtr net_output = FindNode(kNodeNameNetOutput); | |||||
NodePtr net_output = FindNode(NODE_NAME_NET_OUTPUT); | |||||
if (net_output == nullptr) { | if (net_output == nullptr) { | ||||
GE_LOGE("UpdateOutputMapping failed: node %s not exist in graph.", kNodeNameNetOutput); | |||||
GE_LOGE("UpdateOutputMapping failed: node %s not exist in graph.", NODE_NAME_NET_OUTPUT); | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
OpDescPtr op_desc = net_output->GetOpDesc(); | OpDescPtr op_desc = net_output->GetOpDesc(); | ||||
@@ -557,13 +613,13 @@ ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_map | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { | ||||
std::vector<NodePtr> node_vec = nodes_; | std::vector<NodePtr> node_vec = nodes_; | ||||
for (const auto &node : GetAllNodes()) { | |||||
for (const auto &node : GetDirectNode()) { | |||||
if (node == nullptr || node->GetOpDesc() == nullptr) { | if (node == nullptr || node->GetOpDesc() == nullptr) { | ||||
GELOGW("node or OpDescPtr is nullptr."); | GELOGW("node or OpDescPtr is nullptr."); | ||||
continue; | continue; | ||||
} | } | ||||
GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should be not null."); return GRAPH_FAILED); | GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should be not null."); return GRAPH_FAILED); | ||||
if (node->GetOpDesc()->GetType() == kRecvType) { | |||||
if (node->GetOpDesc()->GetType() == RECV) { | |||||
auto iter = find(node_vec.begin(), node_vec.end(), node); | auto iter = find(node_vec.begin(), node_vec.end(), node); | ||||
if (iter == node_vec.end()) { | if (iter == node_vec.end()) { | ||||
GELOGW("no node found."); | GELOGW("no node found."); | ||||
@@ -574,7 +630,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE | |||||
auto dst_iter = find(node_vec.begin(), node_vec.end(), node->GetOutControlNodes().at(0)); | auto dst_iter = find(node_vec.begin(), node_vec.end(), node->GetOutControlNodes().at(0)); | ||||
(void)node_vec.insert(dst_iter, node); | (void)node_vec.insert(dst_iter, node); | ||||
} | } | ||||
if (node->GetOpDesc()->GetType() == kSendType) { | |||||
if (node->GetOpDesc()->GetType() == SEND) { | |||||
auto iter = find(node_vec.begin(), node_vec.end(), node); | auto iter = find(node_vec.begin(), node_vec.end(), node); | ||||
if (iter == node_vec.end()) { | if (iter == node_vec.end()) { | ||||
GELOGW("no node found."); | GELOGW("no node found."); | ||||
@@ -602,7 +658,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE | |||||
graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | ||||
std::map<NodePtr, uint32_t> &map_in_edge_num, | std::map<NodePtr, uint32_t> &map_in_edge_num, | ||||
std::vector<NodePtr> &stack) { | std::vector<NodePtr> &stack) { | ||||
GELOGI("Runing_Dfs_Sort"); | |||||
GELOGI("Runing_Dfs_Sort: %s", name_.c_str()); | |||||
// Record the number of non data nodes but no input nodes | // Record the number of non data nodes but no input nodes | ||||
GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); | GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); | ||||
@@ -647,7 +703,7 @@ graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | |||||
graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec, | graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec, | ||||
std::map<NodePtr, uint32_t> &map_in_edge_num, | std::map<NodePtr, uint32_t> &map_in_edge_num, | ||||
std::deque<NodePtr> &stack) { | std::deque<NodePtr> &stack) { | ||||
GELOGI("Runing_Bfs_Sort"); | |||||
GELOGI("Runing_Bfs_Sort: %s", name_.c_str()); | |||||
std::vector<NodePtr> stack_input; | std::vector<NodePtr> stack_input; | ||||
std::map<string, NodePtr> breadth_node_map; | std::map<string, NodePtr> breadth_node_map; | ||||
// Record the number of non data nodes but no input nodes | // Record the number of non data nodes but no input nodes | ||||
@@ -708,23 +764,36 @@ graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<No | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { | ||||
auto ret = TopologicalSortingSubgraph(); | |||||
auto ret = TopologicalSortingGraph(); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Sub graph partition Failed"); | GELOGE(ret, "Sub graph partition Failed"); | ||||
return ret; | return ret; | ||||
} | } | ||||
if (sub_graph_.empty()) { | |||||
return SUCCESS; | |||||
} | |||||
// partition sub graph | // partition sub graph | ||||
for (const auto &sub_graph : GetAllSubgraphs()) { | |||||
ret = sub_graph->TopologicalSortingSubgraph(); | |||||
for (const auto &sub_graph : sub_graph_) { | |||||
ret = sub_graph->TopologicalSortingGraph(); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Sub graph topological sort Failed"); | GELOGE(ret, "Sub graph topological sort Failed"); | ||||
return ret; | return ret; | ||||
} | } | ||||
} | } | ||||
std::vector<std::shared_ptr<ComputeGraph>> subgraphs; | |||||
(void)AllGraphNodes(subgraphs); | |||||
if (sub_graph_.size() != subgraphs.size()) { // Graph Partition use subgraph, Keep original | |||||
GELOGW("Keep original subgraph for graph size %zu not equal %zu.", sub_graph_.size(), subgraphs.size()); | |||||
return SUCCESS; | |||||
} | |||||
sub_graph_.swap(subgraphs); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSortingSubgraph() { | |||||
graphStatus ComputeGraph::TopologicalSortingGraph() { | |||||
std::vector<NodePtr> node_vec; | std::vector<NodePtr> node_vec; | ||||
std::map<NodePtr, uint32_t> map_in_edge_num; | std::map<NodePtr, uint32_t> map_in_edge_num; | ||||
bool use_BFS = false; | bool use_BFS = false; | ||||
@@ -735,7 +804,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog | |||||
use_BFS = true; | use_BFS = true; | ||||
} | } | ||||
} else { | } else { | ||||
GELOGW("Get OPTION_GRAPH_RUN_MODE failed, use BFSTopologicalSorting by default."); | |||||
GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); | |||||
} | } | ||||
if (use_BFS) { | if (use_BFS) { | ||||
@@ -793,8 +862,8 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||||
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | ||||
map_in_edge_num[node] = static_cast<uint32_t>(GetInEdgeSize(node)); | map_in_edge_num[node] = static_cast<uint32_t>(GetInEdgeSize(node)); | ||||
if (map_in_edge_num[node] == 0) { | if (map_in_edge_num[node] == 0) { | ||||
if ((node->GetOpDesc()->GetType() != kDataType) && (node->GetOpDesc()->GetType() != kAippDataType) && | |||||
(node->GetOpDesc()->GetType() != kInputType) && (node->GetOpDesc()->GetType() != kAnnDataType)) { | |||||
if ((node->GetOpDesc()->GetType() != DATA) && (node->GetOpDesc()->GetType() != AIPPDATA) && | |||||
(node->GetOpDesc()->GetType() != INPUT_TYPE) && (node->GetOpDesc()->GetType() != ANN_DATA)) { | |||||
// At present, can only judge the isolated point without input and output. | // At present, can only judge the isolated point without input and output. | ||||
// It is impossible to judge the situation with multiple output nodes. | // It is impossible to judge the situation with multiple output nodes. | ||||
if (verify_isolated && GetOutEdgeSize(node) == 0) { | if (verify_isolated && GetOutEdgeSize(node) == 0) { | ||||
@@ -832,6 +901,7 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { | size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { | ||||
size_t in_edge_size = 0; | size_t in_edge_size = 0; | ||||
if (node == nullptr) { | if (node == nullptr) { | ||||
@@ -884,6 +954,7 @@ size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::IsValid() const { return is_valid_flag_; } | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::IsValid() const { return is_valid_flag_; } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { | ||||
GELOGI("graph name = %s.", GetName().c_str()); | GELOGI("graph name = %s.", GetName().c_str()); | ||||
for (const auto &node : GetAllNodes()) { | for (const auto &node : GetAllNodes()) { | ||||
@@ -915,6 +986,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
auto next_nodes = node->GetOutAllNodes(); | auto next_nodes = node->GetOutAllNodes(); | ||||
@@ -954,6 +1026,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Isolate | |||||
} | } | ||||
} | } | ||||
} | } | ||||
// If there is an input control side | // If there is an input control side | ||||
auto in_ctrl_anchor = node->GetInControlAnchor(); | auto in_ctrl_anchor = node->GetInControlAnchor(); | ||||
GE_CHECK_NOTNULL(in_ctrl_anchor); | GE_CHECK_NOTNULL(in_ctrl_anchor); | ||||
@@ -991,6 +1064,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Isolate | |||||
return RemoveExtraOutEdge(node); | return RemoveExtraOutEdge(node); | ||||
} | } | ||||
graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) { | graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
// Remove redundant output edges | // Remove redundant output edges | ||||
@@ -1041,7 +1115,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferSh | |||||
node_ptr->GetName().c_str()); | node_ptr->GetName().c_str()); | ||||
graphStatus status = node_ptr->InferShapeAndType(); | graphStatus status = node_ptr->InferShapeAndType(); | ||||
GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == kDataType || GRAPH_PARAM_INVALID != status, break, | |||||
GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == DATA || GRAPH_PARAM_INVALID != status, break, | |||||
"Op %s does not have the IMPLEMT_INFERFUNC definition," | "Op %s does not have the IMPLEMT_INFERFUNC definition," | ||||
" and subsequent operators no longer perform shape inference.", | " and subsequent operators no longer perform shape inference.", | ||||
node_ptr->GetName().c_str()); | node_ptr->GetName().c_str()); | ||||
@@ -16,237 +16,41 @@ | |||||
#ifndef COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | #ifndef COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | ||||
#define COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | #define COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | ||||
#include <limits.h> | |||||
#include <stdint.h> | |||||
#include <algorithm> | |||||
#include <map> | |||||
#include <memory> | |||||
#include <string> | |||||
#include <vector> | |||||
namespace ge { | namespace ge { | ||||
#define GE_REGISTER_OPTYPE(var_name, str_name) static const char* var_name __attribute__((unused)) = str_name | |||||
#define GE_REGISTER_OPTYPE(var_name, str_name) static const char *var_name __attribute__((unused)) = str_name | |||||
GE_REGISTER_OPTYPE(DATA, "Data"); | GE_REGISTER_OPTYPE(DATA, "Data"); | ||||
GE_REGISTER_OPTYPE(AIPPDATA, "AippData"); | GE_REGISTER_OPTYPE(AIPPDATA, "AippData"); | ||||
GE_REGISTER_OPTYPE(CONVOLUTION, "Convolution"); | |||||
GE_REGISTER_OPTYPE(CORRELATION, "Correlation"); | |||||
GE_REGISTER_OPTYPE(CORRELATIONV2, "Correlation_V2"); | |||||
GE_REGISTER_OPTYPE(DECONVOLUTION, "Deconvolution"); | |||||
GE_REGISTER_OPTYPE(POOLING, "Pooling"); | |||||
GE_REGISTER_OPTYPE(ELTWISE, "Eltwise"); | |||||
GE_REGISTER_OPTYPE(RELU, "ReLU"); | |||||
GE_REGISTER_OPTYPE(RELU6, "ReLU6"); | |||||
GE_REGISTER_OPTYPE(SIGMOID, "Sigmoid"); | |||||
GE_REGISTER_OPTYPE(ABSVAL, "AbsVal"); | |||||
GE_REGISTER_OPTYPE(TANH, "TanH"); | |||||
GE_REGISTER_OPTYPE(PRELU, "PReLU"); | |||||
GE_REGISTER_OPTYPE(BATCHNORM, "BatchNorm"); | |||||
GE_REGISTER_OPTYPE(FUSIONBATCHNORM, "FusionBatchNorm"); | |||||
GE_REGISTER_OPTYPE(SCALE, "Scale"); | |||||
GE_REGISTER_OPTYPE(FULL_CONNECTION, "FullConnection"); | |||||
GE_REGISTER_OPTYPE(SOFTMAX, "Softmax"); | |||||
GE_REGISTER_OPTYPE(PLUS, "Plus"); | |||||
GE_REGISTER_OPTYPE(ACTIVATION, "Activation"); | |||||
GE_REGISTER_OPTYPE(FLATTEN, "Flatten"); | |||||
GE_REGISTER_OPTYPE(ADD, "Add"); | |||||
GE_REGISTER_OPTYPE(SUB, "Sub"); | |||||
GE_REGISTER_OPTYPE(MUL, "Mul"); | |||||
GE_REGISTER_OPTYPE(MATMUL, "MatMul"); | GE_REGISTER_OPTYPE(MATMUL, "MatMul"); | ||||
GE_REGISTER_OPTYPE(RSQRT, "Rsqrt"); | |||||
GE_REGISTER_OPTYPE(BIASADD, "BiasAdd"); | |||||
GE_REGISTER_OPTYPE(RESHAPE, "Reshape"); | GE_REGISTER_OPTYPE(RESHAPE, "Reshape"); | ||||
GE_REGISTER_OPTYPE(DEPCONVOLUTION, "ConvolutionDepthwise"); | |||||
GE_REGISTER_OPTYPE(DROPOUT, "Dropout"); | |||||
GE_REGISTER_OPTYPE(CONCAT, "Concat"); | |||||
GE_REGISTER_OPTYPE(ROIPOOLING, "ROIPooling"); | |||||
GE_REGISTER_OPTYPE(PROPOSAL, "Proposal"); | |||||
GE_REGISTER_OPTYPE(FSRDETECTIONOUTPUT, "FSRDetectionOutput"); | |||||
GE_REGISTER_OPTYPE(DETECTIONPOSTPROCESS, "Detectpostprocess"); | |||||
GE_REGISTER_OPTYPE(LRN, "LRN"); | |||||
GE_REGISTER_OPTYPE(TRANSDATA, "TransData"); | |||||
GE_REGISTER_OPTYPE(PERMUTE, "Permute"); | GE_REGISTER_OPTYPE(PERMUTE, "Permute"); | ||||
GE_REGISTER_OPTYPE(SSDNORMALIZE, "SSDNormalize"); | |||||
GE_REGISTER_OPTYPE(SSDPRIORBOX, "SSDPriorBox"); | |||||
GE_REGISTER_OPTYPE(NETOUTPUT, "NetOutput"); | GE_REGISTER_OPTYPE(NETOUTPUT, "NetOutput"); | ||||
GE_REGISTER_OPTYPE(SSDDETECTIONOUTPUT, "SSDDetectionOutput"); | |||||
GE_REGISTER_OPTYPE(CHANNELAXPY, "ChannelAxpy"); | |||||
GE_REGISTER_OPTYPE(PSROIPOOLING, "PSROIPooling"); | |||||
GE_REGISTER_OPTYPE(POWER, "Power"); | |||||
GE_REGISTER_OPTYPE(ROIALIGN, "ROIAlign"); | |||||
GE_REGISTER_OPTYPE(PYTHON, "Python"); | |||||
GE_REGISTER_OPTYPE(FREESPACEEXTRACT, "FreespaceExtract"); | |||||
GE_REGISTER_OPTYPE(SPATIALTF, "SpatialTransform"); | |||||
GE_REGISTER_OPTYPE(SHAPE, "Shape"); | |||||
GE_REGISTER_OPTYPE(ARGMAX, "ArgMax"); | |||||
GE_REGISTER_OPTYPE(GATHERND, "GatherNd"); | |||||
GE_REGISTER_OPTYPE(GATHER, "Gather"); | |||||
GE_REGISTER_OPTYPE(REALDIV, "RealDiv"); | |||||
GE_REGISTER_OPTYPE(PACK, "Pack"); | |||||
GE_REGISTER_OPTYPE(SLICE, "Slice"); | |||||
GE_REGISTER_OPTYPE(FLOORDIV, "FloorDiv"); | |||||
GE_REGISTER_OPTYPE(_WHILE, "_While"); | |||||
GE_REGISTER_OPTYPE(WHILE, "While"); | |||||
GE_REGISTER_OPTYPE(STATELESSWHILE, "StatelessWhile"); | |||||
GE_REGISTER_OPTYPE(SQUEEZE, "Squeeze"); | GE_REGISTER_OPTYPE(SQUEEZE, "Squeeze"); | ||||
GE_REGISTER_OPTYPE(STRIDEDSLICE, "StridedSlice"); | |||||
GE_REGISTER_OPTYPE(RANGE, "Range"); | |||||
GE_REGISTER_OPTYPE(RPNPROPOSALS, "GenerateRpnProposals"); | |||||
GE_REGISTER_OPTYPE(DECODEBBOX, "DecodeBBox"); | |||||
GE_REGISTER_OPTYPE(PAD, "Pad"); | |||||
GE_REGISTER_OPTYPE(TILE, "Tile"); | |||||
GE_REGISTER_OPTYPE(SIZE, "Size"); | |||||
GE_REGISTER_OPTYPE(CLIPBOXES, "Clipboxes"); | |||||
GE_REGISTER_OPTYPE(FASTRCNNPREDICTIONS, "FastrcnnPredictions"); | |||||
GE_REGISTER_OPTYPE(SPLIT, "Split"); | |||||
GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); | GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); | ||||
GE_REGISTER_OPTYPE(MEAN, "Mean"); | |||||
GE_REGISTER_OPTYPE(GREATER, "Greater"); | |||||
GE_REGISTER_OPTYPE(SWITCH, "Switch"); | GE_REGISTER_OPTYPE(SWITCH, "Switch"); | ||||
GE_REGISTER_OPTYPE(REFSWITCH, "RefSwitch"); | |||||
GE_REGISTER_OPTYPE(MERGE, "Merge"); | GE_REGISTER_OPTYPE(MERGE, "Merge"); | ||||
GE_REGISTER_OPTYPE(REFMERGE, "RefMerge"); | |||||
GE_REGISTER_OPTYPE(ENTER, "Enter"); | |||||
GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); | |||||
GE_REGISTER_OPTYPE(LOOPCOND, "LoopCond"); | |||||
GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); | |||||
GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); | GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); | ||||
GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); | GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); | ||||
GE_REGISTER_OPTYPE(EXIT, "Exit"); | |||||
GE_REGISTER_OPTYPE(REFEXIT, "RefExit"); | |||||
GE_REGISTER_OPTYPE(CONTROLTRIGGER, "ControlTrigger"); | |||||
GE_REGISTER_OPTYPE(TRANSPOSE, "Transpose"); | |||||
GE_REGISTER_OPTYPE(CAST, "Cast"); | |||||
GE_REGISTER_OPTYPE(REGION, "Region"); | |||||
GE_REGISTER_OPTYPE(YOLO, "Yolo"); | |||||
GE_REGISTER_OPTYPE(YOLODETECTIONOUTPUT, "YoloDetectionOutput"); | |||||
GE_REGISTER_OPTYPE(FILL, "Fill"); | |||||
GE_REGISTER_OPTYPE(REVERSE, "Reverse"); | |||||
GE_REGISTER_OPTYPE(UNPACK, "Unpack"); | |||||
GE_REGISTER_OPTYPE(YOLO2REORG, "Yolo2Reorg"); | |||||
GE_REGISTER_OPTYPE(REDUCESUM, "ReduceSum"); | |||||
GE_REGISTER_OPTYPE(CONSTANT, "Const"); | GE_REGISTER_OPTYPE(CONSTANT, "Const"); | ||||
GE_REGISTER_OPTYPE(RESIZEBILINEAR, "ResizeBilinear"); | |||||
GE_REGISTER_OPTYPE(MAXIMUM, "Maximum"); | |||||
GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); | GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); | ||||
GE_REGISTER_OPTYPE(ARG, "_Arg"); | |||||
GE_REGISTER_OPTYPE(FUSEDBATCHNORMGRAD, "FusedBatchNormGrad"); | |||||
GE_REGISTER_OPTYPE(LSTM, "LSTM"); | |||||
GE_REGISTER_OPTYPE(HIGHWAY, "HighWay"); | |||||
GE_REGISTER_OPTYPE(RNN, "RNN"); | |||||
GE_REGISTER_OPTYPE(ATTENTIONDECODER, "AttentionDecoder"); | |||||
GE_REGISTER_OPTYPE(LOGICAL_NOT, "LogicalNot"); | |||||
GE_REGISTER_OPTYPE(LOGICAL_AND, "LogicalAnd"); | |||||
GE_REGISTER_OPTYPE(EQUAL, "Equal"); | |||||
GE_REGISTER_OPTYPE(INTERP, "Interp"); | |||||
GE_REGISTER_OPTYPE(SHUFFLECHANNEL, "ShuffleChannel"); | |||||
GE_REGISTER_OPTYPE(AIPP, "Aipp"); | |||||
GE_REGISTER_OPTYPE(CROPANDRESIZE, "CropAndResize"); | |||||
GE_REGISTER_OPTYPE(UNUSEDCONST, "UnusedConst"); | |||||
GE_REGISTER_OPTYPE(BROADCASTGRADIENTARGS, "BroadcastGradientArgs"); | |||||
GE_REGISTER_OPTYPE(BROADCASTARGS, "BroadcastArgs"); | |||||
GE_REGISTER_OPTYPE(STOPGRADIENT, "StopGradient"); | |||||
GE_REGISTER_OPTYPE(PPREVENTGRADIENT, "PreventGradient"); | |||||
GE_REGISTER_OPTYPE(GUARANTEECONST, "GuaranteeConst"); | |||||
GE_REGISTER_OPTYPE(SPARSETODENSE, "SparseToDense"); | |||||
GE_REGISTER_OPTYPE(NONMAXSUPPRESSION, "NonMaxSuppression"); | |||||
GE_REGISTER_OPTYPE(TOPKV2, "TopKV2"); | |||||
GE_REGISTER_OPTYPE(INVERTPERMUTATION, "InvertPermutation"); | |||||
GE_REGISTER_OPTYPE(MULTINOMIAL, "Multinomial"); | |||||
GE_REGISTER_OPTYPE(REVERSESEQUENCE, "ReverseSequence"); | |||||
GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); | GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); | ||||
GE_REGISTER_OPTYPE(INITDATA, "InitData"); | GE_REGISTER_OPTYPE(INITDATA, "InitData"); | ||||
// ANN specific operator | |||||
GE_REGISTER_OPTYPE(ANN_MEAN, "AnnMean"); | |||||
GE_REGISTER_OPTYPE(ANN_CONVOLUTION, "AnnConvolution"); | |||||
GE_REGISTER_OPTYPE(ANN_DEPCONVOLUTION, "AnnDepthConv"); | |||||
GE_REGISTER_OPTYPE(DIV, "Div"); | |||||
GE_REGISTER_OPTYPE(ANN_FULLCONNECTION, "AnnFullConnection"); | |||||
GE_REGISTER_OPTYPE(ANN_NETOUTPUT, "AnnNetOutput"); | |||||
GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); | GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); | ||||
// Training operator | |||||
GE_REGISTER_OPTYPE(CONVGRADFILTER, "Conv2DBackpropFilter"); | |||||
GE_REGISTER_OPTYPE(CONV2D, "Conv2D"); | |||||
GE_REGISTER_OPTYPE(CONV2DBACKPROPINPUT, "Conv2DBackpropInput"); | |||||
GE_REGISTER_OPTYPE(ACTIVATIONGRAD, "ReluGrad"); | |||||
GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); | GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); | ||||
GE_REGISTER_OPTYPE(AVGPOOLGRAD, "AvgPoolGrad"); | |||||
GE_REGISTER_OPTYPE(SQUARE, "Square"); | |||||
GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); | |||||
GE_REGISTER_OPTYPE(END, "End"); | |||||
GE_REGISTER_OPTYPE(VARIABLE, "Variable"); | GE_REGISTER_OPTYPE(VARIABLE, "Variable"); | ||||
GE_REGISTER_OPTYPE(VARIABLEV2, "VariableV2"); | |||||
/// @ingroup domi_omg | |||||
/// @brief INPUT node type | |||||
static const char* const kInputType = "Input"; | |||||
/// | |||||
/// @ingroup domi_omg | |||||
/// @brief AIPP tag, tag for aipp conv operator | |||||
/// | |||||
static const char* const kAippConvFlag = "Aipp_Conv_Flag"; | |||||
/// | |||||
/// @ingroup domi_omg | |||||
/// @brief AIPP tag, tag for aipp data operator | |||||
/// | |||||
static const char* const kAippDataFlag = "Aipp_Data_Flag"; | |||||
/// | |||||
/// @ingroup domi_omg | |||||
/// @brief AIPP tag, tag for aipp data operator | |||||
/// | |||||
static const char* const kAippDataType = "AippData"; | |||||
/// | |||||
/// @ingroup domi_omg | |||||
/// @brief DATA node type | |||||
/// | |||||
static const char* const kDataType = "Data"; | |||||
/// | |||||
/// @ingroup domi_omg | |||||
/// @brief Frame operator type | |||||
/// | |||||
static const char* const kFrameworkOpType = "FrameworkOp"; | |||||
/// | |||||
/// @ingroup domi_omg | |||||
/// @brief Data node type | |||||
/// | |||||
static const char* const kAnnDataType = "AnnData"; | |||||
static const char* const kAnnNetoutputType = "AnnNetOutput"; | |||||
/// | |||||
/// @ingroup domi_omg | |||||
/// @brief Convolution node type | |||||
/// | |||||
static const char* const kNodeNameNetOutput = "Node_Output"; | |||||
/// | |||||
/// @ingroup domi_omg | |||||
/// @brief RECV node type | |||||
/// | |||||
static const char* const kRecvType = "Recv"; | |||||
GE_REGISTER_OPTYPE(INPUT_TYPE, "Input"); | |||||
/// | |||||
/// @ingroup domi_omg | |||||
/// @brief SEND node type | |||||
/// | |||||
static const char* const kSendType = "Send"; | |||||
GE_REGISTER_OPTYPE(NODE_NAME_NET_OUTPUT, "Node_Output"); | |||||
/// | |||||
/// @ingroup domi_omg | |||||
/// @brief Convolution node type | |||||
/// | |||||
static const char* const kOpTypeConvolution = "Convolution"; | |||||
/// | |||||
/// @ingroup domi_omg | |||||
/// @brief Add convolution node name to hard AIPP | |||||
/// | |||||
static const char* const kAippConvOpNmae = "aipp_conv_op"; | |||||
/// | |||||
/// @ingroup domi_omg | |||||
/// @brief Operator configuration item separator | |||||
/// | |||||
static const char* const kOpConfDelimiter = ":"; | |||||
GE_REGISTER_OPTYPE(RECV, "Recv"); | |||||
GE_REGISTER_OPTYPE(SEND, "Send"); | |||||
}; // namespace ge | }; // namespace ge | ||||
#endif // COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | #endif // COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ |
@@ -15,11 +15,14 @@ | |||||
*/ | */ | ||||
#include "format_refiner.h" | #include "format_refiner.h" | ||||
#include <deque> | #include <deque> | ||||
#include <iostream> | #include <iostream> | ||||
#include <set> | #include <set> | ||||
#include <unordered_map> | #include <unordered_map> | ||||
#include <unordered_set> | #include <unordered_set> | ||||
#include "graph/ref_relation.h" | |||||
#include "./compute_graph.h" | #include "./compute_graph.h" | ||||
#include "./ge_error_codes.h" | #include "./ge_error_codes.h" | ||||
#include "./graph/ge_tensor.h" | #include "./graph/ge_tensor.h" | ||||
@@ -34,14 +37,41 @@ | |||||
#include "utils/tensor_utils.h" | #include "utils/tensor_utils.h" | ||||
#include "utils/type_utils.h" | #include "utils/type_utils.h" | ||||
using namespace ge; | |||||
using namespace std; | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
static const std::unordered_set<string> kChangeDimNodes = {RESHAPE, PERMUTE, EXPANDDIMS, SQUEEZE}; | static const std::unordered_set<string> kChangeDimNodes = {RESHAPE, PERMUTE, EXPANDDIMS, SQUEEZE}; | ||||
static bool net_format_is_nd = true; | static bool net_format_is_nd = true; | ||||
static Format g_user_set_format = FORMAT_ND; | static Format g_user_set_format = FORMAT_ND; | ||||
static bool is_first_infer = true; | static bool is_first_infer = true; | ||||
static RefRelations reflection_builder; | |||||
} // namespace | } // namespace | ||||
graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &reflection, | |||||
std::deque<ge::NodePtr> &nodes, ge::Format to_be_set_format) { | |||||
for (const auto &cell : reflection) { | |||||
auto node = cell.node; | |||||
auto in_out_idx = cell.in_out_idx; | |||||
GE_CHECK_NOTNULL(node); | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
if (cell.in_out == ge::NODE_IN) { | |||||
auto desc = node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(in_out_idx)); | |||||
desc.SetOriginFormat(to_be_set_format); | |||||
desc.SetFormat(to_be_set_format); | |||||
(void)node->GetOpDesc()->UpdateInputDesc(static_cast<uint32_t>(in_out_idx), desc); | |||||
} else { | |||||
auto desc = node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(in_out_idx)); | |||||
desc.SetOriginFormat(to_be_set_format); | |||||
desc.SetFormat(to_be_set_format); | |||||
(void)node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(in_out_idx), desc); | |||||
} | |||||
nodes.push_back(cell.node); | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { | graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) { | if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) { | ||||
@@ -66,7 +96,6 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||||
anchor_points.clear(); | anchor_points.clear(); | ||||
// Get all anchor point nodes and switch nodes | // Get all anchor point nodes and switch nodes | ||||
for (const auto &node_ptr : graph->GetAllNodes()) { | for (const auto &node_ptr : graph->GetAllNodes()) { | ||||
std::vector<bool> is_node_set_format; | |||||
if (node_ptr == nullptr) { | if (node_ptr == nullptr) { | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
@@ -86,7 +115,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||||
for (uint32_t i = 0; i < input_size; i++) { | for (uint32_t i = 0; i < input_size; i++) { | ||||
// Operator pre-set format but not origin format | // Operator pre-set format but not origin format | ||||
auto input_format = op_desc->MutableInputDesc(i)->GetFormat(); | auto input_format = op_desc->MutableInputDesc(i)->GetFormat(); | ||||
// Pre-save data node and default infer fail | |||||
// Pre-save data node (only main graph data) and default infer fail | |||||
if (node_ptr->GetType() == DATA) { | if (node_ptr->GetType() == DATA) { | ||||
data_nodes.push_back(node_ptr); | data_nodes.push_back(node_ptr); | ||||
} | } | ||||
@@ -163,6 +192,16 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||||
} | } | ||||
// Check format whether have been set | // Check format whether have been set | ||||
int idx = peer_out_data_anchor->GetIdx(); | int idx = peer_out_data_anchor->GetIdx(); | ||||
// do peer_out_node name and index as key to lookup reflections | |||||
ge::RefCell key(peer_out_data_node->GetName(), peer_out_data_node, ge::NODE_OUT, idx); | |||||
std::unordered_set<RefCell, RefCellHash> reflection; | |||||
auto status = reflection_builder.LookUpRefRelations(key, reflection); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d out edge", | |||||
(peer_out_data_node->GetName()).c_str(), idx); | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(idx)); | auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(idx)); | ||||
if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | ||||
auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | ||||
@@ -181,18 +220,26 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||||
continue; | continue; | ||||
} | } | ||||
ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||||
ge_tensor_desc.SetFormat(to_be_set_format); | |||||
(void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||||
if (reflection.empty()) { | |||||
ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||||
ge_tensor_desc.SetFormat(to_be_set_format); | |||||
(void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||||
// Call operator infer format api (forward) to get out format | |||||
GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); | |||||
graphStatus status = peer_out_data_node->InferOriginFormat(); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str()); | |||||
return GRAPH_FAILED; | |||||
// Call operator infer format api (forward) to get out format | |||||
GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); | |||||
status = peer_out_data_node->InferOriginFormat(); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
nodes.push_back(peer_out_data_node); | |||||
} else { | |||||
auto status = ReflectionProcess(reflection, nodes, to_be_set_format); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "reflection process failed!"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | } | ||||
nodes.push_back(peer_out_data_node); | |||||
} | } | ||||
} | } | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
@@ -213,17 +260,23 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||||
continue; | continue; | ||||
} | } | ||||
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | ||||
if (peer_in_data_anchor == nullptr) { | |||||
GELOGW("Node[%s] some peer_in_anchor is null", (node->GetName()).c_str()); | |||||
continue; | |||||
} | |||||
GE_IF_BOOL_EXEC(peer_in_data_anchor == nullptr, continue); | |||||
auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | ||||
if (peer_in_data_node == nullptr || peer_in_data_node->GetOpDesc() == nullptr) { | |||||
GELOGW("Node[%s] peer_in_data_node or peer_in_data_node desc is null", node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
GE_IF_BOOL_EXEC(peer_in_data_node == nullptr, continue); | |||||
GE_IF_BOOL_EXEC(peer_in_data_node->GetOpDesc() == nullptr, continue); | |||||
// Check format whether have been set | // Check format whether have been set | ||||
int idx = peer_in_data_anchor->GetIdx(); | int idx = peer_in_data_anchor->GetIdx(); | ||||
// do peer_out_node name and index as key to lookup reflections | |||||
ge::RefCell key(peer_in_data_node->GetName(), peer_in_data_node, ge::NODE_IN, idx); | |||||
std::unordered_set<RefCell, RefCellHash> reflection; | |||||
auto status = reflection_builder.LookUpRefRelations(key, reflection); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d input edge", | |||||
(peer_in_data_node->GetName()).c_str(), idx); | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(idx)); | auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(idx)); | ||||
if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | ||||
auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | ||||
@@ -240,24 +293,33 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||||
GELOGD("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str()); | GELOGD("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str()); | ||||
continue; | continue; | ||||
} | } | ||||
ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||||
ge_tensor_desc.SetFormat(to_be_set_format); | |||||
(void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(idx, ge_tensor_desc); | |||||
/// Because netoutput node added before infer format ,so netoutput is end condition | |||||
/// must set netoutput format , because saved result depend on format | |||||
if (peer_in_data_node_type == NETOUTPUT) { | |||||
continue; | |||||
} | |||||
if (reflection.empty()) { | |||||
ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||||
ge_tensor_desc.SetFormat(to_be_set_format); | |||||
(void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||||
// Call operator infer format api (forward) to get out format | |||||
GELOGD("call infer format func[Forward]!Node is [%s] ", (peer_in_data_node->GetName()).c_str()); | |||||
graphStatus status = peer_in_data_node->InferOriginFormat(); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str()); | |||||
return GRAPH_FAILED; | |||||
/// Because netoutput node added before infer format ,so netoutput is end condition | |||||
/// must set netoutput format , because saved result depend on format | |||||
if (peer_in_data_node_type == NETOUTPUT) { | |||||
continue; | |||||
} | |||||
// Call operator infer format api (forward) to get out format | |||||
GELOGD("call infer format func[Back]!Node is [%s] ", (peer_in_data_node->GetName()).c_str()); | |||||
status = peer_in_data_node->InferOriginFormat(); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
nodes.push_back(peer_in_data_node); | |||||
} else { | |||||
auto status = ReflectionProcess(reflection, nodes, to_be_set_format); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "reflection process failed!"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | } | ||||
nodes.push_back(peer_in_data_node); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -355,8 +417,15 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) | |||||
GELOGE(GRAPH_FAILED, "input graph is null"); | GELOGE(GRAPH_FAILED, "input graph is null"); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
// build reflection relations of boundary | |||||
(void)reflection_builder.Clear(); | |||||
auto status = reflection_builder.BuildRefRelations(*graph); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "build reflection relations failed for main and subgraph!"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
// User set global net format | // User set global net format | ||||
graphStatus status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status); | |||||
status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status); | |||||
if (status != GRAPH_SUCCESS) { | if (status != GRAPH_SUCCESS) { | ||||
GELOGE(GRAPH_FAILED, "GetAnchorPoints Process Faild!"); | GELOGE(GRAPH_FAILED, "GetAnchorPoints Process Faild!"); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
@@ -18,6 +18,12 @@ | |||||
namespace ge { | namespace ge { | ||||
// Public attribute | // Public attribute | ||||
const std::string ATTR_NAME_IS_UNKNOWN_SHAPE = "_is_unknown_shape"; | |||||
const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED = "_dynamic_shape_partitioned"; | |||||
const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE = "_unknown_shape_type"; | |||||
const std::string ATTR_NAME_NAME = "name"; | const std::string ATTR_NAME_NAME = "name"; | ||||
const std::string ATTR_NAME_TYPE = "type"; | const std::string ATTR_NAME_TYPE = "type"; | ||||
@@ -42,6 +48,8 @@ const std::string ATTR_NAME_BIAS = "bias"; | |||||
const std::string ATTR_NAME_BIAS_TERM = "bias_term"; | const std::string ATTR_NAME_BIAS_TERM = "bias_term"; | ||||
const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; | |||||
const std::string ATTR_NAME_PAD = "pad"; | const std::string ATTR_NAME_PAD = "pad"; | ||||
const std::string ATTR_NAME_PADS = "pad"; | const std::string ATTR_NAME_PADS = "pad"; | ||||
@@ -83,6 +91,7 @@ const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; | |||||
const std::string ATTR_NAME_AXIS = "axis"; | const std::string ATTR_NAME_AXIS = "axis"; | ||||
const std::string ATTR_NAME_BROADCAST = "broadcast"; | const std::string ATTR_NAME_BROADCAST = "broadcast"; | ||||
const std::string ATTR_NAME_OUTPUT = "output"; | |||||
const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; | const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; | ||||
const std::string ATTR_NAME_TIDX = "t_idx"; | const std::string ATTR_NAME_TIDX = "t_idx"; | ||||
@@ -103,6 +112,13 @@ const std::string ATTR_NAME_TSHAPE = "Tshape"; | |||||
const std::string ATTR_NAME_NAN_OPT = "nan_opt"; | const std::string ATTR_NAME_NAN_OPT = "nan_opt"; | ||||
const std::string ATTR_NAME_AIPP = "aipp"; | const std::string ATTR_NAME_AIPP = "aipp"; | ||||
const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; | |||||
const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||||
const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; | |||||
const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; | |||||
const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; | |||||
const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; | const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; | ||||
const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; | const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; | ||||
@@ -111,6 +127,7 @@ const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; | |||||
const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; | const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; | ||||
const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; | const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; | ||||
const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; | const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; | ||||
const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||||
const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; | const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; | ||||
const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | ||||
@@ -122,9 +139,12 @@ const std::string ATTR_NAME_WEIGHTS = "value"; | |||||
const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; | const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; | ||||
const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; | const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; | ||||
const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; | const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; | ||||
const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||||
const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||||
const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; | |||||
const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL = "_continuous_stream_label"; | |||||
const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; | |||||
const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; | |||||
const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; | |||||
const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; | |||||
// To be deleted | // To be deleted | ||||
const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; | const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; | ||||
@@ -138,15 +158,13 @@ const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; | |||||
const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | ||||
const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | ||||
const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||||
// Refinedet | // Refinedet | ||||
const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; | const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; | ||||
const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||||
const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; | const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; | ||||
const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | ||||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||||
const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||||
const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||||
// _Arg | // _Arg | ||||
const std::string ATTR_NAME_INDEX = "index"; | const std::string ATTR_NAME_INDEX = "index"; | ||||
@@ -236,6 +254,30 @@ const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; | |||||
const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; | const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; | ||||
const std::string BATCHNORM_ATTR_SCALE = "scale"; | const std::string BATCHNORM_ATTR_SCALE = "scale"; | ||||
const std::string BATCHNORM_ATTR_BIAS = "bias"; | const std::string BATCHNORM_ATTR_BIAS = "bias"; | ||||
const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; | |||||
const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; | |||||
const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; | |||||
// huberloss | |||||
const std::string HUBER_LOSS_ATTR_DELTA = "delta"; | |||||
// SSDRealDivTileMul | |||||
const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; | |||||
// SSDSumMulRealDivMean | |||||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; | |||||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; | |||||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; | |||||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; | |||||
// ConcatFive2Four | |||||
// ConcatFour2Five | |||||
const std::string SSD_BOX_TYPE_NUM = "box_type_num"; | |||||
const std::string SSD_CLASS_NUM = "class_num"; | |||||
const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; | |||||
const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; | |||||
const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; | |||||
const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; | |||||
// Scale | // Scale | ||||
const std::string SCALE_ATTR_SCALE = "scale"; | const std::string SCALE_ATTR_SCALE = "scale"; | ||||
@@ -340,6 +382,7 @@ const std::string SOFTMAX_ATTR_AXIS = "axis"; | |||||
// Permute | // Permute | ||||
const std::string PERMUTE_ATTR_ORDER = "order"; | const std::string PERMUTE_ATTR_ORDER = "order"; | ||||
const std::string PERMUTE_ATTR_PERM = "perm"; | |||||
// SSD Normalize | // SSD Normalize | ||||
const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; | const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; | ||||
@@ -367,6 +410,10 @@ const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; | |||||
const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; | const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; | ||||
const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | ||||
// RefinedetDetectionOutput | |||||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||||
// PRelu | // PRelu | ||||
const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; | const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; | ||||
@@ -380,11 +427,16 @@ const std::string POWER_ATTR_NAME_POWER = "power"; | |||||
const std::string POWER_ATTR_NAME_SCALE = "scale"; | const std::string POWER_ATTR_NAME_SCALE = "scale"; | ||||
const std::string POWER_ATTR_NAME_SHIFT = "shift"; | const std::string POWER_ATTR_NAME_SHIFT = "shift"; | ||||
// log | |||||
const std::string LOG_ATTR_NAME_SCALE = "scale"; | |||||
const std::string LOG_ATTR_NAME_SHIFT = "shift"; | |||||
const std::string LOG_ATTR_NAME_BASE = "base"; | |||||
// Pack | // Pack | ||||
const std::string PACK_ATTR_NAME_NUM = "N"; | const std::string PACK_ATTR_NAME_NUM = "N"; | ||||
// Unpack | // Unpack | ||||
const std::string UNPACK_ATTR_NAME_NUM = "num"; | const std::string UNPACK_ATTR_NAME_NUM = "num"; | ||||
const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; | |||||
// Gathernd | // Gathernd | ||||
const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; | const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; | ||||
const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; | const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; | ||||
@@ -394,6 +446,13 @@ const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; | |||||
const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; | const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; | ||||
const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; | const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; | ||||
const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; | const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; | ||||
const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; | |||||
const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; | |||||
const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; | |||||
// upsample | |||||
const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; | |||||
const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; | |||||
// Relu | // Relu | ||||
const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; | const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; | ||||
@@ -531,19 +590,41 @@ const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape | |||||
const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; | const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; | ||||
// Rnn | // Rnn | ||||
const std::string RNN_MODE_ = "rnn_"; | |||||
const std::string CNN_RNN = "cnn_rnn"; | |||||
const std::string RNN_MODE_STATIC = "rnn_static"; | |||||
const std::string MUTI_RNN = "multi_rnn"; | const std::string MUTI_RNN = "multi_rnn"; | ||||
const std::string CNN_RNN = "cnn_rnn"; | |||||
const std::string RNN_MODE_ = "rnn_"; | |||||
const std::string CELL_MODE = "mode"; | const std::string CELL_MODE = "mode"; | ||||
const std::string LSTM_CELL = "lstm_cell"; | const std::string LSTM_CELL = "lstm_cell"; | ||||
const std::string GRU_CELL = "gru_cell"; | const std::string GRU_CELL = "gru_cell"; | ||||
const std::string RNN_HT = "ht"; | const std::string RNN_HT = "ht"; | ||||
const std::string RNN_XT_HT = "xt_ht"; | const std::string RNN_XT_HT = "xt_ht"; | ||||
const std::string RNN_BATCH_SIZE = "batch_size"; | const std::string RNN_BATCH_SIZE = "batch_size"; | ||||
const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; | |||||
const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; | |||||
const std::string LSTM_ACTIVATE = "lstm_activate"; | |||||
const std::string LSTM_OUT_MAP = "lstm_out_map"; | |||||
const std::string LSTM_OUT_MODE = "lstm_out_mode"; | |||||
const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; | |||||
const std::string LSTM_TIME_MAJOR = "lstm_time_major"; | |||||
const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; | |||||
// Upsample | // Upsample | ||||
const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; | const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; | ||||
// PadV2 | |||||
const std::string PADV2_ATTR_NAME_MODE = "mode"; | |||||
const std::string PADV2_ATTR_NAME_PADS = "paddings"; | |||||
const std::string PADV2_ATTR_NAME_T = "T"; | |||||
const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||||
const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; | |||||
// MirrorPad | |||||
const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; | |||||
const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; | |||||
const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||||
const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; | |||||
// Filler | // Filler | ||||
const std::string FILLER_TYPE = "filler_type"; | const std::string FILLER_TYPE = "filler_type"; | ||||
const std::string FILLER_VALUE = "filler_value"; | const std::string FILLER_VALUE = "filler_value"; | ||||
@@ -554,9 +635,6 @@ const std::string SHUFFLE_CHANNEL_GROUP = "group"; | |||||
// TopKV2 | // TopKV2 | ||||
const std::string TOPKV2_ATTR_K = "k"; | const std::string TOPKV2_ATTR_K = "k"; | ||||
const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||||
const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||||
// Calibaration | // Calibaration | ||||
const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; | const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; | ||||
const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; | const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; | ||||
@@ -611,10 +689,14 @@ const std::string ATTR_MODEL_STREAM_NUM = "stream_num"; | |||||
const std::string ATTR_MODEL_EVENT_NUM = "event_num"; | const std::string ATTR_MODEL_EVENT_NUM = "event_num"; | ||||
const std::string ATTR_MODEL_HUGE_STREAM_LIST = "huge_stream_list"; | |||||
const std::string ATTR_MODEL_LABEL_NUM = "label_num"; | const std::string ATTR_MODEL_LABEL_NUM = "label_num"; | ||||
const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; | const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; | ||||
const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size"; | |||||
const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; | const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; | ||||
const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; | const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; | ||||
@@ -660,8 +742,125 @@ const std::string TARGET_TYPE_TINY = "TINY"; | |||||
const std::string TARGET_TYPE_LITE = "LITE"; | const std::string TARGET_TYPE_LITE = "LITE"; | ||||
// l2_normalize | |||||
const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; | |||||
const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||||
const std::string POOL_PARAMA_ATTR_WINDOW = "window"; | |||||
const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; | |||||
const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; | |||||
const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; | |||||
const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; | |||||
const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; | |||||
// HCOM | |||||
const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; | |||||
const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; | |||||
const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; | |||||
const std::string HCOM_ATTR_GROUP = "group"; | |||||
const std::string HCOM_ATTR_SR_TAG = "sr_tag"; | |||||
const std::string HCOM_ATTR_SRC_RANK = "src_rank"; | |||||
const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; | |||||
const std::string HCOM_ATTR_FUSION = "fusion"; | |||||
const std::string HCOM_ATTR_SHAPE = "shape"; | |||||
const std::string HCOM_ATTR_DATA_TYPE = "dtype"; | |||||
// SpaceToDepth/DepthToSpace | |||||
const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; | |||||
// SparseSoftmaxCrossEntropyWithLogits | |||||
const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; | |||||
// MaxPoolGradWithArgmax | |||||
const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; | |||||
// AvgPoolGrad | |||||
const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; | |||||
// Pad | |||||
const std::string ATTR_PAD_FORMAT = "attr_pad_format"; | |||||
// Varible | |||||
const std::string VAR_ATTR_FORMAT = "_var_format"; | |||||
const std::string VAR_ATTR_NAME = "var_name"; | |||||
const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; | |||||
const std::string VAR_ATTR_4D_FORMAT = "4D"; | |||||
const std::string VAR_ATTR_5D_FORMAT = "5D"; | |||||
const std::string VAR_ATTR_DATA_TYPE = "data_format"; | |||||
const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; | |||||
const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; | |||||
const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; | |||||
const std::string VAR_ATTR_SHAPE = "shape"; | |||||
const std::string HALF_VAR_NAME_END = "_fp16"; | |||||
const std::string VAR_ATTR_INITED = "var_is_inited"; | |||||
const std::string VAR_ATTR_CONTAINER = "container"; | |||||
const std::string VAR_ATTR_SHARED_NAME = "shared_name"; | |||||
const std::string VAR_ATTR_DTYPE = "dtype"; | |||||
const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; | |||||
const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; | |||||
const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; | |||||
const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; | |||||
const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; | |||||
const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; | |||||
// Assign | |||||
const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; | |||||
// space2bacth batch2space | |||||
const std::string BATCH_SPACE_ATTR_BLOCK = "block"; | |||||
const std::string BATCH_SPACE_ATTR_PADDING = "padding"; | |||||
// depth_to_space space_to_depth | |||||
const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||||
// FakeQuantWithMinMaxVars | |||||
const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; | |||||
const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; | |||||
// mobilenet_ssd_conv_fusion | |||||
const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; | |||||
const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; | |||||
const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; | |||||
// lsh project | |||||
const std::string LSH_PROJ_TYPE = "lsh_project_type"; | |||||
// log time stamp | |||||
const std::string LOG_TIME_STAMP_LOGID = "logid"; | |||||
const std::string LOG_TIME_STAMP_NOTIFY = "notify"; | |||||
// ShapeN | |||||
const std::string SHAPEN_ATTR_N = "N"; | |||||
const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; | |||||
const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; | |||||
// GatherV2 attr def | |||||
const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; | |||||
const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; | |||||
const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; | |||||
// Reshape attr def | |||||
const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; | |||||
const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; | |||||
// axis attr def | |||||
const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; | |||||
const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; | |||||
const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; | |||||
const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; | |||||
// For constant folding | |||||
const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; | |||||
const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; | const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; | ||||
const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC = "continuous_input_alloc"; | |||||
const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; | const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; | ||||
const std::string ATTR_NAME_REFERENCE = "reference"; | const std::string ATTR_NAME_REFERENCE = "reference"; | ||||
@@ -694,6 +893,8 @@ const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; | |||||
const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; | const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; | ||||
const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; | const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; | ||||
const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | ||||
const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; | |||||
const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; | |||||
const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | ||||
const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | ||||
@@ -705,6 +906,7 @@ const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; | |||||
// Function Op | // Function Op | ||||
const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; | const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; | ||||
const std::string ATTR_NAME_PARENT_CONST_TYPE = "_parent_const_type"; | |||||
// Used for mark the active node is for loop, type:bool | // Used for mark the active node is for loop, type:bool | ||||
const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; | const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; | ||||
@@ -719,6 +921,7 @@ const std::string MODEL_ATTR_SESSION_ID = "session_id"; | |||||
// l1 fusion and other fusion in future | // l1 fusion and other fusion in future | ||||
const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; | const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; | ||||
const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; | |||||
const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; | const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; | ||||
const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; | const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; | ||||
const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; | const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; | ||||
@@ -730,6 +933,9 @@ const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1 | |||||
const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; | const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; | ||||
const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; | const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; | ||||
const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; | const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; | ||||
const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; | |||||
const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; | |||||
const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; | |||||
// Atomic addr clean attrs | // Atomic addr clean attrs | ||||
const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; | const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; | ||||
@@ -748,6 +954,8 @@ const std::string ATTR_NEED_COMPILE = "_node_need_compile"; | |||||
const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; | const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; | ||||
const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS = "_mbatch_origin_input_dims"; | |||||
// For inserted op | // For inserted op | ||||
const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; | const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; | ||||
@@ -764,7 +972,22 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_ou | |||||
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | ||||
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | ||||
// functional ops attr | |||||
const std::string ATTR_NAME_WHILE_COND = "cond"; | |||||
const std::string ATTR_NAME_WHILE_BODY = "body"; | |||||
// used for label switch | // used for label switch | ||||
const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | ||||
const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; | const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; | ||||
const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | |||||
const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | |||||
// used for LX tiling | |||||
const std::string ATTR_NAME_OP_L1_SPACE = "_l1_space"; | |||||
const std::string ATTR_NAME_FUSION_TYPE_LIST = "_fusion_type_list"; | |||||
const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST = "_valid_input_shape_list_list"; | |||||
const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_list_list"; | |||||
const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; | |||||
const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; | |||||
} // namespace ge | } // namespace ge |
@@ -31,19 +31,18 @@ using std::string; | |||||
using std::vector; | using std::vector; | ||||
namespace ge { | namespace ge { | ||||
GeAttrValue::NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } | |||||
NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } | |||||
GeAttrValue::NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) | |||||
: named_attrs_(owner, proto_msg) {} | |||||
NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) : named_attrs_(owner, proto_msg) {} | |||||
void GeAttrValue::NamedAttrs::SetName(const std::string &name) { | |||||
void NamedAttrs::SetName(const std::string &name) { | |||||
auto proto_msg = named_attrs_.GetProtoMsg(); | auto proto_msg = named_attrs_.GetProtoMsg(); | ||||
if (proto_msg != nullptr) { | if (proto_msg != nullptr) { | ||||
proto_msg->set_name(name); | proto_msg->set_name(name); | ||||
} | } | ||||
} | } | ||||
string GeAttrValue::NamedAttrs::GetName() const { | |||||
string NamedAttrs::GetName() const { | |||||
auto proto_msg = named_attrs_.GetProtoMsg(); | auto proto_msg = named_attrs_.GetProtoMsg(); | ||||
if (proto_msg != nullptr) { | if (proto_msg != nullptr) { | ||||
return proto_msg->name(); | return proto_msg->name(); | ||||
@@ -51,13 +50,13 @@ string GeAttrValue::NamedAttrs::GetName() const { | |||||
return string(); | return string(); | ||||
} | } | ||||
GeAttrValue GeAttrValue::NamedAttrs::GetItem(const string &key) const { | |||||
GeAttrValue NamedAttrs::GetItem(const string &key) const { | |||||
GeAttrValue value; | GeAttrValue value; | ||||
GetAttr(key, value); | |||||
(void)GetAttr(key, value); | |||||
return value; | return value; | ||||
} | } | ||||
ProtoAttrMapHelper GeAttrValue::NamedAttrs::MutableAttrMap() { | |||||
ProtoAttrMapHelper NamedAttrs::MutableAttrMap() { | |||||
auto proto_msg = named_attrs_.GetProtoMsg(); | auto proto_msg = named_attrs_.GetProtoMsg(); | ||||
if (proto_msg != nullptr) { | if (proto_msg != nullptr) { | ||||
return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), proto_msg->mutable_attr()); | return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), proto_msg->mutable_attr()); | ||||
@@ -65,7 +64,7 @@ ProtoAttrMapHelper GeAttrValue::NamedAttrs::MutableAttrMap() { | |||||
return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); | return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); | ||||
} | } | ||||
ConstProtoAttrMapHelper GeAttrValue::NamedAttrs::GetAttrMap() const { | |||||
ConstProtoAttrMapHelper NamedAttrs::GetAttrMap() const { | |||||
auto proto_msg = named_attrs_.GetProtoMsg(); | auto proto_msg = named_attrs_.GetProtoMsg(); | ||||
if (proto_msg != nullptr) { | if (proto_msg != nullptr) { | ||||
return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), &proto_msg->attr()); | return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), &proto_msg->attr()); | ||||
@@ -515,7 +514,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAtt | |||||
return true; | return true; | ||||
} | } | ||||
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NamedAttrs &value) { | |||||
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NAMED_ATTRS &value) { | |||||
if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { | if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { | ||||
return false; | return false; | ||||
} | } | ||||
@@ -528,7 +527,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue: | |||||
return true; | return true; | ||||
} | } | ||||
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::NamedAttrs> &value) { | |||||
bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::NAMED_ATTRS> &value) { | |||||
if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, | if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, | ||||
proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) { | proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) { | ||||
return false; | return false; | ||||
@@ -739,7 +738,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||||
} | } | ||||
bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, | bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, | ||||
GeAttrValue::NamedAttrs &value) { | |||||
GeAttrValue::NAMED_ATTRS &value) { | |||||
if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { | if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { | ||||
return false; | return false; | ||||
} | } | ||||
@@ -752,7 +751,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||||
} | } | ||||
bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, | bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, | ||||
vector<GeAttrValue::NamedAttrs> &value) { | |||||
vector<GeAttrValue::NAMED_ATTRS> &value) { | |||||
value.clear(); | value.clear(); | ||||
if (!AttrUtilsHelper::GetValueCheckListType( | if (!AttrUtilsHelper::GetValueCheckListType( | ||||
proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) { | proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) { | ||||
@@ -760,7 +759,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||||
} | } | ||||
auto &list = proto_attr_val.list(); | auto &list = proto_attr_val.list(); | ||||
for (const auto &item : list.na()) { | for (const auto &item : list.na()) { | ||||
value.emplace_back(GeAttrValue::NamedAttrs()); | |||||
value.emplace_back(GeAttrValue::NAMED_ATTRS()); | |||||
if (value.empty()) { | if (value.empty()) { | ||||
return false; | return false; | ||||
} | } | ||||
@@ -967,7 +966,7 @@ ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc) | |||||
ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr) | ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr) | ||||
ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr) | ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr) | ||||
ATTR_UTILS_SET_IMP(Tensor, GeTensor) | ATTR_UTILS_SET_IMP(Tensor, GeTensor) | ||||
ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NamedAttrs) | |||||
ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS) | |||||
ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) | ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) | ||||
ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) | ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) | ||||
ATTR_UTILS_SET_GET_IMP(ListListInt, vector<vector<int64_t>>) | ATTR_UTILS_SET_GET_IMP(ListListInt, vector<vector<int64_t>>) | ||||
@@ -982,7 +981,7 @@ ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector<GeTensorDesc>) | |||||
ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensorPtr>) | ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensorPtr>) | ||||
ATTR_UTILS_SET_IMP(ListTensor, vector<ConstGeTensorPtr>) | ATTR_UTILS_SET_IMP(ListTensor, vector<ConstGeTensorPtr>) | ||||
ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensor>) | ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensor>) | ||||
ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NamedAttrs>) | |||||
ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NAMED_ATTRS>) | |||||
ATTR_UTILS_SET_GET_IMP(ListBytes, vector<Buffer>) | ATTR_UTILS_SET_GET_IMP(ListBytes, vector<Buffer>) | ||||
ATTR_UTILS_SET_GET_IMP(ListGraph, vector<ComputeGraphPtr>) | ATTR_UTILS_SET_GET_IMP(ListGraph, vector<ComputeGraphPtr>) | ||||
ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) | ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) | ||||
@@ -83,6 +83,12 @@ size_t GeShape::GetDimNum() const { | |||||
auto proto_msg = shape_def_.GetProtoMsg(); | auto proto_msg = shape_def_.GetProtoMsg(); | ||||
if (proto_msg != nullptr) { | if (proto_msg != nullptr) { | ||||
if (proto_msg->dim_size() >= 0) { | if (proto_msg->dim_size() >= 0) { | ||||
// check whether contain -2, if true, return -1 | |||||
for (auto i : proto_msg->dim()) { | |||||
if (i == UNKNOWN_DIM_NUM) { | |||||
return 0; | |||||
} | |||||
} | |||||
return proto_msg->dim_size(); | return proto_msg->dim_size(); | ||||
} else { | } else { | ||||
return 0; | return 0; | ||||
@@ -157,6 +163,10 @@ int64_t GeShape::GetShapeSize() const { | |||||
return 0; | return 0; | ||||
} | } | ||||
for (auto i : proto_msg->dim()) { | for (auto i : proto_msg->dim()) { | ||||
// if unknown shape, return -1 | |||||
if (i == UNKNOWN_DIM || i == UNKNOWN_DIM_NUM) { | |||||
return UNKNOWN_DIM; | |||||
} | |||||
res *= i; | res *= i; | ||||
} | } | ||||
} | } | ||||
@@ -209,6 +219,7 @@ const string TENSOR_UTILS_RC = "rc"; | |||||
const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; | const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; | ||||
const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; | const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; | ||||
const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; | const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; | ||||
const string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; | |||||
GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} | GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} | ||||
@@ -396,6 +407,35 @@ GeShape &GeTensorDesc::MutableShape() { return ShapeReference(); } | |||||
void GeTensorDesc::SetShape(GeShape shape) { ShapeReference() = std::move(shape); } | void GeTensorDesc::SetShape(GeShape shape) { ShapeReference() = std::move(shape); } | ||||
// set shape with -2, it stand for unknown shape | |||||
void GeTensorDesc::SetUnknownDimNumShape() { SetShape(GeShape({UNKNOWN_DIM_NUM})); } | |||||
// for unknown shape | |||||
graphStatus GeTensorDesc::SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range) { | |||||
std::vector<vector<int64_t>> shape_range; | |||||
for (const auto &ele : range) { | |||||
shape_range.emplace_back(std::vector<int64_t>({ele.first, ele.second})); | |||||
} | |||||
auto ret = AttrUtils::SetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); | |||||
return ret ? GRAPH_SUCCESS : GRAPH_FAILED; | |||||
} | |||||
graphStatus GeTensorDesc::GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const { | |||||
std::vector<vector<int64_t>> shape_range; | |||||
(void)AttrUtils::GetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); | |||||
for (const auto &ele : shape_range) { | |||||
// here must be only two elemenet because pair | |||||
if (ele.size() != 2) { | |||||
GELOGE(GRAPH_FAILED, "shape_range must contain only 2 value but really is %lu", ele.size()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
std::pair<int64_t, int64_t> pair({ele[0], ele[1]}); | |||||
range.push_back(pair); | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
GeShape GeTensorDesc::GetOriginShape() const { | GeShape GeTensorDesc::GetOriginShape() const { | ||||
vector<int64_t> origin_shape; | vector<int64_t> origin_shape; | ||||
if (!AttrUtils::GetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape)) { | if (!AttrUtils::GetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape)) { | ||||
@@ -16,11 +16,12 @@ | |||||
#include "external/graph/graph.h" | #include "external/graph/graph.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "external/graph/operator.h" | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/ge_attr_value.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "graph/debug/ge_op_types.h" | |||||
#include "graph/model.h" | #include "graph/model.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/utils/op_desc_utils.h" | |||||
using std::map; | using std::map; | ||||
using std::pair; | using std::pair; | ||||
@@ -214,6 +215,23 @@ class GraphImpl { | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
graphStatus FindOpByType(const string &type, std::vector<ge::Operator> &ops) const { | |||||
for (auto &op : op_list_) { | |||||
auto op_type = op.second.GetOpType(); | |||||
if (op_type == type) { | |||||
ops.push_back(op.second); | |||||
continue; | |||||
} | |||||
if (op_type == ge::FRAMEWORKOP) { | |||||
op.second.GetAttr(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, op_type); | |||||
if (op_type == type) { | |||||
ops.push_back(op.second); | |||||
} | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
void SetNeedIteration(bool need_iteration) { | void SetNeedIteration(bool need_iteration) { | ||||
if (compute_graph_ == nullptr) { | if (compute_graph_ == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "Set need iteration failed, as compute graph is null."); | GELOGE(GRAPH_FAILED, "Set need iteration failed, as compute graph is null."); | ||||
@@ -222,6 +240,8 @@ class GraphImpl { | |||||
compute_graph_->SetNeedIteration(need_iteration); | compute_graph_->SetNeedIteration(need_iteration); | ||||
} | } | ||||
const std::string &GetName() const { return name_; } | |||||
private: | private: | ||||
std::string name_; | std::string name_; | ||||
std::string output_name_; | std::string output_name_; | ||||
@@ -255,6 +275,11 @@ graphStatus Graph::FindOpByName(const std::string &name, Operator &op) const { | |||||
return impl_->FindOpByName(name, op); | return impl_->FindOpByName(name, op); | ||||
} | } | ||||
graphStatus Graph::FindOpByType(const string &type, std::vector<ge::Operator> &ops) const { | |||||
GE_CHECK_NOTNULL(impl_); | |||||
return impl_->FindOpByType(type, ops); | |||||
} | |||||
Graph &Graph::SetInputs(const vector<ge::Operator> &inputs) { | Graph &Graph::SetInputs(const vector<ge::Operator> &inputs) { | ||||
GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetInputs failed: graph can not be used, impl is nullptr.") | GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetInputs failed: graph can not be used, impl is nullptr.") | ||||
GE_CHK_BOOL_EXEC(inputs.size() > 0, return *this, "SetInputs failed: input operator size can not be 0."); | GE_CHK_BOOL_EXEC(inputs.size() > 0, return *this, "SetInputs failed: input operator size can not be 0."); | ||||
@@ -331,6 +356,8 @@ graphStatus Graph::LoadFromFile(const string &file_name) { | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string &Graph::GetName() const { return impl_->GetName(); } | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph | ||||
GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) { | GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) { | ||||
GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return Graph("")); | GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return Graph("")); | ||||
@@ -343,4 +370,15 @@ GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) | |||||
return graph; | return graph; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) { | |||||
GE_CHECK_NOTNULL(graph.impl_); | |||||
GE_CHECK_NOTNULL(graph.impl_->compute_graph_); | |||||
graph.impl_->op_list_.clear(); | |||||
for (const auto &node : graph.impl_->compute_graph_->GetDirectNode()) { | |||||
graph.impl_->op_list_[node->GetName()] = OpDescUtils::CreateOperatorFromNode(node); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -16,7 +16,10 @@ | |||||
#include "graph/model_serialize.h" | #include "graph/model_serialize.h" | ||||
#include <google/protobuf/text_format.h> | #include <google/protobuf/text_format.h> | ||||
#include <queue> | |||||
#include <iostream> | #include <iostream> | ||||
#include "debug/ge_attr_define.h" | #include "debug/ge_attr_define.h" | ||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
@@ -26,6 +29,7 @@ | |||||
#include "utils/graph_utils.h" | #include "utils/graph_utils.h" | ||||
#include "debug/ge_op_types.h" | #include "debug/ge_op_types.h" | ||||
using std::map; | |||||
using std::string; | using std::string; | ||||
namespace ge { | namespace ge { | ||||
@@ -121,6 +125,11 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op | |||||
} | } | ||||
} | } | ||||
} | } | ||||
op_def_proto->set_id(op_desc->GetId()); | |||||
for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { | |||||
op_def_proto->add_subgraph_name(name); | |||||
} | |||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
@@ -196,6 +205,14 @@ bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *mode | |||||
GELOGE(GRAPH_FAILED, "SerializeGraph fail"); | GELOGE(GRAPH_FAILED, "SerializeGraph fail"); | ||||
return false; | return false; | ||||
} | } | ||||
for (auto subgraph : compute_graph->GetAllSubgraphs()) { | |||||
if (!SerializeGraph(subgraph, model_proto->add_graph(), is_dump)) { | |||||
GELOGE(GRAPH_FAILED, "Serialize subgraph failed"); | |||||
return false; | |||||
} | |||||
} | |||||
return true; | return true; | ||||
} | } | ||||
@@ -228,6 +245,14 @@ bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_d | |||||
GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); | GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); | ||||
op_desc->outputs_desc_.push_back(temp_value); | op_desc->outputs_desc_.push_back(temp_value); | ||||
} | } | ||||
op_desc->SetId(op_def_proto.id()); | |||||
uint32_t graph_index = 0; | |||||
for (const std::string &name : op_def_proto.subgraph_name()) { | |||||
op_desc->AddSubgraphName(name); | |||||
op_desc->SetSubgraphInstanceName(graph_index++, name); | |||||
} | |||||
return true; | return true; | ||||
} | } | ||||
@@ -238,7 +263,7 @@ bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op | |||||
GELOGW("UnserializeOpDesc error."); | GELOGW("UnserializeOpDesc error."); | ||||
} | } | ||||
NodePtr node = graph->AddNode(op_desc); | |||||
NodePtr node = graph->AddNode(op_desc, op_desc->GetId()); | |||||
GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr."); | GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr."); | ||||
// Inputs | // Inputs | ||||
@@ -319,6 +344,35 @@ bool ModelSerializeImp::HandleNodeNameRef() { | |||||
return true; | return true; | ||||
} | } | ||||
bool ModelSerializeImp::RebuildOwnership(ComputeGraphPtr &compute_graph, map<string, ComputeGraphPtr> &subgraphs) { | |||||
std::queue<ComputeGraphPtr> all_graphs; | |||||
all_graphs.emplace(compute_graph); | |||||
while (!all_graphs.empty()) { | |||||
ComputeGraphPtr graph = all_graphs.front(); | |||||
all_graphs.pop(); | |||||
for (const NodePtr &node : graph->GetDirectNode()) { | |||||
const OpDescPtr op_desc = node->GetOpDesc(); | |||||
for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { | |||||
auto it = subgraphs.find(name); | |||||
if (it == subgraphs.end()) { | |||||
GELOGE(GRAPH_FAILED, "Node:%s, Subgraph:%s not found, num:%zu.", op_desc->GetName().c_str(), name.c_str(), | |||||
subgraphs.size()); | |||||
return false; | |||||
} | |||||
ComputeGraphPtr &subgraph = it->second; | |||||
subgraph->SetParentGraph(graph); | |||||
subgraph->SetParentNode(node); | |||||
compute_graph->AddSubgraph(subgraph->GetName(), subgraph); | |||||
all_graphs.emplace(subgraph); | |||||
} | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) { | bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) { | ||||
model.name_ = model_proto.name(); | model.name_ = model_proto.name(); | ||||
model.version_ = model_proto.version(); | model.version_ = model_proto.version(); | ||||
@@ -332,7 +386,31 @@ bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_pr | |||||
if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) { | if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) { | ||||
model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr); | model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr); | ||||
} | } | ||||
// 0 is main graph, following is subgraph. | |||||
map<string, ComputeGraphPtr> subgraphs; | |||||
for (int idx = 1; idx < graphs_proto.size(); ++idx) { | |||||
ComputeGraphPtr subgraph; | |||||
ModelSerializeImp impl; | |||||
if (!impl.UnserializeGraphWithoutEdge(subgraph, graphs_proto[idx])) { | |||||
GELOGE(GRAPH_FAILED, "UnserializeGraphWithoutEdge failed"); | |||||
return false; | |||||
} | |||||
if (!impl.HandleNodeNameRef()) { | |||||
GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); | |||||
return false; | |||||
} | |||||
subgraphs[subgraph->GetName()] = subgraph; | |||||
} | |||||
if (!RebuildOwnership(compute_graph_ptr, subgraphs)) { | |||||
GELOGE(GRAPH_FAILED, "Rebuild graph ownership failed"); | |||||
return false; | |||||
} | |||||
} | } | ||||
if (!HandleNodeNameRef()) { | if (!HandleNodeNameRef()) { | ||||
GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); | GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); | ||||
return false; | return false; | ||||
@@ -61,6 +61,8 @@ const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes"; | |||||
const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; | const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; | ||||
const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends"; | |||||
const std::string ATTR_NAME_OPT_INPUT = "_opt_input"; | const std::string ATTR_NAME_OPT_INPUT = "_opt_input"; | ||||
const std::string ATTR_NAME_INPUT_NAME_IDX_KEY = "_input_name_idx_key"; | const std::string ATTR_NAME_INPUT_NAME_IDX_KEY = "_input_name_idx_key"; | ||||
@@ -227,6 +229,40 @@ graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &inp | |||||
} | } | ||||
} | } | ||||
graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int num, size_t index) { | |||||
auto input_name_idx = GetAllInputName(); | |||||
for (unsigned int i = 0; i < num; i++) { | |||||
string input_name = name + std::to_string(i); | |||||
GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED, | |||||
"Add input tensor_desc is existed. name[%s]", input_name.c_str()); | |||||
std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc()); | |||||
if (in_desc == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (index > inputs_desc_.size()) { | |||||
GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
(void)inputs_desc_.insert(inputs_desc_.begin() + index + i, in_desc); | |||||
// Update index in input_name_idx | |||||
for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) { | |||||
if (it->second >= (index + i)) { | |||||
it->second += 1; | |||||
} | |||||
} | |||||
(void)input_name_idx.insert(make_pair(input_name, i + index)); | |||||
} | |||||
SetAllInputName(input_name_idx); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { | graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { | ||||
auto input_name_idx = GetAllInputName(); | auto input_name_idx = GetAllInputName(); | ||||
for (unsigned int i = 0; i < num; i++) { | for (unsigned int i = 0; i < num; i++) { | ||||
@@ -239,7 +275,6 @@ graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int n | |||||
GELOGE(GRAPH_FAILED, "AddInputDescForward failed, malloc shared_ptr failed."); | GELOGE(GRAPH_FAILED, "AddInputDescForward failed, malloc shared_ptr failed."); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
(void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); | (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); | ||||
// Update index in input_name_idx | // Update index in input_name_idx | ||||
@@ -634,6 +669,13 @@ graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int n | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
graphStatus OpDesc::AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index) { | |||||
if (AddInputDescMiddle(name, num, index) != GRAPH_SUCCESS) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int num, bool is_push_back) { | graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int num, bool is_push_back) { | ||||
if (is_push_back) { | if (is_push_back) { | ||||
for (unsigned int i = 0; i < num; i++) { | for (unsigned int i = 0; i < num; i++) { | ||||
@@ -1054,6 +1096,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<string> OpDesc::GetDstName | |||||
return dst_name; | return dst_name; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpInferDepends(const vector<string> &depend_names) { | |||||
auto ret = AttrUtils::SetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); | |||||
if (ret != true) { | |||||
GELOGE(GRAPH_FAILED, "set op_infer_depends fail."); | |||||
} | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<string> OpDesc::GetOpInferDepends() const { | |||||
vector<string> depend_names; | |||||
(void)AttrUtils::GetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); | |||||
return depend_names; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstIndex(const vector<int64_t> &dst_index) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstIndex(const vector<int64_t> &dst_index) { | ||||
auto proto_msg = op_def_.GetProtoMsg(); | auto proto_msg = op_def_.GetProtoMsg(); | ||||
if (proto_msg != nullptr) { | if (proto_msg != nullptr) { | ||||
@@ -1199,20 +1254,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector<std::string> &O | |||||
return subgraph_instance_names_; | return subgraph_instance_names_; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::AddSubgraphInstanceName(std::string name) { | |||||
subgraph_instance_names_.emplace_back(std::move(name)); | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { | ||||
for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { | for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { | ||||
if (*iter == name) { | if (*iter == name) { | ||||
subgraph_instance_names_.erase(iter); | |||||
*iter = ""; | |||||
return; | return; | ||||
} | } | ||||
} | } | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { | ||||
GELOGI("Add subgraph name is %s", name.c_str()); | |||||
auto iter = subgraph_names_to_index_.find(name); | auto iter = subgraph_names_to_index_.find(name); | ||||
if (iter != subgraph_names_to_index_.end()) { | if (iter != subgraph_names_to_index_.end()) { | ||||
GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second); | GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second); | ||||
@@ -1220,6 +1272,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphNa | |||||
} | } | ||||
auto size = subgraph_names_to_index_.size(); | auto size = subgraph_names_to_index_.size(); | ||||
subgraph_names_to_index_[name] = size; | subgraph_names_to_index_[name] = size; | ||||
subgraph_instance_names_.resize(size + 1); | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
@@ -1227,4 +1280,34 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map<std::string, uint3 | |||||
const { | const { | ||||
return subgraph_names_to_index_; | return subgraph_names_to_index_; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::SetSubgraphInstanceName(uint32_t index, | |||||
const std::string &name) { | |||||
GELOGI("Add sub graph instans name is %s, index is %u", name.c_str(), index); | |||||
if (index >= subgraph_instance_names_.size()) { | |||||
GE_LOGE("The index %u exceeds the max instance coutn %zu", index, subgraph_instance_names_.size()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
subgraph_instance_names_[index] = name; | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RegisterSubgraphIrName(const string &name, | |||||
SubgraphType type) { | |||||
subgraph_ir_names_to_type_[name] = type; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map<std::string, SubgraphType> &OpDesc::GetSubgraphIrNames() | |||||
const { | |||||
return subgraph_ir_names_to_type_; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY SubgraphType | |||||
OpDesc::GetSubgraphTypeByIrName(const std::string &name) const { | |||||
auto iter = subgraph_ir_names_to_type_.find(name); | |||||
if (iter == subgraph_ir_names_to_type_.end()) { | |||||
return kSubgraphTypeEnd; | |||||
} | |||||
return iter->second; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -15,6 +15,7 @@ | |||||
*/ | */ | ||||
#include "external/graph/operator.h" | #include "external/graph/operator.h" | ||||
#include "external/graph/operator_factory.h" | |||||
#include <stdint.h> | #include <stdint.h> | ||||
#include <algorithm> | #include <algorithm> | ||||
#include <mutex> | #include <mutex> | ||||
@@ -38,6 +39,11 @@ | |||||
#include "utils/tensor_adapter.h" | #include "utils/tensor_adapter.h" | ||||
#include "utils/tensor_utils.h" | #include "utils/tensor_utils.h" | ||||
#include "utils/type_utils.h" | #include "utils/type_utils.h" | ||||
#include <algorithm> | |||||
#include <mutex> | |||||
#include <queue> | |||||
#include <set> | |||||
#include <stdint.h> | |||||
using std::enable_shared_from_this; | using std::enable_shared_from_this; | ||||
using std::make_pair; | using std::make_pair; | ||||
@@ -343,15 +349,71 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||||
InferenceContextPtr GetInferenceContext() const { return inference_context_; } | InferenceContextPtr GetInferenceContext() const { return inference_context_; } | ||||
void SubgraphRegister(const std::string &name, bool dynamic) { | |||||
op_desc_->RegisterSubgraphIrName(name, dynamic ? kDynamic : kStatic); | |||||
} | |||||
void SubgraphCountRegister(const std::string &name, uint32_t count) { | |||||
if (op_desc_->GetSubgraphTypeByIrName(name) == kStatic) { | |||||
op_desc_->AddSubgraphName(name); | |||||
} else { | |||||
for (uint32_t i = 0; i < count; ++i) { | |||||
op_desc_->AddSubgraphName(name + std::to_string(i)); | |||||
} | |||||
} | |||||
subgraph_names_to_builders_[name].resize(count, nullptr); | |||||
} | |||||
void SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder) { | |||||
auto iter = subgraph_names_to_builders_.find(name); | |||||
if (iter == subgraph_names_to_builders_.end()) { | |||||
GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u, invalid name", name.c_str(), index); | |||||
return; | |||||
} | |||||
if (iter->second.size() <= index) { | |||||
GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u, excceds the max size %zu", | |||||
name.c_str(), index, iter->second.size()); | |||||
return; | |||||
} | |||||
iter->second[index] = builder; | |||||
} | |||||
SubgraphBuilder GetSubgraphBuilder(const std::string &name, uint32_t index) const { | |||||
auto iter = subgraph_names_to_builders_.find(name); | |||||
if (iter == subgraph_names_to_builders_.end()) { | |||||
GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s index %u, invalid name", name.c_str(), index); | |||||
return nullptr; | |||||
} | |||||
if (iter->second.size() <= index) { | |||||
GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s index %u, excceds the max size %zu", | |||||
name.c_str(), index, iter->second.size()); | |||||
return nullptr; | |||||
} | |||||
return iter->second[index]; | |||||
} | |||||
std::vector<std::string> GetSubgraphNames() const { | |||||
std::vector<std::string> names; | |||||
for (const auto &subgraph_name_to_type : op_desc_->GetSubgraphIrNames()) { | |||||
names.emplace_back(subgraph_name_to_type.first); | |||||
} | |||||
return names; | |||||
} | |||||
size_t GetSubgraphNamesCount() const { return op_desc_->GetSubgraphIrNames().size(); } | |||||
OpDescPtr op_desc_ = nullptr; | OpDescPtr op_desc_ = nullptr; | ||||
private: | private: | ||||
ge::ConstNodePtr node_{nullptr}; | ge::ConstNodePtr node_{nullptr}; | ||||
ge::InferenceContextPtr inference_context_; | ge::InferenceContextPtr inference_context_; | ||||
GraphBuilderCallback graph_builder_callback_; | |||||
std::map<string, std::vector<OpIO>> output_links_{}; | std::map<string, std::vector<OpIO>> output_links_{}; | ||||
std::map<string, OpIO> input_link_{}; | std::map<string, OpIO> input_link_{}; | ||||
std::vector<std::weak_ptr<OperatorImpl>> control_input_link_{}; | std::vector<std::weak_ptr<OperatorImpl>> control_input_link_{}; | ||||
std::vector<std::weak_ptr<OperatorImpl>> control_output_link_{}; | std::vector<std::weak_ptr<OperatorImpl>> control_output_link_{}; | ||||
std::map<std::string, std::vector<SubgraphBuilder>> subgraph_names_to_builders_; | |||||
}; | }; | ||||
// Used to manage OperatorImpl instances created by ge api. | // Used to manage OperatorImpl instances created by ge api. | ||||
@@ -559,7 +621,6 @@ InferenceContextPtr Operator::GetInferenceContext() const { | |||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); | ||||
return operator_impl_->GetInferenceContext(); | return operator_impl_->GetInferenceContext(); | ||||
} | } | ||||
TensorDesc Operator::GetInputDesc(uint32_t index) const { | TensorDesc Operator::GetInputDesc(uint32_t index) const { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); | ||||
return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(index)); | return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(index)); | ||||
@@ -698,7 +759,7 @@ const std::map<std::string, std::string> Operator::GetAllAttrNamesAndTypes() con | |||||
void Operator::InputRegister(const string &name) { | void Operator::InputRegister(const string &name) { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); | |||||
(void)operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); | |||||
} | } | ||||
void Operator::OptionalInputRegister(const string &name) { | void Operator::OptionalInputRegister(const string &name) { | ||||
@@ -745,6 +806,12 @@ void Operator::DynamicInputRegister(const string &name, const unsigned int num, | |||||
(void)operator_impl_->GetOpDescImpl()->AddDynamicInputDesc(name, num, is_push_back); | (void)operator_impl_->GetOpDescImpl()->AddDynamicInputDesc(name, num, is_push_back); | ||||
} | } | ||||
void Operator::DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index) { | |||||
GE_CHK_BOOL_EXEC(!!operator_impl_, return, "operator impl is nullptr."); | |||||
GE_CHK_BOOL_EXEC(nullptr != operator_impl_->GetOpDescImpl(), return, "GetOpDescImpl is nullptr."); | |||||
operator_impl_->GetOpDescImpl()->AddDynamicInputDescByIndex(name, num, index); | |||||
} | |||||
int Operator::GetDynamicInputNum(const string &name) const { | int Operator::GetDynamicInputNum(const string &name) const { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); | ||||
@@ -896,6 +963,11 @@ OP_ATTR_GET_IMP(string &, Str) | |||||
OP_ATTR_SET_IMP(const vector<string> &, ListStr) | OP_ATTR_SET_IMP(const vector<string> &, ListStr) | ||||
OP_ATTR_GET_IMP(vector<string> &, ListStr) | OP_ATTR_GET_IMP(vector<string> &, ListStr) | ||||
OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) | |||||
OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs) | |||||
OP_ATTR_SET_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) | |||||
OP_ATTR_GET_IMP(vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) | |||||
OP_ATTR_REG_IMP(int64_t, Int) | OP_ATTR_REG_IMP(int64_t, Int) | ||||
OP_ATTR_REG_IMP(const vector<int64_t> &, ListInt) | OP_ATTR_REG_IMP(const vector<int64_t> &, ListInt) | ||||
OP_ATTR_REG_IMP(float, Float) | OP_ATTR_REG_IMP(float, Float) | ||||
@@ -905,6 +977,8 @@ OP_ATTR_REG_IMP(const vector<string> &, ListStr) | |||||
OP_ATTR_REG_IMP(bool, Bool) | OP_ATTR_REG_IMP(bool, Bool) | ||||
OP_ATTR_REG_IMP(const vector<bool> &, ListBool) | OP_ATTR_REG_IMP(const vector<bool> &, ListBool) | ||||
OP_ATTR_REG_IMP(const vector<vector<int64_t>> &, ListListInt) | OP_ATTR_REG_IMP(const vector<vector<int64_t>> &, ListListInt) | ||||
OP_ATTR_REG_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) | |||||
OP_ATTR_REG_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) | |||||
#undef OP_ATTR_SET_IMP | #undef OP_ATTR_SET_IMP | ||||
#undef OP_ATTR_GET_IMP | #undef OP_ATTR_GET_IMP | ||||
@@ -1114,6 +1188,95 @@ void Operator::AttrRegister(const string &name, const OpBytes &attr_value) { | |||||
} | } | ||||
} | } | ||||
void Operator::SubgraphRegister(const std::string &name, bool dynamic) { | |||||
if (operator_impl_ == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); | |||||
return; | |||||
} | |||||
operator_impl_->SubgraphRegister(name, dynamic ? kDynamic : kStatic); | |||||
} | |||||
void Operator::SubgraphCountRegister(const std::string &name, uint32_t count) { | |||||
if (operator_impl_ == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); | |||||
return; | |||||
} | |||||
operator_impl_->SubgraphCountRegister(name, count); | |||||
} | |||||
void Operator::SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder) { | |||||
if (operator_impl_ == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); | |||||
return; | |||||
} | |||||
operator_impl_->SetSubgraphBuilder(name, index, builder); | |||||
} | |||||
std::vector<std::string> Operator::GetSubgraphNames() const { return operator_impl_->GetSubgraphNames(); } | |||||
SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &name, uint32_t index) const { | |||||
if (operator_impl_ == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "operator impl is nullptr."); | |||||
return nullptr; | |||||
} | |||||
return operator_impl_->GetSubgraphBuilder(name, index); | |||||
} | |||||
SubgraphBuilder Operator::GetSubgraphBuilder(const string &name) const { return GetDynamicSubgraphBuilder(name, 0); } | |||||
Graph Operator::GetSubgraph(const string &name) const { | |||||
if (operator_impl_ == nullptr) { | |||||
GE_LOGE("Failed to get subgraph %s, the operator impl is null", name.c_str()); | |||||
return Graph(""); | |||||
} | |||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(*this); | |||||
if (op_desc == nullptr) { | |||||
GE_LOGE("Failed to get subgraph %s, the op_desc is null", name.c_str()); | |||||
return Graph(""); | |||||
} | |||||
const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); | |||||
auto iter = subgraph_names_to_index.find(name); | |||||
if (iter == subgraph_names_to_index.end()) { | |||||
GE_LOGE("Failed to get subgraph %s, the name may be invalid", name.c_str()); | |||||
return Graph(""); | |||||
} | |||||
auto subgraph_instance_name = op_desc->GetSubgraphInstanceName(iter->second); | |||||
if (subgraph_instance_name.empty()) { | |||||
GE_LOGE("Failed to get subgraph %s index %u, the subgraph may not be added", name.c_str(), iter->second); | |||||
return Graph(""); | |||||
} | |||||
auto node = operator_impl_->GetNode(); | |||||
if (node == nullptr) { | |||||
GE_LOGE("Failed to get subgraph %s, the node is null", name.c_str()); | |||||
return Graph(""); | |||||
} | |||||
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||||
if (root_graph == nullptr) { | |||||
GE_LOGE("Failed to get subgraph %s, can not find the root graph", name.c_str()); | |||||
return Graph(""); | |||||
} | |||||
auto subgraph = root_graph->GetSubgraph(subgraph_instance_name); | |||||
if (subgraph == nullptr) { | |||||
GE_LOGE("Failed to get subgraph %s index %u, can not find the instance %s from the root graph", name.c_str(), | |||||
iter->second, subgraph_instance_name.c_str()); | |||||
return Graph(""); | |||||
} | |||||
return GraphUtils::CreateGraphFromComputeGraph(subgraph); | |||||
} | |||||
Graph Operator::GetDynamicSubgraph(const string &name, uint32_t index) const { | |||||
return GetSubgraph(name + std::to_string(index)); | |||||
} | |||||
size_t Operator::GetSubgraphNamesCount() const { | |||||
if (operator_impl_ == nullptr) { | |||||
GE_LOGE("Failed to get subgraph names count, the operator impl is null"); | |||||
return 0; | |||||
} | |||||
return operator_impl_->GetSubgraphNamesCount(); | |||||
} | |||||
class GraphBuilderImpl { | class GraphBuilderImpl { | ||||
public: | public: | ||||
explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared<ComputeGraph>(name)) { | explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared<ComputeGraph>(name)) { | ||||
@@ -96,7 +96,6 @@ VerifyFunc OperatorFactoryImpl::GetVerifyFunc(const std::string &operator_type) | |||||
graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) { | graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) { | ||||
if (operator_creators_ == nullptr) { | if (operator_creators_ == nullptr) { | ||||
GELOGI("operator_creators_ init"); | |||||
operator_creators_.reset(new (std::nothrow) std::map<string, OpCreator>()); | operator_creators_.reset(new (std::nothrow) std::map<string, OpCreator>()); | ||||
} | } | ||||
auto it = operator_creators_->find(operator_type); | auto it = operator_creators_->find(operator_type); | ||||
@@ -0,0 +1,422 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/ref_relation.h" | |||||
#include <unordered_set> | |||||
#include <unordered_map> | |||||
#include "utils/mem_utils.h" | |||||
#include "debug/ge_log.h" | |||||
#include "debug/ge_op_types.h" | |||||
#include "debug/ge_util.h" | |||||
#include "debug/ge_attr_define.h" | |||||
#include "graph/ge_error_codes.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
using namespace std; | |||||
using namespace ge; | |||||
namespace ge { | |||||
namespace { | |||||
const char *kRefIndex = "_parent_node_index"; | |||||
const string kWhile = "While"; | |||||
const string kIf = "If"; | |||||
const string kCase = "Case"; | |||||
const int kMaxElementNum = 100; | |||||
std::unordered_set<string> function_op = {kWhile, kIf, kCase}; | |||||
} // namespace | |||||
/* Impl */ | |||||
class RefRelations::Impl { | |||||
public: | |||||
graphStatus LookUpRefRelations(const RefCell &key, unordered_set<RefCell, RefCellHash> &result) { | |||||
unsigned long number = static_cast<unsigned long>(reinterpret_cast<uintptr_t>(key.node.get())); | |||||
std::string lookup_key = | |||||
key.node_name + std::to_string(key.in_out) + std::to_string(key.in_out_idx) + std::to_string(number); | |||||
auto iter = look_up_table_.find(lookup_key); | |||||
if (iter != look_up_table_.end()) { | |||||
for (auto &c : iter->second) { | |||||
result.insert(c); | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
GELOGW("can not find any relations! key value is %s", lookup_key.c_str()); | |||||
return GRAPH_SUCCESS; | |||||
}; | |||||
graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); | |||||
graphStatus Clear() { | |||||
GELOGD("Start clear boundary reflections between main graph and sub graph!"); | |||||
look_up_table_.clear(); | |||||
values_.clear(); | |||||
return GRAPH_SUCCESS; | |||||
}; | |||||
private: | |||||
graphStatus BuildLookUpTables(); | |||||
graphStatus BuildRefRelationsForBranch(const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||||
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, | |||||
vector<vector<RefCell>> &node_refs); | |||||
graphStatus BuildRefRelationsForWhile(const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||||
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, | |||||
vector<vector<RefCell>> &node_refs); | |||||
graphStatus BuildRelationsWithFuncNodeType(const NodePtr &root_node, | |||||
const vector<vector<NodePtr>> &classed_data_nodes, | |||||
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, | |||||
vector<vector<RefCell>> &node_refs); | |||||
void GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector<NodePtr> &data_nodes, | |||||
vector<NodePtr> &netoutput_nodes, const std::vector<std::string> &sub_graph_names, | |||||
const std::string &node_type); | |||||
graphStatus GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph); | |||||
graphStatus ProcessSubgraphDataNodes(vector<NodePtr> &data_nodes, vector<vector<NodePtr>> &classed_data_nodes); | |||||
graphStatus ProcessSubgraphNetoutput(const vector<NodePtr> &netoutput_nodes, | |||||
vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes); | |||||
std::unordered_map<string, vector<RefCell>> look_up_table_; | |||||
std::vector<vector<vector<RefCell>>> values_; | |||||
}; | |||||
// Node Level | |||||
graphStatus RefRelations::Impl::BuildRefRelationsForBranch( | |||||
const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||||
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) { | |||||
GELOGD("Enter BuildRefRelationsForBranch!"); | |||||
size_t ref_i = 0; | |||||
for (const auto &ref_i_data_nodes : classed_data_nodes) { | |||||
vector<RefCell> in_ref_i_all_refs; | |||||
RefCell cell_root; | |||||
cell_root.node_name = root_node->GetName(); | |||||
cell_root.node = root_node; | |||||
cell_root.in_out = NODE_IN; | |||||
cell_root.in_out_idx = ref_i; | |||||
in_ref_i_all_refs.emplace_back(cell_root); | |||||
for (const auto &data : ref_i_data_nodes) { | |||||
RefCell cell_in; | |||||
RefCell cell_out; | |||||
cell_in.node_name = data->GetName(); | |||||
cell_in.node = data; | |||||
cell_in.in_out = NODE_IN; | |||||
cell_in.in_out_idx = 0; | |||||
cell_out.node_name = data->GetName(); | |||||
cell_out.node = data; | |||||
cell_out.in_out = NODE_OUT; | |||||
cell_out.in_out_idx = 0; | |||||
in_ref_i_all_refs.emplace_back(cell_in); | |||||
in_ref_i_all_refs.emplace_back(cell_out); | |||||
} | |||||
node_refs.emplace_back(in_ref_i_all_refs); | |||||
ref_i++; | |||||
} | |||||
size_t ref_o = 0; | |||||
for (const auto &ref_o_net_nodes : classed_netoutput_nodes) { | |||||
vector<RefCell> out_ref_i_all_refs; | |||||
RefCell cell_root; | |||||
cell_root.node_name = root_node->GetName(); | |||||
cell_root.node = root_node; | |||||
cell_root.in_out = NODE_OUT; | |||||
cell_root.in_out_idx = ref_o; | |||||
out_ref_i_all_refs.emplace_back(cell_root); | |||||
for (const auto &ele : ref_o_net_nodes) { | |||||
RefCell cell_netoutput_in; | |||||
RefCell cell_netoutput_out; | |||||
cell_netoutput_in.node_name = (ele.first)->GetName(); | |||||
cell_netoutput_in.node = ele.first; | |||||
cell_netoutput_in.in_out = NODE_IN; | |||||
cell_netoutput_in.in_out_idx = ele.second; | |||||
cell_netoutput_out.node_name = (ele.first)->GetName(); | |||||
cell_netoutput_out.node = ele.first; | |||||
cell_netoutput_out.in_out = NODE_OUT; | |||||
cell_netoutput_out.in_out_idx = ele.second; | |||||
out_ref_i_all_refs.emplace_back(cell_netoutput_in); | |||||
out_ref_i_all_refs.emplace_back(cell_netoutput_out); | |||||
} | |||||
node_refs.emplace_back(out_ref_i_all_refs); | |||||
ref_o++; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus RefRelations::Impl::BuildLookUpTables() { | |||||
for (size_t i = 0; i < values_.size(); i++) { | |||||
vector<vector<RefCell>> &val = values_[i]; | |||||
for (const auto &ele : val) { | |||||
for (const auto &ref_cell : ele) { | |||||
string key = ref_cell.node_name + std::to_string(ref_cell.in_out) + std::to_string(ref_cell.in_out_idx) + | |||||
std::to_string(static_cast<unsigned long>(reinterpret_cast<uintptr_t>(ref_cell.node.get()))); | |||||
look_up_table_[key] = ele; | |||||
} | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus RefRelations::Impl::BuildRefRelationsForWhile( | |||||
const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||||
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) { | |||||
GELOGD("Enter BuildRefRelations for while op!"); | |||||
// data_nodes has been sorted | |||||
// for while, input num must be same as output num | |||||
auto input_num = root_node->GetAllInDataAnchorsSize(); | |||||
size_t ref_i = 0; | |||||
while (ref_i < input_num) { | |||||
auto &ref_i_data_nodes = classed_data_nodes[ref_i]; | |||||
auto &ref_i_net_nodes = classed_netoutput_nodes[ref_i]; | |||||
vector<RefCell> ref_i_all_refs; | |||||
RefCell cell_root_i; | |||||
RefCell cell_root_o; | |||||
cell_root_i.node_name = root_node->GetName(); | |||||
cell_root_i.node = root_node; | |||||
cell_root_i.in_out = NODE_IN; | |||||
cell_root_i.in_out_idx = ref_i; | |||||
ref_i_all_refs.emplace_back(cell_root_i); | |||||
cell_root_o.node_name = root_node->GetName(); | |||||
cell_root_o.node = root_node; | |||||
cell_root_o.in_out = NODE_OUT; | |||||
cell_root_o.in_out_idx = ref_i; | |||||
ref_i_all_refs.emplace_back(cell_root_o); | |||||
for (const auto &data : ref_i_data_nodes) { | |||||
RefCell cell_in; | |||||
RefCell cell_out; | |||||
cell_in.node_name = data->GetName(); | |||||
cell_in.node = data; | |||||
cell_in.in_out = NODE_IN; | |||||
cell_in.in_out_idx = 0; | |||||
cell_out.node_name = data->GetName(); | |||||
cell_out.node = data; | |||||
cell_out.in_out = NODE_OUT; | |||||
cell_out.in_out_idx = 0; | |||||
ref_i_all_refs.emplace_back(cell_in); | |||||
ref_i_all_refs.emplace_back(cell_out); | |||||
} | |||||
for (const auto &ele : ref_i_net_nodes) { | |||||
RefCell cell_netoutput_in; | |||||
RefCell cell_netoutput_out; | |||||
cell_netoutput_in.node_name = (ele.first)->GetName(); | |||||
cell_netoutput_in.node = ele.first; | |||||
cell_netoutput_in.in_out = NODE_IN; | |||||
cell_netoutput_in.in_out_idx = ele.second; | |||||
cell_netoutput_out.node_name = (ele.first)->GetName(); | |||||
cell_netoutput_out.node = ele.first; | |||||
cell_netoutput_out.in_out = NODE_OUT; | |||||
cell_netoutput_out.in_out_idx = ele.second; | |||||
ref_i_all_refs.emplace_back(cell_netoutput_in); | |||||
ref_i_all_refs.emplace_back(cell_netoutput_out); | |||||
} | |||||
node_refs.emplace_back(ref_i_all_refs); | |||||
ref_i++; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
// build ref relations according to diff func op type | |||||
graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType( | |||||
const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||||
const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) { | |||||
// data_nodes has been sorted | |||||
auto node_type = root_node->GetType(); | |||||
auto status = GRAPH_SUCCESS; | |||||
if (node_type == kIf || node_type == kCase) { | |||||
status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); | |||||
} else if (node_type == kWhile) { | |||||
status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); | |||||
} else { | |||||
GELOGE(GRAPH_PARAM_INVALID, "Node type [%s] is not supported for build ref relations!", node_type.c_str()); | |||||
status = GRAPH_PARAM_INVALID; | |||||
} | |||||
return status; | |||||
} | |||||
void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector<NodePtr> &data_nodes, | |||||
vector<NodePtr> &netoutput_nodes, | |||||
const std::vector<std::string> &sub_graph_names, | |||||
const std::string &node_type) { | |||||
int sub_graph_idx = 0; | |||||
for (const auto &name : sub_graph_names) { | |||||
auto sub_graph = root_graph.GetSubgraph(name); | |||||
for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { | |||||
auto sub_graph_node_type = sub_graph_node->GetType(); | |||||
if (sub_graph_node_type == DATA) { | |||||
data_nodes.emplace_back(sub_graph_node); | |||||
} else if (sub_graph_node_type == NETOUTPUT) { | |||||
// if while, the first subgraph must be cond subgraph. | |||||
// There is no meaning for refs ,so continue | |||||
if (node_type == kWhile && sub_graph_idx == 0) { | |||||
continue; | |||||
} | |||||
netoutput_nodes.emplace_back(sub_graph_node); | |||||
} | |||||
continue; | |||||
} | |||||
sub_graph_idx++; | |||||
} | |||||
} | |||||
graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) { | |||||
auto parent_graph_ptr = graph.GetParentGraph(); | |||||
if (parent_graph_ptr == nullptr) { | |||||
root_graph = graph; | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
auto root_graph_ptr = GraphUtils::FindRootGraph(parent_graph_ptr); | |||||
if (root_graph_ptr == nullptr) { | |||||
GE_LOGE("Get null root graph"); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
root_graph = *root_graph_ptr; | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector<NodePtr> &data_nodes, | |||||
vector<vector<NodePtr>> &classed_data_nodes) { | |||||
int max_ref_idx = 0; | |||||
for (const auto &e : data_nodes) { | |||||
int i; | |||||
bool is_exist = true; | |||||
is_exist = AttrUtils::GetInt(e->GetOpDesc(), kRefIndex, i); | |||||
if (!is_exist) { | |||||
GELOGE(GRAPH_FAILED, "Invalid SubGraph NetOutput node[%s].no attr %s", e->GetName().c_str(), kRefIndex); | |||||
return GRAPH_FAILED; | |||||
} | |||||
max_ref_idx = (i > max_ref_idx) ? i : max_ref_idx; | |||||
} | |||||
while (!data_nodes.empty()) { | |||||
auto data = data_nodes.back(); | |||||
data_nodes.pop_back(); | |||||
int ref_idx = 0; | |||||
(void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx); | |||||
classed_data_nodes[ref_idx].emplace_back(data); | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( | |||||
const vector<NodePtr> &netoutput_nodes, vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes) { | |||||
for (const auto &sub_netoutput_node : netoutput_nodes) { | |||||
auto op_desc = sub_netoutput_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchors()) { | |||||
auto in_desc = op_desc->MutableInputDesc(in_data_anchor->GetIdx()); | |||||
if (in_desc == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "Invalid NetOutput node [%s] idx [%lu], no tensor on it", | |||||
sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
int ref_o; | |||||
if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) { | |||||
if (ref_o >= kMaxElementNum) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
classed_netoutput_nodes[ref_o].emplace_back( | |||||
std::pair<NodePtr, size_t>({sub_netoutput_node, static_cast<size_t>(in_data_anchor->GetIdx())})); | |||||
} | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { | |||||
/* First Step: Get root graph */ | |||||
ge::ComputeGraph &root_graph = graph; | |||||
auto status = GetRootGraph(graph, root_graph); | |||||
if (status != GRAPH_SUCCESS) { | |||||
return status; | |||||
} | |||||
for (const auto &node : graph.GetAllNodes()) { | |||||
auto node_type = node->GetType(); | |||||
if (function_op.find(node_type) == function_op.end()) { | |||||
continue; | |||||
} | |||||
std::vector<NodePtr> ref_nodes; | |||||
auto op_desc = node->GetOpDesc(); | |||||
auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||||
vector<NodePtr> data_nodes; | |||||
vector<NodePtr> netoutput_nodes; | |||||
// Get data and netoutput of sub_graph | |||||
GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type); | |||||
vector<vector<NodePtr>> classed_data_nodes(kMaxElementNum); // according to ref_idx | |||||
vector<vector<std::pair<NodePtr, size_t>>> classed_netoutput_nodes(kMaxElementNum); // according to ref_idx | |||||
status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "classfy data nodes failed!"); | |||||
return status; | |||||
} | |||||
// for netoutput | |||||
// check netoutput | |||||
// here main graph output number must be the same as every sub_graph netoutput node | |||||
// key: netoutput node_ptr ,<ref_idx, net_in_idx> | |||||
status = ProcessSubgraphNetoutput(netoutput_nodes, classed_netoutput_nodes); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "process netoutput failed!"); | |||||
return status; | |||||
} | |||||
vector<vector<RefCell>> node_refs; | |||||
status = BuildRelationsWithFuncNodeType(node, classed_data_nodes, classed_netoutput_nodes, node_refs); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(status, "BuildRelationsWithFuncNodeType Failed! Node is [%s]!", node->GetName().c_str()); | |||||
return status; | |||||
} | |||||
if (!node_refs.empty()) { | |||||
values_.push_back(node_refs); | |||||
} | |||||
} | |||||
/* Seconde Step: generate map */ | |||||
status = BuildLookUpTables(); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(status, "Build look up tables failed!"); | |||||
return status; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
/* Ref Relations Interface */ | |||||
RefRelations::RefRelations() { | |||||
impl_ = MakeShared<Impl>(); | |||||
if (impl_ == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "MakeShared failed!"); | |||||
return; | |||||
} | |||||
} | |||||
graphStatus RefRelations::LookUpRefRelations(const RefCell &key, unordered_set<RefCell, RefCellHash> &result) { | |||||
GE_CHECK_NOTNULL(impl_); | |||||
return impl_->LookUpRefRelations(key, result); | |||||
} | |||||
graphStatus RefRelations::BuildRefRelations(ge::ComputeGraph &root_graph) { | |||||
GE_CHECK_NOTNULL(impl_); | |||||
return impl_->BuildRefRelations(root_graph); | |||||
} | |||||
graphStatus RefRelations::Clear() { | |||||
GE_CHECK_NOTNULL(impl_); | |||||
return impl_->Clear(); | |||||
} | |||||
} // namespace ge |
@@ -21,7 +21,7 @@ | |||||
#include <unordered_map> | #include <unordered_map> | ||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include "framework/common/types.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
@@ -37,7 +37,6 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
constexpr const char *kRefIndex = "parent_node_index"; | |||||
graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | ||||
@@ -47,6 +46,10 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||||
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | ||||
for (const auto &name : sub_graph_names) { | for (const auto &name : sub_graph_names) { | ||||
if (name.empty()) { | |||||
GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
auto sub_graph = root_graph->GetSubgraph(name); | auto sub_graph = root_graph->GetSubgraph(name); | ||||
if (sub_graph == nullptr) { | if (sub_graph == nullptr) { | ||||
GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | ||||
@@ -63,7 +66,7 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||||
node->GetName().c_str()); | node->GetName().c_str()); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
if (!AttrUtils::GetInt(node_sub->GetOpDesc(), kRefIndex, ref_i)) { | |||||
if (!AttrUtils::GetInt(node_sub->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||||
GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), | GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), | ||||
node->GetName().c_str()); | node->GetName().c_str()); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
@@ -76,7 +79,10 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||||
ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); | ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), | |||||
node->GetName().c_str()); | |||||
auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); | auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); | ||||
if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s", | GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s", | ||||
node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | ||||
@@ -101,6 +107,10 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||||
auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | ||||
for (const auto &name : sub_graph_names) { | for (const auto &name : sub_graph_names) { | ||||
if (name.empty()) { | |||||
GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
auto sub_graph = root_graph->GetSubgraph(name); | auto sub_graph = root_graph->GetSubgraph(name); | ||||
if (sub_graph == nullptr) { | if (sub_graph == nullptr) { | ||||
GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | ||||
@@ -132,11 +142,14 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||||
node->GetName().c_str(), edge_anchor->GetIdx()); | node->GetName().c_str(), edge_anchor->GetIdx()); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
GELOGI("Netoutput in anchor index is %zu, input tensor dim is %zu", edge_anchor->GetIdx(), | |||||
edge_desc->GetShape().GetDimNum()); | |||||
int ref_i; | int ref_i; | ||||
if (!AttrUtils::GetInt(edge_desc, kRefIndex, ref_i)) { | |||||
if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||||
// if there is no ref index on the TensorDesc, it means the output data will be ignored outer. | // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. | ||||
continue; | continue; | ||||
} | } | ||||
GELOGI("Parent node index of edge desc is %d", ref_i); | |||||
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i)); | auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i)); | ||||
if (output_desc == nullptr) { | if (output_desc == nullptr) { | ||||
GE_LOGE( | GE_LOGE( | ||||
@@ -29,6 +29,7 @@ namespace { | |||||
/// Extra 1 byte store '\0' | /// Extra 1 byte store '\0' | ||||
const int EXTRA_STORE_POINTER_FOR_STRING = 8; | const int EXTRA_STORE_POINTER_FOR_STRING = 8; | ||||
const int EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL = 9; | const int EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL = 9; | ||||
const int64_t UNKNOWN_DIM_SIZE = -1; | |||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
@@ -65,6 +66,7 @@ class TensorDescImpl { | |||||
TensorDescImpl(const Shape &shape, Format format, DataType dt) : shape_(shape), format_(format), data_type_(dt) {} | TensorDescImpl(const Shape &shape, Format format, DataType dt) : shape_(shape), format_(format), data_type_(dt) {} | ||||
Shape shape_; | Shape shape_; | ||||
std::vector<std::pair<int64_t, int64_t>> range_; | |||||
Format format_ = FORMAT_ND; | Format format_ = FORMAT_ND; | ||||
Format origin_format_ = FORMAT_ND; | Format origin_format_ = FORMAT_ND; | ||||
DataType data_type_ = DT_FLOAT; | DataType data_type_ = DT_FLOAT; | ||||
@@ -94,7 +96,16 @@ class ShapeImpl { | |||||
public: | public: | ||||
ShapeImpl() = default; | ShapeImpl() = default; | ||||
~ShapeImpl() = default; | ~ShapeImpl() = default; | ||||
explicit ShapeImpl(const std::vector<int64_t> &dims) : dims_(dims) {} | |||||
explicit ShapeImpl(const std::vector<int64_t> &dims) { | |||||
bool is_unknown_dim_num = false; | |||||
for (const auto &dim : dims) { | |||||
if (dim == UNKNOWN_DIM_NUM) { | |||||
is_unknown_dim_num = true; | |||||
break; | |||||
} | |||||
} | |||||
dims_ = is_unknown_dim_num ? std::vector<int64_t>({UNKNOWN_DIM_NUM}) : dims; | |||||
} | |||||
std::vector<int64_t> dims_; | std::vector<int64_t> dims_; | ||||
}; | }; | ||||
@@ -105,6 +116,11 @@ Shape::Shape(const std::vector<int64_t> &dims) { impl_ = ComGraphMakeShared<Shap | |||||
size_t Shape::GetDimNum() const { | size_t Shape::GetDimNum() const { | ||||
if (impl_ != nullptr) { | if (impl_ != nullptr) { | ||||
for (auto i : impl_->dims_) { | |||||
if (i == UNKNOWN_DIM_NUM) { | |||||
return 0; | |||||
} | |||||
} | |||||
return impl_->dims_.size(); | return impl_->dims_.size(); | ||||
} | } | ||||
return 0; | return 0; | ||||
@@ -146,6 +162,10 @@ int64_t Shape::GetShapeSize() const { | |||||
} | } | ||||
int64_t size = 1; | int64_t size = 1; | ||||
for (auto i : impl_->dims_) { | for (auto i : impl_->dims_) { | ||||
if (i == UNKNOWN_DIM_NUM || i == UNKNOWN_DIM) { | |||||
return UNKNOWN_DIM_SIZE; | |||||
} | |||||
if (!Int64MulNotOverflow(size, i)) { | if (!Int64MulNotOverflow(size, i)) { | ||||
GELOGE(GRAPH_FAILED, "mul overflow: %ld, %ld", size, i); | GELOGE(GRAPH_FAILED, "mul overflow: %ld, %ld", size, i); | ||||
size = 0; | size = 0; | ||||
@@ -217,6 +237,34 @@ void TensorDesc::SetShape(const Shape &shape) { | |||||
} | } | ||||
} | } | ||||
// set shape with -2, it stand for unknown shape | |||||
graphStatus TensorDesc::SetUnknownDimNumShape() { | |||||
if (impl != nullptr) { | |||||
impl->shape_ = Shape({UNKNOWN_DIM_NUM}); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
GELOGE(GRAPH_FAILED, "Set unknown shape failed,because no impl class!"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
// for unknown shape | |||||
graphStatus TensorDesc::SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range) { | |||||
if (impl != nullptr) { | |||||
impl->range_ = range; | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
GELOGE(GRAPH_FAILED, "SetShapeRange failed!impl is nullptr!"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
graphStatus TensorDesc::GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const { | |||||
if (impl != nullptr) { | |||||
range = impl->range_; | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
GELOGE(GRAPH_FAILED, "impl is nullptr!"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
Shape TensorDesc::GetOriginShape() const { | Shape TensorDesc::GetOriginShape() const { | ||||
if (impl != nullptr) { | if (impl != nullptr) { | ||||
return impl->origin_shape_; | return impl->origin_shape_; | ||||
@@ -541,6 +589,17 @@ GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_des | |||||
tensor_desc.GetDataType()); | tensor_desc.GetDataType()); | ||||
ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); | ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); | ||||
ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); | ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); | ||||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
auto status = tensor_desc.GetShapeRange(shape_range); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Get shape range failed!"); | |||||
return ge_tensor_desc; | |||||
} | |||||
status = ge_tensor_desc.SetShapeRange(shape_range); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Set shape range failed!"); | |||||
return ge_tensor_desc; | |||||
} | |||||
auto size = tensor_desc.GetSize(); | auto size = tensor_desc.GetSize(); | ||||
TensorUtils::SetSize(ge_tensor_desc, size); | TensorUtils::SetSize(ge_tensor_desc, size); | ||||
@@ -554,6 +613,17 @@ TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_ | |||||
ge_tensor_desc.GetDataType()); | ge_tensor_desc.GetDataType()); | ||||
tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); | tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); | ||||
tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); | tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); | ||||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
auto status = ge_tensor_desc.GetShapeRange(shape_range); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Get shape range failed!"); | |||||
return tensor_desc; | |||||
} | |||||
status = tensor_desc.SetShapeRange(shape_range); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Set shape range failed!"); | |||||
return tensor_desc; | |||||
} | |||||
int64_t size = 0; | int64_t size = 0; | ||||
(void)TensorUtils::GetSize(ge_tensor_desc, size); | (void)TensorUtils::GetSize(ge_tensor_desc, size); | ||||
tensor_desc.SetSize(size); | tensor_desc.SetSize(size); | ||||
@@ -28,6 +28,7 @@ | |||||
#include <cstring> | #include <cstring> | ||||
#include <fstream> | #include <fstream> | ||||
#include <iomanip> | #include <iomanip> | ||||
#include <queue> | |||||
#include "./ge_context.h" | #include "./ge_context.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
@@ -390,8 +391,8 @@ GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDa | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
if ((RemoveEdge(src, dst) != GRAPH_SUCCESS) || | |||||
(AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS)) { | |||||
(void)RemoveEdge(src, dst); | |||||
if (AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), | GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), | ||||
dst_node->GetName().c_str(), insert_node->GetName().c_str(), dst_node->GetName().c_str()); | dst_node->GetName().c_str(), insert_node->GetName().c_str(), dst_node->GetName().c_str()); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
@@ -399,7 +400,7 @@ GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDa | |||||
OutControlAnchorPtr new_out_ctrl_anchor = insert_node->GetOutControlAnchor(); | OutControlAnchorPtr new_out_ctrl_anchor = insert_node->GetOutControlAnchor(); | ||||
GE_CHECK_NOTNULL(new_out_ctrl_anchor); | GE_CHECK_NOTNULL(new_out_ctrl_anchor); | ||||
for (InControlAnchorPtr peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
for (const InControlAnchorPtr &peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
if ((RemoveEdge(src_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS) || | if ((RemoveEdge(src_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS) || | ||||
(AddEdge(new_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS)) { | (AddEdge(new_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS)) { | ||||
GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), | GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), | ||||
@@ -706,7 +707,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn | |||||
GELOGE(GRAPH_FAILED, "File name is too longer!"); | GELOGE(GRAPH_FAILED, "File name is too longer!"); | ||||
return; | return; | ||||
} | } | ||||
std::unique_ptr<char> real_path(new (std::nothrow) char[PATH_MAX]{0}); | |||||
std::unique_ptr<char[]> real_path(new (std::nothrow) char[PATH_MAX]{0}); | |||||
if (real_path == nullptr) { | if (real_path == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "New real_path failed."); | GELOGE(GRAPH_FAILED, "New real_path failed."); | ||||
return; | return; | ||||
@@ -1276,6 +1277,423 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::FindR | |||||
} | } | ||||
/// | /// | ||||
/// Get reference-mapping of all data_anchors in graph | |||||
/// @param [in] graph | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol) { | |||||
GE_CHECK_NOTNULL(graph); | |||||
for (auto &node : graph->GetAllNodes()) { | |||||
// in_data_anchor | |||||
if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||||
GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
// out_data_anchor | |||||
if (HandleOutAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||||
GE_LOGE("Find ref_mapping for out_data_anchors of node %s failed.", node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
/// | |||||
/// Get reference-mapping for in_data_anchors of node | |||||
/// @param [in] node | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol) { | |||||
GE_CHECK_NOTNULL(node); | |||||
if (NodeUtils::IsSubgraphOutput(node)) { | |||||
return HandleSubgraphOutput(node, symbol_to_anchors, anchor_to_symbol); | |||||
} | |||||
if (NodeUtils::IsSubgraphInput(node)) { | |||||
return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); | |||||
} | |||||
std::string type = node->GetType(); | |||||
if ((type == MERGE) || (type == STREAMMERGE)) { | |||||
return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); | |||||
} | |||||
for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
NodeIndexIO cur_node_info = NodeIndexIO(node, in_data_anchor->GetIdx(), kIn); | |||||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
if (peer_out_anchor == nullptr) { | |||||
std::string symbol = cur_node_info.ToString(); | |||||
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||||
symbol_to_anchors[symbol] = {cur_node_info}; | |||||
anchor_to_symbol[symbol] = symbol; | |||||
} else { | |||||
NodeIndexIO exist_node_info = NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); | |||||
if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||||
GE_LOGE("Update symbol mapping failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
/// | |||||
/// Get reference-mapping for out_data_anchors of node | |||||
/// @param [in] node | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol) { | |||||
GE_CHECK_NOTNULL(node); | |||||
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||||
NodeIndexIO cur_node_info = NodeIndexIO(node, out_data_anchor->GetIdx(), kOut); | |||||
if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { | |||||
continue; | |||||
} | |||||
int32_t reuse_in_index = -1; | |||||
if (IsRefFromInput(out_data_anchor, reuse_in_index)) { | |||||
NodeIndexIO exist_node_info = NodeIndexIO(node, reuse_in_index, kIn); | |||||
if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||||
GE_LOGE("Update symbol mapping failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} else { | |||||
std::string symbol = cur_node_info.ToString(); | |||||
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||||
symbol_to_anchors.emplace(std::make_pair(symbol, std::vector<NodeIndexIO>{cur_node_info})); | |||||
anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
/// | |||||
/// Handle input of subgraph | |||||
/// @param [in] node | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
graphStatus GraphUtils::HandleSubgraphInput(const NodePtr &node, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol) { | |||||
GE_CHECK_NOTNULL(node); | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
// Data in subgraph | |||||
uint32_t index = 0; | |||||
if (!ge::AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index)) { | |||||
GE_LOGE("Get attr ATTR_NAME_PARENT_NODE_INDEX failed, node:%s.", node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
NodePtr parent_node = node->GetOwnerComputeGraph()->GetParentNode(); | |||||
GE_CHECK_NOTNULL(parent_node); | |||||
InDataAnchorPtr parent_in_anchor = parent_node->GetInDataAnchor(index); | |||||
GE_CHECK_NOTNULL(parent_in_anchor); | |||||
OutDataAnchorPtr peer_out_anchor = parent_in_anchor->GetPeerOutAnchor(); | |||||
if (peer_out_anchor != nullptr) { | |||||
// Data has and only has one input | |||||
NodeIndexIO cur_node_info = NodeIndexIO(node, 0, kIn); | |||||
NodeIndexIO exist_node_info = NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); | |||||
if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||||
GE_LOGE("Update symbol mapping failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
/// | |||||
/// Handle input of Merge op | |||||
/// @param [in] node | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol) { | |||||
GE_CHECK_NOTNULL(node); | |||||
std::vector<NodeIndexIO> exist_node_infos; | |||||
std::vector<NodeIndexIO> cur_node_infos; | |||||
for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
if (peer_out_anchor == nullptr) { | |||||
std::string next_name; | |||||
if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next_name) && !next_name.empty()) { | |||||
ComputeGraphPtr graph = node->GetOwnerComputeGraph(); | |||||
GE_CHECK_NOTNULL(graph); | |||||
ge::NodePtr next_node = graph->FindNode(next_name); | |||||
GE_CHECK_NOTNULL(next_node); | |||||
// NextIteration has and only has one output | |||||
peer_out_anchor = next_node->GetOutDataAnchor(0); | |||||
GE_CHECK_NOTNULL(peer_out_anchor); | |||||
cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); | |||||
cur_node_infos.emplace_back(NodeIndexIO(next_node, peer_out_anchor->GetIdx(), kOut)); | |||||
} | |||||
} else { | |||||
cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); | |||||
exist_node_infos.emplace_back(NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut)); | |||||
} | |||||
} | |||||
size_t anchor_nums = 0; | |||||
NodeIndexIO max_node_index_io(nullptr, 0, kOut); | |||||
for (auto &temp_node_info : exist_node_infos) { | |||||
auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); | |||||
if (iter1 != anchor_to_symbol.end()) { | |||||
std::string temp_symbol = iter1->second; | |||||
auto iter2 = symbol_to_anchors.find(temp_symbol); | |||||
if (iter2 != symbol_to_anchors.end()) { | |||||
if (iter2->second.size() > anchor_nums) { | |||||
max_node_index_io = temp_node_info; | |||||
anchor_nums = iter2->second.size(); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
std::string symbol; | |||||
for (auto &temp_node_info : exist_node_infos) { | |||||
if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != | |||||
GRAPH_SUCCESS) || | |||||
symbol.empty()) { | |||||
GE_LOGE("Union symbol map anchor1:%s & anchor2:%s.", max_node_index_io.ToString().c_str(), | |||||
temp_node_info.ToString().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
auto iter = symbol_to_anchors.find(symbol); | |||||
if (iter != symbol_to_anchors.end()) { | |||||
for (auto &temp_node_info : cur_node_infos) { | |||||
GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); | |||||
iter->second.emplace_back(temp_node_info); | |||||
anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
/// | |||||
/// Handle output of subgraph | |||||
/// @param [in] node | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol) { | |||||
GE_CHECK_NOTNULL(node); | |||||
ComputeGraphPtr owner_graph = node->GetOwnerComputeGraph(); | |||||
GE_CHECK_NOTNULL(owner_graph); | |||||
NodePtr parent_node = owner_graph->GetParentNode(); | |||||
GE_CHECK_NOTNULL(parent_node); | |||||
OpDescPtr op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
GE_CHECK_NOTNULL(peer_out_anchor); | |||||
GeTensorDesc in_tensor = op_desc->GetInputDesc(in_data_anchor->GetIdx()); | |||||
uint32_t index = 0; | |||||
if (!ge::AttrUtils::GetInt(in_tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { | |||||
continue; | |||||
} | |||||
GE_CHECK_NOTNULL(parent_node->GetOutDataAnchor(index)); | |||||
// Union symbol of peer_out_anchor & parent_out_anchor | |||||
NodeIndexIO peer_node_info = NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); | |||||
NodeIndexIO parent_node_info = NodeIndexIO(parent_node, index, kOut); | |||||
std::string symbol; | |||||
if ((UnionSymbolMapping(peer_node_info, parent_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != | |||||
GRAPH_SUCCESS) || | |||||
symbol.empty()) { | |||||
GE_LOGE("Union symbol map anchor1:%s, anchor2:%s.", peer_node_info.ToString().c_str(), | |||||
parent_node_info.ToString().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
NodeIndexIO cur_node_info = NodeIndexIO(node, in_data_anchor->GetIdx(), kIn); | |||||
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||||
symbol_to_anchors[symbol].emplace_back(cur_node_info); | |||||
anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
/// | |||||
/// Union ref-mapping | |||||
/// @param [in] exist_node_info1 | |||||
/// @param [in] exist_node_info2 | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @param [out] symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol) { | |||||
std::string symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; | |||||
std::string symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; | |||||
if (symbol1 == symbol2) { | |||||
symbol = symbol1; | |||||
GELOGI("no need to union."); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
auto iter1 = symbol_to_anchors.find(symbol1); | |||||
auto iter2 = symbol_to_anchors.find(symbol2); | |||||
if ((iter1 == symbol_to_anchors.end()) || (iter2 == symbol_to_anchors.end())) { | |||||
GE_LOGE("symbol %s or %s not exist.", symbol1.c_str(), symbol2.c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto &max_iter = (iter1->second.size() > iter2->second.size() ? iter1 : iter2); | |||||
auto &min_iter = (iter1->second.size() > iter2->second.size() ? iter2 : iter1); | |||||
symbol = (iter1->second.size() > iter2->second.size() ? symbol1 : symbol2); | |||||
std::string min_symbol = (iter1->second.size() > iter2->second.size() ? symbol2 : symbol1); | |||||
for (auto &node_index_io : min_iter->second) { | |||||
GELOGD("Update anchor %s, symbol %s.", node_index_io.ToString().c_str(), symbol.c_str()); | |||||
max_iter->second.emplace_back(node_index_io); | |||||
auto iter = anchor_to_symbol.find(node_index_io.ToString()); | |||||
if (iter == anchor_to_symbol.end()) { | |||||
GE_LOGE("anchor %s not exist.", node_index_io.ToString().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (iter->second != min_symbol) { | |||||
GELOGW("not expected symbol of anchor %s, expect %s but %s exactly.", iter->first.c_str(), min_symbol.c_str(), | |||||
iter->second.c_str()); | |||||
} | |||||
iter->second = symbol; | |||||
} | |||||
GELOGI("Union symbol %s and %s succ.", symbol.c_str(), min_symbol.c_str()); | |||||
symbol_to_anchors.erase(min_iter); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
/// | |||||
/// Update symbol mapping with a new reference pair | |||||
/// @param [in] cur_node_info | |||||
/// @param [in] exist_node_info | |||||
/// @param [out] symbol_to_anchors | |||||
/// @param [out] anchor_to_symbol | |||||
/// @return success: GRAPH_SUCESS | |||||
/// | |||||
graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||||
std::map<std::string, std::string> &anchor_to_symbol) { | |||||
auto iter1 = anchor_to_symbol.find(exist_node_info.ToString()); | |||||
if (iter1 == anchor_to_symbol.end()) { | |||||
GE_LOGE("data_anchor %s is not visible before data_anchor %s, maybe TopoSorting is missing.", | |||||
exist_node_info.ToString().c_str(), cur_node_info.ToString().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
std::string symbol = iter1->second; | |||||
auto iter2 = symbol_to_anchors.find(symbol); | |||||
if (iter2 == symbol_to_anchors.end()) { | |||||
GE_LOGE("symbol %s not found.", symbol.c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||||
iter2->second.emplace_back(cur_node_info); | |||||
anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
/// | |||||
/// Check if out_data_anchor is reference of input | |||||
/// @param [in] out_data_anchor | |||||
/// @param [out] reuse_in_index | |||||
/// @return bool | |||||
/// | |||||
bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index) { | |||||
if (out_data_anchor == nullptr) { | |||||
GELOGW("out_data_anchor is NULL."); | |||||
return false; | |||||
} | |||||
int32_t output_index = out_data_anchor->GetIdx(); | |||||
// pass-through op | |||||
NodePtr node = out_data_anchor->GetOwnerNode(); | |||||
std::string type = node->GetType(); | |||||
const std::set<std::string> pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE}; | |||||
if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) { | |||||
reuse_in_index = output_index; | |||||
GELOGI("Pass-Through node name[%s] index[%u].", node->GetName().c_str(), reuse_in_index); | |||||
return true; | |||||
} | |||||
// Merge op 0th output | |||||
if ((type == MERGE) && (output_index == 0)) { | |||||
reuse_in_index = 0; | |||||
GELOGI("Merge name[%s] output_index[0].", node->GetName().c_str()); | |||||
return true; | |||||
} | |||||
// ref op | |||||
OpDescPtr op_desc = node->GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
GELOGW("op_desc is NULL."); | |||||
return false; | |||||
} | |||||
bool is_ref = false; | |||||
(void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_REFERENCE, is_ref); | |||||
if (is_ref) { | |||||
const string &output_name = op_desc->GetOutputNameByIndex(output_index); | |||||
for (const auto &input_name : op_desc->GetAllInputNames()) { | |||||
if (!input_name.empty() && (output_name == input_name)) { | |||||
reuse_in_index = op_desc->GetInputIndexByName(input_name); | |||||
GELOGI("Reference name[%s] output[%s][%u] ref to input[%s][%d].", op_desc->GetName().c_str(), | |||||
output_name.c_str(), output_index, input_name.c_str(), reuse_in_index); | |||||
return true; | |||||
} | |||||
} | |||||
} | |||||
// reuse input | |||||
auto output_op_desc = op_desc->GetOutputDescPtr(output_index); | |||||
bool reuse_input = false; | |||||
if (output_op_desc != nullptr) { | |||||
if ((TensorUtils::GetReuseInput(*output_op_desc, reuse_input) == GRAPH_SUCCESS) && reuse_input) { | |||||
uint32_t reuse_input_index = 0; | |||||
if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { | |||||
reuse_in_index = static_cast<int32_t>(reuse_input_index); | |||||
GELOGI("ReuseInput name[%s] output[%u] reuse input[%d].", op_desc->GetName().c_str(), output_index, | |||||
reuse_in_index); | |||||
return true; | |||||
} | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
/// | |||||
/// @brief Add node to graph | /// @brief Add node to graph | ||||
/// @param [in] op_desc | /// @param [in] op_desc | ||||
/// @return ComputeGraphBuilder | /// @return ComputeGraphBuilder | ||||
@@ -1561,13 +1979,14 @@ CompleteGraphBuilder &CompleteGraphBuilder::SetOutputMapping(const std::map<uint | |||||
/// | /// | ||||
ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { | ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { | ||||
owner_graph_ = shared_ptr<ComputeGraph>(new (std::nothrow) ComputeGraph(name_)); | owner_graph_ = shared_ptr<ComputeGraph>(new (std::nothrow) ComputeGraph(name_)); | ||||
if (owner_graph_ == nullptr) { | |||||
if ((owner_graph_ == nullptr) || (parent_node_ == nullptr)) { | |||||
error_code = GRAPH_FAILED; | error_code = GRAPH_FAILED; | ||||
error_msg = "graph is NULL."; | |||||
error_msg = "graph / parent_node is NULL."; | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
owner_graph_->SetParentNode(parent_node_); | owner_graph_->SetParentNode(parent_node_); | ||||
owner_graph_->SetParentGraph(parent_node_->GetOwnerComputeGraph()); | |||||
BuildNodes(error_code, error_msg); | BuildNodes(error_code, error_msg); | ||||
if (error_code != GRAPH_SUCCESS) { | if (error_code != GRAPH_SUCCESS) { | ||||
@@ -1584,41 +2003,58 @@ ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
BuildInputs(error_code, error_msg); | |||||
AddDataNodes(error_code, error_msg); | |||||
if (error_code != GRAPH_SUCCESS) { | if (error_code != GRAPH_SUCCESS) { | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
BuildOutputs(error_code, error_msg); | |||||
AddRetValNodes(error_code, error_msg); | |||||
if (error_code != GRAPH_SUCCESS) { | if (error_code != GRAPH_SUCCESS) { | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
if (AddNetOutputNode(error_code, error_msg) == nullptr) { | |||||
// ATTR_NAME_SESSION_GRAPH_ID | |||||
std::string graph_id; | |||||
if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { | |||||
error_code = GRAPH_FAILED; | |||||
error_msg = "Get attr session_graph_id failed."; | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
if (!AttrUtils::SetStr(owner_graph_, ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { | |||||
error_code = GRAPH_FAILED; | |||||
error_msg = "Set attr session_graph_id failed."; | |||||
return nullptr; | |||||
} | |||||
// refresh node name | |||||
for (const NodePtr &node : owner_graph_->GetDirectNode()) { | |||||
if ((node->GetOpDesc() == nullptr) || (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2)) { | |||||
continue; | |||||
} | |||||
node->GetOpDesc()->SetName(owner_graph_->GetName() + "/" + node->GetName()); | |||||
} | |||||
return owner_graph_; | return owner_graph_; | ||||
} | } | ||||
/// | /// | ||||
/// @brief Build inputs | |||||
/// @brief Add data nodes | |||||
/// @param [out] error_code | /// @param [out] error_code | ||||
/// @param [out] error_msg | /// @param [out] error_msg | ||||
/// @return void | /// @return void | ||||
/// | /// | ||||
void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &error_msg) { | |||||
void CompleteGraphBuilder::AddDataNodes(graphStatus &error_code, std::string &error_msg) { | |||||
for (auto &input : graph_inputs_) { | for (auto &input : graph_inputs_) { | ||||
NodePtr data_node = AddDateNode(input.first, error_code, error_msg); | |||||
NodePtr data_node = AddDataNode(input.first, error_code, error_msg); | |||||
if (data_node == nullptr) { | if (data_node == nullptr) { | ||||
error_code = GRAPH_FAILED; | error_code = GRAPH_FAILED; | ||||
error_msg = "BuildInputs failed: add node Data:" + std::to_string(input.first) + +" failed."; | |||||
error_msg = "AddDataNodes failed: add node Data:" + std::to_string(input.first) + +" failed."; | |||||
return; | return; | ||||
} | } | ||||
if (owner_graph_->AddInputNode(data_node) == nullptr) { | if (owner_graph_->AddInputNode(data_node) == nullptr) { | ||||
error_code = GRAPH_FAILED; | error_code = GRAPH_FAILED; | ||||
error_msg = "BuildInputs failed: add input node Data:" + std::to_string(input.first) + +" failed."; | |||||
error_msg = "AddDataNodes failed: add input node Data:" + std::to_string(input.first) + +" failed."; | |||||
return; | return; | ||||
} | } | ||||
@@ -1627,7 +2063,7 @@ void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &err | |||||
std::vector<uint32_t> anchor_indes = input.second.second; | std::vector<uint32_t> anchor_indes = input.second.second; | ||||
if (input_names.size() != anchor_indes.size()) { | if (input_names.size() != anchor_indes.size()) { | ||||
error_code = GRAPH_FAILED; | error_code = GRAPH_FAILED; | ||||
error_msg = "BuildInputs failed: num of input_names and indexs not equal."; | |||||
error_msg = "AddDataNodes failed: num of input_names and indexs not equal."; | |||||
return; | return; | ||||
} | } | ||||
if (input_names.empty()) { | if (input_names.empty()) { | ||||
@@ -1641,29 +2077,29 @@ void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &err | |||||
auto iter = node_names_.find(input_name); | auto iter = node_names_.find(input_name); | ||||
if (iter == node_names_.end()) { | if (iter == node_names_.end()) { | ||||
error_code = GRAPH_FAILED; | error_code = GRAPH_FAILED; | ||||
error_msg = "BuildInputs failed: node " + input_name + " not exist in graph."; | |||||
error_msg = "AddDataNodes failed: node " + input_name + " not exist in graph."; | |||||
return; | return; | ||||
} | } | ||||
NodePtr in_node = node_names_[input_name]; | NodePtr in_node = node_names_[input_name]; | ||||
if (in_node == nullptr) { | if (in_node == nullptr) { | ||||
error_code = GRAPH_FAILED; | error_code = GRAPH_FAILED; | ||||
error_msg = "BuildInputs failed: node " + input_name + " is NULL."; | |||||
error_msg = "AddDataNodes failed: node " + input_name + " is NULL."; | |||||
return; | return; | ||||
} | } | ||||
if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), in_node->GetInDataAnchor(ind)) != GRAPH_SUCCESS) { | if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), in_node->GetInDataAnchor(ind)) != GRAPH_SUCCESS) { | ||||
error_code = GRAPH_FAILED; | error_code = GRAPH_FAILED; | ||||
error_msg = "BuildInputs failed: add data-edge Data:" + std::to_string(input.first) + ":0->" + input_name + | |||||
error_msg = "AddDataNodes failed: add data-edge Data:" + std::to_string(input.first) + ":0->" + input_name + | |||||
":" + std::to_string(ind) + " failed."; | ":" + std::to_string(ind) + " failed."; | ||||
return; | return; | ||||
} | } | ||||
} | } | ||||
GELOGD("BuildInputs : Add %u input succ.", input.first); | |||||
GELOGD("AddDataNodes : Add %u input succ.", input.first); | |||||
} | } | ||||
GELOGD("BuildInputs succ."); | |||||
GELOGD("AddDataNodes succ."); | |||||
} | } | ||||
/// | /// | ||||
@@ -1673,13 +2109,13 @@ void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &err | |||||
/// @param [out] error_msg | /// @param [out] error_msg | ||||
/// @return void | /// @return void | ||||
/// | /// | ||||
NodePtr CompleteGraphBuilder::AddDateNode(uint32_t index, graphStatus &error_code, std::string &error_msg) { | |||||
NodePtr CompleteGraphBuilder::AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg) { | |||||
std::string data_name = "Data_" + std::to_string(index); | std::string data_name = "Data_" + std::to_string(index); | ||||
OpDescBuilder op_desc_builder(data_name, "Data"); | OpDescBuilder op_desc_builder(data_name, "Data"); | ||||
OpDescPtr op_desc = op_desc_builder.AddInput("x").AddOutput("y").Build(); | OpDescPtr op_desc = op_desc_builder.AddInput("x").AddOutput("y").Build(); | ||||
if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
error_code = GRAPH_FAILED; | error_code = GRAPH_FAILED; | ||||
error_msg = "BuildInputs failed: create op_desc " + data_name + " failed."; | |||||
error_msg = "AddDataNode failed: create op_desc " + data_name + " failed."; | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -1687,7 +2123,7 @@ NodePtr CompleteGraphBuilder::AddDateNode(uint32_t index, graphStatus &error_cod | |||||
if (index_iter != input_mapping_.end()) { | if (index_iter != input_mapping_.end()) { | ||||
if (!ge::AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, index_iter->second)) { | if (!ge::AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, index_iter->second)) { | ||||
error_code = GRAPH_FAILED; | error_code = GRAPH_FAILED; | ||||
error_msg = "BuildInputs failed: set attr ATTR_NAME_PARENT_NODE_INDEX for " + data_name + " failed."; | |||||
error_msg = "AddDataNode failed: set attr ATTR_NAME_PARENT_NODE_INDEX for " + data_name + " failed."; | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
} | } | ||||
@@ -1695,189 +2131,83 @@ NodePtr CompleteGraphBuilder::AddDateNode(uint32_t index, graphStatus &error_cod | |||||
NodePtr data_node = owner_graph_->AddNode(op_desc); | NodePtr data_node = owner_graph_->AddNode(op_desc); | ||||
if (data_node == nullptr) { | if (data_node == nullptr) { | ||||
error_code = GRAPH_FAILED; | error_code = GRAPH_FAILED; | ||||
error_msg = "BuildInputs failed: add node " + data_name + " failed."; | |||||
error_msg = "AddDataNode failed: add node " + data_name + " failed."; | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
node_names_[data_name] = data_node; | |||||
return data_node; | return data_node; | ||||
} | } | ||||
/// | /// | ||||
/// @brief Build outputs | |||||
/// @brief Add RetVal nodes | |||||
/// @param [out] error_code | /// @param [out] error_code | ||||
/// @param [out] error_msg | /// @param [out] error_msg | ||||
/// @return void | /// @return void | ||||
/// | /// | ||||
void CompleteGraphBuilder::BuildOutputs(graphStatus &error_code, std::string &error_msg) { | |||||
std::map<std::string, std::vector<int32_t>> out_nodes_map; | |||||
std::vector<std::pair<NodePtr, int32_t>> out_nodes_info; | |||||
for (auto &pair : graph_outputs_) { | |||||
std::string output = pair.first; | |||||
int32_t ind = pair.second; | |||||
auto out_iter = node_names_.find(output); | |||||
void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string &error_msg) { | |||||
size_t output_num = graph_outputs_.size(); | |||||
for (size_t i = 0; i < output_num; i++) { | |||||
int32_t index = graph_outputs_[i].second; | |||||
auto out_iter = node_names_.find(graph_outputs_[i].first); | |||||
if (out_iter == node_names_.end()) { | if (out_iter == node_names_.end()) { | ||||
error_code = GRAPH_FAILED; | error_code = GRAPH_FAILED; | ||||
error_msg = "BuildOutputs failed: node " + output + " not exist in graph."; | |||||
error_msg = "AddRetValNode failed: node " + graph_outputs_[i].first + " not exist in graph."; | |||||
return; | return; | ||||
} | } | ||||
NodePtr out_node = node_names_[output]; | |||||
if (out_node == nullptr) { | |||||
NodePtr node = out_iter->second; | |||||
if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { | |||||
error_code = GRAPH_FAILED; | error_code = GRAPH_FAILED; | ||||
error_msg = "BuildOutputs failed: node " + output + " is NULL."; | |||||
error_msg = "AddRetValNode failed: node is NULL."; | |||||
return; | return; | ||||
} | } | ||||
OutDataAnchorPtr out_anchor = out_node->GetOutDataAnchor(ind); | |||||
if (out_anchor == nullptr) { | |||||
std::string name = node->GetName() + "_RetVal"; | |||||
OpDescPtr ret_val_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, FRAMEWORKOP)); | |||||
if (ret_val_desc == nullptr) { | |||||
error_code = GRAPH_FAILED; | error_code = GRAPH_FAILED; | ||||
error_msg = "BuildOutputs failed: anchor " + output + ":" + std::to_string(ind) + " is NULL."; | |||||
error_msg = "AddRetValNode " + name + " failed: op_desc is NULL."; | |||||
return; | return; | ||||
} | } | ||||
auto iter = out_nodes_map.find(output); | |||||
if (iter == out_nodes_map.end()) { | |||||
std::vector<int32_t> vec = {ind}; | |||||
out_nodes_map[output] = vec; | |||||
} else { | |||||
out_nodes_map[output].emplace_back(ind); | |||||
ge::GeTensorDesc tensor = node->GetOpDesc()->GetOutputDesc(index); | |||||
if ((ret_val_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) || | |||||
(ret_val_desc->AddOutputDesc(tensor) != GRAPH_SUCCESS)) { | |||||
error_code = GRAPH_FAILED; | |||||
error_msg = "AddRetValNode " + name + " failed: add input_desc / output_desc failed."; | |||||
return; | |||||
} | } | ||||
out_nodes_info.emplace_back(std::make_pair(out_node, ind)); | |||||
GELOGD("BuildOutputs : AddOutputAnchor %s:%u succ.", output.c_str(), ind); | |||||
} | |||||
owner_graph_->SetGraphOutNodes(out_nodes_map); | |||||
owner_graph_->SetGraphOutNodesInfo(out_nodes_info); | |||||
GELOGD("BuildOutputs succ."); | |||||
} | |||||
/// | |||||
/// @brief Add NetOutput node | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return NodePtr | |||||
/// | |||||
NodePtr CompleteGraphBuilder::AddNetOutputNode(graphStatus &error_code, std::string &error_msg) { | |||||
std::string log_msg = "AddNetOutputNode name:" + std::string(kNodeNameNetOutput) + ", type:" + NETOUTPUT; | |||||
OpDescPtr net_output_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(kNodeNameNetOutput, NETOUTPUT)); | |||||
if (net_output_desc == nullptr) { | |||||
error_code = GRAPH_FAILED; | |||||
error_msg = log_msg + " failed: op_desc is NULL."; | |||||
return nullptr; | |||||
} | |||||
std::vector<std::pair<NodePtr, int32_t>> out_nodes_info = owner_graph_->GetGraphOutNodesInfo(); | |||||
error_code = BuildInOutForNetOutput(out_nodes_info, net_output_desc); | |||||
if (error_code != GRAPH_SUCCESS) { | |||||
error_msg = log_msg + " failed: add input/output tensor failed."; | |||||
return nullptr; | |||||
} | |||||
NodePtr net_output_node = owner_graph_->AddNode(net_output_desc); | |||||
if (net_output_node == nullptr) { | |||||
error_code = GRAPH_FAILED; | |||||
error_msg = log_msg + " failed: add node failed."; | |||||
return nullptr; | |||||
} | |||||
error_code = AddEdgeForNetOutput(out_nodes_info, net_output_node); | |||||
if (error_code != GRAPH_SUCCESS) { | |||||
error_msg = log_msg + " failed: link edge failed."; | |||||
return nullptr; | |||||
} | |||||
GELOGD("%s succ.", log_msg.c_str()); | |||||
return net_output_node; | |||||
} | |||||
/// | |||||
/// @brief Add input/output tensor for NetOutput node | |||||
/// @param [in] out_nodes_info | |||||
/// @param [out] net_output_desc | |||||
/// @return graphStatus | |||||
/// | |||||
graphStatus CompleteGraphBuilder::BuildInOutForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||||
OpDescPtr &net_output_desc) { | |||||
size_t output_num = out_nodes_info.size(); | |||||
for (size_t i = 0; i < output_num; i++) { | |||||
NodePtr src_node = out_nodes_info[i].first; | |||||
uint32_t src_index = out_nodes_info[i].second; | |||||
if ((src_node == nullptr) || (src_node->GetOpDesc() == nullptr)) { | |||||
GE_LOGE("AddInOutForNetOutputOp failed: src_node is NULL."); | |||||
return GRAPH_FAILED; | |||||
if (!(ge::AttrUtils::SetStr(ret_val_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_RetVal") && | |||||
ge::AttrUtils::SetInt(ret_val_desc, RETVAL_ATTR_NAME_INDEX, i))) { | |||||
error_code = GRAPH_FAILED; | |||||
error_msg = "AddRetValNode " + name + " failed: set FRAMEWORK_ORIGINAL_TYPE / RETVAL_ATTR_NAME_INDEX failed."; | |||||
return; | |||||
} | } | ||||
ge::GeTensorDesc in_desc = src_node->GetOpDesc()->GetOutputDesc(src_index); | |||||
auto iter = output_mapping_.find(i); | auto iter = output_mapping_.find(i); | ||||
if (iter != output_mapping_.end()) { | if (iter != output_mapping_.end()) { | ||||
if (!ge::AttrUtils::SetInt(in_desc, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||||
GE_LOGE("AddInOutForNetOutputOp failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||||
return GRAPH_FAILED; | |||||
if (!ge::AttrUtils::SetInt(ret_val_desc, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||||
error_code = GRAPH_FAILED; | |||||
error_msg = "AddRetValNode " + name + " failed: set attr PARENT_NODE_INDEX failed."; | |||||
return; | |||||
} | } | ||||
} | } | ||||
if (net_output_desc->AddInputDesc(in_desc) != SUCCESS) { | |||||
GE_LOGE("AddInOutForNetOutputOp failed: add input_desc failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
ge::GeTensorDesc out_desc = src_node->GetOpDesc()->GetOutputDesc(src_index); | |||||
TensorUtils::SetOutputTensor(out_desc, true); | |||||
if (net_output_desc->AddOutputDesc(out_desc) != SUCCESS) { | |||||
GE_LOGE("AddInOutForNetOutputOp failed: add output_desc failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
GELOGD("Add input/output tensor for NetOutput node succ."); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
/// | |||||
/// @brief Add edge for NetOutput node | |||||
/// @param [in] out_nodes_info | |||||
/// @param [out] net_output_node | |||||
/// @return graphStatus | |||||
/// | |||||
graphStatus CompleteGraphBuilder::AddEdgeForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||||
const NodePtr &net_output_node) { | |||||
if (net_output_node == nullptr) { | |||||
GE_LOGE("AddEdgeForNetOutputOp failed: NetOutput is NULL."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
size_t out_num = out_nodes_info.size(); | |||||
for (size_t i = 0; i < out_num; i++) { | |||||
NodePtr src_node = out_nodes_info[i].first; | |||||
uint32_t ind = out_nodes_info[i].second; | |||||
if (src_node == nullptr) { | |||||
GE_LOGE("AddEdgeForNetOutputOp failed: src_node is NULL."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (GraphUtils::AddEdge(src_node->GetOutDataAnchor(ind), net_output_node->GetInDataAnchor(i)) != GRAPH_SUCCESS) { | |||||
GE_LOGE("Add data-edge %s:%u->%s:%zu failed.", src_node->GetName().c_str(), ind, | |||||
net_output_node->GetName().c_str(), i); | |||||
return GRAPH_FAILED; | |||||
NodePtr ret_val_node = owner_graph_->AddNode(ret_val_desc); | |||||
if (ret_val_node == nullptr) { | |||||
error_code = GRAPH_FAILED; | |||||
error_msg = "AddRetValNode " + name + " failed: add node failed."; | |||||
return; | |||||
} | } | ||||
} | |||||
std::vector<NodePtr> leaf_nodes; | |||||
for (auto &node : owner_graph_->GetDirectNode()) { | |||||
if (node->GetOutNodes().empty()) { | |||||
leaf_nodes.emplace_back(node); | |||||
} | |||||
} | |||||
for (auto &node : leaf_nodes) { | |||||
if (GraphUtils::AddEdge(node->GetOutControlAnchor(), net_output_node->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||||
GE_LOGE("Add ctrl-edge %s->%s failed.", node->GetName().c_str(), net_output_node->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
if (GraphUtils::AddEdge(node->GetOutDataAnchor(index), ret_val_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { | |||||
error_code = GRAPH_FAILED; | |||||
error_msg = "AddRetValNode " + name + " failed: add data-edge " + node->GetName() + ":" + std::to_string(index) + | |||||
"->" + ret_val_node->GetName() + ":0 failed."; | |||||
return; | |||||
} | } | ||||
} | } | ||||
GELOGD("Add edge for NetOutput node succ."); | |||||
return GRAPH_SUCCESS; | |||||
GELOGD("AddRetValNodes succ."); | |||||
} | } | ||||
/// | /// | ||||
@@ -1999,4 +2329,60 @@ void PartialGraphBuilder::BuildExistNodes(graphStatus &error_code, std::string & | |||||
GELOGD("Build exist nodes succ."); | GELOGD("Build exist nodes succ."); | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
GraphUtils::TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec) { | |||||
std::vector<NodePtr> stack_input; | |||||
std::map<NodePtr, uint32_t> map_in_edge_num; | |||||
graphStatus ret = compute_graph->SortNodes(stack_input, map_in_edge_num); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Sort nodes failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
const size_t non_user_input_index = stack_input.size() - compute_graph->inputs_order_.size() - 1; | |||||
std::sort(stack_input.begin(), stack_input.begin() + non_user_input_index, | |||||
[](const NodePtr &a, const NodePtr &b) -> bool { return (a->GetName() > b->GetName()); }); | |||||
std::queue<NodePtr> stack; | |||||
NodePtr cur_node = nullptr; | |||||
std::map<string, NodePtr> name_node_map; | |||||
vector<string> nodes_name; | |||||
while (!stack_input.empty() || !stack.empty()) { | |||||
if (!stack.empty()) { | |||||
cur_node = stack.front(); | |||||
stack.pop(); | |||||
} else { | |||||
cur_node = stack_input.back(); | |||||
stack_input.pop_back(); | |||||
} | |||||
node_vec.emplace_back(cur_node); | |||||
compute_graph->CollectBreadthOutNode(cur_node, map_in_edge_num, name_node_map); | |||||
for (const auto &iter : name_node_map) { | |||||
nodes_name.emplace_back(iter.first); | |||||
} | |||||
std::sort(nodes_name.begin(), nodes_name.end()); | |||||
for (const auto &iter : nodes_name) { | |||||
stack.push(name_node_map[iter]); | |||||
} | |||||
name_node_map.clear(); | |||||
nodes_name.clear(); | |||||
} | |||||
// If they are not equal, there is a closed loop | |||||
if (node_vec.size() != compute_graph->nodes_.size()) { | |||||
std::set<Node *> itered_nodes_set; | |||||
for (auto &node : node_vec) { | |||||
itered_nodes_set.insert(node.get()); | |||||
} | |||||
GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", | |||||
compute_graph->nodes_.size(), node_vec.size()); | |||||
for (auto &node : compute_graph->nodes_) { | |||||
if (itered_nodes_set.count(node.get()) == 0) { | |||||
GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); | |||||
} | |||||
} | |||||
return GRAPH_FAILED; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -21,6 +21,7 @@ | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/anchor.h" | #include "graph/anchor.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/types.h" | |||||
#include "utils/tensor_utils.h" | #include "utils/tensor_utils.h" | ||||
#include "utils/type_utils.h" | #include "utils/type_utils.h" | ||||
@@ -28,6 +29,26 @@ namespace ge { | |||||
std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{}; | std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{}; | ||||
std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{}; | std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{}; | ||||
bool OpShapeIsUnknown(const OpDescPtr &desc) { | |||||
for (const auto &ptr : desc->GetAllInputsDescPtr()) { | |||||
auto ge_shape = ptr->GetShape(); | |||||
for (const auto &dim : ge_shape.GetDims()) { | |||||
if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { | |||||
return true; | |||||
} | |||||
} | |||||
} | |||||
for (const auto &ptr : desc->GetAllOutputsDescPtr()) { | |||||
auto ge_shape = ptr->GetShape(); | |||||
for (const auto &dim : ge_shape.GetDims()) { | |||||
if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { | |||||
return true; | |||||
} | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node, | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node, | ||||
const uint32_t &event_id) { | const uint32_t &event_id) { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
@@ -282,18 +303,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer | |||||
GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); | GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); | ||||
continue; | continue; | ||||
} | } | ||||
auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->GetInputDescPtr(peer_anchor->GetIdx()); | |||||
auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx()); | |||||
if (peer_input_desc == nullptr) { | if (peer_input_desc == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); | GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); | ||||
continue; | continue; | ||||
} | } | ||||
output_tensor.SetOriginFormat(peer_input_desc->GetOriginFormat()); | |||||
output_tensor.SetFormat(peer_input_desc->GetFormat()); | |||||
auto peer_op_desc = peer_anchor->GetOwnerNode()->GetOpDesc(); | |||||
GE_IF_BOOL_EXEC(peer_op_desc == nullptr, GELOGE(GRAPH_FAILED, "peer opdesc is null"); continue); | |||||
GE_IF_BOOL_EXEC(peer_op_desc->UpdateInputDesc(peer_anchor->GetIdx(), output_tensor) != GRAPH_SUCCESS, | |||||
GELOGE(GRAPH_FAILED, "peer opdesc is null"); | |||||
continue); | |||||
GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", | |||||
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor.GetShape().GetDimNum(), | |||||
output_tensor.GetDataType(), output_tensor.GetOriginDataType()); | |||||
peer_input_desc->SetShape(output_tensor.GetShape()); | |||||
peer_input_desc->SetOriginShape(output_tensor.GetOriginShape()); | |||||
peer_input_desc->SetDataType(output_tensor.GetDataType()); | |||||
peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType()); | |||||
ge::TensorUtils::SetRealDimCnt(*peer_input_desc, | |||||
static_cast<uint32_t>(output_tensor.GetShape().GetDims().size())); | |||||
GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", | |||||
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), | |||||
peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); | |||||
} | } | ||||
} | } | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
@@ -361,6 +387,41 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const | |||||
input_desc->SetShape(shape); | input_desc->SetShape(shape); | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { | |||||
auto desc = node.GetOpDesc(); | |||||
GE_CHECK_NOTNULL(desc); | |||||
auto sub_graph_names = desc->GetSubgraphInstanceNames(); | |||||
if (sub_graph_names.empty()) { | |||||
is_unknow = OpShapeIsUnknown(desc); | |||||
return GRAPH_SUCCESS; | |||||
} else { | |||||
auto owner_graph = node.GetOwnerComputeGraph(); | |||||
GE_CHECK_NOTNULL(owner_graph); | |||||
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||||
if (root_graph == nullptr) { | |||||
GE_LOGE("Node %s gets null root graph", node.GetName().c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
for (auto &sub_graph_name : sub_graph_names) { | |||||
auto sub_graph = root_graph->GetSubgraph(sub_graph_name); | |||||
GE_CHECK_NOTNULL(sub_graph); | |||||
for (const auto &node_ptr : sub_graph->GetDirectNode()) { | |||||
auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow); | |||||
if (status != GRAPH_SUCCESS) { | |||||
GE_LOGE("get node unknown shape status failed!"); | |||||
return status; | |||||
} | |||||
if (is_unknow) { | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
std::string NodeUtils::GetNodeType(const Node &node) { | std::string NodeUtils::GetNodeType(const Node &node) { | ||||
if (node.GetType() != FRAMEWORKOP) { | if (node.GetType() != FRAMEWORKOP) { | ||||
return node.GetType(); | return node.GetType(); | ||||
@@ -381,9 +442,9 @@ ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { | |||||
return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); | return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); | ||||
} | } | ||||
graphStatus NodeUtils::AddSubgraph(Node &node, const ComputeGraphPtr &subgraph) { | |||||
graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) { | |||||
if (subgraph == nullptr) { | if (subgraph == nullptr) { | ||||
GE_LOGE("Failed to add subgraph to node %s, null subgraph", node.GetName().c_str()); | |||||
GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index); | |||||
return GRAPH_PARAM_INVALID; | return GRAPH_PARAM_INVALID; | ||||
} | } | ||||
auto op_desc = node.GetOpDesc(); | auto op_desc = node.GetOpDesc(); | ||||
@@ -395,11 +456,105 @@ graphStatus NodeUtils::AddSubgraph(Node &node, const ComputeGraphPtr &subgraph) | |||||
GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); | GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); | ||||
return GRAPH_PARAM_INVALID; | return GRAPH_PARAM_INVALID; | ||||
} | } | ||||
op_desc->AddSubgraphInstanceName(subgraph->GetName()); | |||||
auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName()); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index); | |||||
return ret; | |||||
} | |||||
subgraph->SetParentNode(node.shared_from_this()); | subgraph->SetParentNode(node.shared_from_this()); | ||||
subgraph->SetParentGraph(node.GetOwnerComputeGraph()); | subgraph->SetParentGraph(node.GetOwnerComputeGraph()); | ||||
root_graph->AddSubgraph(subgraph); | |||||
return root_graph->AddSubgraph(subgraph); | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
/// | |||||
/// Check if node is input of subgraph | |||||
/// @param [in] node | |||||
/// @return bool | |||||
/// | |||||
bool NodeUtils::IsSubgraphInput(const NodePtr &node) { | |||||
if ((node == nullptr) || (node->GetOpDesc() == nullptr) || | |||||
(node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) { | |||||
return false; | |||||
} | |||||
return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); | |||||
} | |||||
/// | |||||
/// Check if node is output of subgraph | |||||
/// @param [in] node | |||||
/// @return bool | |||||
/// | |||||
bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { | |||||
if ((node == nullptr) || (node->GetOpDesc() == nullptr) || | |||||
(node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) { | |||||
return false; | |||||
} | |||||
for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) { | |||||
if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) { | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
/// | |||||
/// @brief Get subgraph original input node. | |||||
/// @param [in] node | |||||
/// @return Node | |||||
/// | |||||
NodePtr NodeUtils::GetParentInput(const NodePtr &node) { | |||||
GE_CHECK_NOTNULL_EXEC(node, return nullptr); | |||||
uint32_t parent_index = 0; | |||||
if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||||
return nullptr; | |||||
} | |||||
// Subgraph Data Node, check for constant input. | |||||
const ComputeGraphPtr &graph = node->GetOwnerComputeGraph(); | |||||
GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | |||||
const NodePtr &parent_node = graph->GetParentNode(); | |||||
GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr); | |||||
const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index); | |||||
GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr); | |||||
const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr); | |||||
return peer_out_anchor->GetOwnerNode(); | |||||
} | |||||
/// | |||||
/// @brief Get subgraph input is constant. | |||||
/// @param [in] node | |||||
/// @param [out] string | |||||
/// @return bool | |||||
/// | |||||
bool NodeUtils::GetConstOpType(const NodePtr &in_node, std::string &op_type) { | |||||
GE_CHECK_NOTNULL_EXEC(in_node, return false); | |||||
if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { | |||||
op_type = in_node->GetType(); | |||||
return true; | |||||
} | |||||
if (in_node->GetType() == DATA) { | |||||
std::string const_type; | |||||
if (!AttrUtils::GetStr(in_node->GetOpDesc(), ATTR_NAME_PARENT_CONST_TYPE, const_type)) { | |||||
return false; | |||||
} | |||||
if ((const_type == CONSTANT) || (const_type == CONSTANTOP)) { | |||||
op_type = const_type; | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -469,7 +469,7 @@ OpDescUtils::SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights) | |||||
return GRAPH_PARAM_INVALID; | return GRAPH_PARAM_INVALID; | ||||
} | } | ||||
ge::GeAttrValue::NamedAttrs named_attrs; | |||||
ge::GeAttrValue::NAMED_ATTRS named_attrs; | |||||
(void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights); | (void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights); | ||||
vector<ge::GeTensorPtr> copy_weights; | vector<ge::GeTensorPtr> copy_weights; | ||||
(void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights); | (void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights); | ||||
@@ -578,7 +578,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWei | |||||
/// @return OpDescBuilder | /// @return OpDescBuilder | ||||
/// | /// | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) { | ||||
inputs_.emplace_back(name); | |||||
inputs_.emplace_back(std::make_pair(name, GeTensorDesc())); | |||||
return *this; | |||||
} | |||||
/// | |||||
/// @brief Add input | |||||
/// @param [in] name | |||||
/// @param [in] tensor | |||||
/// @return OpDescBuilder | |||||
/// | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name, | |||||
const GeTensorDesc &tensor) { | |||||
inputs_.emplace_back(std::make_pair(name, tensor)); | |||||
return *this; | return *this; | ||||
} | } | ||||
@@ -591,7 +603,22 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::Add | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, | ||||
uint32_t num) { | uint32_t num) { | ||||
for (uint32_t i = 0; i < num; i++) { | for (uint32_t i = 0; i < num; i++) { | ||||
inputs_.emplace_back(name + std::to_string(i)); | |||||
inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); | |||||
} | |||||
return *this; | |||||
} | |||||
/// | |||||
/// @brief Add dynamic input | |||||
/// @param [in] name | |||||
/// @param [in] num | |||||
/// @param [in] tensor | |||||
/// @return OpDescBuilder | |||||
/// | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput( | |||||
const std::string &name, uint32_t num, const GeTensorDesc &tensor) { | |||||
for (uint32_t i = 0; i < num; i++) { | |||||
inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); | |||||
} | } | ||||
return *this; | return *this; | ||||
} | } | ||||
@@ -602,7 +629,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::Add | |||||
/// @return OpDescBuilder | /// @return OpDescBuilder | ||||
/// | /// | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) { | ||||
outputs_.emplace_back(name); | |||||
outputs_.emplace_back(std::make_pair(name, GeTensorDesc())); | |||||
return *this; | |||||
} | |||||
/// | |||||
/// @brief Add output | |||||
/// @param [in] name | |||||
/// @param [in] tensor | |||||
/// @return OpDescBuilder | |||||
/// | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name, | |||||
const GeTensorDesc &tensor) { | |||||
outputs_.emplace_back(std::make_pair(name, tensor)); | |||||
return *this; | return *this; | ||||
} | } | ||||
@@ -615,7 +654,22 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::Add | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, | ||||
uint32_t num) { | uint32_t num) { | ||||
for (uint32_t i = 0; i < num; i++) { | for (uint32_t i = 0; i < num; i++) { | ||||
outputs_.emplace_back(name + std::to_string(i)); | |||||
outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); | |||||
} | |||||
return *this; | |||||
} | |||||
/// | |||||
/// @brief Add dynamic output | |||||
/// @param [in] name | |||||
/// @param [in] num | |||||
/// @param [in] tensor | |||||
/// @return OpDescBuilder | |||||
/// | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput( | |||||
const std::string &name, uint32_t num, const GeTensorDesc &tensor) { | |||||
for (uint32_t i = 0; i < num; i++) { | |||||
outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); | |||||
} | } | ||||
return *this; | return *this; | ||||
} | } | ||||
@@ -632,14 +686,14 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() | |||||
} | } | ||||
for (auto &input : inputs_) { | for (auto &input : inputs_) { | ||||
if (op_desc->AddInputDesc(input, GeTensorDesc()) != GRAPH_SUCCESS) { | |||||
if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Add input_desc failed."); | GELOGE(GRAPH_FAILED, "Add input_desc failed."); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
} | } | ||||
for (auto &output : outputs_) { | for (auto &output : outputs_) { | ||||
if (op_desc->AddOutputDesc(output, GeTensorDesc()) != GRAPH_SUCCESS) { | |||||
if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Add output_desc failed."); | GELOGE(GRAPH_FAILED, "Add output_desc failed."); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -647,4 +701,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() | |||||
return op_desc; | return op_desc; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgraphInstanceName( | |||||
const std::string &subgraph_name, const std::string &subgraph_instance_name, OpDescPtr &op_desc) { | |||||
const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); | |||||
auto iter = subgraph_names_to_index.find(subgraph_name); | |||||
if (iter == subgraph_names_to_index.end()) { | |||||
GELOGE(GRAPH_PARAM_INVALID, | |||||
"Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists", | |||||
subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||||
subgraph_name.c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -282,6 +282,7 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||||
case FORMAT_FRACTAL_Z_3D: | case FORMAT_FRACTAL_Z_3D: | ||||
case FORMAT_FRACTAL_Z_3D_TRANSPOSE: | case FORMAT_FRACTAL_Z_3D_TRANSPOSE: | ||||
case FORMAT_NDC1HWC0: | case FORMAT_NDC1HWC0: | ||||
case FORMAT_FRACTAL_Z_C04: | |||||
graph_status = CalcElementCntByDims(dims, element_cnt); | graph_status = CalcElementCntByDims(dims, element_cnt); | ||||
break; | break; | ||||
default: | default: | ||||
@@ -56,6 +56,7 @@ static const std::map<Format, std::string> kFormatToStringMap = { | |||||
{FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, | {FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, | ||||
{FORMAT_CN, "CN"}, | {FORMAT_CN, "CN"}, | ||||
{FORMAT_NC, "NC"}, | {FORMAT_NC, "NC"}, | ||||
{FORMAT_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"}, | |||||
{FORMAT_RESERVED, "FORMAT_RESERVED"}, | {FORMAT_RESERVED, "FORMAT_RESERVED"}, | ||||
{FORMAT_ALL, "ALL"}}; | {FORMAT_ALL, "ALL"}}; | ||||
@@ -76,7 +77,8 @@ static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | |||||
"FRACTAL_NZ", | "FRACTAL_NZ", | ||||
"NDC1HWC0", | "NDC1HWC0", | ||||
"FORMAT_FRACTAL_Z_3D", | "FORMAT_FRACTAL_Z_3D", | ||||
"FORMAT_FRACTAL_Z_3D_TRANSPOSE"}; | |||||
"FORMAT_FRACTAL_Z_3D_TRANSPOSE" | |||||
"FORMAT_FRACTAL_ZN_LSTM"}; | |||||
static const std::map<std::string, Format> kDataFormatMap = { | static const std::map<std::string, Format> kDataFormatMap = { | ||||
{"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}}; | {"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}}; | ||||
@@ -119,6 +121,7 @@ static const std::map<std::string, Format> kStringToFormatMap = { | |||||
{"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, | {"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, | ||||
{"CN", FORMAT_CN}, | {"CN", FORMAT_CN}, | ||||
{"NC", FORMAT_NC}, | {"NC", FORMAT_NC}, | ||||
{"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, | |||||
{"FORMAT_RESERVED", FORMAT_RESERVED}, | {"FORMAT_RESERVED", FORMAT_RESERVED}, | ||||
{"ALL", FORMAT_ALL}}; | {"ALL", FORMAT_ALL}}; | ||||
@@ -13,15 +13,18 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
# ============================================================================ | # ============================================================================ | ||||
# libge_compiler.so & libge_train.so | |||||
# libge_compiler.so & libge_runner.so | |||||
# will later be integrated into libgraph_runner.so, works for both training and inference | # will later be integrated into libgraph_runner.so, works for both training and inference | ||||
# compiling proto files generates some warnings, use no-unused-variable to suppress them | # compiling proto files generates some warnings, use no-unused-variable to suppress them | ||||
set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | ||||
file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"../proto/fusion_model.proto" | "../proto/fusion_model.proto" | ||||
"../proto/optimizer_priority.proto" | |||||
) | ) | ||||
file(GLOB_RECURSE PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
file(GLOB PROTO_CLIENT_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"../proto/ge_api.proto" | |||||
) | |||||
file(GLOB PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"../proto/om.proto" | "../proto/om.proto" | ||||
"../proto/task.proto" | "../proto/task.proto" | ||||
"../proto/insert_op.proto" | "../proto/insert_op.proto" | ||||
@@ -30,57 +33,46 @@ file(GLOB_RECURSE PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"../proto/op_mapping_info.proto" | "../proto/op_mapping_info.proto" | ||||
) | ) | ||||
ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | ||||
ge_protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) | |||||
ge_protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) | ge_protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) | ||||
# include directories | # include directories | ||||
include_directories(${CMAKE_CURRENT_LIST_DIR}) | include_directories(${CMAKE_CURRENT_LIST_DIR}) | ||||
include_directories(${GE_SOURCE_DIR}) | include_directories(${GE_SOURCE_DIR}) | ||||
include_directories(${GE_SOURCE_DIR}/src) | include_directories(${GE_SOURCE_DIR}/src) | ||||
include_directories(${GE_SOURCE_DIR}/inc) | include_directories(${GE_SOURCE_DIR}/inc) | ||||
include_directories(${GE_SOURCE_DIR}/inc/common/util) | |||||
include_directories(${GE_SOURCE_DIR}/inc/external) | include_directories(${GE_SOURCE_DIR}/inc/external) | ||||
include_directories(${GE_SOURCE_DIR}/inc/external/graph) | include_directories(${GE_SOURCE_DIR}/inc/external/graph) | ||||
include_directories(${GE_SOURCE_DIR}/inc/framework) | include_directories(${GE_SOURCE_DIR}/inc/framework) | ||||
include_directories(${GE_SOURCE_DIR}/inc/framework/common) | include_directories(${GE_SOURCE_DIR}/inc/framework/common) | ||||
include_directories(${GE_SOURCE_DIR}/inc/runtime) | include_directories(${GE_SOURCE_DIR}/inc/runtime) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | ||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
######### libge_train.so ############# | |||||
######### libge_runner.so ############# | |||||
# need to remove dependencies on pb files later | # need to remove dependencies on pb files later | ||||
file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | ||||
"client/ge_api.cc" | |||||
"common/formats/format_transfers/*.cc" | "common/formats/format_transfers/*.cc" | ||||
"common/formats/formats.cc" | "common/formats/formats.cc" | ||||
"common/formats/utils/formats_trans_utils.cc" | "common/formats/utils/formats_trans_utils.cc" | ||||
"common/fp16_t.cc" | "common/fp16_t.cc" | ||||
"common/ge/plugin_manager.cc" | "common/ge/plugin_manager.cc" | ||||
"common/helper/model_cache_helper.cc" | |||||
"common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
"engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
"ge_local_engine/engine/host_cpu_engine.cc" | "ge_local_engine/engine/host_cpu_engine.cc" | ||||
"generator/ge_generator.cc" | "generator/ge_generator.cc" | ||||
"generator/generator_api.cc" | "generator/generator_api.cc" | ||||
"graph/build/graph_builder.cc" | |||||
"graph/build/label_allocator.cc" | |||||
"graph/build/logical_stream_allocator.cc" | |||||
"graph/build/model_builder.cc" | |||||
"graph/build/run_context.cc" | |||||
"graph/build/stream_allocator.cc" | |||||
"graph/build/stream_graph_optimizer.cc" | |||||
"graph/build/task_generator.cc" | |||||
"graph/common/bcast.cc" | |||||
"graph/common/omg_util.cc" | |||||
"graph/common/transop_util.cc" | |||||
"graph/build/*.cc" | |||||
"graph/common/*.cc" | |||||
"graph/execute/graph_execute.cc" | "graph/execute/graph_execute.cc" | ||||
"graph/label/*.cc" | "graph/label/*.cc" | ||||
"graph/load/graph_loader.cc" | "graph/load/graph_loader.cc" | ||||
"graph/load/new_model_manager/cpu_queue_schedule.cc" | |||||
"graph/load/new_model_manager/data_dumper.cc" | |||||
"graph/load/new_model_manager/data_inputer.cc" | |||||
"graph/load/new_model_manager/davinci_model.cc" | |||||
"graph/load/new_model_manager/davinci_model_parser.cc" | |||||
"graph/load/new_model_manager/model_manager.cc" | |||||
"graph/load/new_model_manager/model_output.cc" | |||||
"graph/load/new_model_manager/model_utils.cc" | |||||
"graph/load/new_model_manager/*.cc" | |||||
"graph/load/new_model_manager/task_info/end_graph_task_info.cc" | "graph/load/new_model_manager/task_info/end_graph_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/event_record_task_info.cc" | "graph/load/new_model_manager/task_info/event_record_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/event_wait_task_info.cc" | "graph/load/new_model_manager/task_info/event_wait_task_info.cc" | ||||
@@ -89,8 +81,10 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/load/new_model_manager/task_info/hccl_task_info.cc" | "graph/load/new_model_manager/task_info/hccl_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" | "graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/kernel_task_info.cc" | "graph/load/new_model_manager/task_info/kernel_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/label_goto_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/label_set_task_info.cc" | "graph/load/new_model_manager/task_info/label_set_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
@@ -99,15 +93,9 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | ||||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | ||||
"graph/load/new_model_manager/task_info/task_info.cc" | "graph/load/new_model_manager/task_info/task_info.cc" | ||||
"graph/load/new_model_manager/tbe_handle_store.cc" | |||||
"graph/load/output/output.cc" | "graph/load/output/output.cc" | ||||
"graph/manager/graph_context.cc" | |||||
"graph/manager/graph_manager.cc" | |||||
"graph/manager/graph_manager_utils.cc" | |||||
"graph/manager/graph_mem_allocator.cc" | |||||
"graph/manager/graph_var_manager.cc" | |||||
"graph/manager/*.cc" | |||||
"graph/manager/model_manager/event_manager.cc" | "graph/manager/model_manager/event_manager.cc" | ||||
"graph/manager/trans_var_data_utils.cc" | |||||
"graph/manager/util/debug.cc" | "graph/manager/util/debug.cc" | ||||
"graph/manager/util/hcom_util.cc" | "graph/manager/util/hcom_util.cc" | ||||
"graph/manager/util/rt_context_util.cc" | "graph/manager/util/rt_context_util.cc" | ||||
@@ -115,27 +103,10 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/optimize/graph_optimize.cc" | "graph/optimize/graph_optimize.cc" | ||||
"graph/optimize/optimizer/allreduce_fusion_pass.cc" | "graph/optimize/optimizer/allreduce_fusion_pass.cc" | ||||
"graph/optimize/summary_optimize.cc" | "graph/optimize/summary_optimize.cc" | ||||
"graph/partition/dynamic_shape_partition.cc" | |||||
"graph/partition/engine_place.cc" | "graph/partition/engine_place.cc" | ||||
"graph/partition/graph_partition.cc" | "graph/partition/graph_partition.cc" | ||||
"graph/passes/addn_pass.cc" | |||||
"graph/passes/aicpu_constant_folding_pass.cc" | |||||
"graph/passes/assert_pass.cc" | |||||
"graph/passes/atomic_addr_clean_pass.cc" | |||||
"graph/passes/base_pass.cc" | |||||
"graph/passes/cast_remove_pass.cc" | |||||
"graph/passes/cast_translate_pass.cc" | |||||
"graph/passes/common_subexpression_elimination_pass.cc" | |||||
"graph/passes/compile_nodes_pass.cc" | |||||
"graph/passes/constant_folding_pass.cc" | |||||
"graph/passes/constant_fuse_same_pass.cc" | |||||
"graph/passes/control_op_attr_pass.cc" | |||||
"graph/passes/control_trigger_pass.cc" | |||||
"graph/passes/dimension_adjust_pass.cc" | |||||
"graph/passes/dimension_compute_pass.cc" | |||||
"graph/passes/dropout_pass.cc" | |||||
"graph/passes/end_graph_pass.cc" | |||||
"graph/passes/enter_pass.cc" | |||||
"graph/passes/flow_ctrl_pass.cc" | |||||
"graph/passes/*.cc" | |||||
"graph/passes/folding_kernel/add_kernel.cc" | "graph/passes/folding_kernel/add_kernel.cc" | ||||
"graph/passes/folding_kernel/broadcast_args_kernel.cc" | "graph/passes/folding_kernel/broadcast_args_kernel.cc" | ||||
"graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc" | "graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc" | ||||
@@ -171,51 +142,6 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/passes/folding_kernel/sub_kernel.cc" | "graph/passes/folding_kernel/sub_kernel.cc" | ||||
"graph/passes/folding_kernel/transdata_kernel.cc" | "graph/passes/folding_kernel/transdata_kernel.cc" | ||||
"graph/passes/folding_kernel/unpack_kernel.cc" | "graph/passes/folding_kernel/unpack_kernel.cc" | ||||
"graph/passes/folding_pass.cc" | |||||
"graph/passes/get_original_format_pass.cc" | |||||
"graph/passes/guarantee_const_pass.cc" | |||||
"graph/passes/hccl_memcpy_pass.cc" | |||||
"graph/passes/identify_reference_pass.cc" | |||||
"graph/passes/identity_pass.cc" | |||||
"graph/passes/infershape_pass.cc" | |||||
"graph/passes/isolated_op_remove_pass.cc" | |||||
"graph/passes/iterator_op_pass.cc" | |||||
"graph/passes/link_gen_mask_nodes_pass.cc" | |||||
"graph/passes/merge_pass.cc" | |||||
"graph/passes/multi_batch_pass.cc" | |||||
"graph/passes/net_output_pass.cc" | |||||
"graph/passes/next_iteration_pass.cc" | |||||
"graph/passes/no_use_reshape_remove_pass.cc" | |||||
"graph/passes/pass_manager.cc" | |||||
"graph/passes/pass_utils.cc" | |||||
"graph/passes/permute_pass.cc" | |||||
"graph/passes/placeholder_with_default_pass.cc" | |||||
"graph/passes/prevent_gradient_pass.cc" | |||||
"graph/passes/print_op_pass.cc" | |||||
"graph/passes/prune_pass.cc" | |||||
"graph/passes/reshape_remove_pass.cc" | |||||
"graph/passes/resource_pair_add_control_pass.cc" | |||||
"graph/passes/resource_pair_remove_control_pass.cc" | |||||
"graph/passes/same_transdata_breadth_fusion_pass.cc" | |||||
"graph/passes/save_pass.cc" | |||||
"graph/passes/shape_operate_op_remove_pass.cc" | |||||
"graph/passes/snapshot_pass.cc" | |||||
"graph/passes/stop_gradient_pass.cc" | |||||
"graph/passes/switch_logic_remove_pass.cc" | |||||
"graph/passes/switch_op_pass.cc" | |||||
"graph/passes/switch_pass.cc" | |||||
"graph/passes/transop_breadth_fusion_pass.cc" | |||||
"graph/passes/transop_depth_fusion_pass.cc" | |||||
"graph/passes/transop_nearby_allreduce_fusion_pass.cc" | |||||
"graph/passes/transop_without_reshape_fusion_pass.cc" | |||||
"graph/passes/transpose_transdata_pass.cc" | |||||
"graph/passes/unused_const_pass.cc" | |||||
"graph/passes/unused_op_remove_pass.cc" | |||||
"graph/passes/var_is_initialized_op_pass.cc" | |||||
"graph/passes/variable_format_pass.cc" | |||||
"graph/passes/variable_op_pass.cc" | |||||
"graph/passes/variable_prepare_op_pass.cc" | |||||
"graph/passes/variable_ref_delete_op_pass.cc" | |||||
"graph/preprocess/graph_preprocess.cc" | "graph/preprocess/graph_preprocess.cc" | ||||
"graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
"graph/preprocess/insert_op/util_insert_aipp_op.cc" | "graph/preprocess/insert_op/util_insert_aipp_op.cc" | ||||
@@ -231,22 +157,17 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
) | ) | ||||
######### libge_train.so ############# | |||||
add_library(ge_train SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||||
target_compile_definitions(ge_train PRIVATE | |||||
######### libge_runner.so ############# | |||||
add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS} ${PROTO_HEADER_HDRS}) | |||||
target_compile_definitions(ge_runner PRIVATE | |||||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | PROTOBUF_INLINE_NOT_IN_HEADERS=0 | ||||
DAVINCI_SUPPORT_PROFILING | DAVINCI_SUPPORT_PROFILING | ||||
REUSE_MEMORY=1 | REUSE_MEMORY=1 | ||||
DAVINCI_TRAIN | |||||
DAVINCI_CLOUD | |||||
FMK_SUPPORT_DEBUG | |||||
PLATFORM_CLOUD) | |||||
target_link_libraries(ge_train | |||||
DAVINCI_CLOUD) | |||||
target_link_libraries(ge_runner | |||||
graph | graph | ||||
ge_common | ge_common | ||||
"-Wl,--whole-archive" | |||||
ge_memory | ge_memory | ||||
"-Wl,--no-whole-archive" | |||||
${PROTOBUF_LIBRARY} | ${PROTOBUF_LIBRARY} | ||||
${register} | ${register} | ||||
${c_sec} | ${c_sec} | ||||
@@ -267,33 +188,18 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"common/formats/utils/formats_trans_utils.cc" | "common/formats/utils/formats_trans_utils.cc" | ||||
"common/fp16_t.cc" | "common/fp16_t.cc" | ||||
"common/ge/plugin_manager.cc" | "common/ge/plugin_manager.cc" | ||||
"common/helper/model_cache_helper.cc" | |||||
"common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
"engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
"ge_local_engine/engine/host_cpu_engine.cc" | "ge_local_engine/engine/host_cpu_engine.cc" | ||||
"generator/ge_generator.cc" | "generator/ge_generator.cc" | ||||
"generator/generator_api.cc" | "generator/generator_api.cc" | ||||
"graph/build/graph_builder.cc" | |||||
"graph/build/label_allocator.cc" | |||||
"graph/build/logical_stream_allocator.cc" | |||||
"graph/build/model_builder.cc" | |||||
"graph/build/run_context.cc" | |||||
"graph/build/stream_allocator.cc" | |||||
"graph/build/stream_graph_optimizer.cc" | |||||
"graph/build/task_generator.cc" | |||||
"graph/common/bcast.cc" | |||||
"graph/common/omg_util.cc" | |||||
"graph/common/transop_util.cc" | |||||
"graph/build/*.cc" | |||||
"graph/common/*.cc" | |||||
"graph/execute/graph_execute.cc" | "graph/execute/graph_execute.cc" | ||||
"graph/label/*.cc" | "graph/label/*.cc" | ||||
"graph/load/graph_loader.cc" | "graph/load/graph_loader.cc" | ||||
"graph/load/new_model_manager/cpu_queue_schedule.cc" | |||||
"graph/load/new_model_manager/data_dumper.cc" | |||||
"graph/load/new_model_manager/data_inputer.cc" | |||||
"graph/load/new_model_manager/davinci_model.cc" | |||||
"graph/load/new_model_manager/davinci_model_parser.cc" | |||||
"graph/load/new_model_manager/model_manager.cc" | |||||
"graph/load/new_model_manager/model_output.cc" | |||||
"graph/load/new_model_manager/model_utils.cc" | |||||
"graph/load/new_model_manager/*.cc" | |||||
"graph/load/new_model_manager/task_info/end_graph_task_info.cc" | "graph/load/new_model_manager/task_info/end_graph_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/event_record_task_info.cc" | "graph/load/new_model_manager/task_info/event_record_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/event_wait_task_info.cc" | "graph/load/new_model_manager/task_info/event_wait_task_info.cc" | ||||
@@ -301,8 +207,10 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" | "graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" | "graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/kernel_task_info.cc" | "graph/load/new_model_manager/task_info/kernel_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/label_goto_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/label_set_task_info.cc" | "graph/load/new_model_manager/task_info/label_set_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
@@ -311,41 +219,18 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | ||||
"graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | ||||
"graph/load/new_model_manager/task_info/task_info.cc" | "graph/load/new_model_manager/task_info/task_info.cc" | ||||
"graph/load/new_model_manager/tbe_handle_store.cc" | |||||
"graph/load/output/output.cc" | "graph/load/output/output.cc" | ||||
"graph/manager/graph_context.cc" | |||||
"graph/manager/graph_manager.cc" | |||||
"graph/manager/graph_manager_utils.cc" | |||||
"graph/manager/graph_mem_allocator.cc" | |||||
"graph/manager/graph_var_manager.cc" | |||||
"graph/manager/*.cc" | |||||
"graph/manager/model_manager/event_manager.cc" | "graph/manager/model_manager/event_manager.cc" | ||||
"graph/manager/trans_var_data_utils.cc" | |||||
"graph/manager/util/debug.cc" | "graph/manager/util/debug.cc" | ||||
"graph/manager/util/rt_context_util.cc" | "graph/manager/util/rt_context_util.cc" | ||||
"graph/manager/util/variable_accelerate_ctrl.cc" | "graph/manager/util/variable_accelerate_ctrl.cc" | ||||
"graph/optimize/graph_optimize.cc" | "graph/optimize/graph_optimize.cc" | ||||
"graph/optimize/summary_optimize.cc" | "graph/optimize/summary_optimize.cc" | ||||
"graph/partition/dynamic_shape_partition.cc" | |||||
"graph/partition/engine_place.cc" | "graph/partition/engine_place.cc" | ||||
"graph/partition/graph_partition.cc" | "graph/partition/graph_partition.cc" | ||||
"graph/passes/addn_pass.cc" | |||||
"graph/passes/aicpu_constant_folding_pass.cc" | |||||
"graph/passes/assert_pass.cc" | |||||
"graph/passes/atomic_addr_clean_pass.cc" | |||||
"graph/passes/base_pass.cc" | |||||
"graph/passes/cast_remove_pass.cc" | |||||
"graph/passes/cast_translate_pass.cc" | |||||
"graph/passes/common_subexpression_elimination_pass.cc" | |||||
"graph/passes/compile_nodes_pass.cc" | |||||
"graph/passes/constant_folding_pass.cc" | |||||
"graph/passes/constant_fuse_same_pass.cc" | |||||
"graph/passes/control_op_attr_pass.cc" | |||||
"graph/passes/control_trigger_pass.cc" | |||||
"graph/passes/dimension_adjust_pass.cc" | |||||
"graph/passes/dimension_compute_pass.cc" | |||||
"graph/passes/dropout_pass.cc" | |||||
"graph/passes/end_graph_pass.cc" | |||||
"graph/passes/enter_pass.cc" | |||||
"graph/passes/flow_ctrl_pass.cc" | |||||
"graph/passes/*.cc" | |||||
"graph/passes/folding_kernel/add_kernel.cc" | "graph/passes/folding_kernel/add_kernel.cc" | ||||
"graph/passes/folding_kernel/broadcast_args_kernel.cc" | "graph/passes/folding_kernel/broadcast_args_kernel.cc" | ||||
"graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc" | "graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc" | ||||
@@ -380,87 +265,33 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/passes/folding_kernel/strided_slice_kernel.cc" | "graph/passes/folding_kernel/strided_slice_kernel.cc" | ||||
"graph/passes/folding_kernel/sub_kernel.cc" | "graph/passes/folding_kernel/sub_kernel.cc" | ||||
"graph/passes/folding_kernel/transdata_kernel.cc" | "graph/passes/folding_kernel/transdata_kernel.cc" | ||||
"graph/passes/folding_kernel/transpose_kernel.cc" | |||||
"graph/passes/folding_kernel/unpack_kernel.cc" | "graph/passes/folding_kernel/unpack_kernel.cc" | ||||
"graph/passes/folding_pass.cc" | |||||
"graph/passes/get_original_format_pass.cc" | |||||
"graph/passes/guarantee_const_pass.cc" | |||||
"graph/passes/hccl_memcpy_pass.cc" | |||||
"graph/passes/identify_reference_pass.cc" | |||||
"graph/passes/identity_pass.cc" | |||||
"graph/passes/infershape_pass.cc" | |||||
"graph/passes/isolated_op_remove_pass.cc" | |||||
"graph/passes/iterator_op_pass.cc" | |||||
"graph/passes/link_gen_mask_nodes_pass.cc" | |||||
"graph/passes/merge_pass.cc" | |||||
"graph/passes/multi_batch_pass.cc" | |||||
"graph/passes/net_output_pass.cc" | |||||
"graph/passes/next_iteration_pass.cc" | |||||
"graph/passes/no_use_reshape_remove_pass.cc" | |||||
"graph/passes/pass_manager.cc" | |||||
"graph/passes/pass_utils.cc" | |||||
"graph/passes/permute_pass.cc" | |||||
"graph/passes/placeholder_with_default_pass.cc" | |||||
"graph/passes/prevent_gradient_pass.cc" | |||||
"graph/passes/print_op_pass.cc" | |||||
"graph/passes/prune_pass.cc" | |||||
"graph/passes/reshape_remove_pass.cc" | |||||
"graph/passes/resource_pair_add_control_pass.cc" | |||||
"graph/passes/resource_pair_remove_control_pass.cc" | |||||
"graph/passes/same_transdata_breadth_fusion_pass.cc" | |||||
"graph/passes/save_pass.cc" | |||||
"graph/passes/shape_operate_op_remove_pass.cc" | |||||
"graph/passes/snapshot_pass.cc" | |||||
"graph/passes/stop_gradient_pass.cc" | |||||
"graph/passes/switch_logic_remove_pass.cc" | |||||
"graph/passes/switch_op_pass.cc" | |||||
"graph/passes/switch_pass.cc" | |||||
"graph/passes/transop_breadth_fusion_pass.cc" | |||||
"graph/passes/transop_depth_fusion_pass.cc" | |||||
"graph/passes/transop_nearby_allreduce_fusion_pass.cc" | |||||
"graph/passes/transop_without_reshape_fusion_pass.cc" | |||||
"graph/passes/transpose_transdata_pass.cc" | |||||
"graph/passes/unused_const_pass.cc" | |||||
"graph/passes/unused_op_remove_pass.cc" | |||||
"graph/passes/var_is_initialized_op_pass.cc" | |||||
"graph/passes/variable_format_pass.cc" | |||||
"graph/passes/variable_op_pass.cc" | |||||
"graph/passes/variable_prepare_op_pass.cc" | |||||
"graph/passes/variable_ref_delete_op_pass.cc" | |||||
"graph/preprocess/graph_preprocess.cc" | "graph/preprocess/graph_preprocess.cc" | ||||
"graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
"graph/preprocess/insert_op/util_insert_aipp_op.cc" | "graph/preprocess/insert_op/util_insert_aipp_op.cc" | ||||
"graph/preprocess/multi_batch_copy_graph.cc" | "graph/preprocess/multi_batch_copy_graph.cc" | ||||
"init/gelib.cc" | "init/gelib.cc" | ||||
"ir_build/atc_ir_common.cc" | |||||
"ir_build/ge_ir_build.cc" | |||||
"model/ge_model.cc" | "model/ge_model.cc" | ||||
"omm/csa_interact.cc" | "omm/csa_interact.cc" | ||||
"opskernel_manager/ops_kernel_manager.cc" | "opskernel_manager/ops_kernel_manager.cc" | ||||
"session/inner_session.cc" | "session/inner_session.cc" | ||||
"session/session_manager.cc" | "session/session_manager.cc" | ||||
"single_op/single_op.cc" | |||||
"single_op/single_op_manager.cc" | |||||
"single_op/single_op_model.cc" | |||||
"single_op/stream_resource.cc" | |||||
"single_op/task/build_task_utils.cc" | |||||
"single_op/task/op_task.cc" | |||||
"single_op/task/tbe_task_builder.cc" | |||||
########################################## | |||||
# "ir_build/ge_ir_build.cc" | |||||
# "offline/atc_ir_common.cc" | |||||
"single_op/*.cc" | |||||
"single_op/task/*.cc" | |||||
) | ) | ||||
add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | ||||
target_compile_definitions(ge_compiler PRIVATE | target_compile_definitions(ge_compiler PRIVATE | ||||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | PROTOBUF_INLINE_NOT_IN_HEADERS=0 | ||||
DAVINCI_SUPPORT_PROFILING | |||||
REUSE_MEMORY=1 | REUSE_MEMORY=1 | ||||
FMK_HOST_INFER | |||||
PLATFORM_CLOUD) | |||||
FMK_HOST_INFER) | |||||
target_link_libraries(ge_compiler | target_link_libraries(ge_compiler | ||||
graph | graph | ||||
ge_common | ge_common | ||||
"-Wl,--whole-archive" | |||||
ge_memory | ge_memory | ||||
"-Wl,--no-whole-archive" | |||||
${PROTOBUF_LIBRARY} | ${PROTOBUF_LIBRARY} | ||||
${register} | ${register} | ||||
${c_sec} | ${c_sec} | ||||
@@ -469,5 +300,6 @@ target_link_libraries(ge_compiler | |||||
${msprof} | ${msprof} | ||||
${runtime} | ${runtime} | ||||
${resouce} | ${resouce} | ||||
${error_manager} | |||||
rt | rt | ||||
dl) | dl) |
@@ -13,21 +13,21 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
# ============================================================================ | # ============================================================================ | ||||
# libge_client.so & libge_client_train.so | |||||
# libge_client.so | |||||
# add all proto files, generate corresponding .h and .cc files | # add all proto files, generate corresponding .h and .cc files | ||||
set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | ||||
file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"../../proto/ge_api.proto" | "../../proto/ge_api.proto" | ||||
) | ) | ||||
file(GLOB_RECURSE PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
file(GLOB PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"../../proto/ge_ir.proto" | "../../proto/ge_ir.proto" | ||||
"../../proto/task.proto" | "../../proto/task.proto" | ||||
"../../proto/om.proto" | "../../proto/om.proto" | ||||
"../../proto/insert_op.proto" | "../../proto/insert_op.proto" | ||||
) | ) | ||||
file(GLOB_RECURSE SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"ge_api.cc" | "ge_api.cc" | ||||
) | ) | ||||
@@ -49,30 +49,9 @@ include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
######### libge_client_train.so ############# | |||||
add_library(ge_client_train SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||||
target_compile_definitions(ge_client_train PRIVATE | |||||
Werror | |||||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
REUSE_MEMORY=1 | |||||
PLATFORM_CLOUD | |||||
DAVINCI_CLOUD) | |||||
target_link_libraries(ge_client_train | |||||
graph | |||||
ge_train | |||||
ge_common | |||||
${PROTOBUF_LIBRARY} | |||||
${register} | |||||
${c_sec} | |||||
${slog} | |||||
${mmpa} | |||||
${runtime} | |||||
rt | |||||
dl) | |||||
############ libge_client.so ################ | ############ libge_client.so ################ | ||||
add_library(ge_client SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | add_library(ge_client SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | ||||
target_compile_definitions(ge_client_train PRIVATE | |||||
target_compile_definitions(ge_client PRIVATE | |||||
Werror | Werror | ||||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | PROTOBUF_INLINE_NOT_IN_HEADERS=0 | ||||
REUSE_MEMORY=1 | REUSE_MEMORY=1 | ||||
@@ -32,17 +32,18 @@ | |||||
using domi::GetContext; | using domi::GetContext; | ||||
using domi::OpRegistry; | using domi::OpRegistry; | ||||
using domi::RealPath; | |||||
using domi::StringUtils; | |||||
using std::map; | using std::map; | ||||
using std::string; | using std::string; | ||||
using std::vector; | using std::vector; | ||||
namespace ge { | |||||
static const int32_t kMaxStrLen = 128; | |||||
namespace { | |||||
const int32_t kMaxStrLen = 128; | |||||
} | |||||
static bool kGeInitialized = false; | static bool kGeInitialized = false; | ||||
static std::mutex kGeReleaseMutex; // GEFinalize and ~Session use | static std::mutex kGeReleaseMutex; // GEFinalize and ~Session use | ||||
namespace ge { | |||||
void GetOpsProtoPath(std::string &opsproto_path) { | void GetOpsProtoPath(std::string &opsproto_path) { | ||||
GELOGI("Enter get ops proto path schedule"); | GELOGI("Enter get ops proto path schedule"); | ||||
const char *path_env = std::getenv("ASCEND_OPP_PATH"); | const char *path_env = std::getenv("ASCEND_OPP_PATH"); | ||||
@@ -394,8 +395,8 @@ Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc | |||||
return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | ||||
} | } | ||||
Status Session::RunGraphAsync(uint32_t graph_id, const std::vector<TensorInfo> &inputs, | |||||
std::vector<TensorInfo> &outputs, std::function<void(Status)> callback) { | |||||
Status Session::RunGraphAsync(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs, | |||||
RunAsyncCallback callback) { | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | ||||
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | ||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "SessionConstructor failed"); | GELOGE(GE_CLI_GE_NOT_INITIALIZED, "SessionConstructor failed"); | ||||
@@ -405,8 +406,7 @@ Status Session::RunGraphAsync(uint32_t graph_id, const std::vector<TensorInfo> & | |||||
GELOGW( | GELOGW( | ||||
"The callback function will not be checked. Please ensure that the implementation of the function is trusted."); | "The callback function will not be checked. Please ensure that the implementation of the function is trusted."); | ||||
Status ret = | |||||
ge::GELib::GetInstance()->SessionManagerObj().RunGraphAsync(sessionId_, graph_id, inputs, outputs, callback); | |||||
Status ret = ge::GELib::GetInstance()->SessionManagerObj().RunGraphAsync(sessionId_, graph_id, inputs, callback); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "SessionManager RunGraphAsync failed"); | GELOGE(ret, "SessionManager RunGraphAsync failed"); | ||||
return FAILED; | return FAILED; | ||||
@@ -28,7 +28,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"debug/memory_dumper.cc" | "debug/memory_dumper.cc" | ||||
"fmk_error_codes.cc" | "fmk_error_codes.cc" | ||||
"formats/format_transfers/datatype_transfer.cc" | "formats/format_transfers/datatype_transfer.cc" | ||||
"formats/format_transfers/format_transfer.cc" | |||||
"formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" | "formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" | ||||
"formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" | "formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" | ||||
"formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" | "formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" | ||||
@@ -41,6 +40,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" | "formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" | ||||
"formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" | "formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" | ||||
"formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" | "formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" | ||||
"formats/format_transfers/format_transfer_nchw_fz_c04.cc" | |||||
"formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" | "formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" | ||||
"formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" | "formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" | ||||
"formats/format_transfers/format_transfer_transpose.cc" | "formats/format_transfers/format_transfer_transpose.cc" | ||||
@@ -54,6 +54,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"helper/om_file_helper.cc" | "helper/om_file_helper.cc" | ||||
"math/fp16_math.cc" | "math/fp16_math.cc" | ||||
"model_parser/base.cc" | "model_parser/base.cc" | ||||
"model_saver.cc" | |||||
"op/attr_value_util.cc" | "op/attr_value_util.cc" | ||||
"op/ge_op_utils.cc" | "op/ge_op_utils.cc" | ||||
"properties_manager.cc" | "properties_manager.cc" | ||||
@@ -61,9 +62,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"thread_pool.cc" | "thread_pool.cc" | ||||
"types.cc" | "types.cc" | ||||
"util.cc" | "util.cc" | ||||
"model_saver.cc" | |||||
############################### | |||||
"op/attr_define.cc" | |||||
) | ) | ||||
ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | ||||
@@ -73,6 +71,7 @@ include_directories(${CMAKE_CURRENT_LIST_DIR}) | |||||
include_directories(${CMAKE_CURRENT_LIST_DIR}/op) | include_directories(${CMAKE_CURRENT_LIST_DIR}/op) | ||||
include_directories(${GE_SOURCE_DIR}/src/ge) | include_directories(${GE_SOURCE_DIR}/src/ge) | ||||
include_directories(${GE_SOURCE_DIR}/inc) | include_directories(${GE_SOURCE_DIR}/inc) | ||||
include_directories(${GE_SOURCE_DIR}/inc/common/util) | |||||
include_directories(${GE_SOURCE_DIR}/inc/external) | include_directories(${GE_SOURCE_DIR}/inc/external) | ||||
include_directories(${GE_SOURCE_DIR}/inc/external/graph) | include_directories(${GE_SOURCE_DIR}/inc/external/graph) | ||||
include_directories(${GE_SOURCE_DIR}/inc/framework) | include_directories(${GE_SOURCE_DIR}/inc/framework) | ||||
@@ -96,5 +95,6 @@ target_link_libraries(ge_common | |||||
${slog} | ${slog} | ||||
${mmpa} | ${mmpa} | ||||
${resource} | ${resource} | ||||
${error_manager} | |||||
rt | rt | ||||
dl) | dl) |
@@ -17,7 +17,6 @@ | |||||
#include "common/auth/file_saver.h" | #include "common/auth/file_saver.h" | ||||
#include <fcntl.h> | #include <fcntl.h> | ||||
#include <securec.h> | #include <securec.h> | ||||
#include <unistd.h> | #include <unistd.h> | ||||
#include <cstdlib> | #include <cstdlib> | ||||
@@ -29,10 +28,6 @@ | |||||
#include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
#include "framework/common/util.h" | #include "framework/common/util.h" | ||||
using domi::CreateDirectory; | |||||
using domi::ModelEncryptType; | |||||
using ge::ModelBufferData; | |||||
namespace { | namespace { | ||||
const int kFileOpSuccess = 0; | const int kFileOpSuccess = 0; | ||||
} // namespace | } // namespace | ||||
@@ -270,4 +265,4 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(co | |||||
} | } | ||||
return ret; | return ret; | ||||
} | } | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -26,30 +26,26 @@ | |||||
#include "graph/buffer.h" | #include "graph/buffer.h" | ||||
#include "mmpa/mmpa_api.h" | #include "mmpa/mmpa_api.h" | ||||
using domi::ModelFileHeader; | |||||
using domi::ModelPartition; | |||||
using domi::ModelPartitionTable; | |||||
struct PROC_PARAM { | struct PROC_PARAM { | ||||
uint8_t *model_name; | uint8_t *model_name; | ||||
/* ISV Ek buffer */ | |||||
// ISV Ek buffer | |||||
uint8_t *model_key; | uint8_t *model_key; | ||||
uint32_t model_key_len; | uint32_t model_key_len; | ||||
/* ISV root certificate buffer */ | |||||
// ISV root certificate buffer | |||||
uint8_t *root_cert; | uint8_t *root_cert; | ||||
uint32_t root_cert_len; | uint32_t root_cert_len; | ||||
/* ISV private key buffer */ | |||||
// ISV private key buffer | |||||
uint8_t *pri_key; | uint8_t *pri_key; | ||||
uint32_t pri_key_len; | uint32_t pri_key_len; | ||||
/* Raw AI Module Image buffer */ | |||||
// Raw AI Module Image buffer | |||||
uint8_t *ai_image; | uint8_t *ai_image; | ||||
uint32_t ai_image_len; | uint32_t ai_image_len; | ||||
/* ISV HW key buffer */ | |||||
// ISV HW key buffer | |||||
uint8_t *hw_key; | uint8_t *hw_key; | ||||
uint32_t hw_key_len; | uint32_t hw_key_len; | ||||
}; | }; | ||||
@@ -66,11 +62,11 @@ using std::string; | |||||
class FileSaver { | class FileSaver { | ||||
public: | public: | ||||
/** | |||||
* @ingroup domi_common | |||||
* @brief save model, no encryption | |||||
* @return Status result | |||||
*/ | |||||
/// | |||||
/// @ingroup domi_common | |||||
/// @brief save model, no encryption | |||||
/// @return Status result | |||||
/// | |||||
static Status SaveToFile(const string &file_path, const ge::ModelData &model, | static Status SaveToFile(const string &file_path, const ge::ModelData &model, | ||||
const ModelFileHeader *model_file_header = nullptr); | const ModelFileHeader *model_file_header = nullptr); | ||||
@@ -84,26 +80,26 @@ class FileSaver { | |||||
static Status SaveToFile(const string &file_path, const void *data, int len); | static Status SaveToFile(const string &file_path, const void *data, int len); | ||||
protected: | protected: | ||||
/** | |||||
* @ingroup domi_common | |||||
* @brief Check validity of the file path | |||||
* @return Status result | |||||
*/ | |||||
/// | |||||
/// @ingroup domi_common | |||||
/// @brief Check validity of the file path | |||||
/// @return Status result | |||||
/// | |||||
static Status CheckPath(const string &file_path); | static Status CheckPath(const string &file_path); | ||||
static Status WriteData(const void *data, uint32_t size, int32_t fd); | static Status WriteData(const void *data, uint32_t size, int32_t fd); | ||||
static Status OpenFile(int32_t &fd, const std::string &file_path); | static Status OpenFile(int32_t &fd, const std::string &file_path); | ||||
/** | |||||
* @ingroup domi_common | |||||
* @brief save model to file | |||||
* @param [in] file_path file output path | |||||
* @param [in] file_header file header info | |||||
* @param [in] data model data | |||||
* @param [in] len model length | |||||
* @return Status result | |||||
*/ | |||||
/// | |||||
/// @ingroup domi_common | |||||
/// @brief save model to file | |||||
/// @param [in] file_path file output path | |||||
/// @param [in] file_header file header info | |||||
/// @param [in] data model data | |||||
/// @param [in] len model length | |||||
/// @return Status result | |||||
/// | |||||
static Status SaveWithFileHeader(const string &file_path, const ModelFileHeader &file_header, const void *data, | static Status SaveWithFileHeader(const string &file_path, const ModelFileHeader &file_header, const void *data, | ||||
int len); | int len); | ||||
@@ -16,6 +16,7 @@ | |||||
#include "framework/omg/omg_inner_types.h" | #include "framework/omg/omg_inner_types.h" | ||||
using ge::OmgContext; | |||||
namespace domi { | namespace domi { | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OmgContext &GetContext() { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OmgContext &GetContext() { | ||||
static OmgContext context; | static OmgContext context; | ||||
@@ -155,7 +155,7 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { | |||||
void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | ||||
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | ||||
bool enum2str) { | bool enum2str) { | ||||
if (nullptr == field || nullptr == reflection) { | |||||
if ((field == nullptr) || (reflection == nullptr)) { | |||||
Message2Json(message, black_fields, json, enum2str); | Message2Json(message, black_fields, json, enum2str); | ||||
return; | return; | ||||
} | } | ||||
@@ -28,7 +28,9 @@ | |||||
using std::string; | using std::string; | ||||
static const int kInvalidFd = (-1); | |||||
namespace { | |||||
const int kInvalidFd = (-1); | |||||
} // namespace | |||||
namespace ge { | namespace ge { | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY MemoryDumper::MemoryDumper() : fd_(kInvalidFd) {} | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY MemoryDumper::MemoryDumper() : fd_(kInvalidFd) {} | ||||
@@ -16,7 +16,7 @@ | |||||
#include "common/formats/format_transfers/datatype_transfer.h" | #include "common/formats/format_transfers/datatype_transfer.h" | ||||
#include <stdint.h> | |||||
#include <cstdint> | |||||
#include <map> | #include <map> | ||||
#include <utility> | #include <utility> | ||||
@@ -27,8 +27,6 @@ | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "securec.h" | #include "securec.h" | ||||
using ge::fp16_t; | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -134,10 +132,6 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
} | } | ||||
auto trans_mode = iter->second; | auto trans_mode = iter->second; | ||||
if (args.src_data_size == 0) { | |||||
GELOGE(PARAM_INVALID, "Invalid src data size %zu", args.src_data_size); | |||||
return PARAM_INVALID; | |||||
} | |||||
int size = GetSizeByDataType(args.dst_data_type); | int size = GetSizeByDataType(args.dst_data_type); | ||||
if (size <= 0) { | if (size <= 0) { | ||||
GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", | GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", | ||||
@@ -149,6 +143,12 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
size_t total_size = static_cast<size_t>(args.src_data_size * size); | size_t total_size = static_cast<size_t>(args.src_data_size * size); | ||||
result.length = total_size; | |||||
if (total_size == 0) { | |||||
GELOGI("In TransDataType, total_size is zero, has no data."); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); | GELOGE(OUT_OF_MEMORY, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); | ||||
@@ -162,7 +162,6 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
result.data = dst; | result.data = dst; | ||||
result.length = total_size; | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -21,7 +21,7 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <vector> | #include <vector> | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "register/register_format_transfer.h" | |||||
#include "external/graph/types.h" | #include "external/graph/types.h" | ||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
@@ -1,69 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
#include <map> | |||||
#include <utility> | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "graph/utils/type_utils.h" | |||||
namespace ge { | |||||
namespace formats { | |||||
namespace { | |||||
struct FormatTransferRegistry { | |||||
Status RegisterBuilder(Format src, Format dst, FormatTransferBuilder builder) { | |||||
src_dst_builder[src][dst] = std::move(builder); | |||||
return SUCCESS; | |||||
} | |||||
std::map<Format, std::map<Format, FormatTransferBuilder>> src_dst_builder; | |||||
}; | |||||
FormatTransferRegistry &GetFormatTransferRegistry() { | |||||
static FormatTransferRegistry registry; | |||||
return registry; | |||||
} | |||||
} // namespace | |||||
std::shared_ptr<FormatTransfer> BuildFormatTransfer(const TransArgs &args) { | |||||
auto registry = GetFormatTransferRegistry(); | |||||
auto dst_builder = registry.src_dst_builder.find(args.src_format); | |||||
if (dst_builder == registry.src_dst_builder.end()) { | |||||
return nullptr; | |||||
} | |||||
auto builder_iter = dst_builder->second.find(args.dst_format); | |||||
if (builder_iter == dst_builder->second.end()) { | |||||
return nullptr; | |||||
} | |||||
return builder_iter->second(); | |||||
} | |||||
bool FormatTransferExists(const TransArgs &args) { | |||||
auto registry = GetFormatTransferRegistry(); | |||||
auto dst_builder = registry.src_dst_builder.find(args.src_format); | |||||
if (dst_builder == registry.src_dst_builder.end()) { | |||||
return false; | |||||
} | |||||
return dst_builder->second.count(args.dst_format) > 0; | |||||
} | |||||
FormatTransferRegister::FormatTransferRegister(FormatTransferBuilder builder, Format src, Format dst) { | |||||
(void)GetFormatTransferRegistry().RegisterBuilder(src, dst, std::move(builder)); | |||||
// RegisterBuilder() always return success, no need to check value | |||||
} | |||||
} // namespace formats | |||||
} // namespace ge |
@@ -27,7 +27,9 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
namespace { | namespace { | ||||
bool CheckDataTypeSupported(const DataType &data_type) { return (data_type == DT_FLOAT || data_type == DT_FLOAT16); } | |||||
bool CheckDataTypeSupported(const DataType &data_type) { | |||||
return (data_type == DT_FLOAT || data_type == DT_FLOAT16 || data_type == DT_INT8); | |||||
} | |||||
Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | ||||
auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
@@ -51,10 +53,11 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||||
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / kCubeSize + 1 || | |||||
auto cube_size = GetCubeSizeByDataType(args.src_data_type); | |||||
if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / cube_size + 1 || | |||||
src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || | src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || | ||||
src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != kCubeSize || | |||||
src_shape.at(kC1hwncoc0C0) != kCubeSize) { | |||||
src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size || | |||||
src_shape.at(kC1hwncoc0C0) != cube_size) { | |||||
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | ||||
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -78,6 +81,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||||
auto c0 = args.src_shape.at(kC1hwncoc0C0); | auto c0 = args.src_shape.at(kC1hwncoc0C0); | ||||
auto co = args.src_shape.at(kC1hwncoc0Co); | auto co = args.src_shape.at(kC1hwncoc0Co); | ||||
auto c = args.dst_shape.at(kHwcnC); | auto c = args.dst_shape.at(kHwcnC); | ||||
auto cube_size = GetCubeSizeByDataType(args.src_data_type); | |||||
int64_t cn = c * n; | int64_t cn = c * n; | ||||
int64_t wcn = w * cn; | int64_t wcn = w * cn; | ||||
int64_t coc0 = co * c0; | int64_t coc0 = co * c0; | ||||
@@ -93,8 +97,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||||
int64_t c_head_addr = w_head_addr + c_idx * n; | int64_t c_head_addr = w_head_addr + c_idx * n; | ||||
for (int64_t n_idx = 0; n_idx < n; n_idx++) { | for (int64_t n_idx = 0; n_idx < n; n_idx++) { | ||||
int64_t dst_idx = c_head_addr + n_idx; | int64_t dst_idx = c_head_addr + n_idx; | ||||
int64_t c1_idx = c_idx / kCubeSize; | |||||
int64_t c0_idx = c_idx % kCubeSize; | |||||
int64_t c1_idx = c_idx / cube_size; | |||||
int64_t c0_idx = c_idx % cube_size; | |||||
int64_t co_idx = c0_idx; | int64_t co_idx = c0_idx; | ||||
int64_t src_idx = c1_idx * hwncoc0 + h_idx * wncoc0 + w_idx * ncoc0 + n_idx * coc0 + co_idx * c0 + c0_idx; | int64_t src_idx = c1_idx * hwncoc0 + h_idx * wncoc0 + w_idx * ncoc0 + n_idx * coc0 + co_idx * c0 + c0_idx; | ||||
auto src_offset = src_idx * size; | auto src_offset = src_idx * size; | ||||
@@ -130,6 +134,11 @@ Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResu | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t total_size = GetItemNumByShape(args.dst_shape) * size; | int64_t total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -19,7 +19,7 @@ | |||||
#include <vector> | #include <vector> | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "register/register_format_transfer.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -88,6 +88,11 @@ Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { | |||||
dst_size *= dim; | dst_size *= dim; | ||||
} | } | ||||
dst_size *= data_size; | dst_size *= data_size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -18,7 +18,7 @@ | |||||
#define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWCN_FRACTAL_Z_3D_H_ | #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWCN_FRACTAL_Z_3D_H_ | ||||
#include <vector> | #include <vector> | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "register/register_format_transfer.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -89,6 +89,11 @@ Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &resul | |||||
dst_size *= dim; | dst_size *= dim; | ||||
} | } | ||||
dst_size *= data_size; | dst_size *= data_size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -18,7 +18,7 @@ | |||||
#define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWNC_FRACTAL_Z_3D_TRANSPOSE_H_ | #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWNC_FRACTAL_Z_3D_TRANSPOSE_H_ | ||||
#include <vector> | #include <vector> | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "register/register_format_transfer.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -116,6 +116,11 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -184,6 +189,11 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||||
Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -19,7 +19,7 @@ | |||||
#include <vector> | #include <vector> | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "register/register_format_transfer.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -119,6 +119,11 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||||
int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = total_ele_cnt * size; | int64_t dst_size = total_ele_cnt * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -194,6 +199,11 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||||
dst_size *= dim; | dst_size *= dim; | ||||
} | } | ||||
dst_size *= data_size; | dst_size *= data_size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -259,6 +269,11 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||||
dst_size *= dim; | dst_size *= dim; | ||||
} | } | ||||
dst_size *= data_size; | dst_size *= data_size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -19,7 +19,7 @@ | |||||
#include <vector> | #include <vector> | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "register/register_format_transfer.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -117,6 +117,11 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -189,6 +194,11 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||||
Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -19,7 +19,7 @@ | |||||
#include <vector> | #include <vector> | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "register/register_format_transfer.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -133,6 +133,12 @@ Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult & | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -19,7 +19,7 @@ | |||||
#include <vector> | #include <vector> | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "register/register_format_transfer.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -133,6 +133,12 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -140,6 +146,7 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||||
GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
@@ -19,7 +19,7 @@ | |||||
#include <vector> | #include <vector> | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "register/register_format_transfer.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||