Merge pull request !6 from yanghaoran/mastertags/v0.2.0-alpha
@@ -75,17 +75,16 @@ elseif(DEFINED ENV{D_LINK_PATH}) | |||||
find_library(resource libresource.so ${GE_LIB_PATH}) | find_library(resource libresource.so ${GE_LIB_PATH}) | ||||
else() | else() | ||||
# Ascend mode | # Ascend mode | ||||
set(HIAI_INSTALLED_DIR /usr/local/HiAI) | |||||
set(HIAI_DRIVER_DIR ${HIAI_INSTALLED_DIR}/driver/lib64) | |||||
set(HIAI_RUNTIME_DIR ${HIAI_INSTALLED_DIR}/runtime/lib64) | |||||
set(HIAI_INSTALLED_DIR /usr/local/Ascend) | |||||
set(HIAI_DRIVER_DIR ${HIAI_INSTALLED_DIR}/driver/lib64/common) | |||||
set(HIAI_RUNTIME_DIR ${HIAI_INSTALLED_DIR}/fwkacllib/lib64) | |||||
find_library(c_sec libc_sec.so ${HIAI_DRIVER_DIR}) | find_library(c_sec libc_sec.so ${HIAI_DRIVER_DIR}) | ||||
find_library(slog libslog.so ${HIAI_DRIVER_DIR}) | find_library(slog libslog.so ${HIAI_DRIVER_DIR}) | ||||
find_library(mmpa libmmpa.so ${HIAI_DRIVER_DIR}) | find_library(mmpa libmmpa.so ${HIAI_DRIVER_DIR}) | ||||
find_library(msprof libmsprof.so ${HIAI_DRIVER_DIR}) | |||||
find_library(cce libcce.so ${HIAI_RUNTIME_DIR}) | |||||
find_library(hccl libhccl.so ${HIAI_RUNTIME_DIR}) | find_library(hccl libhccl.so ${HIAI_RUNTIME_DIR}) | ||||
find_library(runtime libruntime.so ${HIAI_RUNTIME_DIR}) | find_library(runtime libruntime.so ${HIAI_RUNTIME_DIR}) | ||||
find_library(msprof libmsprof.so ${HIAI_RUNTIME_DIR}) | |||||
find_library(register libregister.so ${HIAI_RUNTIME_DIR}) | find_library(register libregister.so ${HIAI_RUNTIME_DIR}) | ||||
find_library(resource libresource.so ${HIAI_RUNTIME_DIR}) | find_library(resource libresource.so ${HIAI_RUNTIME_DIR}) | ||||
endif() | endif() | ||||
@@ -18,16 +18,15 @@ | |||||
#define INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ | #define INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ | ||||
#include <stdint.h> | #include <stdint.h> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/buffer.h" | #include "graph/buffer.h" | ||||
#include "runtime/rt_model.h" | #include "runtime/rt_model.h" | ||||
using std::string; | using std::string; | ||||
namespace ge { | namespace ge { | ||||
/*lint -e148*/ | |||||
struct RunContext { | struct RunContext { | ||||
rtModel_t model; | rtModel_t model; | ||||
rtStream_t stream; | rtStream_t stream; | ||||
@@ -37,10 +36,12 @@ struct RunContext { | |||||
uint64_t weightMemSize; | uint64_t weightMemSize; | ||||
uint8_t *weightMemBase; | uint8_t *weightMemBase; | ||||
ge::Buffer weightsBuffer; | ge::Buffer weightsBuffer; | ||||
std::vector<rtStream_t> graphStreamList; // all streams of graph which are sort by ge stream id(0,1,...) | |||||
std::vector<rtEvent_t> graphEventList; // all events of graph which are sort by ge event id(0,1,...) | |||||
std::vector<rtStream_t> graphStreamList; // all streams of graph, order by ge stream id(0,1,...) | |||||
std::vector<rtEvent_t> graphEventList; // all events of graph, order by ge event id(0,1,...) | |||||
}; | }; | ||||
/*lint +e148*/ | |||||
struct Task { | struct Task { | ||||
uint32_t id; | uint32_t id; | ||||
uint16_t type; | uint16_t type; | ||||
@@ -49,10 +50,11 @@ struct Task { | |||||
}; | }; | ||||
struct OpInfo { | struct OpInfo { | ||||
string engine; // engine name | |||||
string opKernelLib; // opsKernelStore name | |||||
string engine; // which engin | |||||
/*lint -e148*/ | |||||
string opKernelLib; // which opsKernelStore | |||||
int computeCost; // compute cost | int computeCost; // compute cost | ||||
bool flagPartial; // whether to support related shape | |||||
bool flagPartial; // whether to support is related to shape | |||||
bool flagAsync; // Whether to support asynchronous | bool flagAsync; // Whether to support asynchronous | ||||
bool isAtomic; // whether to support atomic addr clean | bool isAtomic; // whether to support atomic addr clean | ||||
string opFileName; // op file name | string opFileName; // op file name | ||||
@@ -50,6 +50,16 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||||
Status AddGraph(uint32_t graphId, const Graph &graph); | Status AddGraph(uint32_t graphId, const Graph &graph); | ||||
/// | /// | ||||
/// @ingroup client | |||||
/// @brief add a graph with a specific graphId and graphOptions | |||||
/// @param [in] graphId graph id | |||||
/// @param [in] graph the graph | |||||
/// @param [in] options graph options | |||||
/// @return Status result of function | |||||
/// | |||||
Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<std::string, std::string> &options); | |||||
/// | |||||
/// @ingroup ge_graph | /// @ingroup ge_graph | ||||
/// @brief remove a graph of the session with specific session id | /// @brief remove a graph of the session with specific session id | ||||
/// @param [in] graphId graph id | /// @param [in] graphId graph id | ||||
@@ -50,7 +50,7 @@ const char *const VARIABLE_MEMORY_MAX_SIZE = "ge.variableMemoryMaxSize"; | |||||
// its value should be int32_t type, default value is "1" | // its value should be int32_t type, default value is "1" | ||||
const std::string STREAM_NUM = "ge.streamNum"; | const std::string STREAM_NUM = "ge.streamNum"; | ||||
// Configure add head stream to model, | |||||
// Configure add head stream to model. | |||||
// its value should be "0" or "1", default value is "0" | // its value should be "0" or "1", default value is "0" | ||||
const std::string HEAD_STREAM = "ge.headStream"; | const std::string HEAD_STREAM = "ge.headStream"; | ||||
@@ -138,7 +138,7 @@ const std::string GE_FE_FLAG = "ge.feFlag"; | |||||
// this option is to obtain stream max parallel num | // this option is to obtain stream max parallel num | ||||
const std::string STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; | const std::string STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; | ||||
// configure outputDatatype to setting net output type | |||||
// congigure outputDatatype to setting net output type | |||||
const std::string OUTPUT_DATATYPE = "ge.outputDatatype"; | const std::string OUTPUT_DATATYPE = "ge.outputDatatype"; | ||||
// configure whether to enable hcom parallel by session constructor options param, | // configure whether to enable hcom parallel by session constructor options param, | ||||
@@ -149,7 +149,7 @@ const std::string HCOM_PARALLEL = "ge.hcomParallel"; | |||||
// example: GA|RL, support configure multiple, split by | | // example: GA|RL, support configure multiple, split by | | ||||
const std::string AUTO_TUNE_MODE = "ge.autoTuneMode"; | const std::string AUTO_TUNE_MODE = "ge.autoTuneMode"; | ||||
// Configure core type "VectorEngine", default value is "AICoreEngine" | |||||
// Configure core type "VectorEngine", default value is "AIcoreEngine" | |||||
const std::string CORE_TYPE = "ge.engineType"; | const std::string CORE_TYPE = "ge.engineType"; | ||||
// Configure soc version , example: "Ascend310" | // Configure soc version , example: "Ascend310" | ||||
@@ -165,6 +165,10 @@ const char *const OPTION_GE_MAX_DUMP_FILE_NUM = "ge.maxDumpFileNum"; | |||||
const char *const OPTION_GE_MAX_DUMP_FILE_SIZE = "ge.maxDumpFileSize"; | const char *const OPTION_GE_MAX_DUMP_FILE_SIZE = "ge.maxDumpFileSize"; | ||||
const char *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; | const char *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; | ||||
// Configure for print op pass | |||||
// Its value should be "0" or "1", default value is "1" | |||||
const char *const ENABLE_PRINT_OP_PASS = "ge.enablePrintOpPass"; | |||||
// Graph run mode | // Graph run mode | ||||
enum GraphRunMode { PREDICTION = 0, TRAIN }; | enum GraphRunMode { PREDICTION = 0, TRAIN }; | ||||
@@ -28,29 +28,29 @@ namespace ge { | |||||
class InferenceContext; | class InferenceContext; | ||||
using InferenceContextPtr = std::shared_ptr<InferenceContext>; | using InferenceContextPtr = std::shared_ptr<InferenceContext>; | ||||
class ShapeAndTypeImpl; | |||||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ShapeAndType { | class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ShapeAndType { | ||||
public: | public: | ||||
ShapeAndType() = default; | |||||
ShapeAndType(); | |||||
~ShapeAndType() = default; | ~ShapeAndType() = default; | ||||
ShapeAndType(const Shape &shape, DataType data_type); | |||||
ShapeAndType(const Shape &shape, DataType dataType); | |||||
void SetShape(const Shape &shape); | void SetShape(const Shape &shape); | ||||
void SetType(DataType data_type); | |||||
void SetType(DataType dataType); | |||||
const Shape &GetShape() const; | |||||
Shape GetShape() const; | |||||
DataType GetDataType() const; | DataType GetDataType() const; | ||||
private: | private: | ||||
Shape shape_; | |||||
DataType data_type_ = DT_UNDEFINED; | |||||
std::shared_ptr<ShapeAndTypeImpl> shape_and_type_impl_; | |||||
}; | }; | ||||
class InferenceContextImpl; | |||||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { | class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { | ||||
public: | public: | ||||
InferenceContext() = default; | |||||
~InferenceContext() = default; | ~InferenceContext() = default; | ||||
InferenceContext(const InferenceContext &context) = delete; | InferenceContext(const InferenceContext &context) = delete; | ||||
InferenceContext(const InferenceContext &&context) = delete; | InferenceContext(const InferenceContext &&context) = delete; | ||||
@@ -58,22 +58,19 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { | |||||
InferenceContext &operator=(const InferenceContext &&context) = delete; | InferenceContext &operator=(const InferenceContext &&context) = delete; | ||||
void SetInputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types); | void SetInputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types); | ||||
const std::vector<std::vector<ShapeAndType>> &GetInputHandleShapesAndTypes() const; | const std::vector<std::vector<ShapeAndType>> &GetInputHandleShapesAndTypes() const; | ||||
const std::vector<std::vector<ShapeAndType>> &GetOutputHandleShapesAndTypes() const; | const std::vector<std::vector<ShapeAndType>> &GetOutputHandleShapesAndTypes() const; | ||||
void SetOutputHandleShapesAndTypes(const std::vector<std::vector<ShapeAndType>> &shapes_and_types); | void SetOutputHandleShapesAndTypes(const std::vector<std::vector<ShapeAndType>> &shapes_and_types); | ||||
void SetOutputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types); | void SetOutputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types); | ||||
void SetMarks(const std::vector<std::string> &marks); | void SetMarks(const std::vector<std::string> &marks); | ||||
const std::vector<std::string> &GetMarks() const; | const std::vector<std::string> &GetMarks() const; | ||||
static std::unique_ptr<InferenceContext> Create(); | |||||
private: | private: | ||||
// For deliver to op in pair, help to support dynamic shape | |||||
std::vector<std::string> marks_; | |||||
std::vector<std::vector<ShapeAndType>> input_handle_shapes_and_types_; | |||||
std::vector<std::vector<ShapeAndType>> output_handle_shapes_and_types_; | |||||
InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||||
std::shared_ptr<InferenceContextImpl> inference_context_impl_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ | #endif // INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ |
@@ -24,9 +24,8 @@ | |||||
#include <vector> | #include <vector> | ||||
#include "external/graph/ge_error_codes.h" | #include "external/graph/ge_error_codes.h" | ||||
#include "external/graph//inference_context.h" | |||||
#include "external/graph//tensor.h" | |||||
#include "external/graph//usr_types.h" | |||||
#include "external/graph/inference_context.h" | |||||
#include "external/graph/tensor.h" | |||||
#ifndef USER_GE_LOGI | #ifndef USER_GE_LOGI | ||||
#define USER_GE_LOGI(...) | #define USER_GE_LOGI(...) | ||||
@@ -182,9 +181,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
// Bytes type | // Bytes type | ||||
graphStatus GetAttr(const string &name, OpBytes &attr_value) const; | graphStatus GetAttr(const string &name, OpBytes &attr_value) const; | ||||
Operator &SetAttr(const string &name, const UsrQuantizeFactorParams &attr_value); | |||||
graphStatus GetAttr(const string &name, UsrQuantizeFactorParams &attr_value) const; | |||||
Operator &SetAttr(const string &name, const std::vector<std::vector<int64_t>> &attr_value); | Operator &SetAttr(const string &name, const std::vector<std::vector<int64_t>> &attr_value); | ||||
graphStatus GetAttr(const string &name, std::vector<std::vector<int64_t>> &attr_value) const; | graphStatus GetAttr(const string &name, std::vector<std::vector<int64_t>> &attr_value) const; | ||||
@@ -235,11 +231,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
graphStatus VerifyAll(); | graphStatus VerifyAll(); | ||||
// Only has one output index = 0 | // Only has one output index = 0 | ||||
Operator &SetInput(const string &dst_name, uint32_t dst_index, | |||||
const Operator &src_oprt); | |||||
Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt); | |||||
Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, | |||||
const string &name); | |||||
Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name); | |||||
private: | private: | ||||
Operator &SetInput(const string &dst_name, const OutHandler &out_handler); | Operator &SetInput(const string &dst_name, const OutHandler &out_handler); | ||||
@@ -26,9 +26,10 @@ | |||||
#include "external/graph/types.h" | #include "external/graph/types.h" | ||||
namespace ge { | namespace ge { | ||||
class ShapeImpl; | |||||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Shape { | class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Shape { | ||||
public: | public: | ||||
Shape() = default; | |||||
Shape(); | |||||
~Shape() = default; | ~Shape() = default; | ||||
explicit Shape(const std::vector<int64_t> &dims); | explicit Shape(const std::vector<int64_t> &dims); | ||||
@@ -40,7 +41,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Shape { | |||||
int64_t GetShapeSize() const; | int64_t GetShapeSize() const; | ||||
private: | private: | ||||
std::vector<int64_t> dims_; | |||||
std::shared_ptr<ShapeImpl> impl_; | |||||
}; | }; | ||||
class TensorDescImpl; | class TensorDescImpl; | ||||
@@ -66,10 +67,10 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc { | |||||
void SetFormat(Format format); | void SetFormat(Format format); | ||||
Shape GetOriginShape() const; | Shape GetOriginShape() const; | ||||
void SetOriginShape(const Shape &origin_shape); | |||||
void SetOriginShape(const Shape &originShape); | |||||
Format GetOriginFormat() const; | Format GetOriginFormat() const; | ||||
void SetOriginFormat(Format origin_format); | |||||
void SetOriginFormat(Format originFormat); | |||||
DataType GetDataType() const; | DataType GetDataType() const; | ||||
void SetDataType(DataType dt); | void SetDataType(DataType dt); | ||||
@@ -82,7 +83,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc { | |||||
int64_t GetSize() const; | int64_t GetSize() const; | ||||
int64_t GetRealDimCnt() const; | int64_t GetRealDimCnt() const; | ||||
void SetRealDimCnt(const int64_t real_dim_cnt); | |||||
void SetRealDimCnt(const int64_t realDimCnt); | |||||
private: | private: | ||||
std::shared_ptr<TensorDescImpl> impl; | std::shared_ptr<TensorDescImpl> impl; | ||||
@@ -67,33 +67,33 @@ enum DataType { | |||||
inline int GetSizeByDataType(DataType data_type) { | inline int GetSizeByDataType(DataType data_type) { | ||||
static int data_type_size[DT_UNDEFINED] = { | static int data_type_size[DT_UNDEFINED] = { | ||||
4, // DT_FLOAT = 0, float type | |||||
2, // DT_FLOAT16 = 1, fp16 type | |||||
1, // DT_INT8 = 2, int8 type | |||||
4, // DT_INT32 = 3, | |||||
1, // DT_UINT8 = 4, uint8 type | |||||
-1, | |||||
2, // DT_INT16 = 6, int16 type | |||||
2, // DT_UINT16 = 7, uint16 type | |||||
4, // DT_UINT32 = 8, unsigned int32 | |||||
8, // DT_INT64 = 9, int64 type | |||||
8, // DT_UINT64 = 10, unsigned int64 | |||||
8, // DT_DOUBLE = 11, double type | |||||
1, // DT_BOOL = 12, bool type | |||||
-1, // DT_STRING = 13, string type | |||||
1, // DT_DUAL_SUB_INT8 = 14, dual output int8 type | |||||
1, // DT_DUAL_SUB_UINT8 = 15, dual output uint8 type | |||||
8, // DT_COMPLEX64 = 16, complex64 type | |||||
16, // DT_COMPLEX128 = 17, complex128 type | |||||
1, // DT_QINT8 = 18, qint8 type | |||||
2, // DT_QINT16 = 19, qint16 type | |||||
4, // DT_QINT32 = 20, qint32 type | |||||
1, // DT_QUINT8 = 21, quint8 type | |||||
2, // DT_QUINT16 = 22, quint16 type | |||||
-1, // DT_RESOURCE = 23, resource type | |||||
-1, // DT_STRING_REF = 24, string ref type | |||||
5, // DT_DUAL = 25, dual output type (float + int8) | |||||
// DT_UNDEFINED Used to indicate a DataType field has not been set. | |||||
4, // DT_FLOAT = 0, float type | |||||
2, // DT_FLOAT16 = 1, fp16 type | |||||
1, // DT_INT8 = 2, int8 type | |||||
4, // DT_INT32 = 3, | |||||
1, // DT_UINT8 = 4, uint8 type | |||||
-1, | |||||
2, // DT_INT16 = 6, int16 type | |||||
2, // DT_UINT16 = 7, uint16 type | |||||
4, // DT_UINT32 = 8, unsigned int32 | |||||
8, // DT_INT64 = 9, int64 type | |||||
8, // DT_UINT64 = 10, unsigned int64 | |||||
8, // DT_DOUBLE = 11, double type | |||||
1, // DT_BOOL = 12, bool type | |||||
-1, // DT_STRING = 13, string type | |||||
1, // DT_DUAL_SUB_INT8 = 14, dual output int8 type | |||||
1, // DT_DUAL_SUB_UINT8 = 15, dual output uint8 type | |||||
8, // DT_COMPLEX64 = 16, complex64 type | |||||
16, // DT_COMPLEX128 = 17, complex128 type | |||||
1, // DT_QINT8 = 18, qint8 type | |||||
2, // DT_QINT16 = 19, qint16 type | |||||
4, // DT_QINT32 = 20, qint32 type | |||||
1, // DT_QUINT8 = 21, quint8 type | |||||
2, // DT_QUINT16 = 22, quint16 type | |||||
-1, // DT_RESOURCE = 23, resource type | |||||
-1, // DT_STRING_REF = 24, string ref type | |||||
5, // DT_DUAL = 25, dual output type (float + int8) | |||||
// DT_UNDEFINED Used to indicate a DataType field has not been set. | |||||
}; | }; | ||||
if (data_type >= DT_UNDEFINED) { | if (data_type >= DT_UNDEFINED) { | ||||
return -1; | return -1; | ||||
@@ -152,10 +152,11 @@ enum DeviceType { | |||||
CPU = 1, | CPU = 1, | ||||
}; | }; | ||||
class TensorTypeImpl; | |||||
struct TensorType { | struct TensorType { | ||||
explicit TensorType(DataType dt) { dt_vec_.push_back(dt); } | |||||
explicit TensorType(DataType dt); | |||||
TensorType(const std::initializer_list<DataType> &types) { dt_vec_ = types; } | |||||
TensorType(const std::initializer_list<DataType> &types); | |||||
static TensorType ALL() { | static TensorType ALL() { | ||||
return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, | return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, | ||||
@@ -204,7 +205,7 @@ struct TensorType { | |||||
static TensorType FLOAT() { return TensorType{DT_FLOAT, DT_FLOAT16}; } | static TensorType FLOAT() { return TensorType{DT_FLOAT, DT_FLOAT16}; } | ||||
std::vector<DataType> dt_vec_; | |||||
std::shared_ptr<TensorTypeImpl> tensor_type_impl_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -17,7 +17,6 @@ | |||||
#ifndef INC_EXTERNAL_REGISTER_REGISTER_H_ | #ifndef INC_EXTERNAL_REGISTER_REGISTER_H_ | ||||
#define INC_EXTERNAL_REGISTER_REGISTER_H_ | #define INC_EXTERNAL_REGISTER_REGISTER_H_ | ||||
#include <google/protobuf/message.h> | |||||
#include <functional> | #include <functional> | ||||
#include <initializer_list> | #include <initializer_list> | ||||
#include <map> | #include <map> | ||||
@@ -33,12 +32,12 @@ | |||||
#include "register/register_fmk_types.h" | #include "register/register_fmk_types.h" | ||||
#include "register/register_types.h" | #include "register/register_types.h" | ||||
using std::unique_ptr; | |||||
using std::map; | |||||
using std::make_shared; | using std::make_shared; | ||||
using std::to_string; | |||||
using std::string; | |||||
using std::map; | |||||
using std::pair; | using std::pair; | ||||
using std::string; | |||||
using std::to_string; | |||||
using std::unique_ptr; | |||||
using std::vector; | using std::vector; | ||||
namespace ge { | namespace ge { | ||||
@@ -46,55 +45,17 @@ class Operator; | |||||
class TensorDesc; | class TensorDesc; | ||||
class Tensor; | class Tensor; | ||||
class TBEPluginManager; | class TBEPluginManager; | ||||
} | |||||
} // namespace ge | |||||
namespace domi { | namespace domi { | ||||
struct OpOutput { | |||||
ge::Operator op; | |||||
// The output name of op | |||||
std::string outputName; | |||||
}; | |||||
struct InferShapeContext { | |||||
ge::Operator op; | |||||
// Input name, input | |||||
std::map<std::string, OpOutput> inputs; | |||||
}; | |||||
struct InferShapeOutput { | |||||
std::vector<ge::TensorDesc> outputDescs; | |||||
std::vector<uint32_t> realDimCnt; | |||||
}; | |||||
enum OmgMoveTypeToAttr { | |||||
OMG_MOVE_TYPE_DTYPE = 0, | |||||
OMG_MOVE_TYPE_VALUE, | |||||
OMG_MOVE_TYPE_SHAPE, | |||||
OMG_MOVE_TYPE_FORMAT, | |||||
OMG_MOVE_TYPE_AXIS, | |||||
OMG_MOVE_TYPE_SCALAR_VALUE, | |||||
OMG_REMOVE_TYPE_WITH_COND = 1000, | |||||
}; | |||||
struct MoveInputToAttrStu { | |||||
int inputIdx; | |||||
std::string attrName; | |||||
OmgMoveTypeToAttr moveType; | |||||
bool attrValue; | |||||
}; | |||||
Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); | Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); | ||||
Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | ||||
std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value, | std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value, | ||||
int in_pos = -1, int out_pos = -1); | int in_pos = -1, int out_pos = -1); | ||||
using google::protobuf::Message; | using google::protobuf::Message; | ||||
class OpRegistrationDataImpl; | |||||
using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>; | using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>; | ||||
using InferShapeFunc = std::function<domi::Status(const ge::Operator &, std::vector<ge::TensorDesc> &)>; | |||||
using InferShapeFuncV2 = std::function<domi::Status(const InferShapeContext &, InferShapeOutput &)>; | |||||
using GetWorkspaceSizeFunc = std::function<domi::Status(const ge::Operator &, std::vector<int64_t> &)>; | |||||
using UpdateOpDescFunc = std::function<domi::Status(ge::Operator &)>; | |||||
using BuildTeBinFunc = std::function<domi::Status(const ge::Operator &, TEBinInfo &)>; | |||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | ||||
public: | public: | ||||
@@ -110,64 +71,18 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||||
OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn); | OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn); | ||||
OpRegistrationData &InferShapeAndTypeFn(const InferShapeFunc &inferShapeFn); | |||||
OpRegistrationData &InferShapeAndTypeFn(const InferShapeFuncV2 &inferShapeFn); | |||||
OpRegistrationData &UpdateOpDescFn(const UpdateOpDescFunc &updateOpDescFn); | |||||
OpRegistrationData &GetWorkspaceSizeFn(const GetWorkspaceSizeFunc &getWorkspaceSizeFn); | |||||
OpRegistrationData &TEBinBuildFn(const BuildTeBinFunc &buildTeBinFn); | |||||
OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | ||||
OpRegistrationData &Formats(const std::initializer_list<domi::tagDomiTensorFormat> &input_formats, | |||||
const std::initializer_list<domi::tagDomiTensorFormat> &output_formats); | |||||
OpRegistrationData &WeightFormats(const std::initializer_list<domi::tagDomiTensorFormat> &weight_formats); | |||||
OpRegistrationData &InputFormat(const std::initializer_list<std::initializer_list<ge::Format>> &inputFormats); | |||||
OpRegistrationData &OutputFormat(const std::initializer_list<std::initializer_list<ge::Format>> &outputFormats); | |||||
OpRegistrationData &InputDataType(const std::initializer_list<std::initializer_list<ge::DataType>> &inputDataTypes); | |||||
OpRegistrationData &OutputDataType(const std::initializer_list<std::initializer_list<ge::DataType>> &outputDataTypes); | |||||
OpRegistrationData &InputLimitedTensorDescInfo( | |||||
const std::initializer_list<std::initializer_list<ge::TensorDescInfo>> &limitedTensorDescs); | |||||
OpRegistrationData &OutputLimitedTensorDescInfo( | |||||
const std::initializer_list<std::initializer_list<ge::TensorDescInfo>> &limitedTensorDescs); | |||||
OpRegistrationData &MoveInputToAttr(int inputIdx, const std::string &attrName, OmgMoveTypeToAttr moveType); | |||||
OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); | OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); | ||||
domi::ImplyType GetImplyType() const; | |||||
std::string GetOmOptype() const; | |||||
std::set<std::string> GetOriginOpTypeSet() const; | |||||
domi::FrameworkType GetFrameworkType() const; | |||||
ParseParamFunc GetParseParamFn() const; | |||||
private: | private: | ||||
domi::FrameworkType fmk_type_; // Framework type | |||||
std::set<std::string> ori_optype_set_; // OP type in the original model, there may be multiple | |||||
std::string om_optype_; // OP type in OM model | |||||
domi::ImplyType imply_type_; // Execution type | |||||
std::vector<domi::tagDomiTensorFormat> input_formats_; // Data formats supported by operator input | |||||
std::vector<domi::tagDomiTensorFormat> output_formats_; // Data formats supported by operator output | |||||
std::vector<domi::tagDomiTensorFormat> weight_formats_; // Data format supported by operator weight | |||||
ParseParamFunc parseParamFn_; // ParseParam function | |||||
InferShapeFunc inferShapeFn_; // InferShape function | |||||
InferShapeFuncV2 inferShapeFnV2_; // InferShape function | |||||
GetWorkspaceSizeFunc getWorkspaceSizeFn_; // GetWorkspaceSizeFunc function | |||||
UpdateOpDescFunc updateOpDescFn_; | |||||
BuildTeBinFunc buildTeBinFn_; | |||||
// Input formats list supported by tbe operators | |||||
std::vector<std::vector<ge::Format>> supportedInputFormats_; | |||||
// Output formats list supported by tbe operators | |||||
std::vector<std::vector<ge::Format>> supportedOutputFormats_; | |||||
// Input datatypes list supported by tbe operators | |||||
std::vector<std::vector<ge::DataType>> supportedInputDataTypes_; | |||||
// Output datatypes list supported by tbe operators | |||||
std::vector<std::vector<ge::DataType>> supportedOutputDataTypes_; | |||||
// Input tensordesinfo list supported by tbe operator | |||||
std::vector<std::vector<ge::TensorDescInfo>> inputLimitedTensorDescs_; | |||||
// Output tensordesinfo list supported by tbe operator | |||||
std::vector<std::vector<ge::TensorDescInfo>> outputLimitedTensorDescs_; | |||||
std::vector<MoveInputToAttrStu> moveInputToAttrVec_; | |||||
std::shared_ptr<OpRegistrationDataImpl> impl_; | |||||
friend class OpRegistry; | friend class OpRegistry; | ||||
friend class OpRegistrationTbe; | friend class OpRegistrationTbe; | ||||
friend class ge::TBEPluginManager; | friend class ge::TBEPluginManager; | ||||
@@ -181,19 +96,12 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { | |||||
#define REGISTER_CUSTOM_OP(name) REGISTER_CUSTOM_OP_UNIQ_HELPER(__COUNTER__, name) | #define REGISTER_CUSTOM_OP(name) REGISTER_CUSTOM_OP_UNIQ_HELPER(__COUNTER__, name) | ||||
#define REGISTER_CUSTOM_OP_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_OP_UNIQ(ctr, name) | #define REGISTER_CUSTOM_OP_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_OP_UNIQ(ctr, name) | ||||
#define REGISTER_CUSTOM_OP_UNIQ(ctr, name) \ | |||||
static OpReceiver register_op##ctr \ | |||||
__attribute__((unused)) = \ | |||||
OpRegistrationData(name) | |||||
#define REGISTER_CUSTOM_OP_UNIQ(ctr, name) \ | |||||
static OpReceiver register_op##ctr __attribute__((unused)) = OpRegistrationData(name) | |||||
} // namespace domi | } // namespace domi | ||||
namespace ge { | namespace ge { | ||||
using OpOutput = domi::OpOutput; | |||||
using InferShapeContext = domi::InferShapeContext; | |||||
using InferShapeOutput = domi::InferShapeOutput; | |||||
using OmgMoveTypeToAttr = domi::OmgMoveTypeToAttr; | |||||
using MoveInputToAttrStu = domi::MoveInputToAttrStu; | |||||
using OpRegistrationData = domi::OpRegistrationData; | using OpRegistrationData = domi::OpRegistrationData; | ||||
using OpReceiver = domi::OpReceiver; | using OpReceiver = domi::OpReceiver; | ||||
} | |||||
} // namespace ge | |||||
#endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ |
@@ -31,12 +31,6 @@ enum FrameworkType { | |||||
FMK_TYPE_A_NN, | FMK_TYPE_A_NN, | ||||
FMK_TYPE_RESERVED, | FMK_TYPE_RESERVED, | ||||
}; | }; | ||||
struct TEBinInfo { | |||||
std::string bin_file_path; | |||||
std::string json_file_path; | |||||
std::string ddk_version; | |||||
}; | |||||
} // namespace domi | } // namespace domi | ||||
#endif // INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ | #endif // INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ |
@@ -44,6 +44,8 @@ inline bool IsLogEnable(int module_name, int log_level) noexcept { | |||||
return false; | return false; | ||||
} | } | ||||
/*lint --emacro((773),GE_TIMESTAMP_START)*/ | |||||
/*lint -esym(773,GE_TIMESTAMP_START)*/ | |||||
#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() | #define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() | ||||
#define GE_TIMESTAMP_END(stage, stage_name) \ | #define GE_TIMESTAMP_END(stage, stage_name) \ | ||||
@@ -29,18 +29,7 @@ | |||||
using cce::CC_STATUS_SUCCESS; | using cce::CC_STATUS_SUCCESS; | ||||
using cce::ccStatus_t; | using cce::ccStatus_t; | ||||
#if !defined(__ANDROID__) && !defined(ANDROID) | |||||
#define DOMI_LOGE(...) DAV_LOGE("DOMI", __VA_ARGS__) | |||||
#else | |||||
#include <android/log.h> | |||||
#if defined(BUILD_VERSION_PERF) | |||||
#define DOMI_LOGE(fmt, ...) | |||||
#else | |||||
// The Android system has strict log control. Do not modify the log. | |||||
#define DOMI_LOGE(fmt, ...) \ | |||||
__android_log_print(ANDROID_LOG_ERROR, "NPU_FMK", "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
#endif | |||||
#endif | |||||
#define GE_LOGE(...) DAV_LOGE("GE", __VA_ARGS__) | |||||
// ge marco | // ge marco | ||||
#define GE_LOGI_IF(condition, ...) \ | #define GE_LOGI_IF(condition, ...) \ | ||||
@@ -53,9 +42,9 @@ using cce::ccStatus_t; | |||||
GELOGW(__VA_ARGS__); \ | GELOGW(__VA_ARGS__); \ | ||||
} | } | ||||
#define GE_LOGE_IF(condition, ...) \ | |||||
if ((condition)) { \ | |||||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
#define GE_LOGE_IF(condition, ...) \ | |||||
if ((condition)) { \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
} | } | ||||
// If expr is not SUCCESS, print the log and return the same value | // If expr is not SUCCESS, print the log and return the same value | ||||
@@ -63,7 +52,7 @@ using cce::ccStatus_t; | |||||
do { \ | do { \ | ||||
const ge::Status _status = (expr); \ | const ge::Status _status = (expr); \ | ||||
if (_status != ge::SUCCESS) { \ | if (_status != ge::SUCCESS) { \ | ||||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
return _status; \ | return _status; \ | ||||
} \ | } \ | ||||
} while (0); | } while (0); | ||||
@@ -73,7 +62,7 @@ using cce::ccStatus_t; | |||||
do { \ | do { \ | ||||
const ge::Status _status = (expr); \ | const ge::Status _status = (expr); \ | ||||
if (_status != ge::SUCCESS) { \ | if (_status != ge::SUCCESS) { \ | ||||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
} \ | } \ | ||||
} while (0); | } while (0); | ||||
@@ -102,11 +91,25 @@ using cce::ccStatus_t; | |||||
(void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ | (void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ | ||||
(void)msg.append( \ | (void)msg.append( \ | ||||
ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | ||||
GELOGE(ge::FAILED, "%s", msg.c_str()); \ | |||||
GE_LOGE("%s", msg.c_str()); \ | |||||
return _status; \ | return _status; \ | ||||
} \ | } \ | ||||
} while (0); | } while (0); | ||||
// If expr is not true, print the Info log and return the specified status | |||||
#define GE_CHK_BOOL_RET_STATUS_LOGI(expr, _status, ...) \ | |||||
do { \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
std::string msg; \ | |||||
(void)msg.append(StringUtils::FormatString(__VA_ARGS__)); \ | |||||
(void)msg.append( \ | |||||
StringUtils::FormatString(" Check result false, status: 0x%X %s", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||||
GELOGI("%s", msg.c_str()); \ | |||||
return _status; \ | |||||
} \ | |||||
} while (0); | |||||
// If expr is not true, print the log and return the specified status | // If expr is not true, print the log and return the specified status | ||||
#define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ | #define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ | ||||
do { \ | do { \ | ||||
@@ -121,7 +124,7 @@ using cce::ccStatus_t; | |||||
{ \ | { \ | ||||
bool b = (expr); \ | bool b = (expr); \ | ||||
if (!b) { \ | if (!b) { \ | ||||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
exec_expr; \ | exec_expr; \ | ||||
} \ | } \ | ||||
}; | }; | ||||
@@ -145,12 +148,22 @@ using cce::ccStatus_t; | |||||
} \ | } \ | ||||
}; | }; | ||||
// If expr is not true, print the log and execute a custom statement | |||||
#define GE_CHK_BOOL_TRUE_EXEC_INFO(expr, exec_expr, ...) \ | |||||
{ \ | |||||
bool b = (expr); \ | |||||
if (b) { \ | |||||
GELOGI(__VA_ARGS__); \ | |||||
exec_expr; \ | |||||
} \ | |||||
}; | |||||
// If expr is true, print logs and execute custom statements | // If expr is true, print logs and execute custom statements | ||||
#define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \ | #define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \ | ||||
{ \ | { \ | ||||
bool b = (expr); \ | bool b = (expr); \ | ||||
if (b) { \ | if (b) { \ | ||||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
exec_expr; \ | exec_expr; \ | ||||
} \ | } \ | ||||
}; | }; | ||||
@@ -164,12 +177,23 @@ using cce::ccStatus_t; | |||||
} \ | } \ | ||||
}; | }; | ||||
// If expr is not SUCCESS, print the log and execute the expression + return | |||||
#define GE_CHK_BOOL_TRUE_RET_VOID(expr, exec_expr, ...) \ | |||||
{ \ | |||||
bool b = (expr); \ | |||||
if (b) { \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
exec_expr; \ | |||||
return; \ | |||||
} \ | |||||
}; | |||||
// If expr is not SUCCESS, print the log and execute the expression + return _status | // If expr is not SUCCESS, print the log and execute the expression + return _status | ||||
#define GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(expr, _status, exec_expr, ...) \ | #define GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(expr, _status, exec_expr, ...) \ | ||||
{ \ | { \ | ||||
bool b = (expr); \ | bool b = (expr); \ | ||||
if (b) { \ | if (b) { \ | ||||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
exec_expr; \ | exec_expr; \ | ||||
return _status; \ | return _status; \ | ||||
} \ | } \ | ||||
@@ -186,52 +210,62 @@ using cce::ccStatus_t; | |||||
// -----------------runtime related macro definitions------------------------------- | // -----------------runtime related macro definitions------------------------------- | ||||
// If expr is not RT_ERROR_NONE, print the log | // If expr is not RT_ERROR_NONE, print the log | ||||
#define GE_CHK_RT(expr) \ | |||||
do { \ | |||||
rtError_t _rt_ret = (expr); \ | |||||
if (_rt_ret != RT_ERROR_NONE) { \ | |||||
GELOGE(ge::RT_FAILED, "Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
} \ | |||||
#define GE_CHK_RT(expr) \ | |||||
do { \ | |||||
rtError_t _rt_ret = (expr); \ | |||||
if (_rt_ret != RT_ERROR_NONE) { \ | |||||
GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
} \ | |||||
} while (0); | } while (0); | ||||
// If expr is not RT_ERROR_NONE, print the log and execute the exec_expr expression | // If expr is not RT_ERROR_NONE, print the log and execute the exec_expr expression | ||||
#define GE_CHK_RT_EXEC(expr, exec_expr) \ | |||||
{ \ | |||||
rtError_t _rt_ret = (expr); \ | |||||
if (_rt_ret != RT_ERROR_NONE) { \ | |||||
GELOGE(ge::RT_FAILED, "Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
exec_expr; \ | |||||
} \ | |||||
#define GE_CHK_RT_EXEC(expr, exec_expr) \ | |||||
{ \ | |||||
rtError_t _rt_ret = (expr); \ | |||||
if (_rt_ret != RT_ERROR_NONE) { \ | |||||
GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
exec_expr; \ | |||||
} \ | |||||
} | } | ||||
// If expr is not RT_ERROR_NONE, print the log and return | // If expr is not RT_ERROR_NONE, print the log and return | ||||
#define GE_CHK_RT_RET(expr) \ | |||||
do { \ | |||||
rtError_t _rt_ret = (expr); \ | |||||
if (_rt_ret != RT_ERROR_NONE) { \ | |||||
GELOGE(ge::RT_FAILED, "Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
return ge::RT_FAILED; \ | |||||
} \ | |||||
#define GE_CHK_RT_RET(expr) \ | |||||
do { \ | |||||
rtError_t _rt_ret = (expr); \ | |||||
if (_rt_ret != RT_ERROR_NONE) { \ | |||||
GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
return ge::RT_FAILED; \ | |||||
} \ | |||||
} while (0); | } while (0); | ||||
// ------------------------cce related macro definitions---------------------------- | // ------------------------cce related macro definitions---------------------------- | ||||
// If expr is not CC_STATUS_SUCCESS, print the log | // If expr is not CC_STATUS_SUCCESS, print the log | ||||
#define GE_CHK_CCE(expr) \ | |||||
do { \ | |||||
ccStatus_t _cc_ret = (expr); \ | |||||
if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||||
GELOGE(ge::CCE_FAILED, "Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
} \ | |||||
#define GE_CHK_CCE(expr) \ | |||||
do { \ | |||||
ccStatus_t _cc_ret = (expr); \ | |||||
if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||||
GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
} \ | |||||
} while (0); | |||||
// If expr is not CC_STATUS_SUCCESS, print the log and execute the exec_expr expression | |||||
#define GE_CHK_CCE_EXEC(expr, exec_expr) \ | |||||
do { \ | |||||
ccStatus_t _cc_ret = (expr); \ | |||||
if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||||
GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
exec_expr; \ | |||||
} \ | |||||
} while (0); | } while (0); | ||||
// If expr is not CC_STATUS_SUCCESS, print the log and return | // If expr is not CC_STATUS_SUCCESS, print the log and return | ||||
#define GE_CHK_CCE_RET(expr) \ | |||||
do { \ | |||||
ccStatus_t _cc_ret = (expr); \ | |||||
if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||||
GELOGE(ge::CCE_FAILED, "Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
return ge::CCE_FAILED; \ | |||||
} \ | |||||
#define GE_CHK_CCE_RET(expr) \ | |||||
do { \ | |||||
ccStatus_t _cc_ret = (expr); \ | |||||
if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||||
GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
return ge::CCE_FAILED; \ | |||||
} \ | |||||
} while (0); | } while (0); | ||||
// If expr is true, execute exec_expr without printing logs | // If expr is true, execute exec_expr without printing logs | ||||
@@ -247,8 +281,37 @@ using cce::ccStatus_t; | |||||
try { \ | try { \ | ||||
exec_expr0; \ | exec_expr0; \ | ||||
} catch (const std::bad_alloc &) { \ | } catch (const std::bad_alloc &) { \ | ||||
GELOGE(ge::FAILED, "Make shared failed"); \ | |||||
GE_LOGE("Make shared failed"); \ | |||||
exec_expr1; \ | exec_expr1; \ | ||||
} | } | ||||
#define GE_CHECK_INT32_MUL_OVERFLOW(a, b, ...) \ | |||||
do { \ | |||||
if ((a) > 0) { \ | |||||
if ((b) > 0) { \ | |||||
if ((a) > (INT32_MAX / (b))) { \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
} else { \ | |||||
if ((b) < (INT32_MIN / (a))) { \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
} \ | |||||
} else { \ | |||||
if ((b) > 0) { \ | |||||
if ((a) < (INT32_MAX / (b))) { \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
} else { \ | |||||
if (((a) != 0) && ((b) < (INT32_MAX / (a)))) { \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
} \ | |||||
} \ | |||||
} while (0); | |||||
#endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | #endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ |
@@ -1,4 +1,4 @@ | |||||
/** | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | * Copyright 2019-2020 Huawei Technologies Co., Ltd | ||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
@@ -1,4 +1,4 @@ | |||||
/** | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | * Copyright 2019-2020 Huawei Technologies Co., Ltd | ||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
@@ -20,4 +20,4 @@ | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "register/register_types.h" | #include "register/register_types.h" | ||||
#endif // INC_FRAMEWORK_COMMON_FMK_TYPES_H_ | |||||
#endif // INC_FRAMEWORK_COMMON_FMK_TYPES_H_ |
@@ -14,79 +14,78 @@ | |||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
/*lint -e* */ | |||||
#ifndef INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ | #ifndef INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ | ||||
#define INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ | #define INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ | ||||
#include <map> | #include <map> | ||||
#include <string> | #include <string> | ||||
#include "ge/ge_api_error_codes.h" | #include "ge/ge_api_error_codes.h" | ||||
namespace ge { | namespace ge { | ||||
// System ID | // System ID | ||||
enum SystemIdType { kSysidGE = 8 }; | |||||
enum SystemIdType { SYSID_GE = 8 }; | |||||
// Runtime location | // Runtime location | ||||
enum LogRuntime { | enum LogRuntime { | ||||
KRtHost = 0b01, | |||||
kRtDevice = 0b10, | |||||
RT_HOST = 0b01, | |||||
RT_DEVICE = 0b10, | |||||
}; | }; | ||||
// Sub model | // Sub model | ||||
enum SubModuleId { | enum SubModuleId { | ||||
kCommonModule = 0, | |||||
kClientModule = 1, | |||||
kInitModule = 2, | |||||
kSessionModule = 3, | |||||
kGraphModule = 4, | |||||
kEngineMOdule = 5, | |||||
kOpsModule = 6, | |||||
kPluginModule = 7, | |||||
kRuntimeModule = 8, | |||||
kExecutorModule = 9, | |||||
kGeneratorModule = 10, | |||||
COMMON_MODULE = 0, | |||||
CLIENT_MODULE = 1, | |||||
INIT_MODULE = 2, | |||||
SESSION_MODULE = 3, | |||||
GRAPH_MODULE = 4, | |||||
ENGINE_MODULE = 5, | |||||
OPS_MODULE = 6, | |||||
PLUGIN_MODULE = 7, | |||||
RUNTIME_MODULE = 8, | |||||
EXECUTOR_MODULE = 9, | |||||
GENERATOR_MODULE = 10, | |||||
}; | }; | ||||
// Error code type | // Error code type | ||||
enum ErrorCodeType { | enum ErrorCodeType { | ||||
kErrorCode = 0b01, | |||||
kExceptionCode = 0b10, | |||||
ERROR_CODE = 0b01, | |||||
EXCEPTION_CODE = 0b10, | |||||
}; | }; | ||||
// Error level | // Error level | ||||
enum ErrorLevel { | enum ErrorLevel { | ||||
kCommonLevel = 0b000, | |||||
kSuggestionLevel = 0b001, | |||||
kMinorLevel = 0b010, | |||||
kMajorLevel = 0b011, | |||||
kCriticalLevel = 0b100, | |||||
COMMON_LEVEL = 0b000, | |||||
SUGGESTION_LEVEL = 0b001, | |||||
MINOR_LEVEL = 0b010, | |||||
MAJOR_LEVEL = 0b011, | |||||
CRITICAL_LEVEL = 0b100, | |||||
}; | }; | ||||
// The error code is defined by the following macros | |||||
// Each module defines error codes using the following macros | |||||
#define GE_ERRORNO_COMMON(name, value, desc) \ | #define GE_ERRORNO_COMMON(name, value, desc) \ | ||||
GE_ERRORNO(KRtHost, kErrorCode, kCommonLevel, kSysidGE, kCommonModule, name, value, desc) | |||||
GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, COMMON_MODULE, name, value, desc) | |||||
#define GE_ERRORNO_CLIENT(name, value, desc) \ | #define GE_ERRORNO_CLIENT(name, value, desc) \ | ||||
GE_ERRORNO(KRtHost, kErrorCode, kCommonLevel, kSysidGE, kClientModule, name, value, desc) | |||||
GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, CLIENT_MODULE, name, value, desc) | |||||
#define GE_ERRORNO_INIT(name, value, desc) \ | #define GE_ERRORNO_INIT(name, value, desc) \ | ||||
GE_ERRORNO(KRtHost, kErrorCode, kCommonLevel, kSysidGE, kInitModule, name, value, desc) | |||||
GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, INIT_MODULE, name, value, desc) | |||||
#define GE_ERRORNO_SESSION(name, value, desc) \ | #define GE_ERRORNO_SESSION(name, value, desc) \ | ||||
GE_ERRORNO(KRtHost, kErrorCode, kCommonLevel, kSysidGE, kSessionModule, name, value, desc) | |||||
GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, SESSION_MODULE, name, value, desc) | |||||
#define GE_ERRORNO_GRAPH(name, value, desc) \ | #define GE_ERRORNO_GRAPH(name, value, desc) \ | ||||
GE_ERRORNO(KRtHost, kErrorCode, kCommonLevel, kSysidGE, kGraphModule, name, value, desc) | |||||
GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, GRAPH_MODULE, name, value, desc) | |||||
#define GE_ERRORNO_ENGINE(name, value, desc) \ | #define GE_ERRORNO_ENGINE(name, value, desc) \ | ||||
GE_ERRORNO(KRtHost, kErrorCode, kCommonLevel, kSysidGE, kEngineMOdule, name, value, desc) | |||||
GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, ENGINE_MODULE, name, value, desc) | |||||
#define GE_ERRORNO_OPS(name, value, desc) \ | #define GE_ERRORNO_OPS(name, value, desc) \ | ||||
GE_ERRORNO(KRtHost, kErrorCode, kCommonLevel, kSysidGE, kOpsModule, name, value, desc) | |||||
GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, OPS_MODULE, name, value, desc) | |||||
#define GE_ERRORNO_PLUGIN(name, value, desc) \ | #define GE_ERRORNO_PLUGIN(name, value, desc) \ | ||||
GE_ERRORNO(KRtHost, kErrorCode, kCommonLevel, kSysidGE, kPluginModule, name, value, desc) | |||||
GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, PLUGIN_MODULE, name, value, desc) | |||||
#define GE_ERRORNO_RUNTIME(name, value, desc) \ | #define GE_ERRORNO_RUNTIME(name, value, desc) \ | ||||
GE_ERRORNO(KRtHost, kErrorCode, kCommonLevel, kSysidGE, kRuntimeModule, name, value, desc) | |||||
GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, RUNTIME_MODULE, name, value, desc) | |||||
#define GE_ERRORNO_EXECUTOR(name, value, desc) \ | #define GE_ERRORNO_EXECUTOR(name, value, desc) \ | ||||
GE_ERRORNO(kRtDevice, kErrorCode, kCommonLevel, kSysidGE, kExecutorModule, name, value, desc) | |||||
GE_ERRORNO(RT_DEVICE, ERROR_CODE, COMMON_LEVEL, SYSID_GE, EXECUTOR_MODULE, name, value, desc) | |||||
#define GE_ERRORNO_GENERATOR(name, value, desc) \ | #define GE_ERRORNO_GENERATOR(name, value, desc) \ | ||||
GE_ERRORNO(KRtHost, kErrorCode, kCommonLevel, kSysidGE, kGeneratorModule, name, value, desc) | |||||
GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, GENERATOR_MODULE, name, value, desc) | |||||
// Get the description of the error code | |||||
// Get error code description | |||||
#define GE_GET_ERRORNO_STR(value) ge::StatusFactory::Instance()->GetErrDesc(value) | #define GE_GET_ERRORNO_STR(value) ge::StatusFactory::Instance()->GetErrDesc(value) | ||||
// Common module error code definition | // Common module error code definition | ||||
@@ -206,10 +205,9 @@ GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_GET_GRAPH_REBUILD_FAILED, 60, | |||||
GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_SET_GRAPH_FINISH_REBUILD_GRAPH_FAILED, 61, | GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_SET_GRAPH_FINISH_REBUILD_GRAPH_FAILED, 61, | ||||
"Failed set graph finish rebuild in node searcher."); // 1343242301 | "Failed set graph finish rebuild in node searcher."); // 1343242301 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_OP_PASS_FAILED, 62, "Failed to run variable pass."); // 1343242302 | GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_OP_PASS_FAILED, 62, "Failed to run variable pass."); // 1343242302 | ||||
// Optimize errocode | // Optimize errocode | ||||
GE_ERRORNO_GRAPH(TO_BE_DELETED, 63, "The node of the graph to be deleted."); // 1343242303 | |||||
GE_ERRORNO_GRAPH(NOT_CHANGED, 64, "The node of the graph no changed."); // 1343242304 | |||||
GE_ERRORNO_GRAPH(TO_BE_DELETED, 200, "The node of the graph to be deleted."); | |||||
GE_ERRORNO_GRAPH(NOT_CHANGED, 201, "NThe node of the graph not changed."); | |||||
// Engine_manager module error code definition | // Engine_manager module error code definition | ||||
GE_ERRORNO_ENGINE(GE_ENG_INIT_FAILED, 0, "Failed to initialize engine."); // 1343246336 | GE_ERRORNO_ENGINE(GE_ENG_INIT_FAILED, 0, "Failed to initialize engine."); // 1343246336 | ||||
@@ -137,7 +137,7 @@ class ModelListener { | |||||
struct Options { | struct Options { | ||||
int64_t session_id; | int64_t session_id; | ||||
int32_t device_id; | int32_t device_id; | ||||
int64_t job_id; | |||||
std::string job_id; | |||||
bool isUseHcom; | bool isUseHcom; | ||||
bool deployMode; | bool deployMode; | ||||
bool isAICPUMode; | bool isAICPUMode; | ||||
@@ -149,5 +149,4 @@ struct Options { | |||||
int32_t physical_device_id; | int32_t physical_device_id; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_FRAMEWORK_COMMON_GE_TYPES_H_ | #endif // INC_FRAMEWORK_COMMON_GE_TYPES_H_ |
@@ -23,11 +23,6 @@ | |||||
namespace ge { | namespace ge { | ||||
class GflagsUtils { | class GflagsUtils { | ||||
public: | public: | ||||
/// | |||||
/// @brief Determines whether the parameter is true | |||||
/// @param name name parameter name | |||||
/// @return true or false | |||||
/// | |||||
static bool IsSetCommandTrue(const char *name) { | static bool IsSetCommandTrue(const char *name) { | ||||
std::string out; | std::string out; | ||||
return gflags::GetCommandLineOption(name, &out) && out == "true"; | return gflags::GetCommandLineOption(name, &out) && out == "true"; | ||||
@@ -19,6 +19,7 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <memory> | |||||
#include "common/fmk_types.h" | #include "common/fmk_types.h" | ||||
#include "common/helper/om_file_helper.h" | #include "common/helper/om_file_helper.h" | ||||
@@ -35,8 +35,8 @@ struct ModelPartition { | |||||
}; | }; | ||||
struct OmFileContext { | struct OmFileContext { | ||||
vector<ModelPartition> partition_datas_; | |||||
vector<char> partition_table_; | |||||
std::vector<ModelPartition> partition_datas_; | |||||
std::vector<char> partition_table_; | |||||
uint32_t model_data_len_; | uint32_t model_data_len_; | ||||
}; | }; | ||||
@@ -78,7 +78,7 @@ class OmFileSaveHelper { | |||||
Status AddPartition(ModelPartition &partition); | Status AddPartition(ModelPartition &partition); | ||||
const vector<ModelPartition> &GetModelPartitions() const; | |||||
const std::vector<ModelPartition> &GetModelPartitions() const; | |||||
Status SaveModel(const SaveParam &save_param, const char *target_file); | Status SaveModel(const SaveParam &save_param, const char *target_file); | ||||
@@ -88,4 +88,5 @@ class OmFileSaveHelper { | |||||
OmFileContext context_; | OmFileContext context_; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
/*lint +e148*/ | |||||
#endif // INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ | #endif // INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ |
@@ -28,11 +28,14 @@ | |||||
#include "common/util.h" | #include "common/util.h" | ||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
using std::vector; | |||||
namespace ge { | namespace ge { | ||||
// Size of RC memory alignment, 2M | // Size of RC memory alignment, 2M | ||||
const size_t ALIGN_SIZE = 2097152; | |||||
const uint32_t RC_VALUE_DEFAULT = 1; | |||||
const uint32_t RC_VALUE_MAC = 32; | |||||
constexpr size_t ALIGN_SIZE = 2097152; | |||||
constexpr uint32_t RC_VALUE_DEFAULT = 1; | |||||
constexpr uint32_t RC_VALUE_MAX = 32; | |||||
// RC data type classification | // RC data type classification | ||||
enum RCType { | enum RCType { | ||||
@@ -100,7 +103,7 @@ class L2CacheOptimize { | |||||
void HandOPoutput(ge::NodePtr node, vector<int64_t> &outputList, vector<RCMemoryBlock> &blocks); | void HandOPoutput(ge::NodePtr node, vector<int64_t> &outputList, vector<RCMemoryBlock> &blocks); | ||||
// maximum common divisor | // maximum common divisor | ||||
uint32_t Measure(uint32_t x, uint32_t y) const { | |||||
uint32_t Measure(uint32_t x, uint32_t y) { | |||||
if (x == 0 || y == 0) return RC_VALUE_DEFAULT; | if (x == 0 || y == 0) return RC_VALUE_DEFAULT; | ||||
uint32_t z = y; | uint32_t z = y; | ||||
while (x % y != 0) { | while (x % y != 0) { | ||||
@@ -0,0 +1,806 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ | |||||
#define INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ | |||||
#include <string> | |||||
#include "framework/common/fmk_types.h" | |||||
namespace domi { | |||||
// Public Attribute | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NAME; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WEIGHT_NAME; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IS_QUANTIZE_FACTOR; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ALPHA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BETA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADMODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADMODES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FILTER; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BIAS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BIAS_TERM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_HAS_BIAS_VALUE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WINDOWS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_GLOBAL_POOLING; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CEIL_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDE_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RELU_FLAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ALGO; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FILTER_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_K; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_NORM_REGION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_LOCAL_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_ALPHA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_BETA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BROADCAST; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TIDX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TPADDINGS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_IMG_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_IMG_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NET_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NET_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TMULTIPLES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTIPLES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_T; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_N; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TSHAPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NAN_OPT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AIPP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string NEW_AIPP_CONV_OP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||||
static const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | |||||
static const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_BATCH_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INPUT_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_NODE_DEF; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_OP_DEF; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INPUT_TENSOR_DESC; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_TENSOR_DESC; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INFERRED_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WEIGHTS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DIM_ALIGN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||||
// To be deleted | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_TO_BE_DELETED; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_LOC_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_CONF_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_OCR_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | |||||
// Refinedet | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_LOC_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_CONF_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIORBOX_CONCAT; | |||||
// _Arg | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INDEX; | |||||
// _RetVal | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETVAL_ATTR_NAME_INDEX; | |||||
// Data | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DATA_ATTR_NAME_DATA_TYPE; | |||||
// Send | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SEND_ATTR_EVENT_ID; | |||||
// Recv | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RECV_ATTR_EVENT_ID; | |||||
// convolution | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_COEF; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATIONS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_ALGO; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_GROUP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_PAD_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_PAD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_STRIDE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_DILATION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_NUM_OUTPUT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_KERNEL; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_FILTER; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_BIAS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_RELU_FLAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_ADJ; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_TARGET_SHAPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_BEFORE_PAD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_HAS_BIAS; | |||||
// Pooling | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_NAN_OPT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_PAD_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_GLOBAL_POOLING; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_WINDOW; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_PAD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_STRIDE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_CEIL_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_DATA_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_BEFORE_PAD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_NAME_ALGO; | |||||
// Eltwise | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_COEFF; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_WEIGHT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_RELU_FLAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_ALPHA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_BETA; | |||||
// BatchNorm | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_EPSILON; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_ESTIMATED_MEAN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_BIAS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_DATA_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_IS_TRAINING; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; | |||||
// Huberloss | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HUBER_LOSS_ATTR_DELTA; | |||||
// SSDRealDivTileMul | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; | |||||
// SSDSumMulRealDivMean | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||||
SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; | |||||
/// ConcatFive2Four | |||||
/// ConcatFour2Five | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_CLASS_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TRANS_FOR_LOSS_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOX_TYPE_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_HIGH; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_WIDTH; | |||||
// Scale | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SCALE_ATTR_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SCALE_ATTR_BIAS; | |||||
// FullConnection | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_FILTER; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_BIAS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_RELU_FLAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_ATTR_NAME_ALGO; | |||||
// SoftmaxOpParams | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_ALGO; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_MODE; | |||||
// SparseSoftmaxCrossEntropy | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING; | |||||
// Activation | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ACTIVATION_ATTR_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ACTIVATION_ATTR_COEF; | |||||
// Concat | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_ATTR_NAME_AXIS; | |||||
// Const | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_DATA_TRANSTYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | |||||
// Roipooling | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLED_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLED_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO; | |||||
// DetectionOutput | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_TOP_K; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IMG_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IMG_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE; | |||||
// Ssd DetectionOutput | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_ETA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||||
DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K; | |||||
// Refinedet DetectionOutput | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE; | |||||
// yolo DetectionOutput | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_ClASSES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BIASES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_RELATIVE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION; | |||||
// DetectionPostprocess | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_CLS_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_POST_NMS_TOPN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT; | |||||
// Spatialtransfrom | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_OUTPUT_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_OUTPUT_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM; | |||||
// Proposal | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_BASE_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_MIN_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_RATIO; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_NMS_THRESH; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_TOP_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_IMG_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_IMG_W; | |||||
// Softmax | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_AXIS; | |||||
// Permute | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_ATTR_ORDER; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_ATTR_PERM; | |||||
// SSD Normalize | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_EPS; | |||||
// Flatten | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_ATTR_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_ATTR_END_AXIS; | |||||
// SsdPRIORBOX | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_FLIP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_CLIP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_IMG_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_IMG_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_STEP_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_STEP_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_OFFSET; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_VARIANCE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM; | |||||
// RefinedetPRIORBOX | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | |||||
// PRelu | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PRELU_ATTR_CHANNEL_SHARED; | |||||
// Psroi pooling | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_OUTPUT_DIM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_GROUP_SIZE; | |||||
// Power | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_POWER; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_SHIFT; | |||||
// Log | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_SHIFT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_BASE; | |||||
// Pack | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PACK_ATTR_NAME_NUM; | |||||
// Dynamic stitch | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||||
// Unpack | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UNPACK_ATTR_NAME_NUM; | |||||
// Gathernd | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERND_ATTR_NAME_TINDICES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERND_ATTR_NAME_TPARAMS; | |||||
// Argmax | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_TOPK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_REDUCESIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_OUTMAX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_AXISTYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_KEEPDIMS; | |||||
// Upsample | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE_W; | |||||
// Relu | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NEGATIVE_SLOPE; | |||||
// FreeSpaceExtract | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT; | |||||
// split | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_SLICE_POINT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_SIZE_SPLIT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_NUM_SPLIT; | |||||
// Tvm | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_MAGIC; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_BLOCKDIM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_METADATA; | |||||
// Squeeze | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_ATTR_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_ATTR_DIMS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_OP_NAME; | |||||
// Stride slice | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_BEGIN_MASK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_END_MASK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK; | |||||
// Slice | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SLICE_ATTR_NAME_BEGINS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SLICE_ATTR_NAME_SIZES; | |||||
// Roialign | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_SPATIAL_SCALE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_SAMPLING_RATIO; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_NAME_POOLED_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_NAME_POOLED_W; | |||||
// Generate_rpn_proposal | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||||
GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||||
GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH; | |||||
// Decode_bbox | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DECODE_BBOX_ATTR_DECODECLIP; | |||||
// Cast | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CAST_ATTR_DSTT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CAST_ATTR_SRCT; | |||||
// Fastrcnnn predications | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES; | |||||
// REORG | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REORG_ATTR_STRIDE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REORG_ATTR_REVERSE; | |||||
// MERGE | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MERGE_DEAD_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MERGE_PRENODE_FLAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TO_BE_OUTPUT; | |||||
static const std::string NOT_NET_OUTPUT = "not_net_output"; | |||||
// Concatv2 | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_V2_ATTR_TIDX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_V2_ATTR_N; | |||||
// SUM | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_TIDX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_KEEP_DIMS; | |||||
// ResizeBilinear | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_HEIGHT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_WIDTH; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_PAD_END; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ALPHA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_BETA; | |||||
// RetinaNet | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETINANET_FILTER_BACKGROUND_TRUE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETINANET_ANCHOR_FUSION; | |||||
// MatMul | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_TRANSPOSE_X; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_TRANSPOSE_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_HAS_BIAS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_ATTR_IS_TRAINING; | |||||
// Flatten | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_START_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_END_AXIS; | |||||
// Reshape | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NUM_AXES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_SHAPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_ALPHA; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_BETA; | |||||
// Frameoworkop | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string T_IN_DATATYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string T_OUT_DATATYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_N; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_C; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_H; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_W; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_PAD_DEPTH_CONV; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_PAD_CONV; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BEFORE_PAD; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ANN_MEAN_KEEPDIMS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_ATTR_PADDINGDS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_ATTR_CONSTANT_VALUE; | |||||
// ConvGradFilter | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE; | |||||
// ConvGradInput | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE; | |||||
// Rnn | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_MODE_STATIC; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MUTI_RNN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CELL_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CNN_RNN; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_CELL; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GRU_CELL; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_HT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_XT_HT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_BATCH_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_CELL_CLIP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_PROJ_CLIP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_ACTIVATE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_OUT_MAP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_OUT_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_STATE_OUT_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_TIME_MAJOR; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_IS_INPUT_PRE_PROCESS; | |||||
// Upsample | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE; | |||||
// PadV2 | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_PADS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_T; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_PAD_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_CONST_VALUE; | |||||
// MirrorPad | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_PADS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; | |||||
// Filler | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FILLER_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FILLER_VALUE; | |||||
// Shufflechannel | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHUFFLE_CHANNEL_GROUP; | |||||
// TopKV2 | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TOPKV2_ATTR_K; | |||||
// Calibaration | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_H_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_W_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_TOP_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_BOTTOM_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_RIGHT_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_LEFT_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IS_CONST; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_GROUP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATION_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_EPSILON; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_POOLING_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CLASS_NUM; | |||||
// model | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TARGET_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_STREAM_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_EVENT_NUM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_MEMORY_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_WEIGHT_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | |||||
// Public Attribute | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IMPLY_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BYTE_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_INFERENCE_ID; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_OPDEF; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_SCOPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OPATTR; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RELUFLAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SEQLEN_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_X_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CONT_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_XSTATIC_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_MINI; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_TINY; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_LITE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STREAM_LABEL; | |||||
// L2_normalize | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string L2_NORMALIZE_ATTR_AXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string L2_NORMALIZE_ATTR_EPS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_WINDOW; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_CEIL_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_DATA_MODE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_NAN_OP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_PAD_MOD; | |||||
// HCOM | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_ROOT_RANK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_REDUCE_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_RANK_SIZE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_REDUCTION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_GROUP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SR_TAG; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SRC_RANK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_DEST_RANK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SHAPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_DATA_TYPE; | |||||
// Log time stamp | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_TIME_STAMP_LOGID; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_TIME_STAMP_NOTIFY; | |||||
// SpaceToDepth/DepthToSpace | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BLOCK_SIZE; | |||||
// SparseSoftmaxCrossEntropyWithLogits | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; | |||||
// MaxPoolGradWithArgmax | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; | |||||
// AvgPoolGrad | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; | |||||
// Pad | |||||
extern const std::string ATTR_PAD_FORMAT; | |||||
// Varible | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_NAME; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_FRACTALZ_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_4D_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_5D_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_DATA_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_NAME; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_OUT_INDEX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SHAPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HALF_VAR_NAME_END; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_CONTAINER; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SHARED_NAME; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_DTYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_ADDR_OFFSET; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_INDEX_KEY; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SRC_VAR_NAME; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_SAVE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_RESTORE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REF_VAR_SRC_VAR_NAME; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||||
// Assign | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ASSIGN_VALIDATE_SHAPE; | |||||
// ShapeN | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_N; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_IN_TYPE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_OUT_TYPE; | |||||
// Space2bacth batch2space | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCH_SPACE_ATTR_BLOCK; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCH_SPACE_ATTR_PADDING; | |||||
// Depth_to_space space_to_depth | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||||
// FakeQuantWithMinMaxVars | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FakeQuantWithMinMaxVars_ATTR_MAX; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FakeQuantWithMinMaxVars_ATTR_MIN; | |||||
// Mobilenet_ssd_conv_fusion | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_BOXES_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_SCORES_FUSION; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; | |||||
// Lsh project | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSH_PROJ_TYPE; | |||||
// Control flow | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ITERATORS_PER_LOOP; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; | |||||
// GatherV2 attr def | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TAXIS; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TINDICES; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TPARAMS; | |||||
// Reshape attr def | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NAME_INPUT_DESC; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; | |||||
// Axis attr def | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AXIS_ORG_OP; | |||||
// The node link with SparseSoftmaxCrossEntropyWithLogits | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LINK_WITH_SPARE; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | |||||
extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ |
@@ -17,19 +17,17 @@ | |||||
#ifndef INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | #ifndef INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | ||||
#define INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | #define INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | ||||
#include <google/protobuf/map.h> | |||||
#include <string> | |||||
#include <unordered_map> | #include <unordered_map> | ||||
#include "common/types.h" | |||||
#include <google/protobuf/map.h> | |||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "common/types.h" | |||||
#include "proto/om.pb.h" | #include "proto/om.pb.h" | ||||
using domi::AttrDef; | using domi::AttrDef; | ||||
using domi::OpDef; | |||||
using domi::AttrDef_ListValue; | using domi::AttrDef_ListValue; | ||||
using domi::ModelDef; | using domi::ModelDef; | ||||
using domi::NamedAttrs; | using domi::NamedAttrs; | ||||
using domi::OpDef; | |||||
namespace ge { | namespace ge { | ||||
using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; | using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; | ||||
@@ -172,7 +172,7 @@ class OpUtils { | |||||
/// | /// | ||||
/// @ingroup domi_omg | /// @ingroup domi_omg | ||||
/// @brief Convert the convolution‘s weight data from [h, w, c, k] to [k, c, h, w] | |||||
/// @brief Convert the convolution¡®s weight data from [h, w, c, k] to [k, c, h, w] | |||||
/// @param [in] input Weight data in HWCK format | /// @param [in] input Weight data in HWCK format | ||||
/// @param [in] H value of H dimension | /// @param [in] H value of H dimension | ||||
/// @param [in] W value of W dimension | /// @param [in] W value of W dimension | ||||
@@ -183,7 +183,7 @@ class OpUtils { | |||||
static void TransDataHWCK2KCHW(const void *input, int64_t H, int64_t W, int64_t C, int64_t K, void **output); | static void TransDataHWCK2KCHW(const void *input, int64_t H, int64_t W, int64_t C, int64_t K, void **output); | ||||
/// | /// | ||||
/// @ingroup domi_omg | /// @ingroup domi_omg | ||||
/// @brief Converts the convolution‘s weight data from [k, c, h, w] to [h, w, c, k]. | |||||
/// @brief Converts the convolution¡®s weight data from [k, c, h, w] to [h, w, c, k]. | |||||
/// @param [in] input Weight data in HWCK format | /// @param [in] input Weight data in HWCK format | ||||
/// @param [in] K value of K dimension | /// @param [in] K value of K dimension | ||||
/// @param [in] C value of C dimension | /// @param [in] C value of C dimension | ||||
@@ -198,7 +198,7 @@ class OpUtils { | |||||
/// training network | /// training network | ||||
/// @param [in] model_tensor input and output tensor information | /// @param [in] model_tensor input and output tensor information | ||||
/// @param [out] cc_tensor Tensor in CCE format after conversion | /// @param [out] cc_tensor Tensor in CCE format after conversion | ||||
//// | |||||
/// | |||||
static Status InitFilterTensorDescriptor(const ge::GeTensorDesc &model_tensor, ccFilterDescriptor_t &cc_tensor); | static Status InitFilterTensorDescriptor(const ge::GeTensorDesc &model_tensor, ccFilterDescriptor_t &cc_tensor); | ||||
static void SetTensorDescriptorAllOffsetQuantizeInfo(const GeTensorDesc &tensor, ccTensorDescriptor_t cc_tensor); | static void SetTensorDescriptorAllOffsetQuantizeInfo(const GeTensorDesc &tensor, ccTensorDescriptor_t cc_tensor); | ||||
@@ -23,7 +23,7 @@ | |||||
#include <stdint.h> | #include <stdint.h> | ||||
namespace domi { | namespace domi { | ||||
// General | |||||
// general | |||||
const float DEFAULT_ALPHA_VALUE = 1.0; | const float DEFAULT_ALPHA_VALUE = 1.0; | ||||
const float DEFAULT_BETA_VALUE = 0.0; | const float DEFAULT_BETA_VALUE = 0.0; | ||||
const uint32_t NORMAL_INPUT_NUM = 1; | const uint32_t NORMAL_INPUT_NUM = 1; | ||||
@@ -37,7 +37,7 @@ const int NORMAL_DEVICE_DATA_TYPE = static_cast<const int>(cce::CC_DATA_HALF); | |||||
const int DEFAULT_POOLING_MODE = static_cast<const int>(cce::CC_POOLING_MAX); | const int DEFAULT_POOLING_MODE = static_cast<const int>(cce::CC_POOLING_MAX); | ||||
const uint32_t DEFAULT_REAL_DIM_CNT = 4; | const uint32_t DEFAULT_REAL_DIM_CNT = 4; | ||||
// Const | |||||
// const | |||||
const uint32_t CONST_OP_INPUT_NUM = 0; | const uint32_t CONST_OP_INPUT_NUM = 0; | ||||
const uint32_t CONST_OP_NORMAL_WEIGHT_SIZE = 1; | const uint32_t CONST_OP_NORMAL_WEIGHT_SIZE = 1; | ||||
@@ -56,7 +56,7 @@ const int32_t FUSEDBATCHNORMGRAD_WORKSPACE_NUM = 1; | |||||
const int32_t FUSEDBATCHNORMGRAD_INPUT_NUM = 5; | const int32_t FUSEDBATCHNORMGRAD_INPUT_NUM = 5; | ||||
const int32_t FUSEDBATCHNORMGRAD_OUTPUT_NUM = 3; | const int32_t FUSEDBATCHNORMGRAD_OUTPUT_NUM = 3; | ||||
// Conv | |||||
// conv | |||||
const uint32_t CONVOLUTION_WORKSPACE_NUM = 1; | const uint32_t CONVOLUTION_WORKSPACE_NUM = 1; | ||||
const uint32_t CONVOLUTION_PAD_SIZE = 4; | const uint32_t CONVOLUTION_PAD_SIZE = 4; | ||||
const uint32_t CONVOLUTION_STRIDE_SIZE = 2; | const uint32_t CONVOLUTION_STRIDE_SIZE = 2; | ||||
@@ -104,7 +104,7 @@ const float LRN_DEFAULT_BETA = 0.75; | |||||
/// | /// | ||||
/// @ingroup domi_common | /// @ingroup domi_common | ||||
/// @brief default value of roipooling | |||||
/// @brief roipooling default value | |||||
/// | /// | ||||
const uint32_t ROIPOOLING_DEFAULT_POOLED_H = 0; | const uint32_t ROIPOOLING_DEFAULT_POOLED_H = 0; | ||||
const uint32_t ROIPOOLING_DEFAULT_POOLED_W = 0; | const uint32_t ROIPOOLING_DEFAULT_POOLED_W = 0; | ||||
@@ -115,7 +115,7 @@ const int32_t ROIPOOLING_DEFAULT_SAMPLING_RATIO = -1; | |||||
const int32_t DETECTIONOUTPUT_INPUT_SIZE = 3; | const int32_t DETECTIONOUTPUT_INPUT_SIZE = 3; | ||||
const int32_t DETECTIONOUTPUT_OUTPUT_SIZE = 2; | const int32_t DETECTIONOUTPUT_OUTPUT_SIZE = 2; | ||||
const int32_t DETECTIONOUTPUT_WORKSPACE_NUM = 1; | const int32_t DETECTIONOUTPUT_WORKSPACE_NUM = 1; | ||||
const int DETECTIONOUTPUT_CLASS_NUM = 20; | |||||
const int DETECTIONOUTPUT_CLASS_NUM = 20; // Number of background categories | |||||
const int DETECTIONOUTPUT_NUM_CLASSES_DEFAULT_VALUE = 21; | const int DETECTIONOUTPUT_NUM_CLASSES_DEFAULT_VALUE = 21; | ||||
const float DETECTIONOUTPUT_NMS_THRESHOLD_DEFAULT_VALUE = 0.3; | const float DETECTIONOUTPUT_NMS_THRESHOLD_DEFAULT_VALUE = 0.3; | ||||
const float DETECTIONOUTPUT_CONFIDENCE_THRESHOLD_DEFAULT_VALUE = 0.8; | const float DETECTIONOUTPUT_CONFIDENCE_THRESHOLD_DEFAULT_VALUE = 0.8; | ||||
@@ -392,9 +392,9 @@ const uint32_t ATTENTION_DECODER_WEIGHT_CELL1_CANDIDATE_BIAS = 14; | |||||
const uint32_t ATTENTION_DECODER_WEIGHT_EMBEDDING = 15; | const uint32_t ATTENTION_DECODER_WEIGHT_EMBEDDING = 15; | ||||
const uint32_t ATTENTION_DECODER_WEIGHT_ATTENVA = 16; | const uint32_t ATTENTION_DECODER_WEIGHT_ATTENVA = 16; | ||||
const uint32_t ATTENTION_DECODER_WEIGHT_DECODER_INITIAL = 17; | const uint32_t ATTENTION_DECODER_WEIGHT_DECODER_INITIAL = 17; | ||||
// Attention decoder weight size | // Attention decoder weight size | ||||
const uint32_t ATTENTION_DECODER_WEIGHT_SIZE = 18; | const uint32_t ATTENTION_DECODER_WEIGHT_SIZE = 18; | ||||
const uint32_t ATTENTION_DECODER_INPUT_SIZE = 2; | const uint32_t ATTENTION_DECODER_INPUT_SIZE = 2; | ||||
const uint32_t ATTENTION_DECODER_WORKSPACE_NUM = 1; | const uint32_t ATTENTION_DECODER_WORKSPACE_NUM = 1; | ||||
const uint32_t ATTENTION_DECODER_INPUT_DECODER_INPUTS = 0; | const uint32_t ATTENTION_DECODER_INPUT_DECODER_INPUTS = 0; | ||||
@@ -24,7 +24,7 @@ | |||||
/// Acquire Resource 1 | /// Acquire Resource 1 | ||||
/// MAKE_GUARD([&] { Release Resource 1 }) | /// MAKE_GUARD([&] { Release Resource 1 }) | ||||
/// Acquire Resource 2 | /// Acquire Resource 2 | ||||
/// MAKE_GUARD([&] { Release Resource 2 }) | |||||
// MAKE_GUARD([&] { Release Resource 2 }) | |||||
#define GE_MAKE_GUARD(var, callback) ge::ScopeGuard make_guard_##var(callback) | #define GE_MAKE_GUARD(var, callback) ge::ScopeGuard make_guard_##var(callback) | ||||
#define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() | #define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() | ||||
@@ -20,7 +20,6 @@ | |||||
#include <limits.h> | #include <limits.h> | ||||
#include <linux/limits.h> | #include <linux/limits.h> | ||||
#include <stdint.h> | #include <stdint.h> | ||||
#include <algorithm> | #include <algorithm> | ||||
#include <map> | #include <map> | ||||
#include <memory> | #include <memory> | ||||
@@ -49,7 +48,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_S | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; | ||||
// public property names which are supported | |||||
// Supported public properties name | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_DUMP_PATH; // Dump path | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_DUMP_PATH; // Dump path | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_LOG_PATH; // Log path | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_LOG_PATH; // Log path | ||||
@@ -1033,11 +1032,14 @@ struct BasicInfo { | |||||
uint32_t workspace_size; // workspace | uint32_t workspace_size; // workspace | ||||
uint32_t total_size; // total memory size | uint32_t total_size; // total memory size | ||||
}; | }; | ||||
#pragma pack() // Cancels single-byte alignment | #pragma pack() // Cancels single-byte alignment | ||||
} // namespace ge | } // namespace ge | ||||
namespace domi { | namespace domi { | ||||
/// @brief Data structure definition related to task sinking | /// @brief Data structure definition related to task sinking | ||||
/// Build model | |||||
enum BuildMode { | enum BuildMode { | ||||
GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | ||||
GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | ||||
@@ -65,117 +65,160 @@ | |||||
if (var) GE_CHK_CCE(ccDestroyFilterDescriptor(&var)); \ | if (var) GE_CHK_CCE(ccDestroyFilterDescriptor(&var)); \ | ||||
}); | }); | ||||
// For propagating errors when calling a function. | |||||
#define GE_RETURN_IF_ERROR(expr) \ | |||||
do { \ | |||||
const ::ge::Status _status = (expr); \ | |||||
if (_status) return _status; \ | |||||
} while (0) | |||||
#define GE_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ | #define GE_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ | ||||
do { \ | do { \ | ||||
const ::ge::Status _status = (expr); \ | const ::ge::Status _status = (expr); \ | ||||
if (_status) { \ | if (_status) { \ | ||||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
return _status; \ | return _status; \ | ||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
// check whether the parameter is true. If it is, return FAILED and record the error log | |||||
#define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ | |||||
do { \ | |||||
if (condition) { \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
} while (0) | |||||
// Check if the parameter is false. If yes, return FAILED and record the error log | // Check if the parameter is false. If yes, return FAILED and record the error log | ||||
#define GE_RETURN_WITH_LOG_IF_FALSE(condition, ...) \ | #define GE_RETURN_WITH_LOG_IF_FALSE(condition, ...) \ | ||||
do { \ | do { \ | ||||
bool _condition = (condition); \ | bool _condition = (condition); \ | ||||
if (!_condition) { \ | if (!_condition) { \ | ||||
GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
return ge::FAILED; \ | return ge::FAILED; \ | ||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
// Checks whether the parameter is true. If so, returns PARAM_INVALID and records the error log | |||||
#define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ | |||||
do { \ | |||||
if (condition) { \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | |||||
// Check if the parameter is false. If yes, return PARAM_INVALID and record the error log | |||||
#define GE_RT_PARAM_INVALID_WITH_LOG_IF_FALSE(condition, ...) \ | |||||
do { \ | |||||
bool _condition = (condition); \ | |||||
if (!_condition) { \ | |||||
GE_LOGE(__VA_ARGS__); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | |||||
// Check if the parameter is null. If yes, return PARAM_INVALID and record the error | // Check if the parameter is null. If yes, return PARAM_INVALID and record the error | ||||
#define GE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GELOGE(ge::PARAM_INVALID, "param[#val] must not be null."); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GE_LOGE(param[#val] must not be null.); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the parameter is null. If yes, return PARAM_INVALID and record the error | // Check if the parameter is null. If yes, return PARAM_INVALID and record the error | ||||
#define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GELOGE(ge::PARAM_INVALID, "param[#val] must not be null."); \ | |||||
return; \ | |||||
} \ | |||||
#define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GE_LOGE(param[#val] must not be null.); \ | |||||
return; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check whether the parameter is null. If so, execute the exec_expr expression and record the error log | // Check whether the parameter is null. If so, execute the exec_expr expression and record the error log | ||||
#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GELOGE(ge::PARAM_INVALID, "param[#val] must not be null."); \ | |||||
exec_expr; \ | |||||
} \ | |||||
#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GE_LOGE(param[#val] must not be null.); \ | |||||
exec_expr; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check whether the parameter is null. If yes, return directly and record the error log | // Check whether the parameter is null. If yes, return directly and record the error log | ||||
#define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GELOGE(ge::PARAM_INVALID, "param[#val] must not be null."); \ | |||||
return; \ | |||||
} \ | |||||
#define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GE_LOGE(param[#val] must not be null.); \ | |||||
return; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the parameter is null. If yes, return false and record the error log | // Check if the parameter is null. If yes, return false and record the error log | ||||
#define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GELOGE(ge::FAILED, "param[#val] must not be null."); \ | |||||
return false; \ | |||||
} \ | |||||
#define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GE_LOGE(param[#val] must not be null.); \ | |||||
return false; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the parameter is out of bounds | // Check if the parameter is out of bounds | ||||
#define GE_CHECK_SIZE(size) \ | |||||
do { \ | |||||
if (size == 0) { \ | |||||
GELOGE(ge::PARAM_INVALID, "param[#size] is out of range"); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_SIZE(size) \ | |||||
do { \ | |||||
if (size == 0) { \ | |||||
GE_LOGE(param[#size] is out of range); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Macros that define the size variable | // Macros that define the size variable | ||||
#define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \ | |||||
uint32_t _var_name; \ | |||||
do { \ | |||||
uint32_t _expr_size = (_expr); \ | |||||
uint32_t _sizeof_size = (_sizeof); \ | |||||
if (_expr_size > (0xffffffff) / _sizeof_size) { \ | |||||
GELOGE(ge::PARAM_INVALID, "byte_size: [#_var_name] is out of range"); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
_var_name = _sizeof_size * _expr_size; \ | |||||
#define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \ | |||||
uint32_t _var_name; \ | |||||
do { \ | |||||
uint32_t _expr_size = (_expr); \ | |||||
uint32_t _sizeof_size = (_sizeof); \ | |||||
if (_expr_size > (0xffffffff) / _sizeof_size) { \ | |||||
GE_LOGE(byte size : #_var_name is out of range); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
_var_name = _sizeof_size * _expr_size; \ | |||||
} while (0); | } while (0); | ||||
// Check if the container is empty | // Check if the container is empty | ||||
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||||
do { \ | |||||
if (vector.empty()) { \ | |||||
GELOGE(ge::FAILED, "param[#vector] is empty !"); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||||
do { \ | |||||
if (vector.empty()) { \ | |||||
GE_LOGE(param[#vector] is empty !); \ | |||||
return ge::FAILED; \ | |||||
} \ | |||||
} while (0) | |||||
#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||||
do { \ | |||||
if (size <= 0) { \ | |||||
GE_LOGE(param[#size] is not a positive number); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the value on the left is greater than or equal to the value on the right | // Check if the value on the left is greater than or equal to the value on the right | ||||
#define GE_CHECK_GE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs < rhs) { \ | |||||
GELOGE(ge::PARAM_INVALID, "param[#lhs] is less than[#rhs]"); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_GE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs < rhs) { \ | |||||
GE_LOGE(param[#lhs] is less than[#rhs]); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check if the value on the left is less than or equal to the value on the right | // Check if the value on the left is less than or equal to the value on the right | ||||
#define GE_CHECK_LE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs > rhs) { \ | |||||
GELOGE(ge::PARAM_INVALID, "param[#lhs] is greater than[#rhs]"); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_LE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs > rhs) { \ | |||||
GE_LOGE(param[#lhs] is greater than[#rhs]); \ | |||||
return ge::PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
#define GE_DELETE_NEW_SINGLE(var) \ | #define GE_DELETE_NEW_SINGLE(var) \ | ||||
@@ -194,13 +237,15 @@ | |||||
} \ | } \ | ||||
}; | }; | ||||
/// | |||||
/// @ingroup domi_common | |||||
/// @brief version of om.proto file | |||||
/// | |||||
/** | |||||
* @ingroup domi_common | |||||
* @brief version of om.proto file | |||||
*/ | |||||
static constexpr int32_t OM_PROTO_VERSION = 2; | static constexpr int32_t OM_PROTO_VERSION = 2; | ||||
// Finding an Integer Ceiling Value Without Precision Loss | |||||
/** | |||||
* Finding an Integer Ceiling Value Without Precision Loss | |||||
*/ | |||||
#define CEIL(N, n) (((N) + (n)-1) / (n)) | #define CEIL(N, n) (((N) + (n)-1) / (n)) | ||||
namespace ge { | namespace ge { | ||||
@@ -21,7 +21,7 @@ | |||||
#if !defined(__ANDROID__) && !defined(ANDROID) | #if !defined(__ANDROID__) && !defined(ANDROID) | ||||
#include "toolchain/slog.h" | #include "toolchain/slog.h" | ||||
#else | #else | ||||
#include<android/log.h> | |||||
#include <android/log.h> | |||||
#endif | #endif | ||||
#ifdef _MSC_VER | #ifdef _MSC_VER | ||||
@@ -31,16 +31,11 @@ | |||||
#endif | #endif | ||||
#if !defined(__ANDROID__) && !defined(ANDROID) | #if !defined(__ANDROID__) && !defined(ANDROID) | ||||
#define DAV_LOGI(MOD_NAME, fmt, ...) \ | |||||
dlog_info(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define DAV_LOGW(MOD_NAME, fmt, ...) \ | |||||
dlog_warn(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define DAV_LOGE(MOD_NAME, fmt, ...) \ | |||||
dlog_error(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define DAV_LOGD(MOD_NAME, fmt, ...) \ | |||||
dlog_debug(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define DAV_EVENT(MOD_NAME, fmt, ...) \ | |||||
dlog_event(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define DAV_LOGI(MOD_NAME, fmt, ...) dlog_info(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define DAV_LOGW(MOD_NAME, fmt, ...) dlog_warn(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define DAV_LOGE(MOD_NAME, fmt, ...) dlog_error(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define DAV_LOGD(MOD_NAME, fmt, ...) dlog_debug(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define DAV_EVENT(MOD_NAME, fmt, ...) dlog_event(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#else | #else | ||||
#define DAV_LOGI(MOD_NAME, fmt, ...) \ | #define DAV_LOGI(MOD_NAME, fmt, ...) \ | ||||
__android_log_print(ANDROID_LOG_INFO, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) | __android_log_print(ANDROID_LOG_INFO, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) | ||||
@@ -28,23 +28,23 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace model_runner { | namespace model_runner { | ||||
enum TaskInfoType { | enum TaskInfoType { | ||||
kCce = 0, | |||||
kTbe, | |||||
kAiCpu, | |||||
kLabelSet, | |||||
kLabelSwitch, | |||||
kLabelGoto, | |||||
kEventRecord, | |||||
kEventWait, | |||||
kFusionStart, | |||||
kFusionEnd, | |||||
kHccl, | |||||
kProfilerTrace, | |||||
kMemcpyAsync, | |||||
kStreamSwitch, | |||||
kStreamActive, | |||||
CCE = 0, | |||||
TBE, | |||||
AICPU, | |||||
LABEL_SET, | |||||
LABEL_SWITCH, | |||||
LABEL_GOTO, | |||||
EVENT_RECORD, | |||||
EVENT_WAIT, | |||||
FUSION_START, | |||||
FUSION_END, | |||||
HCCL, | |||||
PROFILER_TRACE, | |||||
MEMCPY_ASYNC, | |||||
STREAM_SWITCH, | |||||
STREAM_ACTIVE, | |||||
// Insert new task type here | // Insert new task type here | ||||
kReserved = 23 | |||||
REVSERVED = 23 | |||||
}; | }; | ||||
class TaskInfo { | class TaskInfo { | ||||
@@ -66,7 +66,7 @@ class CceTaskInfo : public TaskInfo { | |||||
CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim, | CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim, | ||||
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, | const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, | ||||
const std::vector<uint8_t> &flow_table, const std::vector<uint8_t> &args_offset, bool is_flowtable) | const std::vector<uint8_t> &flow_table, const std::vector<uint8_t> &args_offset, bool is_flowtable) | ||||
: TaskInfo(stream_id, TaskInfoType::kCce), | |||||
: TaskInfo(stream_id, TaskInfoType::CCE), | |||||
ctx_(ctx), | ctx_(ctx), | ||||
stub_func_(stub_func), | stub_func_(stub_func), | ||||
block_dim_(block_dim), | block_dim_(block_dim), | ||||
@@ -106,7 +106,7 @@ class TbeTaskInfo : public TaskInfo { | |||||
uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, uint32_t binary_size, | uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, uint32_t binary_size, | ||||
const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs, | const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs, | ||||
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs) | const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs) | ||||
: TaskInfo(stream_id, TaskInfoType::kTbe), | |||||
: TaskInfo(stream_id, TaskInfoType::TBE), | |||||
stub_func_(stub_func), | stub_func_(stub_func), | ||||
block_dim_(block_dim), | block_dim_(block_dim), | ||||
args_(args), | args_(args), | ||||
@@ -155,7 +155,7 @@ class AicpuTaskInfo : public TaskInfo { | |||||
public: | public: | ||||
AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def, | AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def, | ||||
const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs) | const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs) | ||||
: TaskInfo(stream_id, TaskInfoType::kAiCpu), | |||||
: TaskInfo(stream_id, TaskInfoType::AICPU), | |||||
so_name_(so_name), | so_name_(so_name), | ||||
kernel_name_(kernel_name), | kernel_name_(kernel_name), | ||||
node_def_(node_def), | node_def_(node_def), | ||||
@@ -192,21 +192,21 @@ class LabelTaskInfo : public TaskInfo { | |||||
class LabelSetTaskInfo : public LabelTaskInfo { | class LabelSetTaskInfo : public LabelTaskInfo { | ||||
public: | public: | ||||
LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) | LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) | ||||
: LabelTaskInfo(stream_id, TaskInfoType::kLabelSet, label_id) {} | |||||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {} | |||||
~LabelSetTaskInfo() override {} | ~LabelSetTaskInfo() override {} | ||||
}; | }; | ||||
class LabelSwitchTaskInfo : public LabelTaskInfo { | class LabelSwitchTaskInfo : public LabelTaskInfo { | ||||
public: | public: | ||||
LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id) | LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id) | ||||
: LabelTaskInfo(stream_id, TaskInfoType::kLabelSwitch, label_id) {} | |||||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {} | |||||
~LabelSwitchTaskInfo() override {} | ~LabelSwitchTaskInfo() override {} | ||||
}; | }; | ||||
class LabelGotoTaskInfo : public LabelTaskInfo { | class LabelGotoTaskInfo : public LabelTaskInfo { | ||||
public: | public: | ||||
LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) | LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) | ||||
: LabelTaskInfo(stream_id, TaskInfoType::kLabelGoto, label_id) {} | |||||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {} | |||||
~LabelGotoTaskInfo() override {} | ~LabelGotoTaskInfo() override {} | ||||
}; | }; | ||||
@@ -225,26 +225,26 @@ class EventTaskInfo : public TaskInfo { | |||||
class EventRecordTaskInfo : public EventTaskInfo { | class EventRecordTaskInfo : public EventTaskInfo { | ||||
public: | public: | ||||
EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id) | EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id) | ||||
: EventTaskInfo(stream_id, TaskInfoType::kEventRecord, event_id) {} | |||||
: EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {} | |||||
~EventRecordTaskInfo() override {} | ~EventRecordTaskInfo() override {} | ||||
}; | }; | ||||
class EventWaitTaskInfo : public EventTaskInfo { | class EventWaitTaskInfo : public EventTaskInfo { | ||||
public: | public: | ||||
EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id) | EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id) | ||||
: EventTaskInfo(stream_id, TaskInfoType::kEventWait, event_id) {} | |||||
: EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {} | |||||
~EventWaitTaskInfo() override {} | ~EventWaitTaskInfo() override {} | ||||
}; | }; | ||||
class FusionStartTaskInfo : public TaskInfo { | class FusionStartTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::kFusionStart) {} | |||||
FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {} | |||||
~FusionStartTaskInfo() override {} | ~FusionStartTaskInfo() override {} | ||||
}; | }; | ||||
class FusionEndTaskInfo : public TaskInfo { | class FusionEndTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::kFusionEnd) {} | |||||
FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {} | |||||
~FusionEndTaskInfo() override {} | ~FusionEndTaskInfo() override {} | ||||
}; | }; | ||||
@@ -256,7 +256,7 @@ class HcclTaskInfo : public TaskInfo { | |||||
int64_t op_type, int64_t data_type, std::function<bool(void *, void *)> hcom_bind_model, | int64_t op_type, int64_t data_type, std::function<bool(void *, void *)> hcom_bind_model, | ||||
std::function<bool(void *)> hcom_unbind_model, | std::function<bool(void *)> hcom_unbind_model, | ||||
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task) | std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task) | ||||
: TaskInfo(stream_id, TaskInfoType::kHccl), | |||||
: TaskInfo(stream_id, TaskInfoType::HCCL), | |||||
hccl_type_(hccl_type), | hccl_type_(hccl_type), | ||||
input_data_addr_(input_data_addr), | input_data_addr_(input_data_addr), | ||||
output_data_addr_(output_data_addr), | output_data_addr_(output_data_addr), | ||||
@@ -313,7 +313,7 @@ class HcclTaskInfo : public TaskInfo { | |||||
class ProfilerTraceTaskInfo : public TaskInfo { | class ProfilerTraceTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) | ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) | ||||
: TaskInfo(stream_id, TaskInfoType::kProfilerTrace), log_id_(log_id), notify_(notify), flat_(flat) {} | |||||
: TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {} | |||||
~ProfilerTraceTaskInfo() override {} | ~ProfilerTraceTaskInfo() override {} | ||||
uint64_t log_id() const { return log_id_; } | uint64_t log_id() const { return log_id_; } | ||||
@@ -329,7 +329,7 @@ class ProfilerTraceTaskInfo : public TaskInfo { | |||||
class MemcpyAsyncTaskInfo : public TaskInfo { | class MemcpyAsyncTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind) | MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind) | ||||
: TaskInfo(stream_id, TaskInfoType::kMemcpyAsync), | |||||
: TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC), | |||||
dst_(dst), | dst_(dst), | ||||
dst_max_(dst_max), | dst_max_(dst_max), | ||||
src_(src), | src_(src), | ||||
@@ -355,7 +355,7 @@ class StreamSwitchTaskInfo : public TaskInfo { | |||||
public: | public: | ||||
StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond, | StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond, | ||||
int64_t data_type) | int64_t data_type) | ||||
: TaskInfo(stream_id, TaskInfoType::kStreamSwitch), | |||||
: TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH), | |||||
true_stream_id_(true_stream_id), | true_stream_id_(true_stream_id), | ||||
input_addr_(input_addr), | input_addr_(input_addr), | ||||
value_addr_(value_addr), | value_addr_(value_addr), | ||||
@@ -380,7 +380,7 @@ class StreamSwitchTaskInfo : public TaskInfo { | |||||
class StreamActiveTaskInfo : public TaskInfo { | class StreamActiveTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id) | StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id) | ||||
: TaskInfo(stream_id, TaskInfoType::kStreamActive), active_stream_id_(active_stream_id) {} | |||||
: TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {} | |||||
~StreamActiveTaskInfo() override {} | ~StreamActiveTaskInfo() override {} | ||||
uint32_t active_stream_id() const { return active_stream_id_; } | uint32_t active_stream_id() const { return active_stream_id_; } | ||||
@@ -24,6 +24,7 @@ extern "C" { | |||||
#endif | #endif | ||||
typedef uint32_t Status_t; | typedef uint32_t Status_t; | ||||
using Status_t = uint32_t; | |||||
typedef void *OpAttr_t; | typedef void *OpAttr_t; | ||||
typedef void *OpTensor_t; | typedef void *OpTensor_t; | ||||
@@ -39,4 +39,4 @@ class MemoryAssigner { | |||||
ge::ComputeGraphPtr compute_graph_; | ge::ComputeGraphPtr compute_graph_; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_ | |||||
#endif // INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_ |
@@ -24,7 +24,6 @@ | |||||
#include <unordered_map> | #include <unordered_map> | ||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include "framework/common/fmk_error_codes.h" | #include "framework/common/fmk_error_codes.h" | ||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "register/register_fmk_types.h" | #include "register/register_fmk_types.h" | ||||
@@ -40,10 +39,10 @@ using std::unordered_map; | |||||
using std::vector; | using std::vector; | ||||
namespace ge { | namespace ge { | ||||
/// | |||||
/// @ingroup domi_omg | |||||
/// @brief run model | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief run model | |||||
*/ | |||||
enum RunMode { | enum RunMode { | ||||
kGeOmModel = 0, // generate offline model file | kGeOmModel = 0, // generate offline model file | ||||
kModelToJson = 1, // convert to JSON file | kModelToJson = 1, // convert to JSON file | ||||
@@ -119,12 +118,20 @@ struct OmgContext { | |||||
} // namespace ge | } // namespace ge | ||||
namespace domi { | namespace domi { | ||||
/// | |||||
/// @ingroup domi_omg | |||||
/// @brief get OMG context | |||||
/// @return OmgContext context | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief get OMG context | |||||
* @return OmgContext context | |||||
*/ | |||||
ge::OmgContext &GetContext(); | ge::OmgContext &GetContext(); | ||||
struct TEBinInfo { | |||||
// It is obsolete. It will be automatically obtained from the binfilename field of the JSON file later. | |||||
// To be compatible with use cases written by previous users, fields are not deleted.(2018.11.21) | |||||
std::string bin_file_path; | |||||
std::string json_file_path; | |||||
std::string ddk_version; | |||||
}; | |||||
} // namespace domi | } // namespace domi | ||||
#endif // INC_FRAMEWORK_OMG_OMG_INNER_TYPES_H_ | #endif // INC_FRAMEWORK_OMG_OMG_INNER_TYPES_H_ |
@@ -23,7 +23,7 @@ | |||||
#include <string> | #include <string> | ||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include <deque> | |||||
#include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
#include "graph/anchor.h" | #include "graph/anchor.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
@@ -1,4 +1,4 @@ | |||||
/** | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | * Copyright 2019-2020 Huawei Technologies Co., Ltd | ||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
@@ -149,7 +149,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | ||||
@@ -468,9 +467,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string | ||||
GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH; | |||||
GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string | ||||
GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH; | |||||
GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH; | |||||
// Decode_bbox | // Decode_bbox | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DECODE_BBOX_ATTR_DECODECLIP; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DECODE_BBOX_ATTR_DECODECLIP; | ||||
@@ -767,7 +766,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATT | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DATATYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DATATYPE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_DATATYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_DATATYPE; | ||||
@@ -776,3 +774,4 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_ | |||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | ||||
/*lint +e618*/ |
@@ -27,15 +27,15 @@ class GEContext { | |||||
graphStatus GetOption(const std::string &key, std::string &option); | graphStatus GetOption(const std::string &key, std::string &option); | ||||
uint64_t SessionId(); | uint64_t SessionId(); | ||||
uint32_t DeviceId(); | uint32_t DeviceId(); | ||||
uint64_t JobId(); | |||||
uint64_t TraceId(); | |||||
void Init(); | void Init(); | ||||
void SetCtxDeviceId(uint32_t device_id); | void SetCtxDeviceId(uint32_t device_id); | ||||
private: | private: | ||||
uint64_t session_id_ = 0; | uint64_t session_id_ = 0; | ||||
uint32_t device_id_ = 0; | uint32_t device_id_ = 0; | ||||
uint64_t job_id_ = 0; | |||||
}; | |||||
uint64_t trace_id_ = 0; | |||||
}; // class GEContext | |||||
/// Get context | /// Get context | ||||
/// @return | /// @return | ||||
@@ -23,20 +23,22 @@ | |||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
using std::string; | |||||
using std::map; | using std::map; | ||||
using std::string; | |||||
namespace ge { | namespace ge { | ||||
class GEThreadLocalContext { | class GEThreadLocalContext { | ||||
public: | public: | ||||
graphStatus GetOption(const string &key, string &option); | graphStatus GetOption(const string &key, string &option); | ||||
void SetGlobalOption(map<std::string, string> options_map); | |||||
void SetGraphOption(map<std::string, string> options_map); | |||||
void SetSessionOption(map<std::string, string> options_map); | void SetSessionOption(map<std::string, string> options_map); | ||||
void SetGlobalOption(map<std::string, string> options_map); | |||||
private: | private: | ||||
map<string, string> graph_options_; | |||||
map<string, string> session_options_; | map<string, string> session_options_; | ||||
map<string, string> global_options_; | map<string, string> global_options_; | ||||
}; | |||||
}; // class GEThreadLocalContext | |||||
GEThreadLocalContext &GetThreadLocalContext(); | GEThreadLocalContext &GetThreadLocalContext(); | ||||
} // namespace ge | } // namespace ge | ||||
@@ -31,6 +31,8 @@ using std::map; | |||||
using std::string; | using std::string; | ||||
using std::vector; | using std::vector; | ||||
/*lint -e148*/ | |||||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { | class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { | ||||
public: | public: | ||||
Model(); | Model(); | ||||
@@ -65,7 +67,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { | |||||
graphStatus Save(Buffer &buffer) const; | graphStatus Save(Buffer &buffer) const; | ||||
graphStatus SaveToFile(const string &file_name) const; | graphStatus SaveToFile(const string &file_name) const; | ||||
// Model will be rewritten | |||||
// Model will be rewrite | |||||
static graphStatus Load(const uint8_t *data, size_t len, Model &model); | static graphStatus Load(const uint8_t *data, size_t len, Model &model); | ||||
graphStatus Load(ge::proto::ModelDef &model_def); | graphStatus Load(ge::proto::ModelDef &model_def); | ||||
graphStatus LoadFromFile(const string &file_name); | graphStatus LoadFromFile(const string &file_name); | ||||
@@ -89,6 +91,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { | |||||
std::string platform_version_{""}; | std::string platform_version_{""}; | ||||
Graph graph_; | Graph graph_; | ||||
}; | }; | ||||
/*lint +e148*/ | |||||
} // namespace ge | } // namespace ge | ||||
using ModelPtr = std::shared_ptr<ge::Model>; | using ModelPtr = std::shared_ptr<ge::Model>; | ||||
@@ -20,14 +20,14 @@ | |||||
#include <map> | #include <map> | ||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <unordered_set> | |||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include <unordered_set> | |||||
#include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
#include "utils/attr_utils.h" | |||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
#include "graph/range_vistor.h" | #include "graph/range_vistor.h" | ||||
#include "utils/attr_utils.h" | |||||
namespace ge { | namespace ge { | ||||
class ComputeGraph; | class ComputeGraph; | ||||
@@ -20,7 +20,6 @@ | |||||
#include <atomic> | #include <atomic> | ||||
#include <memory> | #include <memory> | ||||
#include <vector> | #include <vector> | ||||
namespace ge { | namespace ge { | ||||
#define USR_TYPE_DEC(type, name) \ | #define USR_TYPE_DEC(type, name) \ | ||||
inline void set_##name(const type &value) { name = value; } \ | inline void set_##name(const type &value) { name = value; } \ |
@@ -15,10 +15,8 @@ | |||||
*/ | */ | ||||
#include "graph/anchor.h" | #include "graph/anchor.h" | ||||
#include <algorithm> | #include <algorithm> | ||||
#include <cstring> | #include <cstring> | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
@@ -53,6 +51,7 @@ void Anchor::UnlinkAll() noexcept { | |||||
if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { | if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { | ||||
GELOGW("unlink peer_anchor_ptr failed."); | GELOGW("unlink peer_anchor_ptr failed."); | ||||
} | } | ||||
} while (!peer_anchors_.empty()); | } while (!peer_anchors_.empty()); | ||||
} | } | ||||
} | } | ||||
@@ -70,10 +69,10 @@ graphStatus Anchor::Unlink(const AnchorPtr &peer) { | |||||
GE_IF_BOOL_EXEC(it == peer_anchors_.end(), GELOGW("this anchor is not connected to peer"); return GRAPH_FAILED); | GE_IF_BOOL_EXEC(it == peer_anchors_.end(), GELOGW("this anchor is not connected to peer"); return GRAPH_FAILED); | ||||
auto it_peer = | auto it_peer = | ||||
std::find_if(peer->peer_anchors_.begin(), peer->peer_anchors_.end(), [this](const std::weak_ptr<Anchor> &an) { | |||||
auto anchor = an.lock(); | |||||
return Equal(anchor); | |||||
}); | |||||
std::find_if(peer->peer_anchors_.begin(), peer->peer_anchors_.end(), [this](const std::weak_ptr<Anchor> &an) { | |||||
auto anchor = an.lock(); | |||||
return Equal(anchor); | |||||
}); | |||||
GE_CHK_BOOL_RET_STATUS(it_peer != peer->peer_anchors_.end(), GRAPH_FAILED, "peer is not connected to this anchor"); | GE_CHK_BOOL_RET_STATUS(it_peer != peer->peer_anchors_.end(), GRAPH_FAILED, "peer is not connected to this anchor"); | ||||
@@ -15,7 +15,6 @@ | |||||
*/ | */ | ||||
#include "external/graph/attr_value.h" | #include "external/graph/attr_value.h" | ||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
@@ -15,7 +15,6 @@ | |||||
*/ | */ | ||||
#include "graph/buffer.h" | #include "graph/buffer.h" | ||||
#include "proto/ge_ir.pb.h" | #include "proto/ge_ir.pb.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
@@ -15,7 +15,9 @@ | |||||
*/ | */ | ||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include <deque> | #include <deque> | ||||
#include "./format_refiner.h" | #include "./format_refiner.h" | ||||
#include "./ge_context.h" | #include "./ge_context.h" | ||||
#include "debug/ge_attr_define.h" | #include "debug/ge_attr_define.h" | ||||
@@ -95,7 +97,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(co | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphAttrsAreEqual( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphAttrsAreEqual( | ||||
const ComputeGraph &r_graph) const { | |||||
const ComputeGraph &r_graph) const { | |||||
// ProtoMsgOwner <::google::protobuf::Message> is temporarily ignored | // ProtoMsgOwner <::google::protobuf::Message> is temporarily ignored | ||||
if ((this->attrs_.protoMsg_ != nullptr) && (r_graph.attrs_.protoMsg_ != nullptr)) { | if ((this->attrs_.protoMsg_ != nullptr) && (r_graph.attrs_.protoMsg_ != nullptr)) { | ||||
const auto &proto_attr_map = *(this->attrs_.protoMsg_); | const auto &proto_attr_map = *(this->attrs_.protoMsg_); | ||||
@@ -122,7 +124,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphAttrsAreE | |||||
/// Since there may be different input nodes | /// Since there may be different input nodes | ||||
/// chosen by user in the same graph, special judgment is needed | /// chosen by user in the same graph, special judgment is needed | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNodePtrIsEqual( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNodePtrIsEqual( | ||||
const std::vector<NodePtr> &left_nodes, const std::vector<NodePtr> &right_nodes) const { | |||||
const std::vector<NodePtr> &left_nodes, const std::vector<NodePtr> &right_nodes) const { | |||||
const auto left_nodes_size = left_nodes.size(); | const auto left_nodes_size = left_nodes.size(); | ||||
const auto right_nodes_size = right_nodes.size(); | const auto right_nodes_size = right_nodes.size(); | ||||
if (left_nodes_size != right_nodes_size) { | if (left_nodes_size != right_nodes_size) { | ||||
@@ -151,7 +153,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNod | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual( | ||||
const ComputeGraph &r_graph) const { | |||||
const ComputeGraph &r_graph) const { | |||||
return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.sub_graph_.size()") && | return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.sub_graph_.size()") && | ||||
IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") && | IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") && | ||||
VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) && | VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) && | ||||
@@ -472,14 +474,14 @@ graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | |||||
} | } | ||||
} | } | ||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
node->GetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor | |||||
: node->GetOutControlAnchor()->GetPeerAnchors()) { | |||||
GE_CHECK_NOTNULL(peer_in_anchor); | |||||
auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); | |||||
if (iter != map_in_edge_num.end() && --iter->second == 0) { | |||||
stack.push_back(peer_in_anchor->GetOwnerNode()); | |||||
} | |||||
}) | |||||
node->GetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor | |||||
: node->GetOutControlAnchor()->GetPeerAnchors()) { | |||||
GE_CHECK_NOTNULL(peer_in_anchor); | |||||
auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); | |||||
if (iter != map_in_edge_num.end() && --iter->second == 0) { | |||||
stack.push_back(peer_in_anchor->GetOwnerNode()); | |||||
} | |||||
}) | |||||
} | } | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
@@ -521,28 +523,30 @@ graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec, | |||||
graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | ||||
std::map<string, NodePtr> &breadth_node_map) { | std::map<string, NodePtr> &breadth_node_map) { | ||||
for (const auto &anchor : node->GetAllOutDataAnchors()) { | |||||
for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { | |||||
auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); | |||||
if (iter != map_in_edge_num.end() && --iter->second == 0) { | |||||
(void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); | |||||
} | |||||
for (const auto &anchor : node->GetAllOutDataAnchors()) { | |||||
for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { | |||||
auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); | |||||
if (iter != map_in_edge_num.end() && 0 == --iter->second) { | |||||
(void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); | |||||
} | } | ||||
for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { | |||||
auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); | |||||
if (iter != map_in_edge_num.end() && --iter->second == 0) { | |||||
(void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); | |||||
} | |||||
} | |||||
for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { | |||||
auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); | |||||
if (iter != map_in_edge_num.end() && 0 == --iter->second) { | |||||
(void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); | |||||
} | } | ||||
} | } | ||||
GE_IF_BOOL_EXEC( | |||||
node->GetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor | |||||
: node->GetOutControlAnchor()->GetPeerAnchors()) { | |||||
auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); | |||||
if (iter != map_in_edge_num.end() && --iter->second == 0) { | |||||
(void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); | |||||
} | |||||
}) | |||||
} | |||||
GE_IF_BOOL_EXEC( | |||||
node->GetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor | |||||
: node->GetOutControlAnchor()->GetPeerAnchors()) { | |||||
auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); | |||||
if (iter != map_in_edge_num.end() && 0 == --iter->second) { | |||||
(void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); | |||||
} | |||||
}) | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
@@ -636,7 +640,7 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||||
/// 2. Compare two indices, if not match, swap the positions of two inputs | /// 2. Compare two indices, if not match, swap the positions of two inputs | ||||
/// *: Remind: stack is reverse-order | /// *: Remind: stack is reverse-order | ||||
for (size_t i = 0; i < stack.size(); ++i) { | for (size_t i = 0; i < stack.size(); ++i) { | ||||
// [stack: should not be null] | |||||
//[stack: should not be null] | |||||
for (size_t j = i + 1; j < stack.size(); ++j) { | for (size_t j = i + 1; j < stack.size(); ++j) { | ||||
// If not found in 'inputs_order_', skip it | // If not found in 'inputs_order_', skip it | ||||
auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); | auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); | ||||
@@ -721,7 +725,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { | |||||
} | } | ||||
} | } | ||||
GE_IF_BOOL_EXEC(node->GetOutControlAnchor() == nullptr, GELOGE(GRAPH_FAILED, "Out control anchor is null"); | GE_IF_BOOL_EXEC(node->GetOutControlAnchor() == nullptr, GELOGE(GRAPH_FAILED, "Out control anchor is null"); | ||||
return); | |||||
return ); | |||||
for (const auto &peer_in_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) { | for (const auto &peer_in_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) { | ||||
GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | ||||
GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | ||||
@@ -745,7 +749,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Isolate | |||||
GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_data_anchor, in_data_anchor) == GRAPH_SUCCESS, | GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_data_anchor, in_data_anchor) == GRAPH_SUCCESS, | ||||
return GRAPH_FAILED, "remove edge failed"); | return GRAPH_FAILED, "remove edge failed"); | ||||
GE_IF_BOOL_EXEC(pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANT || | GE_IF_BOOL_EXEC(pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANT || | ||||
pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANTOP, | |||||
pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANTOP, | |||||
continue); | continue); | ||||
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { | for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { | ||||
for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | ||||
@@ -16,7 +16,6 @@ | |||||
#ifndef COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | #ifndef COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | ||||
#define COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | #define COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | ||||
#include <limits.h> | #include <limits.h> | ||||
#include <stdint.h> | #include <stdint.h> | ||||
#include <algorithm> | #include <algorithm> | ||||
@@ -26,7 +25,7 @@ | |||||
#include <vector> | #include <vector> | ||||
namespace ge { | namespace ge { | ||||
#define GE_REGISTER_OPTYPE(var_name, str_name) static const char *var_name __attribute__((unused)) = str_name | |||||
#define GE_REGISTER_OPTYPE(var_name, str_name) static const char* var_name __attribute__((unused)) = str_name | |||||
GE_REGISTER_OPTYPE(DATA, "Data"); | GE_REGISTER_OPTYPE(DATA, "Data"); | ||||
GE_REGISTER_OPTYPE(AIPPDATA, "AippData"); | GE_REGISTER_OPTYPE(AIPPDATA, "AippData"); | ||||
@@ -249,5 +248,5 @@ static const char* const kAippConvOpNmae = "aipp_conv_op"; | |||||
/// @brief Operator configuration item separator | /// @brief Operator configuration item separator | ||||
/// | /// | ||||
static const char* const kOpConfDelimiter = ":"; | static const char* const kOpConfDelimiter = ":"; | ||||
}; // namespace ge | |||||
}; // namespace ge | |||||
#endif // COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | #endif // COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ |
@@ -39,125 +39,126 @@ | |||||
#endif | #endif | ||||
#define GE_RETURN_IF_ERROR(expr) \ | #define GE_RETURN_IF_ERROR(expr) \ | ||||
do { \ | |||||
const ::ge::optStatus _status = (expr); \ | |||||
if (_status) return _status; \ | |||||
do { \ | |||||
const ::ge::optStatus _status = (expr); \ | |||||
if (_status) return _status; \ | |||||
} while (0) | } while (0) | ||||
#define GE_RETURN_WITH_LOG_IF_INFO(expr, ...) \ | #define GE_RETURN_WITH_LOG_IF_INFO(expr, ...) \ | ||||
do { \ | |||||
const ::ge::optStatus _status = (expr); \ | |||||
if (_status) { \ | |||||
GELOGI(__VA_ARGS__); \ | |||||
return _status; \ | |||||
} \ | |||||
do { \ | |||||
const ::ge::optStatus _status = (expr); \ | |||||
if (_status) { \ | |||||
GELOGI(__VA_ARGS__); \ | |||||
return _status; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Verify whether the parameter is true. If yes, return graph failed and record the error log | // Verify whether the parameter is true. If yes, return graph failed and record the error log | ||||
#define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ | #define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ | ||||
do { \ | |||||
if (condition) { \ | |||||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
return ge::GRAPH_FAILED; \ | |||||
} \ | |||||
do { \ | |||||
if (condition) { \ | |||||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
return ge::GRAPH_FAILED; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Verify whether the parameter is false. If yes, return graph failed and record the error log | // Verify whether the parameter is false. If yes, return graph failed and record the error log | ||||
#define GE_RETURN_WITH_LOG_IF_FALSE(condition, ...) \ | #define GE_RETURN_WITH_LOG_IF_FALSE(condition, ...) \ | ||||
do { \ | |||||
bool _condition = (condition); \ | |||||
if (!_condition) { \ | |||||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
return ge::GRAPH_FAILED; \ | |||||
} \ | |||||
do { \ | |||||
bool _condition = (condition); \ | |||||
if (!_condition) { \ | |||||
GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
return ge::GRAPH_FAILED; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Verify whether the parameter is true. If yes, return GRAPH_PARAM_INVALID and record the error log | // Verify whether the parameter is true. If yes, return GRAPH_PARAM_INVALID and record the error log | ||||
#define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ | |||||
do { \ | |||||
if (condition) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ | |||||
do { \ | |||||
if (condition) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Verify whether the parameter is false. If yes, return GRAPH_PARAM_INVALID and record the error log | // Verify whether the parameter is false. If yes, return GRAPH_PARAM_INVALID and record the error log | ||||
#define GE_RT_PARAM_INVALID_WITH_LOG_IF_FALSE(condition, ...) \ | #define GE_RT_PARAM_INVALID_WITH_LOG_IF_FALSE(condition, ...) \ | ||||
do { \ | |||||
bool _condition = (condition); \ | |||||
if (!_condition) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
do { \ | |||||
bool _condition = (condition); \ | |||||
if (!_condition) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log | // Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log | ||||
#define GE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log | // Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log | ||||
#define GE_CHECK_NOTNULL_EXEC(val, expr) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \ | |||||
expr; \ | |||||
} \ | |||||
#define GE_CHECK_NOTNULL_EXEC(val, expr) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \ | |||||
expr; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Verify whether the parameter is null. If yes, return false and record the error log | // Verify whether the parameter is null. If yes, return false and record the error log | ||||
#define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GELOGE(ge::GRAPH_FAILED, "param[%s] must not be null.", #val); \ | |||||
return false; \ | |||||
} \ | |||||
#define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if (val == nullptr) { \ | |||||
GELOGE(ge::GRAPH_FAILED, "param[%s] must not be null.", #val); \ | |||||
return false; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check whether the parameter is out of range | // Check whether the parameter is out of range | ||||
#define GE_CHECK_SIZE(size) \ | |||||
do { \ | |||||
if (size == 0) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_SIZE(size) \ | |||||
do { \ | |||||
if (size == 0) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
/// | /// | ||||
/// @ingroup GE_common | /// @ingroup GE_common | ||||
/// eg:GE_DEFINE_BYTE_SIZE(filter_byte, filter.data().size(), sizeof(float)); | |||||
/// | /// | ||||
#define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \ | |||||
uint32_t _var_name; \ | |||||
do { \ | |||||
uint32_t _expr_size = (_expr); \ | |||||
uint32_t _sizeof_size = (_sizeof); \ | |||||
if (_expr_size > (0xffffffff) / _sizeof_size) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "byte size : %s is out of range", #_var_name); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
_var_name = _sizeof_size * _expr_size; \ | |||||
#define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \ | |||||
uint32_t _var_name; \ | |||||
do { \ | |||||
uint32_t _expr_size = (_expr); \ | |||||
uint32_t _sizeof_size = (_sizeof); \ | |||||
if (_expr_size > (0xffffffff) / _sizeof_size) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "byte size : %s is out of range", #_var_name); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
_var_name = _sizeof_size * _expr_size; \ | |||||
} while (0); | } while (0); | ||||
// Check whether the container is empty | // Check whether the container is empty | ||||
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||||
do { \ | |||||
if (vector.empty()) { \ | |||||
GELOGE(ge::GRAPH_FAILED, "param[#vector] is empty", #vector); \ | |||||
return ge::GRAPH_FAILED; \ | |||||
} \ | |||||
#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||||
do { \ | |||||
if (vector.empty()) { \ | |||||
GELOGE(ge::GRAPH_FAILED, "param[#vector] is empty", #vector); \ | |||||
return ge::GRAPH_FAILED; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check whether the container is empty and return the specified status code | // Check whether the container is empty and return the specified status code | ||||
#define GE_CHECK_VECTOR_NOT_EMPTY_RET_STATUS(vector, _status) \ | #define GE_CHECK_VECTOR_NOT_EMPTY_RET_STATUS(vector, _status) \ | ||||
do { \ | |||||
if (vector.empty()) { \ | |||||
GELOGE(_status, "param[%s] is empty", #vector); \ | |||||
return _status; \ | |||||
} \ | |||||
do { \ | |||||
if (vector.empty()) { \ | |||||
GELOGE(_status, "param[%s] is empty", #vector); \ | |||||
return _status; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
/// | /// | ||||
@@ -166,102 +167,102 @@ | |||||
/// It is usually placed under private | /// It is usually placed under private | ||||
/// | /// | ||||
#define GE_DISALLOW_COPY_AND_ASSIGN(TypeName) \ | #define GE_DISALLOW_COPY_AND_ASSIGN(TypeName) \ | ||||
TypeName(const TypeName &) = delete; \ | |||||
TypeName(const TypeName &) = delete; \ | |||||
void operator=(const TypeName &) = delete | void operator=(const TypeName &) = delete | ||||
/// Check whether the size is 0 or out of range | /// Check whether the size is 0 or out of range | ||||
/// @param:size:Size to be verified | /// @param:size:Size to be verified | ||||
#define GE_CHECK_SIZE_RANGE(size) \ | |||||
do { \ | |||||
if (size == 0 || size >= UINT_MAX / 4) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_SIZE_RANGE(size) \ | |||||
do { \ | |||||
if (size == 0 || size >= UINT_MAX / 4) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
#define GE_CHECK_SHORT_SIZE_RANGE(size) \ | |||||
do { \ | |||||
if (size == 0 || size >= UINT_MAX / 2) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_SHORT_SIZE_RANGE(size) \ | |||||
do { \ | |||||
if (size == 0 || size >= UINT_MAX / 2) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||||
do { \ | |||||
if (size <= 0) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not a positive number", #size); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||||
do { \ | |||||
if (size <= 0) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not a positive number", #size); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
#define GE_CHECK_POSITIVE_SHORT_SIZE_RANGE(size) \ | |||||
do { \ | |||||
if (size <= 0 || size == 0 || size >= UINT_MAX / 4) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_POSITIVE_SHORT_SIZE_RANGE(size) \ | |||||
do { \ | |||||
if (size <= 0 || size == 0 || size >= UINT_MAX / 4) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Verify that the value on the left is greater than or equal to the value on the right | // Verify that the value on the left is greater than or equal to the value on the right | ||||
#define GE_CHECK_GE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs < rhs) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is less than[%s]", #lhs, #rhs); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_GE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs < rhs) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is less than[%s]", #lhs, #rhs); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check whether the parameters are equal | // Check whether the parameters are equal | ||||
#define GE_CHECK_EQ(val1, val2) \ | |||||
do { \ | |||||
if (val1 != val2) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not equals to[%s]", #val1, #val2); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_EQ(val1, val2) \ | |||||
do { \ | |||||
if (val1 != val2) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not equals to[%s]", #val1, #val2); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Verify that the value on the left is less than or equal to the value on the right | // Verify that the value on the left is less than or equal to the value on the right | ||||
#define GE_CHECK_LE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs > rhs) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is greater than[%s]", #lhs, #rhs); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_LE(lhs, rhs) \ | |||||
do { \ | |||||
if (lhs > rhs) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is greater than[%s]", #lhs, #rhs); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// Check whether the parameters are equal | // Check whether the parameters are equal | ||||
#define GE_CHECK_EQ_WITH_LOG(val1, val2, ...) \ | |||||
do { \ | |||||
if (val1 != val2) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
#define GE_CHECK_EQ_WITH_LOG(val1, val2, ...) \ | |||||
do { \ | |||||
if (val1 != val2) { \ | |||||
GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ | |||||
return ge::GRAPH_PARAM_INVALID; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
// If expr is false, the custom statement is executed | // If expr is false, the custom statement is executed | ||||
#define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ | #define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ | ||||
do { \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
exec_expr; \ | |||||
} \ | |||||
do { \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
exec_expr; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
#define GE_DELETE_NEW_SINGLE(var) \ | #define GE_DELETE_NEW_SINGLE(var) \ | ||||
do { \ | |||||
if (var != nullptr) { \ | |||||
delete var; \ | |||||
var = nullptr; \ | |||||
} \ | |||||
do { \ | |||||
if (var != nullptr) { \ | |||||
delete var; \ | |||||
var = nullptr; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
#define GE_DELETE_NEW_ARRAY(var) \ | #define GE_DELETE_NEW_ARRAY(var) \ | ||||
do { \ | |||||
if (var != nullptr) { \ | |||||
delete[] var; \ | |||||
var = nullptr; \ | |||||
} \ | |||||
do { \ | |||||
if (var != nullptr) { \ | |||||
delete[] var; \ | |||||
var = nullptr; \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
template <typename T, typename... Args> | template <typename T, typename... Args> | ||||
@@ -31,9 +31,9 @@ | |||||
namespace ge { | namespace ge { | ||||
std::unordered_set<std::string> control_anchor; | std::unordered_set<std::string> control_anchor; | ||||
std::vector<string> types = { | std::vector<string> types = { | ||||
"DT_FLOAT", "DT_FLOAT16", "DT_INT8", "DT_INT32", "DT_UINT8", "", | |||||
"DT_INT16", "DT_UINT16", "DT_UINT32", "DT_INT64", "DT_UINT64", "DT_DOUBLE", | |||||
"DT_BOOL", "DT_DUAL", "DT_DUAL_SUB_INT8", "DT_DUAL_SUB_UINT8", "DT_UNDEFINED"}; | |||||
"DT_FLOAT", "DT_FLOAT16", "DT_INT8", "DT_INT32", "DT_UINT8", "", | |||||
"DT_INT16", "DT_UINT16", "DT_UINT32", "DT_INT64", "DT_UINT64", "DT_DOUBLE", | |||||
"DT_BOOL", "DT_DUAL", "DT_DUAL_SUB_INT8", "DT_DUAL_SUB_UINT8", "DT_UNDEFINED"}; | |||||
std::vector<string> formats = {"FORMAT_NCHW", | std::vector<string> formats = {"FORMAT_NCHW", | ||||
"FORMAT_NHWC", | "FORMAT_NHWC", | ||||
@@ -92,7 +92,7 @@ void GraphDebugPrinter::DumpNodeToDot(const NodePtr node, std::ostringstream &ou | |||||
auto input_anchors = node->GetAllInDataAnchors(); | auto input_anchors = node->GetAllInDataAnchors(); | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL_EXEC(op_desc, return); | |||||
GE_CHECK_NOTNULL_EXEC(op_desc, return ); | |||||
if (!input_anchors.empty()) { | if (!input_anchors.empty()) { | ||||
out_ << TAB << TAB << "<tr>"; | out_ << TAB << TAB << "<tr>"; | ||||
} | } | ||||
@@ -138,7 +138,7 @@ void GraphDebugPrinter::DumpEdgeToDot(const NodePtr node, std::ostringstream &ou | |||||
} | } | ||||
auto all_out_anchor = node->GetAllOutDataAnchors(); | auto all_out_anchor = node->GetAllOutDataAnchors(); | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL_EXEC(op_desc, return); | |||||
GE_CHECK_NOTNULL_EXEC(op_desc, return ); | |||||
for (const auto &anchor : all_out_anchor) { | for (const auto &anchor : all_out_anchor) { | ||||
auto src_anchor = anchor; | auto src_anchor = anchor; | ||||
auto src_node_name = node->GetName(); | auto src_node_name = node->GetName(); | ||||
@@ -170,12 +170,12 @@ void GraphDebugPrinter::DumpEdgeToDot(const NodePtr node, std::ostringstream &ou | |||||
if (flag != DOT_NOT_SHOW_EDGE_LABEL && in_data_anchor) { | if (flag != DOT_NOT_SHOW_EDGE_LABEL && in_data_anchor) { | ||||
string label; | string label; | ||||
auto src_ops = src_anchor->GetOwnerNode()->GetOpDesc(); | auto src_ops = src_anchor->GetOwnerNode()->GetOpDesc(); | ||||
GE_CHECK_NOTNULL_EXEC(src_ops, return); | |||||
GE_CHECK_NOTNULL_EXEC(src_ops, return ); | |||||
auto src_shape = src_ops->GetOutputDesc(src_anchor->GetIdx()).GetShape(); | auto src_shape = src_ops->GetOutputDesc(src_anchor->GetIdx()).GetShape(); | ||||
auto dim = src_shape.GetDims(); | auto dim = src_shape.GetDims(); | ||||
std::ostringstream tensor_info; | std::ostringstream tensor_info; | ||||
if (dim.size() > 0) { | if (dim.size() > 0) { | ||||
for (unsigned int i = 0; i < dim.size(); i++) { | |||||
for (size_t i = 0; i < dim.size(); i++) { | |||||
if (i != dim.size() - 1) { | if (i != dim.size() - 1) { | ||||
tensor_info << dim[i] << "x"; | tensor_info << dim[i] << "x"; | ||||
} else { | } else { | ||||
@@ -186,7 +186,7 @@ void GraphDebugPrinter::DumpEdgeToDot(const NodePtr node, std::ostringstream &ou | |||||
tensor_info << "?"; | tensor_info << "?"; | ||||
} | } | ||||
auto src_tensor_desc = src_ops->GetOutputDescPtr(src_anchor->GetIdx()); | auto src_tensor_desc = src_ops->GetOutputDescPtr(src_anchor->GetIdx()); | ||||
GE_CHECK_NOTNULL_EXEC(src_tensor_desc, return); | |||||
GE_CHECK_NOTNULL_EXEC(src_tensor_desc, return ); | |||||
auto format = src_tensor_desc->GetFormat(); | auto format = src_tensor_desc->GetFormat(); | ||||
auto datatype = src_tensor_desc->GetDataType(); | auto datatype = src_tensor_desc->GetDataType(); | ||||
tensor_info << " : " << formats[format] << " : " << types[datatype]; | tensor_info << " : " << formats[format] << " : " << types[datatype]; | ||||
@@ -67,6 +67,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||||
anchor_points.clear(); | anchor_points.clear(); | ||||
// Get all anchor point nodes and switch nodes | // Get all anchor point nodes and switch nodes | ||||
for (const auto &node_ptr : graph->GetAllNodes()) { | for (const auto &node_ptr : graph->GetAllNodes()) { | ||||
std::vector<bool> is_node_set_format; | |||||
if (node_ptr == nullptr) { | if (node_ptr == nullptr) { | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
@@ -166,7 +167,7 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||||
if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | ||||
auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | ||||
if (dim_num == 0) { | if (dim_num == 0) { | ||||
GELOGI("node name:%s idx:%d out is scalar. stop back infer!", peer_out_data_node->GetName().c_str(), idx); | |||||
GELOGD("node name:%s idx:%d out is scalar. stop back infer!", peer_out_data_node->GetName().c_str(), idx); | |||||
continue; | continue; | ||||
} | } | ||||
/// Check whether node to change dims () | /// Check whether node to change dims () | ||||
@@ -175,7 +176,7 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||||
auto iter1 = kChangeDimNodes.find(peer_out_data_node_type); | auto iter1 = kChangeDimNodes.find(peer_out_data_node_type); | ||||
// 4 means dims num | // 4 means dims num | ||||
if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) { | if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) { | ||||
GELOGI("Node[%s] is change dim node and shape is smaller than 4. do not modify format", | |||||
GELOGD("Node[%s] is change dim node and shape is smaller than 4. do not modify format", | |||||
(peer_out_data_node->GetName()).c_str()); | (peer_out_data_node->GetName()).c_str()); | ||||
continue; | continue; | ||||
} | } | ||||
@@ -235,7 +236,7 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||||
auto iter1 = kChangeDimNodes.find(peer_in_data_node_type); | auto iter1 = kChangeDimNodes.find(peer_in_data_node_type); | ||||
// 4 means dims num | // 4 means dims num | ||||
if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) { | if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) { | ||||
GELOGI("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str()); | |||||
GELOGD("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str()); | |||||
continue; | continue; | ||||
} | } | ||||
ge_tensor_desc.SetOriginFormat(to_be_set_format); | ge_tensor_desc.SetOriginFormat(to_be_set_format); | ||||
@@ -292,7 +293,7 @@ graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_ | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
GELOGD("Enter DataNodeFormatProcess"); | GELOGD("Enter DataNodeFormatProcess"); | ||||
std::vector<ge::NodePtr> uninferred_data_nodes; | |||||
std::vector<ge::NodePtr> uninfered_data_nodes; | |||||
// Check and renew data nodes format | // Check and renew data nodes format | ||||
for (const auto &data_node : data_nodes) { | for (const auto &data_node : data_nodes) { | ||||
GE_CHECK_NOTNULL(data_node); | GE_CHECK_NOTNULL(data_node); | ||||
@@ -301,10 +302,10 @@ graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_ | |||||
GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(0)); | GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(0)); | ||||
auto curr_format = op_desc->GetOutputDescPtr(0)->GetOriginFormat(); | auto curr_format = op_desc->GetOutputDescPtr(0)->GetOriginFormat(); | ||||
if (curr_format != FORMAT_ND) { | if (curr_format != FORMAT_ND) { | ||||
// Data format has been inferred , continue | |||||
// Data format has been infered , continue | |||||
continue; | continue; | ||||
} | } | ||||
// Set format for un-inferred data node | |||||
// Set format for un-infered data node | |||||
auto input_descs = op_desc->GetAllInputsDescPtr(); | auto input_descs = op_desc->GetAllInputsDescPtr(); | ||||
auto output_descs = op_desc->GetAllOutputsDescPtr(); | auto output_descs = op_desc->GetAllOutputsDescPtr(); | ||||
@@ -320,10 +321,10 @@ graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_ | |||||
output_desc->SetFormat(data_format); | output_desc->SetFormat(data_format); | ||||
} | } | ||||
} | } | ||||
uninferred_data_nodes.push_back(data_node); | |||||
uninfered_data_nodes.push_back(data_node); | |||||
} | } | ||||
// Reinfer format from uninfered data nodes | // Reinfer format from uninfered data nodes | ||||
for (const auto &node : uninferred_data_nodes) { | |||||
for (const auto &node : uninfered_data_nodes) { | |||||
if (node == nullptr) { | if (node == nullptr) { | ||||
continue; | continue; | ||||
} | } | ||||
@@ -341,7 +342,7 @@ graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_ | |||||
graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) { | graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) { | ||||
GELOGI("Enter InferOrigineFormat process!"); | GELOGI("Enter InferOrigineFormat process!"); | ||||
// True: inferred false:no-inferred | |||||
// True: infered false:no-infered | |||||
std::unordered_map<ge::NodePtr, bool> node_status; | std::unordered_map<ge::NodePtr, bool> node_status; | ||||
std::vector<ge::NodePtr> anchor_points; | std::vector<ge::NodePtr> anchor_points; | ||||
std::vector<ge::NodePtr> data_nodes; | std::vector<ge::NodePtr> data_nodes; | ||||
@@ -373,7 +374,7 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) | |||||
} | } | ||||
} | } | ||||
/// According to discuss with sys-enginer, data node default format is ND.Its format | /// According to discuss with sys-enginer, data node default format is ND.Its format | ||||
/// should be set by inferred.But if some data-node can not be got by infer, set context's | |||||
/// should be set by infered.But if some data-node can not be got by infer, set context's | |||||
/// format for these data nodes. | /// format for these data nodes. | ||||
/// Notice: ignore 5D formats | /// Notice: ignore 5D formats | ||||
auto data_format = graph->GetDataFormat(); | auto data_format = graph->GetDataFormat(); | ||||
@@ -21,7 +21,6 @@ | |||||
#include <string> | #include <string> | ||||
#include <unordered_map> | #include <unordered_map> | ||||
#include <vector> | #include <vector> | ||||
#include "./compute_graph.h" | #include "./compute_graph.h" | ||||
#include "./external/graph/types.h" | #include "./external/graph/types.h" | ||||
#include "./ge_error_codes.h" | #include "./ge_error_codes.h" | ||||
@@ -1,4 +1,4 @@ | |||||
/** | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | * Copyright 2019-2020 Huawei Technologies Co., Ltd | ||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
@@ -15,7 +15,6 @@ | |||||
*/ | */ | ||||
#include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
#include "external/graph/graph.h" | #include "external/graph/graph.h" | ||||
#include "utils/attr_utils.h" | #include "utils/attr_utils.h" | ||||
@@ -35,7 +34,7 @@ namespace ge { | |||||
GeAttrValue::NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } | GeAttrValue::NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } | ||||
GeAttrValue::NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) | GeAttrValue::NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) | ||||
: named_attrs_(owner, proto_msg) {} | |||||
: named_attrs_(owner, proto_msg) {} // lint !e1744 | |||||
void GeAttrValue::NamedAttrs::SetName(const std::string &name) { | void GeAttrValue::NamedAttrs::SetName(const std::string &name) { | ||||
auto proto_msg = named_attrs_.GetProtoMsg(); | auto proto_msg = named_attrs_.GetProtoMsg(); | ||||
@@ -155,29 +154,29 @@ class GeAttrValueImp { | |||||
}; | }; | ||||
map<proto::AttrDef::ValueCase, GeAttrValue::ValueType> GeAttrValueImp::attr_val_one_type_map_ = { | map<proto::AttrDef::ValueCase, GeAttrValue::ValueType> GeAttrValueImp::attr_val_one_type_map_ = { | ||||
{proto::AttrDef::kI, GeAttrValue::VT_INT}, | |||||
{proto::AttrDef::kF, GeAttrValue::VT_FLOAT}, | |||||
{proto::AttrDef::kB, GeAttrValue::VT_BOOL}, | |||||
{proto::AttrDef::kS, GeAttrValue::VT_STRING}, | |||||
{proto::AttrDef::kT, GeAttrValue::VT_TENSOR}, | |||||
{proto::AttrDef::kTd, GeAttrValue::VT_TENSOR_DESC}, | |||||
{proto::AttrDef::kG, GeAttrValue::VT_GRAPH}, | |||||
{proto::AttrDef::kBt, GeAttrValue::VT_BYTES}, | |||||
{proto::AttrDef::kFunc, GeAttrValue::VT_NAMED_ATTRS}, | |||||
{proto::AttrDef::kListListInt, GeAttrValue::VT_LIST_LIST_INT}, | |||||
{proto::AttrDef::kDt, GeAttrValue::VT_DATA_TYPE}, | |||||
{proto::AttrDef::kI, GeAttrValue::VT_INT}, | |||||
{proto::AttrDef::kF, GeAttrValue::VT_FLOAT}, | |||||
{proto::AttrDef::kB, GeAttrValue::VT_BOOL}, | |||||
{proto::AttrDef::kS, GeAttrValue::VT_STRING}, | |||||
{proto::AttrDef::kT, GeAttrValue::VT_TENSOR}, | |||||
{proto::AttrDef::kTd, GeAttrValue::VT_TENSOR_DESC}, | |||||
{proto::AttrDef::kG, GeAttrValue::VT_GRAPH}, | |||||
{proto::AttrDef::kBt, GeAttrValue::VT_BYTES}, | |||||
{proto::AttrDef::kFunc, GeAttrValue::VT_NAMED_ATTRS}, | |||||
{proto::AttrDef::kListListInt, GeAttrValue::VT_LIST_LIST_INT}, | |||||
{proto::AttrDef::kDt, GeAttrValue::VT_DATA_TYPE}, | |||||
}; | }; | ||||
map<proto::AttrDef_ListValue_ListValueType, GeAttrValue::ValueType> GeAttrValueImp::attr_val_list_type_map_ = { | map<proto::AttrDef_ListValue_ListValueType, GeAttrValue::ValueType> GeAttrValueImp::attr_val_list_type_map_ = { | ||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_INT, GeAttrValue::VT_LIST_INT}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT, GeAttrValue::VT_LIST_FLOAT}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_BOOL, GeAttrValue::VT_LIST_BOOL}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_STRING, GeAttrValue::VT_LIST_STRING}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR, GeAttrValue::VT_LIST_TENSOR}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, GeAttrValue::VT_LIST_TENSOR_DESC}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH, GeAttrValue::VT_LIST_GRAPH}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, GeAttrValue::VT_LIST_BYTES}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, GeAttrValue::VT_LIST_NAMED_ATTRS}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE, GeAttrValue::VT_LIST_DATA_TYPE}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_INT, GeAttrValue::VT_LIST_INT}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT, GeAttrValue::VT_LIST_FLOAT}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_BOOL, GeAttrValue::VT_LIST_BOOL}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_STRING, GeAttrValue::VT_LIST_STRING}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR, GeAttrValue::VT_LIST_TENSOR}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, GeAttrValue::VT_LIST_TENSOR_DESC}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH, GeAttrValue::VT_LIST_GRAPH}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, GeAttrValue::VT_LIST_BYTES}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, GeAttrValue::VT_LIST_NAMED_ATTRS}, | |||||
{proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE, GeAttrValue::VT_LIST_DATA_TYPE}, | |||||
}; | }; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue::GeAttrValue() { value_.InitDefault(); } | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue::GeAttrValue() { value_.InitDefault(); } | ||||
@@ -240,7 +239,7 @@ ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR) | |||||
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::STR>) | ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::STR>) | ||||
ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT) | ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT) | ||||
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::INT>) | ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::INT>) | ||||
ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) | |||||
ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524 | |||||
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::FLOAT>) | ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::FLOAT>) | ||||
ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL) | ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL) | ||||
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BOOL>) | ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BOOL>) | ||||
@@ -254,9 +253,11 @@ ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES) | |||||
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BYTES>) | ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BYTES>) | ||||
ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS) | ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS) | ||||
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::NAMED_ATTRS>) | ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::NAMED_ATTRS>) | ||||
/*lint -e665*/ | |||||
ATTR_VALUE_SET_GET_IMP(vector<vector<int64_t>>) | ATTR_VALUE_SET_GET_IMP(vector<vector<int64_t>>) | ||||
ATTR_VALUE_SET_GET_IMP(vector<DataType>) | |||||
ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) | |||||
/*lint +e665*/ | |||||
ATTR_VALUE_SET_GET_IMP(vector<DataType>) // lint !e665 | |||||
ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665 | |||||
#undef ATTR_VALUE_SET_GET_IMP | #undef ATTR_VALUE_SET_GET_IMP | ||||
@@ -275,8 +276,8 @@ class AttrUtilsHelper { | |||||
} | } | ||||
inline static bool GetValueCheckListType( | inline static bool GetValueCheckListType( | ||||
const proto::AttrDef &attr_def, proto::AttrDef_ListValue_ListValueType proto_list_case, | |||||
const std::function<bool(const proto::AttrDef &proto_attr_val)> item_check_fun) { | |||||
const proto::AttrDef &attr_def, proto::AttrDef_ListValue_ListValueType proto_list_case, | |||||
const std::function<bool(const proto::AttrDef &proto_attr_val)> item_check_fun) { | |||||
if (attr_def.value_case() != proto::AttrDef::kList) { | if (attr_def.value_case() != proto::AttrDef::kList) { | ||||
GELOGW("Check ListType Failed, value_case %u", attr_def.value_case()); | GELOGW("Check ListType Failed, value_case %u", attr_def.value_case()); | ||||
return false; | return false; | ||||
@@ -636,9 +637,8 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::DataType | |||||
#define ATTR_VALUE_IMP_GET_LIST(ValType, proto_list_case, protoItem) \ | #define ATTR_VALUE_IMP_GET_LIST(ValType, proto_list_case, protoItem) \ | ||||
bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, vector<ValType> &value) { \ | bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, vector<ValType> &value) { \ | ||||
value.clear(); \ | value.clear(); \ | ||||
if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, \ | |||||
proto::AttrDef_ListValue_ListValueType_##proto_list_case, \ | |||||
ListValueItemCheck(protoItem))) { \ | |||||
if (!AttrUtilsHelper::GetValueCheckListType( \ | |||||
proto_attr_val, proto::AttrDef_ListValue_ListValueType_##proto_list_case, ListValueItemCheck(protoItem))) { \ | |||||
return false; \ | return false; \ | ||||
} \ | } \ | ||||
auto &list = proto_attr_val.list(); \ | auto &list = proto_attr_val.list(); \ | ||||
@@ -673,7 +673,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||||
bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, | bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, | ||||
vector<GeTensorDesc> &value) { | vector<GeTensorDesc> &value) { | ||||
if (!AttrUtilsHelper::GetValueCheckListType( | if (!AttrUtilsHelper::GetValueCheckListType( | ||||
proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, ListValueItemCheck(td))) { | |||||
proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, ListValueItemCheck(td))) { | |||||
return false; | return false; | ||||
} | } | ||||
auto &list = proto_attr_val.list(); | auto &list = proto_attr_val.list(); | ||||
@@ -693,8 +693,8 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||||
if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kT)) { | if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kT)) { | ||||
return false; | return false; | ||||
} | } | ||||
value = std::shared_ptr<GeTensor>( | |||||
new (std::nothrow) GeTensor(proto_owner, const_cast<proto::AttrDef &>(proto_attr_val).mutable_t())); | |||||
value = std::shared_ptr<GeTensor>(new (std::nothrow) | |||||
GeTensor(proto_owner, const_cast<proto::AttrDef &>(proto_attr_val).mutable_t())); | |||||
GE_CHK_BOOL_RET_STATUS(value != nullptr, false, "value is nullptr"); | GE_CHK_BOOL_RET_STATUS(value != nullptr, false, "value is nullptr"); | ||||
return true; | return true; | ||||
} | } | ||||
@@ -757,7 +757,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||||
vector<GeAttrValue::NamedAttrs> &value) { | vector<GeAttrValue::NamedAttrs> &value) { | ||||
value.clear(); | value.clear(); | ||||
if (!AttrUtilsHelper::GetValueCheckListType( | if (!AttrUtilsHelper::GetValueCheckListType( | ||||
proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) { | |||||
proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) { | |||||
return false; | return false; | ||||
} | } | ||||
auto &list = proto_attr_val.list(); | auto &list = proto_attr_val.list(); | ||||
@@ -931,7 +931,7 @@ bool AttrUtils::HasAttr(ConstAttrHolderAdapter &&obj, const string &name) { | |||||
#define ATTR_UTILS_SET_IMP(FuncName, Type) \ | #define ATTR_UTILS_SET_IMP(FuncName, Type) \ | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Set##FuncName( \ | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Set##FuncName( \ | ||||
AttrHolderAdapter &&obj, const string &name, const Type &value) { \ | |||||
AttrHolderAdapter &&obj, const string &name, const Type &value) { \ | |||||
proto::AttrDef *proto_attr_val = nullptr; \ | proto::AttrDef *proto_attr_val = nullptr; \ | ||||
if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \ | if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \ | ||||
return false; \ | return false; \ | ||||
@@ -15,12 +15,10 @@ | |||||
*/ | */ | ||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
#include <cstdlib> | #include <cstdlib> | ||||
#include <cstring> | #include <cstring> | ||||
#include <iostream> | #include <iostream> | ||||
#include <map> | #include <map> | ||||
#include "debug/ge_attr_define.h" | #include "debug/ge_attr_define.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
@@ -36,37 +34,37 @@ namespace ge { | |||||
static const char *const kKeyDataTypeSelfDefined = "__tensor_desc_data_type__"; | static const char *const kKeyDataTypeSelfDefined = "__tensor_desc_data_type__"; | ||||
static const std::map<DataType, ::ge::proto::DataType> kDataTypeMap = { | static const std::map<DataType, ::ge::proto::DataType> kDataTypeMap = { | ||||
{DT_UNDEFINED, proto::DT_UNDEFINED}, | |||||
{DT_FLOAT, proto::DT_FLOAT}, | |||||
{DT_FLOAT16, proto::DT_FLOAT16}, | |||||
{DT_INT8, proto::DT_INT8}, | |||||
{DT_UINT8, proto::DT_UINT8}, | |||||
{DT_INT16, proto::DT_INT16}, | |||||
{DT_UINT16, proto::DT_UINT16}, | |||||
{DT_INT32, proto::DT_INT32}, | |||||
{DT_INT64, proto::DT_INT64}, | |||||
{DT_UINT32, proto::DT_UINT32}, | |||||
{DT_UINT64, proto::DT_UINT64}, | |||||
{DT_BOOL, proto::DT_BOOL}, | |||||
{DT_DOUBLE, proto::DT_DOUBLE}, | |||||
{DT_DUAL, proto::DT_DUAL}, | |||||
{DT_DUAL_SUB_INT8, proto::DT_DUAL_SUB_INT8}, | |||||
{DT_DUAL_SUB_UINT8, proto::DT_DUAL_SUB_UINT8}, | |||||
{DT_COMPLEX64, proto::DT_COMPLEX64}, | |||||
{DT_COMPLEX128, proto::DT_COMPLEX128}, | |||||
{DT_QINT8, proto::DT_QINT8}, | |||||
{DT_QINT16, proto::DT_QINT16}, | |||||
{DT_QINT32, proto::DT_QINT32}, | |||||
{DT_QUINT8, proto::DT_QUINT8}, | |||||
{DT_QUINT16, proto::DT_QUINT16}, | |||||
{DT_RESOURCE, proto::DT_RESOURCE}, | |||||
{DT_STRING_REF, proto::DT_STRING_REF}, | |||||
{DT_STRING, proto::DT_STRING}, | |||||
{DT_UNDEFINED, proto::DT_UNDEFINED}, | |||||
{DT_FLOAT, proto::DT_FLOAT}, | |||||
{DT_FLOAT16, proto::DT_FLOAT16}, | |||||
{DT_INT8, proto::DT_INT8}, | |||||
{DT_UINT8, proto::DT_UINT8}, | |||||
{DT_INT16, proto::DT_INT16}, | |||||
{DT_UINT16, proto::DT_UINT16}, | |||||
{DT_INT32, proto::DT_INT32}, | |||||
{DT_INT64, proto::DT_INT64}, | |||||
{DT_UINT32, proto::DT_UINT32}, | |||||
{DT_UINT64, proto::DT_UINT64}, | |||||
{DT_BOOL, proto::DT_BOOL}, | |||||
{DT_DOUBLE, proto::DT_DOUBLE}, | |||||
{DT_DUAL, proto::DT_DUAL}, | |||||
{DT_DUAL_SUB_INT8, proto::DT_DUAL_SUB_INT8}, | |||||
{DT_DUAL_SUB_UINT8, proto::DT_DUAL_SUB_UINT8}, | |||||
{DT_COMPLEX64, proto::DT_COMPLEX64}, | |||||
{DT_COMPLEX128, proto::DT_COMPLEX128}, | |||||
{DT_QINT8, proto::DT_QINT8}, | |||||
{DT_QINT16, proto::DT_QINT16}, | |||||
{DT_QINT32, proto::DT_QINT32}, | |||||
{DT_QUINT8, proto::DT_QUINT8}, | |||||
{DT_QUINT16, proto::DT_QUINT16}, | |||||
{DT_RESOURCE, proto::DT_RESOURCE}, | |||||
{DT_STRING_REF, proto::DT_STRING_REF}, | |||||
{DT_STRING, proto::DT_STRING}, | |||||
}; | }; | ||||
static const std::map<DataType, int> kDataTypeSelfDefinedMap = { | static const std::map<DataType, int> kDataTypeSelfDefinedMap = { | ||||
{DT_DUAL, 13}, {DT_DUAL_SUB_INT8, 14}, {DT_DUAL_SUB_UINT8, 15}, {DT_COMPLEX64, 16}, {DT_COMPLEX128, 17}, | |||||
{DT_QINT8, 18}, {DT_QINT16, 19}, {DT_QINT32, 20}, {DT_QUINT8, 21}, {DT_QUINT16, 22}, | |||||
{DT_DUAL, 13}, {DT_DUAL_SUB_INT8, 14}, {DT_DUAL_SUB_UINT8, 15}, {DT_COMPLEX64, 16}, {DT_COMPLEX128, 17}, | |||||
{DT_QINT8, 18}, {DT_QINT16, 19}, {DT_QINT32, 20}, {DT_QUINT8, 21}, {DT_QUINT16, 22}, | |||||
}; | }; | ||||
GeShape::GeShape() { shape_def_.InitDefault(); } | GeShape::GeShape() { shape_def_.InitDefault(); } | ||||
@@ -287,35 +285,32 @@ bool GeTensorDesc::GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_des | |||||
const auto &r_tensor_descriptor = r_ge_tensor_desc.tensor_descriptor_.GetProtoMsg(); | const auto &r_tensor_descriptor = r_ge_tensor_desc.tensor_descriptor_.GetProtoMsg(); | ||||
if ((tensor_descriptor != nullptr) && (r_tensor_descriptor != nullptr)) { | if ((tensor_descriptor != nullptr) && (r_tensor_descriptor != nullptr)) { | ||||
// Message TensorDescriptor in ge_ir.proto | // Message TensorDescriptor in ge_ir.proto | ||||
return (IsEqual(tensor_descriptor->name(), r_tensor_descriptor->name(), "TensorDescriptor.name()") && | |||||
IsEqual(tensor_descriptor->dtype(), r_tensor_descriptor->dtype(), "TensorDescriptor.dtype()") && | |||||
// Message ShapeDef in ge_ir.proto | |||||
IsEqual(ToString(tensor_descriptor->shape().dim()), ToString(r_tensor_descriptor->shape().dim()), | |||||
"TensorDescriptor.shape().dim()") && | |||||
IsEqual(tensor_descriptor->layout(), r_tensor_descriptor->layout(), "TensorDescriptor.layout()") && | |||||
IsEqual(tensor_descriptor->has_out_attr(), r_tensor_descriptor->has_out_attr(), | |||||
"TensorDescriptor.has_out_attr()") && | |||||
IsEqual(tensor_descriptor->size(), r_tensor_descriptor->size(), "TensorDescriptor.size()") && | |||||
IsEqual(tensor_descriptor->weight_size(), r_tensor_descriptor->weight_size(), | |||||
"TensorDescriptor.weight_size()") && | |||||
IsEqual(tensor_descriptor->reuse_input(), r_tensor_descriptor->reuse_input(), | |||||
"TensorDescriptor.reuse_input()") && | |||||
IsEqual(tensor_descriptor->output_tensor(), r_tensor_descriptor->output_tensor(), | |||||
"TensorDescriptor.output_tensor()") && | |||||
IsEqual(tensor_descriptor->device_type(), r_tensor_descriptor->device_type(), | |||||
"TensorDescriptor.device_type()") && | |||||
IsEqual(tensor_descriptor->input_tensor(), r_tensor_descriptor->input_tensor(), | |||||
"TensorDescriptor.input_tensor()") && | |||||
IsEqual(tensor_descriptor->real_dim_cnt(), r_tensor_descriptor->real_dim_cnt(), | |||||
"TensorDescriptor.real_dim_cnt()") && | |||||
IsEqual(tensor_descriptor->reuse_input_index(), r_tensor_descriptor->reuse_input_index(), | |||||
"TensorDescriptor.reuse_input_index()") && | |||||
IsEqual(tensor_descriptor->data_offset(), r_tensor_descriptor->data_offset(), | |||||
"TensorDescriptor.data_offset()") && | |||||
IsEqual(tensor_descriptor->cmps_size(), r_tensor_descriptor->cmps_size(), "TensorDescriptor.cmps_size()") && | |||||
IsEqual(tensor_descriptor->cmps_tab(), r_tensor_descriptor->cmps_tab(), "TensorDescriptor.cmps_tab()") && | |||||
IsEqual(tensor_descriptor->cmps_tab_offset(), r_tensor_descriptor->cmps_tab_offset(), | |||||
"TensorDescriptor.cmps_tab_offset()")); | |||||
return ( | |||||
IsEqual(tensor_descriptor->name(), r_tensor_descriptor->name(), "TensorDescriptor.name()") && | |||||
IsEqual(tensor_descriptor->dtype(), r_tensor_descriptor->dtype(), "TensorDescriptor.dtype()") && | |||||
// Message ShapeDef in ge_ir.proto | |||||
IsEqual(ToString(tensor_descriptor->shape().dim()), ToString(r_tensor_descriptor->shape().dim()), | |||||
"TensorDescriptor.shape().dim()") && | |||||
IsEqual(tensor_descriptor->layout(), r_tensor_descriptor->layout(), "TensorDescriptor.layout()") && | |||||
IsEqual(tensor_descriptor->has_out_attr(), r_tensor_descriptor->has_out_attr(), | |||||
"TensorDescriptor.has_out_attr()") && | |||||
IsEqual(tensor_descriptor->size(), r_tensor_descriptor->size(), "TensorDescriptor.size()") && | |||||
IsEqual(tensor_descriptor->weight_size(), r_tensor_descriptor->weight_size(), "TensorDescriptor.weight_size()") && | |||||
IsEqual(tensor_descriptor->reuse_input(), r_tensor_descriptor->reuse_input(), "TensorDescriptor.reuse_input()") && | |||||
IsEqual(tensor_descriptor->output_tensor(), r_tensor_descriptor->output_tensor(), | |||||
"TensorDescriptor.output_tensor()") && | |||||
IsEqual(tensor_descriptor->device_type(), r_tensor_descriptor->device_type(), "TensorDescriptor.device_type()") && | |||||
IsEqual(tensor_descriptor->input_tensor(), r_tensor_descriptor->input_tensor(), | |||||
"TensorDescriptor.input_tensor()") && | |||||
IsEqual(tensor_descriptor->real_dim_cnt(), r_tensor_descriptor->real_dim_cnt(), | |||||
"TensorDescriptor.real_dim_cnt()") && | |||||
IsEqual(tensor_descriptor->reuse_input_index(), r_tensor_descriptor->reuse_input_index(), | |||||
"TensorDescriptor.reuse_input_index()") && | |||||
IsEqual(tensor_descriptor->data_offset(), r_tensor_descriptor->data_offset(), "TensorDescriptor.data_offset()") && | |||||
IsEqual(tensor_descriptor->cmps_size(), r_tensor_descriptor->cmps_size(), "TensorDescriptor.cmps_size()") && | |||||
IsEqual(tensor_descriptor->cmps_tab(), r_tensor_descriptor->cmps_tab(), "TensorDescriptor.cmps_tab()") && | |||||
IsEqual(tensor_descriptor->cmps_tab_offset(), r_tensor_descriptor->cmps_tab_offset(), | |||||
"TensorDescriptor.cmps_tab_offset()")); | |||||
} else { | } else { | ||||
return ((tensor_descriptor == nullptr) && (r_tensor_descriptor == nullptr)); | return ((tensor_descriptor == nullptr) && (r_tensor_descriptor == nullptr)); | ||||
} | } | ||||
@@ -575,9 +570,7 @@ GeTensorDesc &GeTensor::DescReference() const { | |||||
return __desc_; | return __desc_; | ||||
} | } | ||||
void GeTensor::SetTensorDesc(const GeTensorDesc &tensor_desc) { | |||||
DescReference() = tensor_desc; | |||||
} | |||||
void GeTensor::SetTensorDesc(const GeTensorDesc &tensor_desc) { DescReference() = tensor_desc; } | |||||
const Buffer GeTensor::GetData() const { | const Buffer GeTensor::GetData() const { | ||||
auto proto_msg = tensor_def_.GetProtoMsg(); | auto proto_msg = tensor_def_.GetProtoMsg(); | ||||
@@ -741,10 +734,12 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetOutputTensor | |||||
} | } | ||||
static map<uint32_t, string> device_to_str_map{ | static map<uint32_t, string> device_to_str_map{ | ||||
{0, "NPU"}, {1, "CPU"}, | |||||
{0, "NPU"}, | |||||
{1, "CPU"}, | |||||
}; | }; | ||||
static map<string, uint32_t> str_to_device_map{ | static map<string, uint32_t> str_to_device_map{ | ||||
{"NPU", 0}, {"CPU", 1}, | |||||
{"NPU", 0}, | |||||
{"CPU", 1}, | |||||
}; | }; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDeviceType(const GeTensorDesc &tensor_desc, | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDeviceType(const GeTensorDesc &tensor_desc, | ||||
@@ -901,7 +896,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsInfo(GeT | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool TensorUtils::HasAlloffsetQuantizeInfo( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool TensorUtils::HasAlloffsetQuantizeInfo( | ||||
const GeTensorDesc &tensor_desc) { | |||||
const GeTensorDesc &tensor_desc) { | |||||
return tensor_desc.HasAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO); | return tensor_desc.HasAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO); | ||||
} | } | ||||
@@ -15,7 +15,6 @@ | |||||
*/ | */ | ||||
#include "external/graph/graph.h" | #include "external/graph/graph.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "external/graph/operator.h" | #include "external/graph/operator.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
@@ -15,39 +15,98 @@ | |||||
*/ | */ | ||||
#include "external/graph/inference_context.h" | #include "external/graph/inference_context.h" | ||||
#include "debug/ge_util.h" | |||||
namespace ge { | namespace ge { | ||||
ShapeAndType::ShapeAndType(const Shape &shape, DataType data_type) : shape_(shape), data_type_(data_type) {} | |||||
class ShapeAndTypeImpl { | |||||
public: | |||||
ShapeAndTypeImpl() = default; | |||||
~ShapeAndTypeImpl() = default; | |||||
void ShapeAndType::SetShape(const Shape &shape) { shape_ = shape; } | |||||
ShapeAndTypeImpl(const Shape &shape, DataType data_type) : shape_(shape), data_type_(data_type) {} | |||||
void ShapeAndType::SetType(DataType data_type) { data_type_ = data_type; } | |||||
Shape shape_; | |||||
DataType data_type_ = DT_UNDEFINED; | |||||
}; | |||||
const Shape &ShapeAndType::GetShape() const { return shape_; } | |||||
class InferenceContextImpl { | |||||
public: | |||||
InferenceContextImpl() = default; | |||||
~InferenceContextImpl() = default; | |||||
DataType ShapeAndType::GetDataType() const { return data_type_; } | |||||
// For deliver to op in pair, help to support dynamic shape | |||||
std::vector<std::string> marks_; | |||||
std::vector<std::vector<ShapeAndType>> input_handle_shapes_and_types_; | |||||
std::vector<std::vector<ShapeAndType>> output_handle_shapes_and_types_; | |||||
}; | |||||
ShapeAndType::ShapeAndType() { shape_and_type_impl_ = ComGraphMakeShared<ShapeAndTypeImpl>(); } | |||||
ShapeAndType::ShapeAndType(const Shape &shape, DataType data_type) { | |||||
shape_and_type_impl_ = ComGraphMakeShared<ShapeAndTypeImpl>(shape, data_type); | |||||
} | |||||
void ShapeAndType::SetShape(const Shape &shape) { | |||||
if (shape_and_type_impl_ != nullptr) { | |||||
shape_and_type_impl_->shape_ = shape; | |||||
} | |||||
} | |||||
void ShapeAndType::SetType(DataType data_type) { | |||||
if (shape_and_type_impl_ != nullptr) { | |||||
shape_and_type_impl_->data_type_ = data_type; | |||||
} | |||||
} | |||||
Shape ShapeAndType::GetShape() const { | |||||
if (shape_and_type_impl_ != nullptr) { | |||||
return shape_and_type_impl_->shape_; | |||||
} | |||||
return Shape(); | |||||
} | |||||
DataType ShapeAndType::GetDataType() const { | |||||
if (shape_and_type_impl_ != nullptr) { | |||||
return shape_and_type_impl_->data_type_; | |||||
} | |||||
return DT_UNDEFINED; | |||||
} | |||||
InferenceContext::InferenceContext(std::unique_ptr<InferenceContextImpl> &impl) { | |||||
inference_context_impl_ = std::move(impl); | |||||
} | |||||
std::unique_ptr<InferenceContext> InferenceContext::Create() { | |||||
std::unique_ptr<InferenceContextImpl> impl = | |||||
std::unique_ptr<InferenceContextImpl>(new (std::nothrow) InferenceContextImpl()); | |||||
if (impl == nullptr) { | |||||
return nullptr; | |||||
} | |||||
return std::unique_ptr<InferenceContext>(new (std::nothrow) InferenceContext(impl)); | |||||
} | |||||
void InferenceContext::SetInputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types) { | void InferenceContext::SetInputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types) { | ||||
input_handle_shapes_and_types_.swap(shapes_and_types); | |||||
inference_context_impl_->input_handle_shapes_and_types_.swap(shapes_and_types); | |||||
} | } | ||||
const std::vector<std::vector<ShapeAndType>> &InferenceContext::GetInputHandleShapesAndTypes() const { | const std::vector<std::vector<ShapeAndType>> &InferenceContext::GetInputHandleShapesAndTypes() const { | ||||
return input_handle_shapes_and_types_; | |||||
return inference_context_impl_->input_handle_shapes_and_types_; | |||||
} | } | ||||
const std::vector<std::vector<ShapeAndType>> &InferenceContext::GetOutputHandleShapesAndTypes() const { | const std::vector<std::vector<ShapeAndType>> &InferenceContext::GetOutputHandleShapesAndTypes() const { | ||||
return output_handle_shapes_and_types_; | |||||
return inference_context_impl_->output_handle_shapes_and_types_; | |||||
} | } | ||||
void InferenceContext::SetOutputHandleShapesAndTypes(const std::vector<std::vector<ShapeAndType>> &shapes_and_types) { | void InferenceContext::SetOutputHandleShapesAndTypes(const std::vector<std::vector<ShapeAndType>> &shapes_and_types) { | ||||
output_handle_shapes_and_types_ = shapes_and_types; | |||||
inference_context_impl_->output_handle_shapes_and_types_ = shapes_and_types; | |||||
} | } | ||||
void InferenceContext::SetOutputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types) { | void InferenceContext::SetOutputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types) { | ||||
output_handle_shapes_and_types_.swap(shapes_and_types); | |||||
inference_context_impl_->output_handle_shapes_and_types_.swap(shapes_and_types); | |||||
} | } | ||||
void InferenceContext::SetMarks(const std::vector<std::string> &marks) { marks_ = marks; } | |||||
void InferenceContext::SetMarks(const std::vector<std::string> &marks) { inference_context_impl_->marks_ = marks; } | |||||
const std::vector<std::string> &InferenceContext::GetMarks() const { return marks_; } | |||||
const std::vector<std::string> &InferenceContext::GetMarks() const { return inference_context_impl_->marks_; } | |||||
} // namespace ge | } // namespace ge |
@@ -15,7 +15,6 @@ | |||||
*/ | */ | ||||
#include "graph/model.h" | #include "graph/model.h" | ||||
#include <fcntl.h> | #include <fcntl.h> | ||||
#include <google/protobuf/io/coded_stream.h> | #include <google/protobuf/io/coded_stream.h> | ||||
#include <google/protobuf/io/zero_copy_stream.h> | #include <google/protobuf/io/zero_copy_stream.h> | ||||
@@ -28,7 +27,6 @@ | |||||
#include <cstring> | #include <cstring> | ||||
#include <fstream> | #include <fstream> | ||||
#include <iomanip> | #include <iomanip> | ||||
#include "debug/ge_attr_define.h" | #include "debug/ge_attr_define.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
@@ -191,7 +191,7 @@ bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *mode | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeTensor( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeTensor( | ||||
GeTensorPtr &tensor, proto::TensorDef &tensor_proto) { | |||||
GeTensorPtr &tensor, proto::TensorDef &tensor_proto) { | |||||
tensor = std::shared_ptr<GeTensor>(new (std::nothrow) GeTensor(protobuf_owner_, &tensor_proto)); | tensor = std::shared_ptr<GeTensor>(new (std::nothrow) GeTensor(protobuf_owner_, &tensor_proto)); | ||||
if (tensor == nullptr) { | if (tensor == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "tensor is nullptr"); | GELOGE(GRAPH_FAILED, "tensor is nullptr"); | ||||
@@ -208,14 +208,14 @@ bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_d | |||||
// Input tensor | // Input tensor | ||||
for (auto &input_desc : *op_def_proto.mutable_input_desc()) { | for (auto &input_desc : *op_def_proto.mutable_input_desc()) { | ||||
std::shared_ptr<GeTensorDesc> temp_value = | std::shared_ptr<GeTensorDesc> temp_value = | ||||
std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &input_desc)); | |||||
std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &input_desc)); | |||||
GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); | GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); | ||||
op_desc->inputs_desc_.push_back(temp_value); | op_desc->inputs_desc_.push_back(temp_value); | ||||
} | } | ||||
// Output tensor | // Output tensor | ||||
for (auto &output_desc : *op_def_proto.mutable_output_desc()) { | for (auto &output_desc : *op_def_proto.mutable_output_desc()) { | ||||
std::shared_ptr<GeTensorDesc> temp_value = | std::shared_ptr<GeTensorDesc> temp_value = | ||||
std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &output_desc)); | |||||
std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &output_desc)); | |||||
GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); | GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); | ||||
op_desc->outputs_desc_.push_back(temp_value); | op_desc->outputs_desc_.push_back(temp_value); | ||||
} | } | ||||
@@ -265,13 +265,13 @@ bool ModelSerializeImp::HandleNodeNameRef() { | |||||
item.dst_node_name.c_str(), item.dst_in_index); | item.dst_node_name.c_str(), item.dst_in_index); | ||||
return false; | return false; | ||||
} | } | ||||
GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); | |||||
GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 | |||||
} else { | } else { | ||||
// Control edge | // Control edge | ||||
auto src_anchor = src_node_it->second->GetOutControlAnchor(); | auto src_anchor = src_node_it->second->GetOutControlAnchor(); | ||||
auto dst_anchor = item.dst_node->GetInControlAnchor(); | auto dst_anchor = item.dst_node->GetInControlAnchor(); | ||||
if (src_anchor != nullptr && dst_anchor != nullptr) { | if (src_anchor != nullptr && dst_anchor != nullptr) { | ||||
GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); | |||||
GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -15,9 +15,7 @@ | |||||
*/ | */ | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
#include <utility> | #include <utility> | ||||
#include "debug/ge_op_types.h" | #include "debug/ge_op_types.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "external/graph/operator_factory.h" | #include "external/graph/operator_factory.h" | ||||
@@ -533,7 +531,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetIn | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::IsAllInNodesSeen( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::IsAllInNodesSeen( | ||||
std::unordered_set<Node *> &nodes_seen) const { | |||||
std::unordered_set<Node *> &nodes_seen) const { | |||||
for (const auto &in_anchor : in_data_anchors_) { | for (const auto &in_anchor : in_data_anchors_) { | ||||
GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); | GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); | ||||
auto out_anchor = in_anchor->GetPeerOutAnchor(); | auto out_anchor = in_anchor->GetPeerOutAnchor(); | ||||
@@ -736,10 +734,10 @@ graphStatus Node::Verify() const { | |||||
continue; | continue; | ||||
} | } | ||||
GE_CHK_BOOL_RET_STATUS( | GE_CHK_BOOL_RET_STATUS( | ||||
op_->GetType() == data_type || op_->GetType() == aipp_data_type || op_->GetType() == const_type || | |||||
op_->GetType() == variable_type || op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || | |||||
in_anchor_ptr->GetPeerAnchors().size() > 0, | |||||
GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); | |||||
op_->GetType() == data_type || op_->GetType() == aipp_data_type || op_->GetType() == const_type || | |||||
op_->GetType() == variable_type || op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || | |||||
in_anchor_ptr->GetPeerAnchors().size() > 0, | |||||
GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); | |||||
} | } | ||||
string frameworkop_type = "FrameworkOp"; | string frameworkop_type = "FrameworkOp"; | ||||
@@ -15,7 +15,6 @@ | |||||
*/ | */ | ||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
#include "debug/ge_attr_define.h" | #include "debug/ge_attr_define.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "external/graph/operator.h" | #include "external/graph/operator.h" | ||||
@@ -33,6 +32,7 @@ using std::shared_ptr; | |||||
using std::string; | using std::string; | ||||
using std::vector; | using std::vector; | ||||
/*lint -save -e521 -e681 -e732 -e737*/ | |||||
namespace ge { | namespace ge { | ||||
const std::string ATTR_NAME_ID = "id"; | const std::string ATTR_NAME_ID = "id"; | ||||
@@ -302,29 +302,28 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual( | |||||
if ((op_def != nullptr) && (r_op_def != nullptr)) { | if ((op_def != nullptr) && (r_op_def != nullptr)) { | ||||
// Message OpDef in ge_ir.proto | // Message OpDef in ge_ir.proto | ||||
return ( | return ( | ||||
IsEqual(op_def->name(), r_op_def->name(), "OpDef_.name()") && | |||||
IsEqual(op_def->type(), r_op_def->type(), "OpDef_.type()") && | |||||
IsEqual(ToString(op_def->input()), ToString(r_op_def->input()), "OpDef_.input()") && | |||||
IsEqual(op_def->has_out_attr(), r_op_def->has_out_attr(), "OpDef_.has_out_attr()") && | |||||
IsEqual(op_def->stream_id(), r_op_def->stream_id(), "OpDef_.stream_id()") && | |||||
IsEqual(ToString(op_def->input_name()), ToString(r_op_def->input_name()), "OpDef_.input_name()") && | |||||
IsEqual(ToString(op_def->src_name()), ToString(r_op_def->src_name()), "OpDef_.src_name()") && | |||||
IsEqual(ToString(op_def->dst_name()), ToString(r_op_def->dst_name()), "OpDef_.dst_name()") && | |||||
IsEqual(ToString(op_def->src_index()), ToString(r_op_def->src_index()), "OpDef_.src_index()") && | |||||
IsEqual(ToString(op_def->dst_index()), ToString(r_op_def->dst_index()), "OpDef_.dst_index()") && | |||||
IsEqual(ToString(op_def->input_i()), ToString(r_op_def->input_i()), "OpDef_.input_i()") && | |||||
IsEqual(ToString(op_def->output_i()), ToString(r_op_def->output_i()), "OpDef_.output_i()") && | |||||
IsEqual(ToString(op_def->workspace()), ToString(r_op_def->workspace()), "OpDef_.workspace()") && | |||||
IsEqual(ToString(op_def->workspace_bytes()), ToString(r_op_def->workspace_bytes()), | |||||
"OpDef_.workspace_bytes()") && | |||||
IsEqual(ToString(op_def->is_input_const()), ToString(r_op_def->is_input_const()), "OpDef_.is_input_const()")); | |||||
IsEqual(op_def->name(), r_op_def->name(), "OpDef_.name()") && | |||||
IsEqual(op_def->type(), r_op_def->type(), "OpDef_.type()") && | |||||
IsEqual(ToString(op_def->input()), ToString(r_op_def->input()), "OpDef_.input()") && | |||||
IsEqual(op_def->has_out_attr(), r_op_def->has_out_attr(), "OpDef_.has_out_attr()") && | |||||
IsEqual(op_def->stream_id(), r_op_def->stream_id(), "OpDef_.stream_id()") && | |||||
IsEqual(ToString(op_def->input_name()), ToString(r_op_def->input_name()), "OpDef_.input_name()") && | |||||
IsEqual(ToString(op_def->src_name()), ToString(r_op_def->src_name()), "OpDef_.src_name()") && | |||||
IsEqual(ToString(op_def->dst_name()), ToString(r_op_def->dst_name()), "OpDef_.dst_name()") && | |||||
IsEqual(ToString(op_def->src_index()), ToString(r_op_def->src_index()), "OpDef_.src_index()") && | |||||
IsEqual(ToString(op_def->dst_index()), ToString(r_op_def->dst_index()), "OpDef_.dst_index()") && | |||||
IsEqual(ToString(op_def->input_i()), ToString(r_op_def->input_i()), "OpDef_.input_i()") && | |||||
IsEqual(ToString(op_def->output_i()), ToString(r_op_def->output_i()), "OpDef_.output_i()") && | |||||
IsEqual(ToString(op_def->workspace()), ToString(r_op_def->workspace()), "OpDef_.workspace()") && | |||||
IsEqual(ToString(op_def->workspace_bytes()), ToString(r_op_def->workspace_bytes()), "OpDef_.workspace_bytes()") && | |||||
IsEqual(ToString(op_def->is_input_const()), ToString(r_op_def->is_input_const()), "OpDef_.is_input_const()")); | |||||
} else { | } else { | ||||
return ((op_def == nullptr) && (r_op_def == nullptr)); | return ((op_def == nullptr) && (r_op_def == nullptr)); | ||||
} | } | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescGenTensorDescsAreEqual( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescGenTensorDescsAreEqual( | ||||
const OpDesc &r_op_desc) const { | |||||
const OpDesc &r_op_desc) const { | |||||
// 1.Verify inputs and outputs desc size | // 1.Verify inputs and outputs desc size | ||||
const auto inputs_desc_size = this->inputs_desc_.size(); | const auto inputs_desc_size = this->inputs_desc_.size(); | ||||
const auto r_inputs_desc_size = r_op_desc.inputs_desc_.size(); | const auto r_inputs_desc_size = r_op_desc.inputs_desc_.size(); | ||||
@@ -20,14 +20,16 @@ | |||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
using namespace std; | |||||
namespace ge { | namespace ge { | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | ||||
BroadCastInfer(const std::function<std::vector<int64_t>()>& get_in1_shape, | |||||
const std::function<std::vector<int64_t>()>& get_in2_shape, | |||||
const std::function<void(const std::vector<int64_t>& outShape)>& set_out_shape) { | |||||
BroadCastInfer(const function<vector<int64_t>()>& get_in1_shape, const function<vector<int64_t>()>& get_in2_shape, | |||||
const function<void(const vector<int64_t>& outShape)>& set_out_shape) { | |||||
auto x1_shape = get_in1_shape(); | auto x1_shape = get_in1_shape(); | ||||
auto x2_shape = get_in2_shape(); | auto x2_shape = get_in2_shape(); | ||||
std::vector<int64_t> y_shape; | |||||
vector<int64_t> y_shape; | |||||
if (x1_shape.empty()) { | if (x1_shape.empty()) { | ||||
y_shape = x2_shape; | y_shape = x2_shape; | ||||
@@ -48,7 +50,7 @@ BroadCastInfer(const std::function<std::vector<int64_t>()>& get_in1_shape, | |||||
int x2_shape_size = static_cast<int>(x2_shape.size()); | int x2_shape_size = static_cast<int>(x2_shape.size()); | ||||
for (int i = 0; i < x2_shape_size; i++) { | for (int i = 0; i < x2_shape_size; i++) { | ||||
bool shapeFlag = | bool shapeFlag = | ||||
((x1_shape[i + len_diff] != x2_shape[i]) && (std::min(x1_shape[i + len_diff], x2_shape[i]) != 1)); | |||||
((x1_shape[i + len_diff] != x2_shape[i]) && (std::min(x1_shape[i + len_diff], x2_shape[i]) != 1)); | |||||
if (shapeFlag) { | if (shapeFlag) { | ||||
GE_LOGE("operands could not be broadcast together"); | GE_LOGE("operands could not be broadcast together"); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
@@ -62,7 +64,7 @@ BroadCastInfer(const std::function<std::vector<int64_t>()>& get_in1_shape, | |||||
int x1_shape_size = static_cast<int>(x1_shape.size()); | int x1_shape_size = static_cast<int>(x1_shape.size()); | ||||
for (int i = 0; i < x1_shape_size; i++) { | for (int i = 0; i < x1_shape_size; i++) { | ||||
bool shapeFlag = | bool shapeFlag = | ||||
((x1_shape[i] != x2_shape[i - len_diff]) && (std::min(x1_shape[i], x2_shape[i - len_diff]) != 1)); | |||||
((x1_shape[i] != x2_shape[i - len_diff]) && (std::min(x1_shape[i], x2_shape[i - len_diff]) != 1)); | |||||
if (shapeFlag) { | if (shapeFlag) { | ||||
GE_LOGE("operands could not be broadcast together"); | GE_LOGE("operands could not be broadcast together"); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
@@ -73,4 +75,5 @@ BroadCastInfer(const std::function<std::vector<int64_t>()>& get_in1_shape, | |||||
set_out_shape(y_shape); | set_out_shape(y_shape); | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -22,10 +22,12 @@ | |||||
#include <queue> | #include <queue> | ||||
#include <set> | #include <set> | ||||
//#include "./array_ops.h" | |||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
#include "debug/ge_op_types.h" | #include "debug/ge_op_types.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "external/graph/attr_value.h" | #include "external/graph/attr_value.h" | ||||
#include "external/graph/types.h" | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
@@ -33,6 +35,7 @@ | |||||
#include "graph/node.h" | #include "graph/node.h" | ||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
#include "graph/operator_factory.h" | #include "graph/operator_factory.h" | ||||
#include "graph/usr_types.h" | |||||
#include "utils/graph_utils.h" | #include "utils/graph_utils.h" | ||||
#include "utils/op_desc_utils.h" | #include "utils/op_desc_utils.h" | ||||
#include "utils/tensor_adapter.h" | #include "utils/tensor_adapter.h" | ||||
@@ -74,6 +77,29 @@ class OpIO { | |||||
int index_; | int index_; | ||||
std::shared_ptr<OperatorImpl> owner_; | std::shared_ptr<OperatorImpl> owner_; | ||||
}; | }; | ||||
class TensorTypeImpl { | |||||
public: | |||||
TensorTypeImpl() = default; | |||||
~TensorTypeImpl() = default; | |||||
std::vector<DataType> dt_vec_; | |||||
}; | |||||
TensorType::TensorType(DataType dt) { | |||||
tensor_type_impl_ = ComGraphMakeShared<TensorTypeImpl>(); | |||||
if (tensor_type_impl_ != nullptr) { | |||||
tensor_type_impl_->dt_vec_.push_back(dt); | |||||
} | |||||
} | |||||
TensorType::TensorType(const std::initializer_list<DataType> &types) { | |||||
tensor_type_impl_ = ComGraphMakeShared<TensorTypeImpl>(); | |||||
if (tensor_type_impl_ != nullptr) { | |||||
tensor_type_impl_->dt_vec_ = types; | |||||
} | |||||
} | |||||
class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | ||||
friend class GraphBuilderImpl; | friend class GraphBuilderImpl; | ||||
friend class OpDescUtils; | friend class OpDescUtils; | ||||
@@ -128,8 +154,15 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||||
OpIO op_dst(dst_name, dst_index, shared_from_this()); | OpIO op_dst(dst_name, dst_index, shared_from_this()); | ||||
src_op_impl->UpdateLinkMapImpl(src_name, op_dst); | src_op_impl->UpdateLinkMapImpl(src_name, op_dst); | ||||
auto output_desc = src_op_impl->GetOutputDesc(src_name); | |||||
auto input_desc = op_desc_->GetInputDesc(dst_name); | |||||
if (input_desc.GetFormat() == FORMAT_RESERVED) { | |||||
output_desc.SetFormat(FORMAT_ND); | |||||
} else { | |||||
output_desc.SetFormat(input_desc.GetFormat()); | |||||
} | |||||
// Fix for linking opdesc | // Fix for linking opdesc | ||||
if (op_desc_->UpdateInputDesc(dst_name, src_op_impl->GetOutputDesc(src_name)) != GRAPH_SUCCESS) { | |||||
if (op_desc_->UpdateInputDesc(dst_name, output_desc) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Update inputdesc failed,dst name is %s, src name is %s", dst_name.c_str(), | GELOGE(GRAPH_FAILED, "Update inputdesc failed,dst name is %s, src name is %s", dst_name.c_str(), | ||||
src_name.c_str()); | src_name.c_str()); | ||||
return; | return; | ||||
@@ -146,10 +179,11 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||||
int dst_index = op_desc_->GetInputIndexByName(dst_name); | int dst_index = op_desc_->GetInputIndexByName(dst_name); | ||||
GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), | GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), | ||||
op_desc_->GetName().c_str()); | op_desc_->GetName().c_str()); | ||||
GE_CHK_BOOL_EXEC(out_handler->GetOwner() != nullptr && out_handler->GetOwner()->GetOpDescImpl() != nullptr, return, | |||||
"out_handler invalid. name[%s]", dst_name.c_str()); | |||||
auto out_op_impl = out_handler->GetOwner(); | |||||
GE_CHK_BOOL_EXEC(out_op_impl && out_op_impl->GetOpDescImpl(), return, "out_handler invalid. name[%s]", | |||||
dst_name.c_str()); | |||||
bool is_const = false; | bool is_const = false; | ||||
if (out_handler->GetOwner()->GetOpDescImpl()->GetType() == CONSTANT) { | |||||
if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) { | |||||
is_const = true; | is_const = true; | ||||
} | } | ||||
auto is_input_const = op_desc_->GetIsInputConst(); | auto is_input_const = op_desc_->GetIsInputConst(); | ||||
@@ -160,14 +194,19 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||||
op_desc_->SetIsInputConst(is_input_const); | op_desc_->SetIsInputConst(is_input_const); | ||||
OpIO in_handler(dst_name, dst_index, shared_from_this()); | OpIO in_handler(dst_name, dst_index, shared_from_this()); | ||||
auto out_op_impl = out_handler->GetOwner(); | |||||
GE_CHK_BOOL_EXEC(out_op_impl != nullptr, return, "Get out_handler's impl failed."); | |||||
GE_CHK_BOOL_EXEC(!!out_op_impl, return, "Get out_handler's impl failed."); | |||||
out_op_impl->UpdateLinkMapImpl(src_name, in_handler); | out_op_impl->UpdateLinkMapImpl(src_name, in_handler); | ||||
GE_CHK_BOOL_EXEC( | |||||
op_desc_->UpdateInputDesc(dst_name, out_handler->GetOwner()->GetOutputDesc(src_name)) == GRAPH_SUCCESS, return, | |||||
"Update input desc failed,dst name is %s,src name is %s", dst_name.c_str(), | |||||
src_name.c_str()); // fix for linking opdesc | |||||
auto src_output_desc = out_op_impl->GetOutputDesc(src_name); | |||||
auto dst_input_desc = op_desc_->GetInputDesc(dst_name); | |||||
if (dst_input_desc.GetFormat() == FORMAT_RESERVED) { | |||||
src_output_desc.SetFormat(FORMAT_ND); | |||||
} else { | |||||
src_output_desc.SetFormat(dst_input_desc.GetFormat()); | |||||
} | |||||
GE_CHK_BOOL_EXEC(op_desc_->UpdateInputDesc(dst_name, src_output_desc) == GRAPH_SUCCESS, return, | |||||
"Update input desc failed,dst name is %s,src name is %s", dst_name.c_str(), | |||||
src_name.c_str()); // fix for linking opdesc | |||||
} | } | ||||
void AddControlInputImp(const ge::Operator &src_oprt) { | void AddControlInputImp(const ge::Operator &src_oprt) { | ||||
@@ -382,7 +421,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOpera | |||||
return Operator("default"); | return Operator("default"); | ||||
} | } | ||||
OperatorKeeper::GetInstance().CheckInOperator(operator_impl_ptr); | OperatorKeeper::GetInstance().CheckInOperator(operator_impl_ptr); | ||||
return operator_impl_ptr->ToOperator(); | |||||
return operator_impl_ptr->ToOperator(); // lint !e514 | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::GetOpDescFromOperator(const Operator &oprt) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::GetOpDescFromOperator(const Operator &oprt) { | ||||
@@ -617,26 +656,26 @@ GE_FUNC_HOST_VISIBILITY size_t Operator::GetOutputsSize() const { | |||||
// According to op get the attrs name and type | // According to op get the attrs name and type | ||||
namespace { | namespace { | ||||
const std::map<GeAttrValue::ValueType, std::string> kAttrTypesMap = { | const std::map<GeAttrValue::ValueType, std::string> kAttrTypesMap = { | ||||
{GeAttrValue::VT_NONE, "VT_STRING"}, | |||||
{GeAttrValue::VT_STRING, "VT_STRING"}, | |||||
{GeAttrValue::VT_FLOAT, "VT_FLOAT"}, | |||||
{GeAttrValue::VT_BOOL, "VT_BOOL"}, | |||||
{GeAttrValue::VT_INT, "VT_INT"}, | |||||
{GeAttrValue::VT_TENSOR_DESC, "VT_TENSOR_DESC"}, | |||||
{GeAttrValue::VT_TENSOR, "VT_TENSOR"}, | |||||
{GeAttrValue::VT_BYTES, "VT_BYTES"}, | |||||
{GeAttrValue::VT_GRAPH, "VT_GRAPH"}, | |||||
{GeAttrValue::VT_NAMED_ATTRS, "VT_NAMED_ATTRS"}, | |||||
{GeAttrValue::VT_LIST_BASE, "VT_LIST_BASE"}, | |||||
{GeAttrValue::VT_LIST_STRING, "VT_LIST_STRING"}, | |||||
{GeAttrValue::VT_LIST_FLOAT, "VT_LIST_FLOAT"}, | |||||
{GeAttrValue::VT_LIST_BOOL, "VT_LIST_BOOL"}, | |||||
{GeAttrValue::VT_LIST_INT, "VT_LIST_INT"}, | |||||
{GeAttrValue::VT_LIST_TENSOR_DESC, "VT_LIST_TENSOR_DESC"}, | |||||
{GeAttrValue::VT_LIST_TENSOR, "VT_LIST_TENSOR"}, | |||||
{GeAttrValue::VT_LIST_BYTES, "VT_LIST_BYTES"}, | |||||
{GeAttrValue::VT_GRAPH, "VT_GRAPH"}, | |||||
{GeAttrValue::VT_LIST_NAMED_ATTRS, "VT_LIST_NAMED_ATTRS"}, | |||||
{GeAttrValue::VT_NONE, "VT_STRING"}, | |||||
{GeAttrValue::VT_STRING, "VT_STRING"}, | |||||
{GeAttrValue::VT_FLOAT, "VT_FLOAT"}, | |||||
{GeAttrValue::VT_BOOL, "VT_BOOL"}, | |||||
{GeAttrValue::VT_INT, "VT_INT"}, | |||||
{GeAttrValue::VT_TENSOR_DESC, "VT_TENSOR_DESC"}, | |||||
{GeAttrValue::VT_TENSOR, "VT_TENSOR"}, | |||||
{GeAttrValue::VT_BYTES, "VT_BYTES"}, | |||||
{GeAttrValue::VT_GRAPH, "VT_GRAPH"}, | |||||
{GeAttrValue::VT_NAMED_ATTRS, "VT_NAMED_ATTRS"}, | |||||
{GeAttrValue::VT_LIST_BASE, "VT_LIST_BASE"}, | |||||
{GeAttrValue::VT_LIST_STRING, "VT_LIST_STRING"}, | |||||
{GeAttrValue::VT_LIST_FLOAT, "VT_LIST_FLOAT"}, | |||||
{GeAttrValue::VT_LIST_BOOL, "VT_LIST_BOOL"}, | |||||
{GeAttrValue::VT_LIST_INT, "VT_LIST_INT"}, | |||||
{GeAttrValue::VT_LIST_TENSOR_DESC, "VT_LIST_TENSOR_DESC"}, | |||||
{GeAttrValue::VT_LIST_TENSOR, "VT_LIST_TENSOR"}, | |||||
{GeAttrValue::VT_LIST_BYTES, "VT_LIST_BYTES"}, | |||||
{GeAttrValue::VT_GRAPH, "VT_GRAPH"}, | |||||
{GeAttrValue::VT_LIST_NAMED_ATTRS, "VT_LIST_NAMED_ATTRS"}, | |||||
}; | }; | ||||
} // namespace | } // namespace | ||||
const std::map<std::string, std::string> Operator::GetAllAttrNamesAndTypes() const { | const std::map<std::string, std::string> Operator::GetAllAttrNamesAndTypes() const { | ||||
@@ -665,32 +704,32 @@ const std::map<std::string, std::string> Operator::GetAllAttrNamesAndTypes() con | |||||
void Operator::InputRegister(const string &name) { | void Operator::InputRegister(const string &name) { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
(void)operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); | |||||
operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); | |||||
} | } | ||||
void Operator::OptionalInputRegister(const string &name) { | void Operator::OptionalInputRegister(const string &name) { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
(void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name, | (void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name, | ||||
GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED)); | |||||
GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED)); | |||||
} | } | ||||
void Operator::InferFuncRegister(const std::function<graphStatus(Operator &)> &func) { | void Operator::InferFuncRegister(const std::function<graphStatus(Operator &)> &func) { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
operator_impl_->GetOpDescImpl()->AddInferFunc(func); | |||||
(void)operator_impl_->GetOpDescImpl()->AddInferFunc(func); | |||||
} | } | ||||
void Operator::InferFormatFuncRegister(const std::function<graphStatus(Operator &)> &func) { | void Operator::InferFormatFuncRegister(const std::function<graphStatus(Operator &)> &func) { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func); | |||||
(void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func); | |||||
} | } | ||||
void Operator::VerifierFuncRegister(const std::function<graphStatus(Operator &)> &func) { | void Operator::VerifierFuncRegister(const std::function<graphStatus(Operator &)> &func) { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
operator_impl_->GetOpDescImpl()->AddVerifierFunc(func); | |||||
(void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func); | |||||
} | } | ||||
void Operator::OutputRegister(const string &name) { | void Operator::OutputRegister(const string &name) { | ||||
@@ -734,7 +773,7 @@ int Operator::GetDynamicOutputNum(const string &name) const { | |||||
void Operator::RequiredAttrRegister(const string &name) { | void Operator::RequiredAttrRegister(const string &name) { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
(void)operator_impl_->GetOpDescImpl()->AddRequiredAttr(name); | |||||
operator_impl_->GetOpDescImpl()->AddRequiredAttr(name); | |||||
} | } | ||||
graphStatus Operator::VerifyAll() { | graphStatus Operator::VerifyAll() { | ||||
@@ -960,26 +999,6 @@ graphStatus Operator::GetAttr(const string &name, OpBytes &attr_value) const { | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
Operator &Operator::SetAttr(const string &name, const UsrQuantizeFactorParams &attr_value) { | |||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr, name %s.", name.c_str()); | |||||
QuantizeFactorParams def_quant; | |||||
GE_CHK_BOOL_EXEC(TypeUtils::Usr2DefQuantizeFactorParams(attr_value, def_quant) == GRAPH_SUCCESS, return *this, | |||||
"trans para fail"); | |||||
GE_CHK_BOOL_EXEC(OpDescUtils::SetQuantizeFactorParams(operator_impl_->GetOpDescImpl(), def_quant) == GRAPH_SUCCESS, | |||||
return *this, "operator set QuantizeFactorParams fail"); | |||||
return *this; | |||||
} | |||||
graphStatus Operator::GetAttr(const string &name, UsrQuantizeFactorParams &attr_value) const { | |||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); | |||||
QuantizeFactorParams def_quant; | |||||
GE_CHK_BOOL_EXEC(OpDescUtils::GetQuantizeFactorParams(operator_impl_->GetOpDescImpl(), def_quant) == GRAPH_SUCCESS, | |||||
return GRAPH_FAILED, "operator get QuantizeFactorParams fail"); | |||||
GE_CHK_BOOL_EXEC(TypeUtils::Def2UsrQuantizeFactorParams(def_quant, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED, | |||||
"trans para fail"); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
Operator &Operator::SetAttr(const string &name, ge::AttrValue &&attrValue) { | Operator &Operator::SetAttr(const string &name, ge::AttrValue &&attrValue) { | ||||
GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); | ||||
(void)operator_impl_->SetAttr(name, std::move(attrValue.impl->geAttrValue_)); | (void)operator_impl_->SetAttr(name, std::move(attrValue.impl->geAttrValue_)); | ||||
@@ -1099,7 +1118,6 @@ class GraphBuilderImpl { | |||||
explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared<ComputeGraph>(name)) { | explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared<ComputeGraph>(name)) { | ||||
if (graph_ == nullptr) { | if (graph_ == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); | GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); | ||||
graph_ = nullptr; | |||||
return; | return; | ||||
} | } | ||||
} | } | ||||
@@ -15,7 +15,6 @@ | |||||
*/ | */ | ||||
#include "graph/operator_factory_impl.h" | #include "graph/operator_factory_impl.h" | ||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
@@ -97,6 +96,7 @@ VerifyFunc OperatorFactoryImpl::GetVerifyFunc(const std::string &operator_type) | |||||
graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) { | graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) { | ||||
if (operator_creators_ == nullptr) { | if (operator_creators_ == nullptr) { | ||||
GELOGI("operator_creators_ init"); | |||||
operator_creators_.reset(new (std::nothrow) std::map<string, OpCreator>()); | operator_creators_.reset(new (std::nothrow) std::map<string, OpCreator>()); | ||||
} | } | ||||
auto it = operator_creators_->find(operator_type); | auto it = operator_creators_->find(operator_type); | ||||
@@ -33,7 +33,9 @@ OpsProtoManager *OpsProtoManager::Instance() { | |||||
} | } | ||||
bool OpsProtoManager::Initialize(const std::map<std::string, std::string> &options) { | bool OpsProtoManager::Initialize(const std::map<std::string, std::string> &options) { | ||||
/*lint -e1561*/ | |||||
auto proto_iter = options.find("ge.opsProtoLibPath"); | auto proto_iter = options.find("ge.opsProtoLibPath"); | ||||
/*lint +e1561*/ | |||||
if (proto_iter == options.end()) { | if (proto_iter == options.end()) { | ||||
GELOGW("ge.opsProtoLibPath option not set, return."); | GELOGW("ge.opsProtoLibPath option not set, return."); | ||||
return false; | return false; | ||||
@@ -21,6 +21,10 @@ | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
namespace ge { | namespace ge { | ||||
namespace { | |||||
const int64_t kMinTrainingTraceJobId = 256; | |||||
const int kDecimal = 10; | |||||
} // namespace | |||||
GEContext &GetContext() { | GEContext &GetContext() { | ||||
static GEContext ge_context{}; | static GEContext ge_context{}; | ||||
return ge_context; | return ge_context; | ||||
@@ -58,12 +62,21 @@ void GEContext::Init() { | |||||
string job_id; | string job_id; | ||||
(void)GetOption("ge.exec.jobId", job_id); | (void)GetOption("ge.exec.jobId", job_id); | ||||
try { | |||||
job_id_ = static_cast<uint64_t>(std::stoi(job_id.c_str())); | |||||
} catch (std::invalid_argument &) { | |||||
GELOGW("%s transform to int failed.", job_id.c_str()); | |||||
} catch (std::out_of_range &) { | |||||
GELOGW("%s transform to int failed.", job_id.c_str()); | |||||
std::string s_job_id = ""; | |||||
for (auto c : job_id) { | |||||
if (c >= '0' && c <= '9') { | |||||
s_job_id += c; | |||||
} | |||||
} | |||||
if (s_job_id == "") { | |||||
trace_id_ = kMinTrainingTraceJobId; | |||||
return; | |||||
} | |||||
int64_t d_job_id = std::strtoll(s_job_id.c_str(), nullptr, kDecimal); | |||||
if (d_job_id < kMinTrainingTraceJobId) { | |||||
trace_id_ = d_job_id + kMinTrainingTraceJobId; | |||||
} else { | |||||
trace_id_ = d_job_id; | |||||
} | } | ||||
} | } | ||||
@@ -71,7 +84,7 @@ uint64_t GEContext::SessionId() { return session_id_; } | |||||
uint32_t GEContext::DeviceId() { return device_id_; } | uint32_t GEContext::DeviceId() { return device_id_; } | ||||
uint64_t GEContext::JobId() { return job_id_; } | |||||
uint64_t GEContext::TraceId() { return trace_id_; } | |||||
void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } | void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } | ||||
} // namespace ge | } // namespace ge |
@@ -15,7 +15,6 @@ | |||||
*/ | */ | ||||
#include "./ge_local_context.h" | #include "./ge_local_context.h" | ||||
#include <utility> | #include <utility> | ||||
namespace ge { | namespace ge { | ||||
@@ -26,9 +25,14 @@ thread_local GEThreadLocalContext thread_context; | |||||
GEThreadLocalContext &GetThreadLocalContext() { return thread_context; } | GEThreadLocalContext &GetThreadLocalContext() { return thread_context; } | ||||
graphStatus GEThreadLocalContext::GetOption(const string &key, string &option) { | graphStatus GEThreadLocalContext::GetOption(const string &key, string &option) { | ||||
auto iter = session_options_.find(key); | |||||
if (iter != session_options_.end()) { | |||||
option = iter->second; | |||||
auto graph_iter = graph_options_.find(key); | |||||
if (graph_iter != graph_options_.end()) { | |||||
option = graph_iter->second; | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
auto session_iter = session_options_.find(key); | |||||
if (session_iter != session_options_.end()) { | |||||
option = session_iter->second; | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
auto global_iter = global_options_.find(key); | auto global_iter = global_options_.find(key); | ||||
@@ -48,4 +52,9 @@ void GEThreadLocalContext::SetSessionOption(map<string, string> options_map) { | |||||
session_options_.clear(); | session_options_.clear(); | ||||
session_options_ = std::move(options_map); | session_options_ = std::move(options_map); | ||||
} | } | ||||
void GEThreadLocalContext::SetGraphOption(map<std::string, string> options_map) { | |||||
graph_options_.clear(); | |||||
graph_options_ = std::move(options_map); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -40,7 +40,7 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||||
return; | return; | ||||
} | } | ||||
ge::OpDescPtr op_desc = node->GetOpDesc(); | ge::OpDescPtr op_desc = node->GetOpDesc(); | ||||
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return); | |||||
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); | |||||
std::string str; | std::string str; | ||||
if (!op_desc->GetAllInputsDescPtr().empty()) { | if (!op_desc->GetAllInputsDescPtr().empty()) { | ||||
std::string input_desc_str = "input shape: "; | std::string input_desc_str = "input shape: "; | ||||
@@ -118,16 +118,16 @@ graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator & | |||||
InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map, | InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map, | ||||
const NodePtr &node) { | const NodePtr &node) { | ||||
auto ctx = std::shared_ptr<InferenceContext>(new (std::nothrow) InferenceContext()); | |||||
if (ctx == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "Failed to alloc InferenceContext"); | |||||
return nullptr; | |||||
} | |||||
if (node == nullptr) { | if (node == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "node is null"); | GELOGE(GRAPH_FAILED, "node is null"); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
InferenceContextPtr inference_context = std::shared_ptr<InferenceContext>(ctx); | |||||
InferenceContextPtr inference_context = std::shared_ptr<InferenceContext>(InferenceContext::Create()); | |||||
if (inference_context == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "Failed to alloc InferenceContext"); | |||||
return nullptr; | |||||
} | |||||
auto all_in_data_anchors = node->GetAllInDataAnchors(); | auto all_in_data_anchors = node->GetAllInDataAnchors(); | ||||
std::vector<std::vector<ShapeAndType>> input_shapes_and_types(all_in_data_anchors.size()); | std::vector<std::vector<ShapeAndType>> input_shapes_and_types(all_in_data_anchors.size()); | ||||
std::vector<std::string> marks; | std::vector<std::string> marks; | ||||
@@ -169,9 +169,9 @@ InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, Inf | |||||
} | } | ||||
if (has_input_shapes_and_types) { | if (has_input_shapes_and_types) { | ||||
ctx->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types)); | |||||
inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types)); | |||||
} | } | ||||
ctx->SetMarks(marks); | |||||
inference_context->SetMarks(marks); | |||||
return inference_context; | return inference_context; | ||||
} | } | ||||
@@ -15,7 +15,6 @@ | |||||
*/ | */ | ||||
#include "external/graph/tensor.h" | #include "external/graph/tensor.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
@@ -91,41 +90,72 @@ class TensorImpl { | |||||
GeTensor ge_tensor; | GeTensor ge_tensor; | ||||
}; | }; | ||||
Shape::Shape(const std::vector<int64_t> &dims) : dims_(dims) {} | |||||
class ShapeImpl { | |||||
public: | |||||
ShapeImpl() = default; | |||||
~ShapeImpl() = default; | |||||
explicit ShapeImpl(const std::vector<int64_t> &dims) : dims_(dims) {} | |||||
std::vector<int64_t> dims_; | |||||
}; | |||||
Shape::Shape() { impl_ = ComGraphMakeShared<ShapeImpl>(); } | |||||
size_t Shape::GetDimNum() const { return dims_.size(); } | |||||
Shape::Shape(const std::vector<int64_t> &dims) { impl_ = ComGraphMakeShared<ShapeImpl>(dims); } | |||||
size_t Shape::GetDimNum() const { | |||||
if (impl_ != nullptr) { | |||||
return impl_->dims_.size(); | |||||
} | |||||
return 0; | |||||
} | |||||
int64_t Shape::GetDim(size_t idx) const { | int64_t Shape::GetDim(size_t idx) const { | ||||
if (idx >= dims_.size()) { | |||||
return 0; | |||||
if (impl_ != nullptr) { | |||||
if (idx >= impl_->dims_.size()) { | |||||
return 0; | |||||
} | |||||
return impl_->dims_[idx]; | |||||
} | } | ||||
return dims_[idx]; | |||||
return 0; | |||||
} | } | ||||
graphStatus Shape::SetDim(size_t idx, int64_t value) { | graphStatus Shape::SetDim(size_t idx, int64_t value) { | ||||
if (idx >= dims_.size()) { | |||||
return GRAPH_FAILED; | |||||
if (impl_ != nullptr) { | |||||
if (idx >= impl_->dims_.size()) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
impl_->dims_[idx] = value; | |||||
return GRAPH_SUCCESS; | |||||
} | } | ||||
dims_[idx] = value; | |||||
return GRAPH_SUCCESS; | |||||
return GRAPH_FAILED; | |||||
} | } | ||||
std::vector<int64_t> Shape::GetDims() const { return dims_; } | |||||
std::vector<int64_t> Shape::GetDims() const { | |||||
vector<int64_t> dims; | |||||
if (impl_ != nullptr) { | |||||
return impl_->dims_; | |||||
} | |||||
return dims; | |||||
} | |||||
int64_t Shape::GetShapeSize() const { | int64_t Shape::GetShapeSize() const { | ||||
if (dims_.empty()) { | |||||
return 0; | |||||
} | |||||
int64_t size = 1; | |||||
for (auto i : dims_) { | |||||
if (!Int64MulNotOverflow(size, i)) { | |||||
GELOGE(GRAPH_FAILED, "mul overflow: %ld, %ld", size, i); | |||||
size = 0; | |||||
return size; | |||||
if (impl_ != nullptr) { | |||||
if (impl_->dims_.empty()) { | |||||
return 0; | |||||
} | } | ||||
size *= i; | |||||
int64_t size = 1; | |||||
for (auto i : impl_->dims_) { | |||||
if (!Int64MulNotOverflow(size, i)) { | |||||
GELOGE(GRAPH_FAILED, "mul overflow: %ld, %ld", size, i); | |||||
size = 0; | |||||
return size; | |||||
} | |||||
size *= i; | |||||
} | |||||
return size; | |||||
} | } | ||||
return size; | |||||
return 0; | |||||
} | } | ||||
TensorDesc::TensorDesc() { impl = ComGraphMakeShared<TensorDescImpl>(); } | TensorDesc::TensorDesc() { impl = ComGraphMakeShared<TensorDescImpl>(); } | ||||
@@ -486,6 +516,7 @@ graphStatus Tensor::IsValid() { | |||||
GELOGW("mul overflow: %lu, %u", shape_size, type_length); | GELOGW("mul overflow: %lu, %u", shape_size, type_length); | ||||
} else { | } else { | ||||
if (shape_size * type_length != data_size) { | if (shape_size * type_length != data_size) { | ||||
// [Just log] Constructor | |||||
GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, | GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, | ||||
data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); | data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
} | } | ||||
@@ -15,9 +15,7 @@ | |||||
*/ | */ | ||||
#include "utils/anchor_utils.h" | #include "utils/anchor_utils.h" | ||||
#include <algorithm> | #include <algorithm> | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
@@ -32,12 +32,12 @@ const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, | |||||
namespace ge { | namespace ge { | ||||
// Part 1: from IR convert to ONNX Protobuf | // Part 1: from IR convert to ONNX Protobuf | ||||
static const std::map<ge::DataType, onnx::TensorProto_DataType> kGeDataTypeToOnnxMap = { | static const std::map<ge::DataType, onnx::TensorProto_DataType> kGeDataTypeToOnnxMap = { | ||||
{DT_INT64, onnx::TensorProto_DataType_INT64}, {DT_UINT64, onnx::TensorProto_DataType_UINT64}, | |||||
{DT_FLOAT, onnx::TensorProto_DataType_FLOAT}, {DT_INT32, onnx::TensorProto_DataType_INT32}, | |||||
{DT_UINT32, onnx::TensorProto_DataType_UINT32}, {DT_INT8, onnx::TensorProto_DataType_INT8}, | |||||
{DT_UINT8, onnx::TensorProto_DataType_UINT8}, {DT_INT16, onnx::TensorProto_DataType_INT16}, | |||||
{DT_UINT16, onnx::TensorProto_DataType_UINT16}, {DT_FLOAT16, onnx::TensorProto_DataType_FLOAT16}, | |||||
{DT_DOUBLE, onnx::TensorProto_DataType_DOUBLE}, {DT_BOOL, onnx::TensorProto_DataType_BOOL}, | |||||
{DT_INT64, onnx::TensorProto_DataType_INT64}, {DT_UINT64, onnx::TensorProto_DataType_UINT64}, | |||||
{DT_FLOAT, onnx::TensorProto_DataType_FLOAT}, {DT_INT32, onnx::TensorProto_DataType_INT32}, | |||||
{DT_UINT32, onnx::TensorProto_DataType_UINT32}, {DT_INT8, onnx::TensorProto_DataType_INT8}, | |||||
{DT_UINT8, onnx::TensorProto_DataType_UINT8}, {DT_INT16, onnx::TensorProto_DataType_INT16}, | |||||
{DT_UINT16, onnx::TensorProto_DataType_UINT16}, {DT_FLOAT16, onnx::TensorProto_DataType_FLOAT16}, | |||||
{DT_DOUBLE, onnx::TensorProto_DataType_DOUBLE}, {DT_BOOL, onnx::TensorProto_DataType_BOOL}, | |||||
}; | }; | ||||
onnx::TensorProto_DataType OnnxUtils::EncodeDataType(DataType data_type) { | onnx::TensorProto_DataType OnnxUtils::EncodeDataType(DataType data_type) { | ||||
@@ -693,12 +693,12 @@ bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelPr | |||||
// Part 2: from ONNX Protobuf convert to IR | // Part 2: from ONNX Protobuf convert to IR | ||||
static std::map<onnx::TensorProto_DataType, ge::DataType> onnxDataTypeToGeMap = { | static std::map<onnx::TensorProto_DataType, ge::DataType> onnxDataTypeToGeMap = { | ||||
{onnx::TensorProto_DataType_INT64, DT_INT64}, {onnx::TensorProto_DataType_UINT64, DT_UINT64}, | |||||
{onnx::TensorProto_DataType_FLOAT, DT_FLOAT}, {onnx::TensorProto_DataType_INT32, DT_INT32}, | |||||
{onnx::TensorProto_DataType_UINT32, DT_UINT32}, {onnx::TensorProto_DataType_INT8, DT_INT8}, | |||||
{onnx::TensorProto_DataType_UINT8, DT_UINT8}, {onnx::TensorProto_DataType_INT16, DT_INT16}, | |||||
{onnx::TensorProto_DataType_UINT16, DT_UINT16}, {onnx::TensorProto_DataType_FLOAT16, DT_FLOAT16}, | |||||
{onnx::TensorProto_DataType_DOUBLE, DT_DOUBLE}, {onnx::TensorProto_DataType_BOOL, DT_BOOL}, | |||||
{onnx::TensorProto_DataType_INT64, DT_INT64}, {onnx::TensorProto_DataType_UINT64, DT_UINT64}, | |||||
{onnx::TensorProto_DataType_FLOAT, DT_FLOAT}, {onnx::TensorProto_DataType_INT32, DT_INT32}, | |||||
{onnx::TensorProto_DataType_UINT32, DT_UINT32}, {onnx::TensorProto_DataType_INT8, DT_INT8}, | |||||
{onnx::TensorProto_DataType_UINT8, DT_UINT8}, {onnx::TensorProto_DataType_INT16, DT_INT16}, | |||||
{onnx::TensorProto_DataType_UINT16, DT_UINT16}, {onnx::TensorProto_DataType_FLOAT16, DT_FLOAT16}, | |||||
{onnx::TensorProto_DataType_DOUBLE, DT_DOUBLE}, {onnx::TensorProto_DataType_BOOL, DT_BOOL}, | |||||
}; | }; | ||||
ge::DataType OnnxUtils::DecodeDataType(onnx::TensorProto_DataType data_type) { | ge::DataType OnnxUtils::DecodeDataType(onnx::TensorProto_DataType data_type) { | ||||
@@ -949,7 +949,7 @@ bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_ | |||||
auto size_out = attr.i(); | auto size_out = attr.i(); | ||||
for (int64_t i = 0; i < size_out; i++) { | for (int64_t i = 0; i < size_out; i++) { | ||||
GeTensorDesc ge_tensor_desc; | GeTensorDesc ge_tensor_desc; | ||||
if (op_desc->AddOutputDesc(ge_tensor_desc) != GRAPH_SUCCESS) { | |||||
if (op_desc->AddInputDesc(ge_tensor_desc) != GRAPH_SUCCESS) { | |||||
GELOGW("add inputdesc failed"); | GELOGW("add inputdesc failed"); | ||||
continue; | continue; | ||||
} | } | ||||
@@ -176,7 +176,7 @@ graphStatus GraphUtils::ReplaceEdgeDst(const OutControlAnchorPtr &src, const InC | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertNodeBetweenDataAnchors( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertNodeBetweenDataAnchors( | ||||
const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, const NodePtr &new_node) { | |||||
const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, const NodePtr &new_node) { | |||||
GE_CHECK_NOTNULL(src); | GE_CHECK_NOTNULL(src); | ||||
GE_CHECK_NOTNULL(dst); | GE_CHECK_NOTNULL(dst); | ||||
GE_CHECK_NOTNULL(new_node); | GE_CHECK_NOTNULL(new_node); | ||||
@@ -213,10 +213,10 @@ GraphUtils::RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const | |||||
/// Add two edges to the new node, respectively connecting the SRC and DST | /// Add two edges to the new node, respectively connecting the SRC and DST | ||||
/// associated with the original edge | /// associated with the original edge | ||||
/// A ---> B transferred to A ---> N ---> B | |||||
/// A ---> B transfered to A ---> N ---> B | |||||
graphStatus InsertTransNode(ComputeGraph &compute_graph, const InDataAnchorPtr &in_data_anchor, | graphStatus InsertTransNode(ComputeGraph &compute_graph, const InDataAnchorPtr &in_data_anchor, | ||||
const std::vector<OpDescPtr> &vec_op_desc) { | const std::vector<OpDescPtr> &vec_op_desc) { | ||||
for (auto &op_desc : vec_op_desc) { | |||||
for (const auto &op_desc : vec_op_desc) { | |||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
auto ret = op_desc->AddInputDesc(GeTensorDesc()); | auto ret = op_desc->AddInputDesc(GeTensorDesc()); | ||||
@@ -275,9 +275,11 @@ graphStatus InsertTransNode(ComputeGraph &compute_graph, const InDataAnchorPtr & | |||||
int64_t output_format = 0; | int64_t output_format = 0; | ||||
if (!AttrUtils::GetInt(op_desc, "input_format", input_format)) { | if (!AttrUtils::GetInt(op_desc, "input_format", input_format)) { | ||||
GELOGW("get attr input_format failed"); | GELOGW("get attr input_format failed"); | ||||
continue; | |||||
} | } | ||||
if (!AttrUtils::GetInt(op_desc, "output_format", output_format)) { | if (!AttrUtils::GetInt(op_desc, "output_format", output_format)) { | ||||
GELOGW("get attr output_format failed"); | GELOGW("get attr output_format failed"); | ||||
continue; | |||||
} | } | ||||
GE_CHECK_NOTNULL(node_to_insert->GetInDataAnchor(0)->GetPeerOutAnchor()); | GE_CHECK_NOTNULL(node_to_insert->GetInDataAnchor(0)->GetPeerOutAnchor()); | ||||
@@ -299,11 +301,11 @@ graphStatus InsertTransNode(ComputeGraph &compute_graph, const InDataAnchorPtr & | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTransNode( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTransNode( | ||||
ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, const std::vector<OpDescPtr> &vec_op_desc) { | |||||
ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, const std::vector<OpDescPtr> &vec_op_desc) { | |||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
GE_CHECK_NOTNULL(in_data_anchor); | GE_CHECK_NOTNULL(in_data_anchor); | ||||
graphStatus ret = | graphStatus ret = | ||||
ge::InsertTransNode(*compute_graph, in_data_anchor, vec_op_desc) == GRAPH_SUCCESS ? GRAPH_SUCCESS : GRAPH_FAILED; | |||||
ge::InsertTransNode(*compute_graph, in_data_anchor, vec_op_desc) == GRAPH_SUCCESS ? GRAPH_SUCCESS : GRAPH_FAILED; | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -335,6 +337,10 @@ void GraphUtils::RecordOriginalNames(std::vector<ge::NodePtr> original_nodes, co | |||||
for (const auto &node_tmp : original_nodes) { | for (const auto &node_tmp : original_nodes) { | ||||
std::vector<std::string> names_tmp; | std::vector<std::string> names_tmp; | ||||
ge::OpDescPtr opdesc_tmp = node_tmp->GetOpDesc(); | ge::OpDescPtr opdesc_tmp = node_tmp->GetOpDesc(); | ||||
if (opdesc_tmp == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "Node %s get opdesc is nullptr", node_tmp->GetName().c_str()); | |||||
continue; | |||||
} | |||||
(void)ge::AttrUtils::GetListStr(opdesc_tmp, "original_op_names", names_tmp); | (void)ge::AttrUtils::GetListStr(opdesc_tmp, "original_op_names", names_tmp); | ||||
if (names_tmp.size() != 0) { | if (names_tmp.size() != 0) { | ||||
original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); | original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); | ||||
@@ -355,7 +361,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::RecordOriginalNa | |||||
GE_CHK_BOOL_EXEC(node != nullptr, return, "node is null."); | GE_CHK_BOOL_EXEC(node != nullptr, return, "node is null."); | ||||
std::vector<std::string> original_names; | std::vector<std::string> original_names; | ||||
if (names_tmp.size() != 0) { | if (names_tmp.size() != 0) { | ||||
original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); | |||||
(void)original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); | |||||
} else { | } else { | ||||
std::string tmp; | std::string tmp; | ||||
original_names.push_back(tmp); | original_names.push_back(tmp); | ||||
@@ -367,7 +373,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::RecordOriginalNa | |||||
// Check global_step Node has IsVariable and Read. | // Check global_step Node has IsVariable and Read. | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::CheckGlobalStepNode(const ge::NodePtr &node) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::CheckGlobalStepNode(const ge::NodePtr &node) { | ||||
GE_CHK_BOOL_EXEC( | GE_CHK_BOOL_EXEC( | ||||
node != nullptr, { return false; }, "node is null."); | |||||
node != nullptr, { return false; }, "node is null."); | |||||
bool has_variable = false; | bool has_variable = false; | ||||
bool has_cond_read = false; | bool has_cond_read = false; | ||||
for (const auto &out : node->GetOutDataNodes()) { | for (const auto &out : node->GetOutDataNodes()) { | ||||
@@ -382,21 +388,22 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::CheckGlobalStepN | |||||
// Check origin ComputeGraph is TrainGraph. | // Check origin ComputeGraph is TrainGraph. | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::CheckIsTrainGraph( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::CheckIsTrainGraph( | ||||
const ge::ComputeGraphPtr &compute_graph) { | |||||
const ge::ComputeGraphPtr &compute_graph) { | |||||
GE_CHK_BOOL_EXEC( | GE_CHK_BOOL_EXEC( | ||||
compute_graph != nullptr, { return false; }, "compute_graph is nullptr"); | |||||
compute_graph != nullptr, { return false; }, "compute_graph is nullptr"); | |||||
bool is_iterator_v2 = false; | bool is_iterator_v2 = false; | ||||
bool is_train_graph = false; | bool is_train_graph = false; | ||||
for (const auto &node : compute_graph->GetDirectNode()) { | for (const auto &node : compute_graph->GetDirectNode()) { | ||||
if (node->GetType() == "ApplyMomentum") { | |||||
if ((node->GetType() == "ApplyMomentum") || (node->GetType() == "ApplyGradientDescent")) { | |||||
GELOGI("graph needs iteration."); | |||||
return true; | return true; | ||||
} | } | ||||
// Check global_step has IsVariable and Read. | // Check global_step has IsVariable and Read. | ||||
if ((node->GetType() == "Variable") && (node->GetName() == "global_step")) { | if ((node->GetType() == "Variable") && (node->GetName() == "global_step")) { | ||||
is_train_graph = CheckGlobalStepNode(node); | is_train_graph = CheckGlobalStepNode(node); | ||||
} else if ((node->GetType() == "FrameworkOp") && (node->GetName() == "IteratorGetNext")) { | } else if ((node->GetType() == "FrameworkOp") && (node->GetName() == "IteratorGetNext")) { | ||||
// Train Graph must has GetNext. | |||||
// Train Graph must have GetNext. | |||||
is_iterator_v2 = true; | is_iterator_v2 = true; | ||||
} | } | ||||
if (is_iterator_v2 && is_train_graph) { | if (is_iterator_v2 && is_train_graph) { | ||||
@@ -410,7 +417,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::CheckIsTrainGrap | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::MatchDumpStr(const std::string &suffix) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::MatchDumpStr(const std::string &suffix) { | ||||
char *dump_level = std::getenv(kDumpGraphLevel); | char *dump_level = std::getenv(kDumpGraphLevel); | ||||
int64_t dump_graph_level = | int64_t dump_graph_level = | ||||
(dump_level != nullptr) ? std::strtol(dump_level, nullptr, kBaseOfIntegerValue) : kDumpLevel2; | |||||
(dump_level != nullptr) ? std::strtol(dump_level, nullptr, kBaseOfIntegerValue) : kDumpLevel2; | |||||
if (dump_graph_level == kDumpLevel1) { | if (dump_graph_level == kDumpLevel1) { | ||||
return false; | return false; | ||||
} | } | ||||
@@ -499,6 +507,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(cons | |||||
ge::Model model; | ge::Model model; | ||||
// Get Model object from ModelDef by deserialize ModelDef | // Get Model object from ModelDef by deserialize ModelDef | ||||
if (model.Load(model_def) == GRAPH_SUCCESS) { | if (model.Load(model_def) == GRAPH_SUCCESS) { | ||||
GE_CHK_BOOL_EXEC(GraphUtils::GetComputeGraph(model.GetGraph()) != nullptr, return false, | |||||
"Get computer graph is nullptr"); | |||||
compute_graph = *(GraphUtils::GetComputeGraph(model.GetGraph())); | compute_graph = *(GraphUtils::GetComputeGraph(model.GetGraph())); | ||||
return true; | return true; | ||||
} else { | } else { | ||||
@@ -509,7 +519,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(cons | |||||
// Printing protocol messages in text format is useful for debugging and human editing of messages. | // Printing protocol messages in text format is useful for debugging and human editing of messages. | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToTextFile( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToTextFile( | ||||
const google::protobuf::Message &proto, const char *real_path) { | |||||
const google::protobuf::Message &proto, const char *real_path) { | |||||
#ifdef FMK_SUPPORT_DUMP | #ifdef FMK_SUPPORT_DUMP | ||||
const int FILE_AUTHORITY = 0600; | const int FILE_AUTHORITY = 0600; | ||||
int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, FILE_AUTHORITY); | int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, FILE_AUTHORITY); | ||||
@@ -563,7 +573,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToText | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::ReadProtoFromTextFile( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::ReadProtoFromTextFile( | ||||
const char *file, google::protobuf::Message *proto) { | |||||
const char *file, google::protobuf::Message *proto) { | |||||
if (file == nullptr || proto == nullptr) { | if (file == nullptr || proto == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "incorrect parameter. file path or message is invalid"); | GELOGE(GRAPH_FAILED, "incorrect parameter. file path or message is invalid"); | ||||
return false; | return false; | ||||
@@ -587,7 +597,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn | |||||
#ifdef FMK_SUPPORT_DUMP | #ifdef FMK_SUPPORT_DUMP | ||||
char *dump_ge_graph = std::getenv(kDumpGeGraph); | char *dump_ge_graph = std::getenv(kDumpGeGraph); | ||||
int64_t dump_ge_graph_level = | int64_t dump_ge_graph_level = | ||||
(dump_ge_graph != nullptr) ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) : OnnxUtils::NO_DUMP; | |||||
(dump_ge_graph != nullptr) ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) : OnnxUtils::NO_DUMP; | |||||
if ((dump_ge_graph_level == OnnxUtils::NO_DUMP) || (dump_ge_graph_level >= OnnxUtils::DUMP_LEVEL_END)) { | if ((dump_ge_graph_level == OnnxUtils::NO_DUMP) || (dump_ge_graph_level >= OnnxUtils::DUMP_LEVEL_END)) { | ||||
GELOGD("Skip DumpGEGraphToOnnx with dump_ge_graph_level %ld.", dump_ge_graph_level); | GELOGD("Skip DumpGEGraphToOnnx with dump_ge_graph_level %ld.", dump_ge_graph_level); | ||||
return; | return; | ||||
@@ -1029,8 +1039,8 @@ GraphUtils::ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::ReplaceNodeAnchors( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::ReplaceNodeAnchors( | ||||
const NodePtr &new_node, const NodePtr &old_node, const std::initializer_list<int> inputs_map, | |||||
const std::initializer_list<int> outputs_map) { | |||||
const NodePtr &new_node, const NodePtr &old_node, const std::initializer_list<int> inputs_map, | |||||
const std::initializer_list<int> outputs_map) { | |||||
return ReplaceNodeAnchors(new_node, old_node, std::vector<int>(inputs_map), std::vector<int>(outputs_map)); | return ReplaceNodeAnchors(new_node, old_node, std::vector<int>(inputs_map), std::vector<int>(outputs_map)); | ||||
} | } | ||||
@@ -15,7 +15,6 @@ | |||||
*/ | */ | ||||
#include "utils/node_utils.h" | #include "utils/node_utils.h" | ||||
#include "debug/ge_op_types.h" | #include "debug/ge_op_types.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
@@ -86,6 +85,7 @@ graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int dep | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
cur_ptr = src->GetOutDataNodes().at(0); | cur_ptr = src->GetOutDataNodes().at(0); | ||||
GE_CHECK_NOTNULL(cur_ptr); | |||||
} | } | ||||
dst = cur_ptr; | dst = cur_ptr; | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
@@ -289,8 +289,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer | |||||
auto peer_op_desc = peer_anchor->GetOwnerNode()->GetOpDesc(); | auto peer_op_desc = peer_anchor->GetOwnerNode()->GetOpDesc(); | ||||
GE_IF_BOOL_EXEC(peer_op_desc == nullptr, GELOGE(GRAPH_FAILED, "peer opdesc is null"); continue); | GE_IF_BOOL_EXEC(peer_op_desc == nullptr, GELOGE(GRAPH_FAILED, "peer opdesc is null"); continue); | ||||
GE_IF_BOOL_EXEC(peer_op_desc->UpdateInputDesc(peer_anchor->GetIdx(), output_tensor) != GRAPH_SUCCESS, | GE_IF_BOOL_EXEC(peer_op_desc->UpdateInputDesc(peer_anchor->GetIdx(), output_tensor) != GRAPH_SUCCESS, | ||||
GELOGE(GRAPH_FAILED, "peer opdesc is null"); | |||||
continue); | |||||
GELOGE(GRAPH_FAILED, "peer opdesc is null"); | |||||
continue); | |||||
} | } | ||||
} | } | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
@@ -309,7 +309,7 @@ bool NodeUtils::IsInNodesEmpty(const Node &node) { | |||||
if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) { | if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) { | ||||
auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors(); | auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors(); | ||||
for (auto &out_control_anchor : peer_out_control_anchors) { | |||||
for (const auto &out_control_anchor : peer_out_control_anchors) { | |||||
if (out_control_anchor != nullptr) { | if (out_control_anchor != nullptr) { | ||||
if (out_control_anchor->GetOwnerNode() != nullptr) { | if (out_control_anchor->GetOwnerNode() != nullptr) { | ||||
return false; | return false; | ||||
@@ -30,6 +30,7 @@ | |||||
using std::vector; | using std::vector; | ||||
/*lint -e512 -e737 -e752*/ | |||||
namespace ge { | namespace ge { | ||||
const char OP_DESC_QUANT_PARAMS[] = "quantize_factor"; | const char OP_DESC_QUANT_PARAMS[] = "quantize_factor"; | ||||
static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1; | static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1; | ||||
@@ -134,11 +135,11 @@ graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, Quantize | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | ||||
OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) { | OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) { | ||||
GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); | GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); | ||||
return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); | |||||
return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732 | |||||
} | } | ||||
graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) { | graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) { | ||||
return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); | |||||
return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732 | |||||
} | } | ||||
GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) { | GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) { | ||||
@@ -163,7 +164,7 @@ graphStatus OpDescUtils::SetWeights(OpDesc &op_desc, const GeTensorPtr weight) { | |||||
GELOGE(GRAPH_FAILED, "weight is null"); | GELOGE(GRAPH_FAILED, "weight is null"); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
return AttrUtils::SetTensor(&op_desc, ATTR_NAME_WEIGHTS, weight) ? GRAPH_SUCCESS : GRAPH_FAILED; | |||||
return AttrUtils::SetTensor(&op_desc, ATTR_NAME_WEIGHTS, weight) ? GRAPH_SUCCESS : GRAPH_FAILED; // lint !e737 | |||||
} | } | ||||
graphStatus OpDescUtils::SetWeights(OpDescPtr op_desc, const GeTensorPtr weight) { | graphStatus OpDescUtils::SetWeights(OpDescPtr op_desc, const GeTensorPtr weight) { | ||||
@@ -180,7 +181,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUt | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights( | ||||
const ge::ConstNodePtr &node) { | |||||
const ge::ConstNodePtr &node) { | |||||
if (node == nullptr) { | if (node == nullptr) { | ||||
return vector<ge::ConstGeTensorPtr>(); | return vector<ge::ConstGeTensorPtr>(); | ||||
} | } | ||||
@@ -188,7 +189,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUt | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputNode( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputNode( | ||||
const ge::Node &node) { | |||||
const ge::Node &node) { | |||||
vector<ge::NodePtr> ret; | vector<ge::NodePtr> ret; | ||||
auto in_anchors = node.GetAllInDataAnchors(); | auto in_anchors = node.GetAllInDataAnchors(); | ||||
for (const auto &in_anchor : in_anchors) { | for (const auto &in_anchor : in_anchors) { | ||||
@@ -207,7 +208,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils:: | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData( | ||||
const vector<ge::NodePtr> &input_nodes) { | |||||
const vector<ge::NodePtr> &input_nodes) { | |||||
vector<ConstGeTensorPtr> ret; | vector<ConstGeTensorPtr> ret; | ||||
for (const auto &input_node : input_nodes) { | for (const auto &input_node : input_nodes) { | ||||
auto temp_weight = MutableWeights(input_node->GetOpDesc()); | auto temp_weight = MutableWeights(input_node->GetOpDesc()); | ||||
@@ -229,12 +230,12 @@ size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) { | |||||
continue; | continue; | ||||
} | } | ||||
} | } | ||||
return input_num; | |||||
return input_num; // lint !e712 | |||||
} else { | } else { | ||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
node.GetInDataNodes().size() < GetConstInputs(node).size(), | |||||
GELOGE(GRAPH_FAILED, "%zu is smaller than %zu", node.GetInDataNodes().size(), GetConstInputs(node).size()); | |||||
return 0); | |||||
node.GetInDataNodes().size() < GetConstInputs(node).size(), | |||||
GELOGE(GRAPH_FAILED, "%zu is smaller than %zu", node.GetInDataNodes().size(), GetConstInputs(node).size()); | |||||
return 0); | |||||
return node.GetInDataNodes().size() - GetConstInputs(node).size(); | return node.GetInDataNodes().size() - GetConstInputs(node).size(); | ||||
} | } | ||||
} | } | ||||
@@ -334,7 +335,7 @@ bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) { | |||||
bool ret = false; | bool ret = false; | ||||
if (index < node.GetAllInDataAnchors().size()) { | if (index < node.GetAllInDataAnchors().size()) { | ||||
if (NodeUtils::IsAnchorStatusSet(node)) { | if (NodeUtils::IsAnchorStatusSet(node)) { | ||||
ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA); | |||||
ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA); // lint !e712 | |||||
} else { | } else { | ||||
for (const auto &anchor : node.GetAllInDataAnchors()) { | for (const auto &anchor : node.GetAllInDataAnchors()) { | ||||
if (anchor->GetIdx() != static_cast<int>(index)) { | if (anchor->GetIdx() != static_cast<int>(index)) { | ||||
@@ -363,13 +364,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::IsNonConstInput | |||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs( | ||||
const ge::ConstNodePtr &node) { | |||||
if (node == nullptr) { return vector<ge::NodePtr>(); } | |||||
const ge::ConstNodePtr &node) { | |||||
if (node == nullptr) { | |||||
return vector<ge::NodePtr>(); | |||||
} | |||||
return GetConstInputs(*node); | return GetConstInputs(*node); | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUtils::GetNonConstTensorDesc( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUtils::GetNonConstTensorDesc( | ||||
const ge::ConstNodePtr &node) { | |||||
const ge::ConstNodePtr &node) { | |||||
if (node == nullptr || node->GetOpDesc() == nullptr) { | if (node == nullptr || node->GetOpDesc() == nullptr) { | ||||
return vector<ge::GeTensorDesc>(); | return vector<ge::GeTensorDesc>(); | ||||
} | } | ||||
@@ -377,7 +380,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUt | |||||
if (NodeUtils::IsAnchorStatusSet(*node)) { | if (NodeUtils::IsAnchorStatusSet(*node)) { | ||||
for (const auto &in_anchor : node->GetAllInDataAnchors()) { | for (const auto &in_anchor : node->GetAllInDataAnchors()) { | ||||
if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) { | if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) { | ||||
ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||||
(void)ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||||
} | } | ||||
} | } | ||||
} else { | } else { | ||||
@@ -387,7 +390,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUt | |||||
continue; | continue; | ||||
} | } | ||||
if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) { | if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) { | ||||
ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||||
(void)ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -571,3 +574,4 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWei | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
} // namespace ge | } // namespace ge | ||||
/*lint +e512 +e737 +e752*/ |
@@ -22,7 +22,6 @@ | |||||
#include <sstream> | #include <sstream> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "securec.h" | #include "securec.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -260,7 +260,6 @@ target_link_libraries(ge_train | |||||
${hccl} | ${hccl} | ||||
${msprof} | ${msprof} | ||||
${runtime} | ${runtime} | ||||
${cce} | |||||
${resouce} | ${resouce} | ||||
rt | rt | ||||
dl) | dl) | ||||
@@ -468,7 +467,6 @@ target_link_libraries(ge | |||||
${mmpa} | ${mmpa} | ||||
${msprof} | ${msprof} | ||||
${runtime} | ${runtime} | ||||
${cce} | |||||
${resouce} | ${resouce} | ||||
rt | rt | ||||
dl) | dl) |
@@ -47,8 +47,6 @@ include_directories(${GE_SOURCE_DIR}/inc/graph) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/json/include) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/protobuf/src) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
@@ -70,7 +68,6 @@ target_link_libraries(ge_client_train | |||||
${slog} | ${slog} | ||||
${mmpa} | ${mmpa} | ||||
${runtime} | ${runtime} | ||||
${cce} | |||||
rt | rt | ||||
dl) | dl) | ||||
@@ -91,6 +88,5 @@ target_link_libraries(ge_client | |||||
${slog} | ${slog} | ||||
${mmpa} | ${mmpa} | ||||
${runtime} | ${runtime} | ||||
${cce} | |||||
rt | rt | ||||
dl) | dl) |
@@ -15,22 +15,19 @@ | |||||
*/ | */ | ||||
#include "ge/ge_api.h" | #include "ge/ge_api.h" | ||||
#include <iostream> | #include <iostream> | ||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/ge/datatype_util.h" | |||||
#include "common/ge/tbe_plugin_manager.h" | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/detail/model_serialize_imp.h" | |||||
#include "common/ge/datatype_util.h" | |||||
#include "proto/ge_api.pb.h" | |||||
#include "graph/model_serialize.h" | #include "graph/model_serialize.h" | ||||
#include "graph/opsproto_manager.h" | |||||
#include "graph/detail/model_serialize_imp.h" | |||||
#include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
#include "graph/utils/type_utils.h" | |||||
#include "init/gelib.h" | #include "init/gelib.h" | ||||
#include "proto/ge_api.pb.h" | |||||
#include "register/op_registry.h" | |||||
#include "session/session_manager.h" | #include "session/session_manager.h" | ||||
#include "graph/opsproto_manager.h" | |||||
#include "graph/utils/type_utils.h" | |||||
#include "register/op_registry.h" | |||||
using domi::GetContext; | using domi::GetContext; | ||||
using domi::OpRegistry; | using domi::OpRegistry; | ||||
@@ -102,6 +99,20 @@ Status CheckOptionsValid(const std::map<string, string> &options) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void SaveDdkVersion(const std::map<string, string> &options) { | |||||
auto ddk_option = options.find(DDK_VERSION_FLAG); | |||||
if (ddk_option != options.end()) { | |||||
auto ddk_version = ddk_option->second; | |||||
if (!ddk_version.empty()) { | |||||
GELOGI("Input ddk version : %s.", ddk_version.c_str()); | |||||
domi::GetContext().ddk_version = ddk_version; | |||||
} | |||||
} else { | |||||
GELOGW("No ddkVersion!"); | |||||
return; | |||||
} | |||||
} | |||||
// Initialize GE, prepare for execution, call GELib::Initialize | // Initialize GE, prepare for execution, call GELib::Initialize | ||||
Status GEInitialize(const std::map<string, string> &options) { | Status GEInitialize(const std::map<string, string> &options) { | ||||
GELOGT(TRACE_INIT, "GEInitialize start"); | GELOGT(TRACE_INIT, "GEInitialize start"); | ||||
@@ -127,7 +138,8 @@ Status GEInitialize(const std::map<string, string> &options) { | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
TBEPluginManager::Instance().InitPreparation(options); | |||||
SaveDdkVersion(options); | |||||
// call Initialize | // call Initialize | ||||
GELOGT(TRACE_RUNNING, "Initializing environment"); | GELOGT(TRACE_RUNNING, "Initializing environment"); | ||||
Status ret = ge::GELib::Initialize(options); | Status ret = ge::GELib::Initialize(options); | ||||
@@ -169,7 +181,7 @@ Status GEFinalize() { | |||||
GELOGE(ret, "GEFinalize Failed"); | GELOGE(ret, "GEFinalize Failed"); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
TBEPluginManager::Instance().Finalize(); | |||||
if (kGeInitialized && ret == SUCCESS) { | if (kGeInitialized && ret == SUCCESS) { | ||||
kGeInitialized = false; | kGeInitialized = false; | ||||
} | } | ||||
@@ -246,20 +258,24 @@ Session::~Session() { | |||||
} | } | ||||
Status Session::AddGraph(uint32_t graph_id, const Graph &graph) { | Status Session::AddGraph(uint32_t graph_id, const Graph &graph) { | ||||
GELOGT(TRACE_INIT, "Session AddGraph start"); | |||||
std::map<std::string, std::string> options; | |||||
return AddGraph(graph_id, graph, options); | |||||
} | |||||
Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map<std::string, std::string> &options) { | |||||
GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, sessinon_id: %lu.", graph_id, sessionId_); | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | ||||
if (!instance_ptr || !instance_ptr->InitFlag()) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Session AddGraph failed"); | |||||
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "AddGraph failed in Sesson."); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
GELOGT(TRACE_RUNNING, "Adding Graph to session"); | |||||
Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph); | |||||
GELOGD("Adding graph to session"); | |||||
Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph, options); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Session AddGraph failed"); | |||||
GELOGE(ret, "AddGraph failed in Session."); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
GELOGT(TRACE_STOP, "Session AddGraph finished"); | |||||
GELOGD("AddGraph finished in Session."); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -59,6 +59,9 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"types.cc" | "types.cc" | ||||
"util.cc" | "util.cc" | ||||
"model_saver.cc" | "model_saver.cc" | ||||
# new files, possibly to be deleted? | |||||
"op/attr_value_util.cc" | |||||
"op/ge_op_utils.cc" | |||||
) | ) | ||||
ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | ||||
@@ -75,9 +78,6 @@ include_directories(${GE_SOURCE_DIR}/inc/graph) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/json/include) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/eigen) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/protobuf/src) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
@@ -94,7 +94,6 @@ target_link_libraries(ge_common | |||||
${c_sec} | ${c_sec} | ||||
${slog} | ${slog} | ||||
${mmpa} | ${mmpa} | ||||
${cce} | |||||
${resource} | ${resource} | ||||
rt | rt | ||||
dl) | dl) |
@@ -60,10 +60,10 @@ Status FileSaver::WriteData(const void *data, uint32_t size, int32_t fd) { | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size == 0 || data == nullptr, return PARAM_INVALID); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size == 0 || data == nullptr, return PARAM_INVALID); | ||||
// Write data | // Write data | ||||
mmSsize_t write_count = mmWrite(fd, const_cast<void *>(data), size); | |||||
int32_t write_count = mmWrite(fd, const_cast<void *>(data), size); | |||||
// -1: Failed to write to file; - 2: Illegal parameter | // -1: Failed to write to file; - 2: Illegal parameter | ||||
if (write_count == EN_INVALID_PARAM || write_count == EN_ERROR) { | if (write_count == EN_INVALID_PARAM || write_count == EN_ERROR) { | ||||
GELOGE(FAILED, "Write data failed. mmpa_errorno = %ld", write_count); | |||||
GELOGE(FAILED, "Write data failed. mmpa_errorno = %d", write_count); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -102,9 +102,9 @@ Status FileSaver::SaveWithFileHeader(const std::string &file_path, const ModelFi | |||||
ModelPartitionTable &model_partition_table, | ModelPartitionTable &model_partition_table, | ||||
const std::vector<ModelPartition> &partition_datas) { | const std::vector<ModelPartition> &partition_datas) { | ||||
GE_CHK_BOOL_RET_STATUS( | GE_CHK_BOOL_RET_STATUS( | ||||
!partition_datas.empty() && model_partition_table.num != 0 && model_partition_table.num == partition_datas.size(), | |||||
FAILED, "Invalid param:partition data size(%u), model_partition_table.num(%zu).", model_partition_table.num, | |||||
partition_datas.size()); | |||||
!partition_datas.empty() && model_partition_table.num != 0 && model_partition_table.num == partition_datas.size(), | |||||
FAILED, "Invalid param:partition data size(%u), model_partition_table.num(%zu).", model_partition_table.num, | |||||
partition_datas.size()); | |||||
// Open file | // Open file | ||||
int32_t fd = 0; | int32_t fd = 0; | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(OpenFile(fd, file_path) != SUCCESS, return FAILED); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(OpenFile(fd, file_path) != SUCCESS, return FAILED); | ||||
@@ -112,17 +112,16 @@ Status FileSaver::SaveWithFileHeader(const std::string &file_path, const ModelFi | |||||
do { | do { | ||||
// Write file header | // Write file header | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
WriteData(static_cast<const void *>(&file_header), sizeof(ModelFileHeader), fd) != SUCCESS, ret = FAILED; | |||||
break); | |||||
WriteData(static_cast<const void *>(&file_header), sizeof(ModelFileHeader), fd) != SUCCESS, ret = FAILED; break); | |||||
// Write model partition table | // Write model partition table | ||||
uint32_t table_size = static_cast<uint32_t>(SIZE_OF_MODEL_PARTITION_TABLE(model_partition_table)); | uint32_t table_size = static_cast<uint32_t>(SIZE_OF_MODEL_PARTITION_TABLE(model_partition_table)); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
WriteData(static_cast<const void *>(&model_partition_table), table_size, fd) != SUCCESS, ret = FAILED; break); | |||||
WriteData(static_cast<const void *>(&model_partition_table), table_size, fd) != SUCCESS, ret = FAILED; break); | |||||
// Write partition data | // Write partition data | ||||
for (const auto &partition_data : partition_datas) { | for (const auto &partition_data : partition_datas) { | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
WriteData(static_cast<const void *>(partition_data.data), partition_data.size, fd) != SUCCESS, ret = FAILED; | |||||
break); | |||||
WriteData(static_cast<const void *>(partition_data.data), partition_data.size, fd) != SUCCESS, ret = FAILED; | |||||
break); | |||||
} | } | ||||
} while (0); | } while (0); | ||||
// Close file | // Close file | ||||
@@ -28,23 +28,23 @@ | |||||
struct PROC_PARAM { | struct PROC_PARAM { | ||||
uint8_t *model_name; | uint8_t *model_name; | ||||
// ISV Ek buffer | |||||
/* ISV Ek buffer */ | |||||
uint8_t *model_key; | uint8_t *model_key; | ||||
uint32_t model_key_len; | uint32_t model_key_len; | ||||
// ISV root certificate buffer | |||||
/* ISV root certificate buffer */ | |||||
uint8_t *root_cert; | uint8_t *root_cert; | ||||
uint32_t root_cert_len; | uint32_t root_cert_len; | ||||
// ISV private key buffer | |||||
/* ISV private key buffer */ | |||||
uint8_t *pri_key; | uint8_t *pri_key; | ||||
uint32_t pri_key_len; | uint32_t pri_key_len; | ||||
// Raw AI Module Image buffer | |||||
/* Raw AI Module Image buffer */ | |||||
uint8_t *ai_image; | uint8_t *ai_image; | ||||
uint32_t ai_image_len; | uint32_t ai_image_len; | ||||
// ISV HW key buffer | |||||
/* ISV HW key buffer */ | |||||
uint8_t *hw_key; | uint8_t *hw_key; | ||||
uint32_t hw_key_len; | uint32_t hw_key_len; | ||||
}; | }; | ||||
@@ -61,11 +61,11 @@ using std::string; | |||||
class FileSaver { | class FileSaver { | ||||
public: | public: | ||||
/// | |||||
/// @ingroup domi_common | |||||
/// @brief save model, no encryption | |||||
/// @return Status result | |||||
/// | |||||
/** | |||||
* @ingroup domi_common | |||||
* @brief save model, no encryption | |||||
* @return Status result | |||||
*/ | |||||
static Status SaveToFile(const string &file_path, const ge::ModelData &model, | static Status SaveToFile(const string &file_path, const ge::ModelData &model, | ||||
const ModelFileHeader *model_file_header = nullptr); | const ModelFileHeader *model_file_header = nullptr); | ||||
@@ -74,26 +74,26 @@ class FileSaver { | |||||
const std::vector<ModelPartition> &partition_datas); | const std::vector<ModelPartition> &partition_datas); | ||||
protected: | protected: | ||||
/// | |||||
/// @ingroup domi_common | |||||
/// @brief Check validity of the file path | |||||
/// @return Status result | |||||
/// | |||||
/** | |||||
* @ingroup domi_common | |||||
* @brief Check validity of the file path | |||||
* @return Status result | |||||
*/ | |||||
static Status CheckPath(const string &file_path); | static Status CheckPath(const string &file_path); | ||||
static Status WriteData(const void *data, uint32_t size, int32_t fd); | static Status WriteData(const void *data, uint32_t size, int32_t fd); | ||||
static Status OpenFile(int32_t &fd, const std::string &file_path); | static Status OpenFile(int32_t &fd, const std::string &file_path); | ||||
/// | |||||
/// @ingroup domi_common | |||||
/// @brief save model to file | |||||
/// @param [in] file_path file output path | |||||
/// @param [in] file_header file header info | |||||
/// @param [in] data model data | |||||
/// @param [in] len model length | |||||
/// @return Status result | |||||
/// | |||||
/** | |||||
* @ingroup domi_common | |||||
* @brief save model to file | |||||
* @param [in] file_path file output path | |||||
* @param [in] file_header file header info | |||||
* @param [in] data model data | |||||
* @param [in] len model length | |||||
* @return Status result | |||||
*/ | |||||
static Status SaveWithFileHeader(const string &file_path, const ModelFileHeader &file_header, const void *data, | static Status SaveWithFileHeader(const string &file_path, const ModelFileHeader &file_header, const void *data, | ||||
int len); | int len); | ||||
@@ -18,7 +18,6 @@ | |||||
// Description: This imply file for protobuf message and json interconversion | // Description: This imply file for protobuf message and json interconversion | ||||
#include "common/convert/pb2json.h" | #include "common/convert/pb2json.h" | ||||
#include <set> | #include <set> | ||||
#include <string> | #include <string> | ||||
@@ -130,7 +129,7 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr | |||||
void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | ||||
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | ||||
bool enum2str) { | bool enum2str) { | ||||
if (field == nullptr || reflection == nullptr) { | |||||
if (nullptr == field || nullptr == reflection) { | |||||
Message2Json(message, black_fields, json); | Message2Json(message, black_fields, json); | ||||
return; | return; | ||||
} | } | ||||
@@ -19,12 +19,10 @@ | |||||
#ifndef GE_COMMON_CONVERT_PB2JSON_H_ | #ifndef GE_COMMON_CONVERT_PB2JSON_H_ | ||||
#define GE_COMMON_CONVERT_PB2JSON_H_ | #define GE_COMMON_CONVERT_PB2JSON_H_ | ||||
#include <functional> | #include <functional> | ||||
#include <memory> | #include <memory> | ||||
#include <set> | #include <set> | ||||
#include <string> | #include <string> | ||||
#include "google/protobuf/descriptor.h" | #include "google/protobuf/descriptor.h" | ||||
#include "google/protobuf/message.h" | #include "google/protobuf/message.h" | ||||
#include "nlohmann/json.hpp" | #include "nlohmann/json.hpp" | ||||
@@ -40,12 +38,12 @@ using ProtobufEnumValueDescriptor = ::google::protobuf::EnumValueDescriptor; | |||||
class Pb2Json { | class Pb2Json { | ||||
public: | public: | ||||
/** | /** | ||||
* @ingroup domi_omg | |||||
* @brief Transfer protobuf object to JSON object | |||||
* @param [out] json Converted JSON object | |||||
* @return void success | |||||
* @author | |||||
*/ | |||||
* @ingroup domi_omg | |||||
* @brief Transfer protobuf object to JSON object | |||||
* @param [out] json Converted JSON object | |||||
* @return void success | |||||
* @author | |||||
*/ | |||||
static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json, | static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json, | ||||
bool enum2str = false); | bool enum2str = false); | ||||
@@ -21,10 +21,10 @@ | |||||
#include <unistd.h> | #include <unistd.h> | ||||
#include <string> | #include <string> | ||||
#include "framework/common/debug/ge_log.h" | |||||
#include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
#include "framework/common/ge_inner_error_codes.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "framework/common/util.h" | #include "framework/common/util.h" | ||||
#include "framework/common/ge_inner_error_codes.h" | |||||
using std::string; | using std::string; | ||||
@@ -51,10 +51,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::DumpToFile | |||||
// Write the data to the file | // Write the data to the file | ||||
Status ret = SUCCESS; | Status ret = SUCCESS; | ||||
mmSsize_t mmpa_ret = mmWrite(fd, data, len); | |||||
int32_t mmpa_ret = mmWrite(fd, data, len); | |||||
// mmWrite return -1:Failed to write data to file;return -2:Invalid parameter | // mmWrite return -1:Failed to write data to file;return -2:Invalid parameter | ||||
if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { | if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { | ||||
GELOGE(FAILED, "Write to file failed. errno = %ld", mmpa_ret); | |||||
GELOGE(FAILED, "Write to file failed. errno = %d", mmpa_ret); | |||||
ret = FAILED; | ret = FAILED; | ||||
} | } | ||||
@@ -99,10 +99,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status MemoryDumper::Dump(void | |||||
GE_CHK_BOOL_RET_STATUS(data != nullptr, FAILED, "Incorrect parameter. data is nullptr"); | GE_CHK_BOOL_RET_STATUS(data != nullptr, FAILED, "Incorrect parameter. data is nullptr"); | ||||
#ifdef FMK_SUPPORT_DUMP | #ifdef FMK_SUPPORT_DUMP | ||||
mmSsize_t mmpa_ret = mmWrite(fd_, data, len); | |||||
int32_t mmpa_ret = mmWrite(fd_, data, len); | |||||
// mmWrite return -1:failed to write data to file;return -2:invalid parameter | // mmWrite return -1:failed to write data to file;return -2:invalid parameter | ||||
if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { | if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { | ||||
GELOGE(FAILED, "Write to file failed. errno = %ld", mmpa_ret); | |||||
GELOGE(FAILED, "Write to file failed. errno = %d", mmpa_ret); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -136,18 +136,18 @@ int MemoryDumper::OpenFile(const char *filename) { | |||||
string real_path; | string real_path; | ||||
char tmp_path[PATH_MAX] = {0}; | char tmp_path[PATH_MAX] = {0}; | ||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
-1 != path_split_pos, string prefix_path = std::string(filename).substr(0, path_split_pos); | |||||
string last_path = std::string(filename).substr(path_split_pos, strlen(filename) - 1); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(prefix_path.length() >= PATH_MAX, return kInvalidFd, "Prefix path is too long!"); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(realpath(prefix_path.c_str(), tmp_path) == nullptr, return kInvalidFd, | |||||
"Dir %s does not exit.", prefix_path.c_str()); | |||||
real_path = std::string(tmp_path) + last_path;) | |||||
-1 != path_split_pos, string prefix_path = std::string(filename).substr(0, path_split_pos); | |||||
string last_path = std::string(filename).substr(path_split_pos, strlen(filename) - 1); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(prefix_path.length() >= PATH_MAX, return kInvalidFd, "Prefix path is too long!"); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(realpath(prefix_path.c_str(), tmp_path) == nullptr, return kInvalidFd, | |||||
"Dir %s does not exit.", prefix_path.c_str()); | |||||
real_path = std::string(tmp_path) + last_path;) | |||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
path_split_pos == -1 || path_split_pos == 0, | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(filename) >= PATH_MAX, return kInvalidFd, "Prefix path is too long!"); | |||||
GE_IF_BOOL_EXEC(realpath(filename, tmp_path) == nullptr, | |||||
GELOGI("File %s does not exit, it will be created.", filename)); | |||||
real_path = std::string(tmp_path);) | |||||
path_split_pos == -1 || path_split_pos == 0, | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(filename) >= PATH_MAX, return kInvalidFd, "Prefix path is too long!"); | |||||
GE_IF_BOOL_EXEC(realpath(filename, tmp_path) == nullptr, | |||||
GELOGI("File %s does not exit, it will be created.", filename)); | |||||
real_path = std::string(tmp_path);) | |||||
// Open file, only the current user can read and write, to avoid malicious application access | // Open file, only the current user can read and write, to avoid malicious application access | ||||
// Using the O_EXCL, if the file already exists,return failed to avoid privilege escalation vulnerability. | // Using the O_EXCL, if the file already exists,return failed to avoid privilege escalation vulnerability. | ||||
@@ -48,19 +48,19 @@ enum DataTypeTransMode { | |||||
}; | }; | ||||
std::map<std::pair<DataType, DataType>, DataTypeTransMode> trans_mode_map{ | std::map<std::pair<DataType, DataType>, DataTypeTransMode> trans_mode_map{ | ||||
{std::pair<DataType, DataType>(DT_FLOAT, DT_FLOAT16), kTransferWithDatatypeFloatToFloat16}, | |||||
{std::pair<DataType, DataType>(DT_FLOAT, DT_INT32), kTransferWithDatatypeFloatToInt32}, | |||||
{std::pair<DataType, DataType>(DT_FLOAT16, DT_FLOAT), kTransferWithDatatypeFloat16ToFloat}, | |||||
{std::pair<DataType, DataType>(DT_FLOAT16, DT_INT32), kTransferWithDatatypeFloat16ToInt32}, | |||||
{std::pair<DataType, DataType>(DT_INT32, DT_FLOAT), kTransferWithDatatypeInt32ToFloat}, | |||||
{std::pair<DataType, DataType>(DT_INT32, DT_FLOAT16), kTransferWithDatatypeInt32ToFloat16}, | |||||
{std::pair<DataType, DataType>(DT_INT32, DT_UINT8), kTransferWithDatatypeInt32ToUint8}, | |||||
{std::pair<DataType, DataType>(DT_INT32, DT_INT8), kTransferWithDatatypeInt32ToInt8}, | |||||
{std::pair<DataType, DataType>(DT_UINT8, DT_FLOAT), kTransferWithDatatypeUint8ToFloat}, | |||||
{std::pair<DataType, DataType>(DT_UINT8, DT_INT32), kTransferWithDatatypeUint8ToInt32}, | |||||
{std::pair<DataType, DataType>(DT_INT8, DT_FLOAT), kTransferWithDatatypeInt8ToFloat}, | |||||
{std::pair<DataType, DataType>(DT_INT8, DT_INT32), kTransferWithDatatypeInt8ToInt32}, | |||||
{std::pair<DataType, DataType>(DT_INT64, DT_INT32), kTransferWithDatatypeInt64ToInt32}}; | |||||
{std::pair<DataType, DataType>(DT_FLOAT, DT_FLOAT16), kTransferWithDatatypeFloatToFloat16}, | |||||
{std::pair<DataType, DataType>(DT_FLOAT, DT_INT32), kTransferWithDatatypeFloatToInt32}, | |||||
{std::pair<DataType, DataType>(DT_FLOAT16, DT_FLOAT), kTransferWithDatatypeFloat16ToFloat}, | |||||
{std::pair<DataType, DataType>(DT_FLOAT16, DT_INT32), kTransferWithDatatypeFloat16ToInt32}, | |||||
{std::pair<DataType, DataType>(DT_INT32, DT_FLOAT), kTransferWithDatatypeInt32ToFloat}, | |||||
{std::pair<DataType, DataType>(DT_INT32, DT_FLOAT16), kTransferWithDatatypeInt32ToFloat16}, | |||||
{std::pair<DataType, DataType>(DT_INT32, DT_UINT8), kTransferWithDatatypeInt32ToUint8}, | |||||
{std::pair<DataType, DataType>(DT_INT32, DT_INT8), kTransferWithDatatypeInt32ToInt8}, | |||||
{std::pair<DataType, DataType>(DT_UINT8, DT_FLOAT), kTransferWithDatatypeUint8ToFloat}, | |||||
{std::pair<DataType, DataType>(DT_UINT8, DT_INT32), kTransferWithDatatypeUint8ToInt32}, | |||||
{std::pair<DataType, DataType>(DT_INT8, DT_FLOAT), kTransferWithDatatypeInt8ToFloat}, | |||||
{std::pair<DataType, DataType>(DT_INT8, DT_INT32), kTransferWithDatatypeInt8ToInt32}, | |||||
{std::pair<DataType, DataType>(DT_INT64, DT_INT32), kTransferWithDatatypeInt64ToInt32}}; | |||||
template <typename SrcT, typename DstT> | template <typename SrcT, typename DstT> | ||||
Status TransDataSrc2Dst(const CastArgs &args, uint8_t *dst, const size_t data_size) { | Status TransDataSrc2Dst(const CastArgs &args, uint8_t *dst, const size_t data_size) { | ||||
@@ -45,12 +45,12 @@ bool CheckShape(Format format, const ShapeVector &shape) { | |||||
} | } | ||||
} | } | ||||
/// | |||||
/// After the conversion to two-dimensional matrix, the memory arrangement is small z and large N. | |||||
/// @src_shape: N*H*W | |||||
/// @dst_shape: N*W1*H1*H0*w0 | |||||
/// @return | |||||
/// | |||||
/** | |||||
* After the conversion to two-dimensional matrix, the memory arrangement is small z and large N. | |||||
* @src_shape: N*H*W | |||||
* @dst_shape: N*W1*H1*H0*w0 | |||||
* @return | |||||
*/ | |||||
Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, ShapeVector &dst_shape, | Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, ShapeVector &dst_shape, | ||||
ShapeVector &hw_shape) { | ShapeVector &hw_shape) { | ||||
dst_shape.clear(); | dst_shape.clear(); | ||||
@@ -150,8 +150,8 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||||
auto dst_offset = (h1h0_head + w1_idx * h1h0w0) * size; | auto dst_offset = (h1h0_head + w1_idx * h1h0w0) * size; | ||||
auto src_offset = (src_h_head + w1_idx * w0) * size; | auto src_offset = (src_h_head + w1_idx * w0) * size; | ||||
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size * w0)); | static_cast<size_t>(size * w0)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
@@ -165,8 +165,8 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||||
auto dst_offset = (h1h0_head + num_w1 * h1h0w0 + w0_idx) * size; | auto dst_offset = (h1h0_head + num_w1 * h1h0w0 + w0_idx) * size; | ||||
auto src_offset = (src_h_head + src_w_idx) * size; | auto src_offset = (src_h_head + src_w_idx) * size; | ||||
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
@@ -218,8 +218,8 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con | |||||
auto src_offset = (h1h0_head + w1_idx * h1h0w0) * size; | auto src_offset = (h1h0_head + w1_idx * h1h0w0) * size; | ||||
auto dst_offset = (dst_h_head + w1_idx * w0) * size; | auto dst_offset = (dst_h_head + w1_idx * w0) * size; | ||||
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size * w0)); | static_cast<size_t>(size * w0)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
@@ -233,8 +233,8 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con | |||||
auto src_offset = (h1h0_head + num_w1 * h1h0w0 + w0_idx) * size; | auto src_offset = (h1h0_head + num_w1 * h1h0w0 + w0_idx) * size; | ||||
auto dst_offset = (dst_h_head + dst_w_idx) * size; | auto dst_offset = (dst_h_head + dst_w_idx) * size; | ||||
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
@@ -29,13 +29,14 @@ namespace formats { | |||||
namespace { | namespace { | ||||
Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } | Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } | ||||
/// | |||||
/// FZ represents the weight of convolution,. | |||||
/// After the conversion to two-dimensional matrix, the memory arrangement is small n and large Z. | |||||
/// If 4D(eg.NCHW) is used to represent convolution kernel, N is width, HWC is height. | |||||
/// | |||||
/// frac_z axises: (C1*H*W, No, Ni, C0), which Ni = 16, C0 = 16/32, No = Ceil(N/Ni), C1 = Ceil(C/C0) | |||||
/// | |||||
/** | |||||
* FZ represents the weight of convolution,. | |||||
* After the conversion to two-dimensional matrix, the memory arrangement is small n and large Z. | |||||
* If 4D(eg.NCHW) is used to represent convolution kernel, N is width, HWC is height. | |||||
* | |||||
* frac_z axises: (C1*H*W, No, Ni, C0), which Ni = 16, C0 = 16/32, No = Ceil(N/Ni), C1 = Ceil(C/C0) | |||||
* @return | |||||
*/ | |||||
Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape) { | Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape) { | ||||
auto c0 = GetCubeSizeByDataType(data_type); | auto c0 = GetCubeSizeByDataType(data_type); | ||||
if (c0 < 0) { | if (c0 < 0) { | ||||
@@ -148,8 +149,8 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||||
auto idx = gfi * fractal_ele_cnt + col * c0 + row; | auto idx = gfi * fractal_ele_cnt + col * c0 + row; | ||||
auto offset = idx * size; | auto offset = idx * size; | ||||
auto protected_size = dst_size - offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? dst_size - offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? dst_size - offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
errno_t ret; | errno_t ret; | ||||
if (need_pad_zero) { | if (need_pad_zero) { | ||||
ret = memset_s(dst.get() + offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | ret = memset_s(dst.get() + offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | ||||
@@ -209,8 +210,8 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||||
int64_t dst_idx = c1i * hwn1n0c0 + hi * wn1n0c0 + wi * n1n0c0 + n1n0i * c0 + c0i; | int64_t dst_idx = c1i * hwn1n0c0 + hi * wn1n0c0 + wi * n1n0c0 + n1n0i * c0 + c0i; | ||||
int64_t dst_offset = dst_idx * data_size; | int64_t dst_offset = dst_idx * data_size; | ||||
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); | auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); | ||||
errno_t ret; | errno_t ret; | ||||
if (pad_zero) { | if (pad_zero) { | ||||
@@ -274,8 +275,8 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||||
int64_t dst_idx = c1i * hwn1n0c0 + hi * wn1n0c0 + wi * n1n0c0 + n1n0i * c0 + c0i; | int64_t dst_idx = c1i * hwn1n0c0 + hi * wn1n0c0 + wi * n1n0c0 + n1n0i * c0 + c0i; | ||||
int64_t dst_offset = dst_idx * data_size; | int64_t dst_offset = dst_idx * data_size; | ||||
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); | auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); | ||||
errno_t ret; | errno_t ret; | ||||
if (pad_zero) { | if (pad_zero) { | ||||
@@ -105,8 +105,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
auto src_offset = src_idx * size; | auto src_offset = src_idx * size; | ||||
auto dst_offset = dst_idx * size; | auto dst_offset = dst_idx * size; | ||||
auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? total_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? total_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
@@ -105,8 +105,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
auto src_offset = src_idx * size; | auto src_offset = src_idx * size; | ||||
auto dst_offset = dst_idx * size; | auto dst_offset = dst_idx * size; | ||||
auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? total_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? total_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
@@ -105,8 +105,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||||
auto src_offset = src_idx * size; | auto src_offset = src_idx * size; | ||||
auto dst_offset = dst_idx * size; | auto dst_offset = dst_idx * size; | ||||
auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? total_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? total_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
@@ -19,7 +19,6 @@ | |||||
#include <securec.h> | #include <securec.h> | ||||
#include <cmath> | #include <cmath> | ||||
#include <cstring> | #include <cstring> | ||||
#include <functional> | #include <functional> | ||||
#include <sstream> | #include <sstream> | ||||
#include <string> | #include <string> | ||||
@@ -198,7 +198,7 @@ fp16_t &fp16_t::operator=(const int32_t &i_val) { | |||||
} | } | ||||
} else { | } else { | ||||
e_ret = FP16_EXP_BIAS; | e_ret = FP16_EXP_BIAS; | ||||
m_tmp = m_tmp << static_cast<uint32_t >(kDim_11 - len); | |||||
m_tmp = m_tmp << static_cast<uint32_t>(kDim_11 - len); | |||||
e_ret = e_ret + (len - 1); | e_ret = e_ret + (len - 1); | ||||
} | } | ||||
uint16_t m_ret = static_cast<uint16_t>(m_tmp); | uint16_t m_ret = static_cast<uint16_t>(m_tmp); | ||||
@@ -17,11 +17,10 @@ | |||||
#ifndef GE_COMMON_FP16_T_H_ | #ifndef GE_COMMON_FP16_T_H_ | ||||
#define GE_COMMON_FP16_T_H_ | #define GE_COMMON_FP16_T_H_ | ||||
#include <algorithm> | |||||
#include <cmath> | #include <cmath> | ||||
#include <cstdint> | #include <cstdint> | ||||
#include <algorithm> | |||||
namespace ge { | namespace ge { | ||||
/** | /** | ||||
*@ingroup fp16 basic parameter | *@ingroup fp16 basic parameter | ||||
@@ -1,131 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "common/ge/tbe_plugin_manager.h" | |||||
#include <dirent.h> | |||||
#include <unistd.h> | |||||
#include <algorithm> | |||||
#include <cstring> | |||||
#include <fstream> | |||||
#include <iostream> | |||||
#include <map> | |||||
#include <memory> | |||||
#include <string> | |||||
#include "common/ge/ge_util.h" | |||||
#include "framework/common/debug/log.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "framework/common/util.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
#include "framework/engine/dnnengine.h" | |||||
#include "framework/omg/omg_inner_types.h" | |||||
#include "external/ge/ge_api_types.h" | |||||
#include "register/op_registry.h" | |||||
#include "graph/opsproto_manager.h" | |||||
namespace ge { | |||||
// Get Singleton Instance | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY TBEPluginManager &TBEPluginManager::Instance() { | |||||
static TBEPluginManager instance_ptr_; | |||||
return instance_ptr_; | |||||
} | |||||
void TBEPluginManager::ClearHandles_() { | |||||
for (const auto &handle : handles_vec_) { | |||||
if (dlclose(handle) != 0) { | |||||
GELOGW("Failed to close handle: %s", dlerror()); | |||||
} | |||||
} | |||||
handles_vec_.clear(); | |||||
} | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::Finalize() { ClearHandles_(); } | |||||
string TBEPluginManager::GetPath() { | |||||
Dl_info dl_info; | |||||
if (dladdr(reinterpret_cast<void *>(&TBEPluginManager::GetPath), &dl_info) == 0) { | |||||
GELOGW("Failed to read so path!"); | |||||
return string(); | |||||
} else { | |||||
string so_path = dl_info.dli_fname; | |||||
char path[PATH_MAX] = {0}; | |||||
if (so_path.length() >= PATH_MAX) { | |||||
GELOGW("File path is too long!"); | |||||
return string(); | |||||
} | |||||
if (realpath(so_path.c_str(), path) == nullptr) { | |||||
GELOGW("Failed to get realpath of %s", so_path.c_str()); | |||||
return string(); | |||||
} | |||||
so_path = path; | |||||
so_path = so_path.substr(0, so_path.rfind('/') + 1); | |||||
return so_path; | |||||
} | |||||
} | |||||
Status TBEPluginManager::CheckCustomAiCpuOpLib() { | |||||
std::vector<std::string> vec_op_type; | |||||
domi::OpRegistry::Instance()->GetOpTypeByImplyType(vec_op_type, domi::ImplyType::CUSTOM); | |||||
for (size_t i = 0; i < vec_op_type.size(); i++) { | |||||
bool aicpu_so_exist = false; | |||||
std::string ai_cpu_so_name = "lib" + vec_op_type[i] + "_aicpu.so"; | |||||
for (size_t j = 0; j < domi::GetContext().aicpu_op_run_paths.size(); j++) { | |||||
string bin_file_path = domi::GetContext().aicpu_op_run_paths[j]; | |||||
if (bin_file_path.size() >= ai_cpu_so_name.size() && | |||||
bin_file_path.compare(bin_file_path.size() - ai_cpu_so_name.size(), ai_cpu_so_name.size(), ai_cpu_so_name) == | |||||
0) { | |||||
aicpu_so_exist = true; | |||||
break; | |||||
} | |||||
} | |||||
if (!aicpu_so_exist) { | |||||
GELOGE(FAILED, "Can't find aicpu run so(%s), please check the plugin path!", ai_cpu_so_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
void TBEPluginManager::SaveDdkVersion(const std::string &ddk_version) { | |||||
if (ddk_version.empty()) { | |||||
return; | |||||
} | |||||
GELOGI("Input ddk version : %s.", ddk_version.c_str()); | |||||
// Save DDK version number to omgcontext | |||||
domi::GetContext().ddk_version = ddk_version; | |||||
} | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::InitPreparation( | |||||
const std::map<string, string> &options) { | |||||
Status ret = CheckCustomAiCpuOpLib(); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Check custom aicpu run so failed!"); | |||||
return; | |||||
} else { | |||||
auto ddk_version = options.find("ge.DDK_version"); | |||||
if (ddk_version != options.end()) { | |||||
SaveDdkVersion(ddk_version->second); | |||||
} else { | |||||
GELOGW("No ddkVersion!"); | |||||
return; | |||||
} | |||||
} | |||||
} | |||||
} // namespace ge |
@@ -1,62 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_COMMON_GE_TBE_PLUGIN_MANAGER_H_ | |||||
#define GE_COMMON_GE_TBE_PLUGIN_MANAGER_H_ | |||||
#include <dlfcn.h> | |||||
#include <functional> | |||||
#include <iostream> | |||||
#include <map> | |||||
#include <memory> | |||||
#include <string> | |||||
#include <type_traits> | |||||
#include <typeinfo> | |||||
#include <vector> | |||||
#include "external/ge/ge_api_error_codes.h" | |||||
#include "external/register/register.h" | |||||
namespace ge { | |||||
using SoHandlesVec = std::vector<void *>; | |||||
using std::vector; | |||||
using std::string; | |||||
using std::map; | |||||
using std::function; | |||||
class TBEPluginManager { | |||||
public: | |||||
void Finalize(); | |||||
// Get TBEPluginManager singleton instance | |||||
static TBEPluginManager& Instance(); | |||||
static string GetPath(); | |||||
static void InitPreparation(const std::map<string, string> &options); | |||||
private: | |||||
TBEPluginManager() = default; | |||||
~TBEPluginManager() = default; | |||||
void ClearHandles_(); | |||||
static Status CheckCustomAiCpuOpLib(); | |||||
static void SaveDdkVersion(const std::string &ddk_version); | |||||
SoHandlesVec handles_vec_; | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_COMMON_GE_TBE_PLUGIN_MANAGER_H_ |
@@ -26,10 +26,10 @@ | |||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
using std::string; | |||||
using ge::TBEKernelStore; | |||||
using ge::TBEKernelPtr; | |||||
using domi::ModelTaskDef; | using domi::ModelTaskDef; | ||||
using ge::TBEKernelPtr; | |||||
using ge::TBEKernelStore; | |||||
using std::string; | |||||
namespace ge { | namespace ge { | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); } | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); } | ||||
@@ -201,7 +201,7 @@ ModelHelper::SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::strin | |||||
GELOGE(FAILED, "SaveModel fail for save buffer fail"); | GELOGE(FAILED, "SaveModel fail for save buffer fail"); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
shared_ptr<OmFileSaveHelper> om_file_save_helper = ge::MakeShared<OmFileSaveHelper>(); | |||||
std::shared_ptr<OmFileSaveHelper> om_file_save_helper = ge::MakeShared<OmFileSaveHelper>(); | |||||
GE_CHECK_NOTNULL_EXEC(om_file_save_helper, return MEMALLOC_FAILED); | GE_CHECK_NOTNULL_EXEC(om_file_save_helper, return MEMALLOC_FAILED); | ||||
ModelPartition partition_model; | ModelPartition partition_model; | ||||
partition_model.data = model_buffer.GetData(); | partition_model.data = model_buffer.GetData(); | ||||
@@ -428,7 +428,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::TransModelT | |||||
TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); | TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); | ||||
GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); | GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); | ||||
kernel_store.AddTBEKernel(tbe_kernel); | kernel_store.AddTBEKernel(tbe_kernel); | ||||
GELOGI("Add tbe kernel bin %s", tbe_kernel->GetName().c_str()); | |||||
} | } | ||||
} | } | ||||
if (!kernel_store.Build()) { | if (!kernel_store.Build()) { | ||||
@@ -18,9 +18,8 @@ | |||||
#define GE_COMMON_MATH_UTIL_H_ | #define GE_COMMON_MATH_UTIL_H_ | ||||
#include <securec.h> | #include <securec.h> | ||||
#include <cmath> | |||||
#include <algorithm> | #include <algorithm> | ||||
#include <cmath> | |||||
#include "Eigen/Eigen" | #include "Eigen/Eigen" | ||||
#include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
@@ -32,13 +31,13 @@ | |||||
namespace ge { | namespace ge { | ||||
/** | /** | ||||
* @ingroup domi_calibration | |||||
* @brief Initializes an input array to a specified value | |||||
* @param [in] n array initialization length | |||||
* @param [in] alpha initialization value | |||||
* @param [out] output array to be initialized | |||||
* @return Status | |||||
*/ | |||||
* @ingroup domi_calibration | |||||
* @brief Initializes an input array to a specified value | |||||
* @param [in] n array initialization length | |||||
* @param [in] alpha initialization value | |||||
* @param [out] output array to be initialized | |||||
* @return Status | |||||
*/ | |||||
template <typename Dtype> | template <typename Dtype> | ||||
Status NnSet(const int32_t n, const Dtype alpha, Dtype *output) { | Status NnSet(const int32_t n, const Dtype alpha, Dtype *output) { | ||||
GE_CHECK_NOTNULL(output); | GE_CHECK_NOTNULL(output); | ||||
@@ -17,6 +17,7 @@ | |||||
#ifndef GE_COMMON_MODEL_PARSER_BASE_H_ | #ifndef GE_COMMON_MODEL_PARSER_BASE_H_ | ||||
#define GE_COMMON_MODEL_PARSER_BASE_H_ | #define GE_COMMON_MODEL_PARSER_BASE_H_ | ||||
#include <securec.h> | |||||
#include <memory> | #include <memory> | ||||
#include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
@@ -36,7 +37,7 @@ class ModelParserBase { | |||||
/// @ingroup hiai | /// @ingroup hiai | ||||
/// @brief destructor | /// @brief destructor | ||||
/// | /// | ||||
virtual ~ModelParserBase(); | |||||
~ModelParserBase(); | |||||
/// | /// | ||||
/// @ingroup hiai | /// @ingroup hiai | ||||