diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc index 1abdb2f..3c62f8b 100644 --- a/parser/caffe/caffe_parser.cc +++ b/parser/caffe/caffe_parser.cc @@ -2180,7 +2180,7 @@ Status CaffeWeightsParser::CheckNodes(ge::ComputeGraphPtr &graph) { ErrorManager::GetInstance().ATCReportErrMessage("E11029", {"opname"}, {node->GetName()}); GELOGE(ge::GRAPH_FAILED, "[Find][Node] Op[%s] in model file does not exist in weight file.", node->GetName().c_str()); - PreChecker::Instance().RefreshErrorMessageByName(node->GetName(), PreChecker::PARAM_INVALID, + PreChecker::Instance().RefreshErrorMessageByName(node->GetName(), PreChecker::ErrorCode::PARAM_INVALID, "Node does not exist in weight file."); } else { REPORT_INNER_ERROR("E19999", "Op:%s(%s)'s input %d is not linked, check invalid", @@ -2188,7 +2188,8 @@ Status CaffeWeightsParser::CheckNodes(ge::ComputeGraphPtr &graph) { GELOGE(ge::GRAPH_FAILED, "[Check][Param] Op[%s]'s input %d is not linked.", node->GetName().c_str(), in_anchor_ptr->GetIdx()); string check_msg = "input " + to_string(in_anchor_ptr->GetIdx()) + "is not linked in weight file"; - PreChecker::Instance().RefreshErrorMessageByName(node->GetName(), PreChecker::PARAM_INVALID, check_msg); + PreChecker::Instance().RefreshErrorMessageByName(node->GetName(), PreChecker::ErrorCode::PARAM_INVALID, + check_msg); } return FAILED; } diff --git a/parser/common/model_saver.cc b/parser/common/model_saver.cc index b9e3841..9af4fa8 100644 --- a/parser/common/model_saver.cc +++ b/parser/common/model_saver.cc @@ -75,7 +75,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi mmSsize_t mmpa_ret = mmWrite(fd, const_cast((const void *)model_char), len); if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { ErrorManager::GetInstance().ATCReportErrMessage( - "E19004", {"file", "errmsg"}, {file_path, strerror(errno)}); + "E19004", {"file", "errmsg"}, {file_path, strerror(errno)}); // Need to both print the error info of mmWrite and mmClose, so return ret after mmClose GELOGE(FAILED, "[WriteTo][File] %s failed. errno = %ld, %s", file_path, mmpa_ret, strerror(errno)); ret = FAILED; diff --git a/parser/common/model_saver.h b/parser/common/model_saver.h index bc31dba..cee6c07 100644 --- a/parser/common/model_saver.h +++ b/parser/common/model_saver.h @@ -20,6 +20,7 @@ #include #include "ge/ge_api_error_codes.h" +#include "ge/ge_api_types.h" #include "register/register_types.h" #include "nlohmann/json.hpp" diff --git a/parser/common/op_def/arg_op.h b/parser/common/op_def/arg_op.h index b867a91..ed0a2db 100644 --- a/parser/common/op_def/arg_op.h +++ b/parser/common/op_def/arg_op.h @@ -23,7 +23,7 @@ class ArgOpOperator : public ParserOperator { public: ArgOpOperator(); - ~ArgOpOperator(); + ~ArgOpOperator() override; ArgOpOperator &Name(const std::string &name); diff --git a/parser/common/op_def/constant_op.h b/parser/common/op_def/constant_op.h index 29549e5..59d586b 100644 --- a/parser/common/op_def/constant_op.h +++ b/parser/common/op_def/constant_op.h @@ -24,7 +24,7 @@ namespace ge { class ConstantOperator : public ParserOperator { public: ConstantOperator(); - ~ConstantOperator(); + ~ConstantOperator() override; ConstantOperator &Name(const std::string &name); ConstantOperator &VectorAttr(std::string key, std::vector &value); diff --git a/parser/common/op_def/fill_op.h b/parser/common/op_def/fill_op.h index 8b25ee8..4040b49 100644 --- a/parser/common/op_def/fill_op.h +++ b/parser/common/op_def/fill_op.h @@ -23,7 +23,7 @@ class FillOperator : public ParserOperator { public: FillOperator(); - ~FillOperator(); + ~FillOperator() override; FillOperator &DataType(int64_t dataType); diff --git a/parser/common/op_def/frameworkop_op.h b/parser/common/op_def/frameworkop_op.h index c01f0f7..ad88c47 100644 --- a/parser/common/op_def/frameworkop_op.h +++ b/parser/common/op_def/frameworkop_op.h @@ -24,7 +24,7 @@ class FrameworkOpOperator : public ParserOperator { public: FrameworkOpOperator(); - ~FrameworkOpOperator(); + ~FrameworkOpOperator() override; FrameworkOpOperator &Name(const std::string &name); diff --git a/parser/common/op_def/no_op_op.h b/parser/common/op_def/no_op_op.h index 0208c90..f0fe684 100644 --- a/parser/common/op_def/no_op_op.h +++ b/parser/common/op_def/no_op_op.h @@ -24,7 +24,7 @@ namespace ge { class NoOpOperator : public ParserOperator { public: NoOpOperator(); - ~NoOpOperator(); + ~NoOpOperator() override; NoOpOperator &Name(const std::string &name); }; diff --git a/parser/common/op_def/ref_switch_op.h b/parser/common/op_def/ref_switch_op.h index baf2167..de17756 100644 --- a/parser/common/op_def/ref_switch_op.h +++ b/parser/common/op_def/ref_switch_op.h @@ -24,7 +24,7 @@ namespace ge { class RefSwitchOperator : public ParserOperator { public: RefSwitchOperator(); - ~RefSwitchOperator(); + ~RefSwitchOperator() override; RefSwitchOperator &Name(const std::string &name); RefSwitchOperator &T(ge::DataType t); diff --git a/parser/common/op_def/shape_n_op.h b/parser/common/op_def/shape_n_op.h index bb69235..2211ea4 100644 --- a/parser/common/op_def/shape_n_op.h +++ b/parser/common/op_def/shape_n_op.h @@ -24,7 +24,7 @@ namespace ge { class ShapeNOperator : public ParserOperator { public: ShapeNOperator(); - ~ShapeNOperator(); + ~ShapeNOperator() override; ShapeNOperator &Name(const std::string &name); diff --git a/parser/common/op_def/var_is_initialized_op_op.h b/parser/common/op_def/var_is_initialized_op_op.h index 88b649f..ee586e3 100644 --- a/parser/common/op_def/var_is_initialized_op_op.h +++ b/parser/common/op_def/var_is_initialized_op_op.h @@ -24,7 +24,7 @@ namespace ge { class VarIsInitializedOpOperator : public ParserOperator { public: VarIsInitializedOpOperator(); - ~VarIsInitializedOpOperator(); + ~VarIsInitializedOpOperator() override; VarIsInitializedOpOperator &Name(const std::string &name); VarIsInitializedOpOperator &VectorAttr(const std::string &key, std::vector &value); diff --git a/parser/common/op_def/variable_op.h b/parser/common/op_def/variable_op.h index 166681e..ded8cac 100644 --- a/parser/common/op_def/variable_op.h +++ b/parser/common/op_def/variable_op.h @@ -25,7 +25,7 @@ namespace ge { class VariableOperator : public ParserOperator { public: VariableOperator(); - ~VariableOperator(); + ~VariableOperator() override; VariableOperator &Name(const std::string &name); diff --git a/parser/common/parser_fp16_t.h b/parser/common/parser_fp16_t.h index 7a4a5eb..a8bf36f 100644 --- a/parser/common/parser_fp16_t.h +++ b/parser/common/parser_fp16_t.h @@ -586,7 +586,7 @@ T MinMan(const int16_t &e_a, T &m_a, const int16_t &e_b, T &m_b) { template T RightShift(T man, int16_t shift) { int bits = sizeof(T) * 8; // one byte have 8 bits - T mask = (((T) 1u) << ((unsigned int) (bits - 1))); + T mask = static_cast(1u) << static_cast(bits - 1); for (int i = 0; i < shift; i++) { man = ((man & mask) | (man >> 1)); } diff --git a/parser/common/pass_manager.cc b/parser/common/pass_manager.cc index c22828c..f08fad3 100644 --- a/parser/common/pass_manager.cc +++ b/parser/common/pass_manager.cc @@ -27,7 +27,7 @@ const std::vector> &PassManager::GraphPasses return names_to_graph_passes_; } -Status PassManager::AddPass(const string &pass_name, GraphPass *pass) { +Status PassManager::AddPass(const string &pass_name, GraphPass *const pass) { GE_CHECK_NOTNULL(pass); names_to_graph_passes_.emplace_back(pass_name, pass); return SUCCESS; diff --git a/parser/common/pass_manager.h b/parser/common/pass_manager.h index 9d53b6a..5befe03 100644 --- a/parser/common/pass_manager.h +++ b/parser/common/pass_manager.h @@ -41,7 +41,7 @@ public: /// @param [in] pass Pass to be added, it will be destroyed when pass manager destroys. /// @author /// - Status AddPass(const string &pass_name, GraphPass *pass); + Status AddPass(const string &pass_name, GraphPass *const pass); /// /// Optimize graph with added pass diff --git a/parser/common/pre_checker.cc b/parser/common/pre_checker.cc index 56d754b..8f8d577 100644 --- a/parser/common/pre_checker.cc +++ b/parser/common/pre_checker.cc @@ -98,7 +98,7 @@ Status PreChecker::CheckName(OpId id) { // If the name is duplicate, an error is logged if (id != v.first && info.name == v.second.name) { Cause cause; - cause.code = NAME_REPEATED; + cause.code = ErrorCode::NAME_REPEATED; cause.message = "The name is repeated."; GELOGI("Name %s repeated.", info.name.c_str()); @@ -248,7 +248,7 @@ Status PreChecker::CheckTypeSupported(OpId id, const string &type, const string std::string op_type; if (!domi::OpRegistry::Instance()->GetOmTypeByOriOpType(type, op_type)) { Cause cause; - cause.code = TYPE_UNSUPPORTED; + cause.code = ErrorCode::TYPE_UNSUPPORTED; cause.message = "The type is not supported."; GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str()); if (!is_tensorflow) { @@ -262,7 +262,7 @@ Status PreChecker::CheckTypeSupported(OpId id, const string &type, const string // Log error if type not found if (fmk_op_types_->find(type) == fmk_op_types_->end()) { Cause cause; - cause.code = TYPE_UNSUPPORTED; + cause.code = ErrorCode::TYPE_UNSUPPORTED; cause.message = "The type is not supported."; GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str()); diff --git a/parser/common/pre_checker.h b/parser/common/pre_checker.h index 12d3323..cdf35ad 100644 --- a/parser/common/pre_checker.h +++ b/parser/common/pre_checker.h @@ -44,7 +44,7 @@ class PreChecker { * @ingroup domi_omg * @brief error code, 1~99:Error, 100~199:Waring。 */ - enum ErrorCode { + enum class ErrorCode { // no error OK = 0, diff --git a/parser/onnx/onnx_constant_parser.h b/parser/onnx/onnx_constant_parser.h index f8adb3e..4b607be 100644 --- a/parser/onnx/onnx_constant_parser.h +++ b/parser/onnx/onnx_constant_parser.h @@ -23,8 +23,6 @@ #include "parser/common/data_op_parser.h" #include "parser/onnx/onnx_op_parser.h" -using ge::onnx::NodeProto; - namespace ge { class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { public: @@ -60,17 +58,17 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { DataType data_type = tensor.GetTensorDesc().GetDataType(); switch (data_type) { -#define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \ - case dt_type: \ - { \ - unique_ptr addr_trans(new(std::nothrow) value_type[count]()); \ - GE_CHECK_NOTNULL(addr_trans); \ - for (int32_t i = 0; i < count; i++) { \ - *(addr_trans.get() + i) = static_cast(*(addr.get() + i)); \ - } \ - tensor.SetData(reinterpret_cast(addr_trans.get()), count * sizeof(value_type)); \ - break; \ - } \ +#define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \ + case dt_type: \ + { \ + unique_ptr addr_trans(new(std::nothrow) value_type[count]()); \ + GE_CHECK_NOTNULL(addr_trans); \ + for (int32_t i = 0; i < (count); i++) { \ + *(addr_trans.get() + i) = static_cast(*((addr).get() + i)); \ + } \ + (tensor).SetData(reinterpret_cast(addr_trans.get()), (count) * sizeof(value_type)); \ + break; \ + } \ CASE_SET_DATA(DT_FLOAT16, uint16_t, addr, count, tensor) CASE_SET_DATA(DT_INT16, int16_t, addr, count, tensor) diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 269d738..2c74064 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -586,7 +586,7 @@ Status TensorFlowModelParser::AddNode(const domi::tensorflow::NodeDef *node_def, } void TensorFlowModelParser::GetInputOutputTensorNum(const ge::OpDescPtr &op_desc, size_t &input_tensor_num, - size_t &output_tensor_num) { + size_t &output_tensor_num) const { // The caller guarantees that the pointer is not null auto iter = op_node_context_map_.find(op_desc->GetName()); if (iter == op_node_context_map_.end()) { @@ -817,8 +817,7 @@ Status TensorFlowModelParser::AddEdges(ge::ComputeGraphPtr &graph) { return SUCCESS; } -Status TensorFlowModelParser::AddFmkNodeDefToMap(const domi::tensorflow::GraphDef &graph_def, - const domi::tensorflow::NodeDef *node_def, +Status TensorFlowModelParser::AddFmkNodeDefToMap(const domi::tensorflow::NodeDef *node_def, vector &op_node_name_list) { GE_CHECK_NOTNULL(node_def); const string &node_name = node_def->name(); @@ -1224,7 +1223,7 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g GELOGI("Node: %s maybe a fusion op.", node_def->name().c_str());); // Do not exit immediately when there is an error, wait until all errors are collected before exiting - GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(graph_def, node_def, op_node_name_list), has_error = true, + GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(node_def, op_node_name_list), has_error = true, "add node failed."); } @@ -1459,7 +1458,7 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro } // Do not exit immediately when there is an error, wait until all errors are collected before exiting - GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(graph_def, node_def, op_node_name_list), has_error = true); + GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(node_def, op_node_name_list), has_error = true); } // The fusion operator has passed the verification. @@ -1545,7 +1544,7 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro return SUCCESS; } -Status TensorFlowModelParser::CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) { +Status TensorFlowModelParser::CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) const { // Number of data nodes uint32_t data_node_count = 0; for (const domi::tensorflow::NodeDef &node_def : graph_def.node()) { @@ -2275,7 +2274,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, } // Do not exit immediately when there is an error, wait until all errors are collected before exiting - Status ret = AddFmkNodeDefToMap(*graph_def, node_def, op_node_name_list); + Status ret = AddFmkNodeDefToMap(node_def, op_node_name_list); GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed"); } PARSER_TIMESTAMP_END(AddFmkNodeDefToMap, "TensorFlowModelParser::AddFmkNodeDefToMap"); @@ -2865,7 +2864,7 @@ Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph // mutable_node return vale is not empty domi::tensorflow::NodeDef *node_def = graph_def->mutable_node(i); const string &node_name = node_def->name(); - Status ret = AddFmkNodeDefToMap(*graph_def, node_def, op_node_name_list); + Status ret = AddFmkNodeDefToMap(node_def, op_node_name_list); GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed"); if (node_def->op() == ge::parser::IDENTITY || node_def->op() == ge::parser::READVARIABLEOP) { identity_to_optimize.push_back(node_def); diff --git a/parser/tensorflow/tensorflow_parser.h b/parser/tensorflow/tensorflow_parser.h index ff39a79..e18d474 100644 --- a/parser/tensorflow/tensorflow_parser.h +++ b/parser/tensorflow/tensorflow_parser.h @@ -185,7 +185,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { * @return FAILED add failed */ - Status AddFmkNodeDefToMap(const domi::tensorflow::GraphDef &graph_def, const domi::tensorflow::NodeDef *node_def, + Status AddFmkNodeDefToMap(const domi::tensorflow::NodeDef *node_def, vector &op_node_name_list); /** @@ -243,7 +243,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { * @return SUCCESS check successfully * @return FAILED check failed */ - Status CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def); + Status CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) const; /** * @ingroup domi_omg @@ -516,7 +516,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { Status UppdateOutputMap(shared_ptr &scope_graph, const ge::ScopeFusionOpInfo &info, OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context); void GetInputOutputTensorNum(const ge::OpDescPtr &op_desc, size_t &input_tensor_num, - size_t &output_tensor_num); + size_t &output_tensor_num) const; static Status CheckOpShapeDim(const domi::tensorflow::NodeDef *node_def, const std::set &dims, bool &valid); Status CheckOpType(const domi::tensorflow::NodeDef *node_def, string &op_type); diff --git a/parser/tensorflow/tensorflow_util.cc b/parser/tensorflow/tensorflow_util.cc index ecf2586..3184286 100644 --- a/parser/tensorflow/tensorflow_util.cc +++ b/parser/tensorflow/tensorflow_util.cc @@ -31,6 +31,94 @@ using domi::tensorflow::DT_INVALID; namespace ge { +/***************************TensorFlow attribute type, constant definition*******************************************/ +const std::string TENSORFLOW_ATTR_TYPE_STRING = "string"; +const std::string TENSORFLOW_ATTR_TYPE_INT = "int"; +const std::string TENSORFLOW_ATTR_TYPE_FLOAT = "float"; +const std::string TENSORFLOW_ATTR_TYPE_BOOL = "bool"; +const std::string TENSORFLOW_ATTR_TYPE_TYPE = "type"; +const std::string TENSORFLOW_ATTR_TYPE_SHAPE = "shape"; +const std::string TENSORFLOW_ATTR_TYPE_TENSOR = "tensor"; +const std::string TENSORFLOW_ATTR_TYPE_FUNC = "func"; + +const std::string TENSORFLOW_ATTR_LIST_TYPE_STRING = "list(string)"; +const std::string TENSORFLOW_ATTR_LIST_TYPE_INT = "list(int)"; +const std::string TENSORFLOW_ATTR_LIST_TYPE_FLOAT = "list(float)"; +const std::string TENSORFLOW_ATTR_LIST_TYPE_BOOL = "list(bool)"; +const std::string TENSORFLOW_ATTR_LIST_TYPE_TYPE = "list(type)"; +const std::string TENSORFLOW_ATTR_LIST_TYPE_SHAPE = "list(shape)"; +const std::string TENSORFLOW_ATTR_LIST_TYPE_TENSOR = "list(tensor)"; +const std::string TENSORFLOW_ATTR_LIST_TYPE_FUNC = "list(func)"; + +/***************************constant definition*******************************************/ +const std::string TENSORFLOW_ATTR_OUTPUT_OP = "output_op"; + +const std::string TENSORFLOW_ATTR_T = "T"; +const std::string TENSORFLOW_ATTR_N = "N"; +const std::string TENSORFLOW_ATTR_DATA_FORMAT = "data_format"; +const std::string TENSORFLOW_ATTR_PADDING = "padding"; +const std::string TENSORFLOW_ATTR_KSIZE = "ksize"; +const std::string TENSORFLOW_ATTR_STRIDES = "strides"; +const std::string TENSORFLOW_ATTR_DILATIONS = "dilations"; +const std::string TENSORFLOW_ATTR_DTYPE = "dtype"; +const std::string TENSORFLOW_ATTR_VALUE = "value"; +const std::string TENSORFLOW_ATTR_TRANSINPUT = "transpose_a"; +const std::string TENSORFLOW_ATTR_TRANSWEIGHT = "transpose_b"; +const std::string TENSORFLOW_ATTR_SHAPE = "shape"; +const std::string TENSORFLOW_ATTR_TIDX = "Tidx"; +const std::string TENSORFLOW_ATTR_TPADDINGS = "Tpaddings"; +const std::string TENSORFLOW_ATTR_TMULTIPLES = "Tmultiples"; +const std::string TENSORFLOW_ATTR_TINDICES = "Tindices"; +const std::string TENSORFLOW_ATTR_TPARAMS = "Tparams"; +const std::string TENSORFLOW_ATTR_TAXIS = "Taxis"; +const std::string TENSORFLOW_ATTR_DSTT = "DstT"; +const std::string TENSORFLOW_ATTR_SRCT = "SrcT"; +const std::string TENSORFLOW_ATTR_PERM = "perm"; +const std::string TENSORFLOW_ATTR_INDEX = "Index"; +const std::string TENSORFLOW_ATTR_TSHAPE = "Tshape"; +const std::string TENSORFLOW_ATTR_AXIS = "Axis"; +const std::string TENSORFLOW_ATTR_BIAS = "bias"; +const std::string TENSORFLOW_ATTR_DEPTH_RADIUS = "depth_radius"; +const std::string TENSORFLOW_ATTR_ALPHA = "alpha"; +const std::string TENSORFLOW_ATTR_BETA = "beta"; +const std::string TENSORFLOW_ATTR_MODE = "mode"; + +// op:Const +const std::string TENSORFLOWF_NODE_OP_CONST = "Const"; +const std::string TENSORFLOWF_NODE_OP_IDENTITY = "Identity"; +const std::string TENSORFLOWF_NODE_OP_SWITCH = "Switch"; +const std::string TENSORFLOWF_NODE_OP_PLACEHOLDER = "Placeholder"; +const std::string TENSORFLOWF_NODE_OP_ADDN = "AddN"; +const std::string TENSORFLOWF_NODE_OP_MATMUL = "MatMul"; +const std::string TENSORFLOWF_NODE_OP_RELU = "Relu"; +const std::string TENSORFLOWF_NODE_OP_SHAPE = "Shape"; +const std::string TENSORFLOWF_NODE_OP_TRANSPOSE = "Transpose"; +const std::string TENSORFLOWF_NODE_OP_MERGE = "Merge"; + +// data_format +const std::string TENSORFLOWF_TENSOR_NCHW = "NCHW"; +const std::string TENSORFLOWF_TENSOR_NHWC = "NHWC"; + +const int TENSORFLOW_CONV_STRIDE_NUM = 4; +const int TENSORFLOW_CONV_DILATION_NUM = 4; + +// padding +const std::string TENSORFLOWF_OP_PADDING_VALID = "VALID"; +const std::string TENSORFLOWF_OP_PADDING_SAME = "SAME"; + +// normal input size +const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_MATMUL = 2; +const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_RESHAPE = 1; +const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_POOL = 1; + +// normal weight size +const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_MATMUL = 1; +const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_RESHAPE = 1; + +// input or output +const uint32_t TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG = 1; +const uint32_t TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG = 2; + using AttrValueMap = ::google::protobuf::Map; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrValue( const domi::tensorflow::NodeDef *node_def, const std::string &attr_name, domi::tensorflow::AttrValue &attr_value) { diff --git a/parser/tensorflow/tensorflow_util.h b/parser/tensorflow/tensorflow_util.h index 18bd744..c3a5565 100644 --- a/parser/tensorflow/tensorflow_util.h +++ b/parser/tensorflow/tensorflow_util.h @@ -44,92 +44,92 @@ using domi::tensorflow::FunctionDefLibrary; namespace ge { /***************************TensorFlow attribute type, constant definition*******************************************/ -static const std::string TENSORFLOW_ATTR_TYPE_STRING = "string"; -static const std::string TENSORFLOW_ATTR_TYPE_INT = "int"; -static const std::string TENSORFLOW_ATTR_TYPE_FLOAT = "float"; -static const std::string TENSORFLOW_ATTR_TYPE_BOOL = "bool"; -static const std::string TENSORFLOW_ATTR_TYPE_TYPE = "type"; -static const std::string TENSORFLOW_ATTR_TYPE_SHAPE = "shape"; -static const std::string TENSORFLOW_ATTR_TYPE_TENSOR = "tensor"; -static const std::string TENSORFLOW_ATTR_TYPE_FUNC = "func"; - -static const std::string TENSORFLOW_ATTR_LIST_TYPE_STRING = "list(string)"; -static const std::string TENSORFLOW_ATTR_LIST_TYPE_INT = "list(int)"; -static const std::string TENSORFLOW_ATTR_LIST_TYPE_FLOAT = "list(float)"; -static const std::string TENSORFLOW_ATTR_LIST_TYPE_BOOL = "list(bool)"; -static const std::string TENSORFLOW_ATTR_LIST_TYPE_TYPE = "list(type)"; -static const std::string TENSORFLOW_ATTR_LIST_TYPE_SHAPE = "list(shape)"; -static const std::string TENSORFLOW_ATTR_LIST_TYPE_TENSOR = "list(tensor)"; -static const std::string TENSORFLOW_ATTR_LIST_TYPE_FUNC = "list(func)"; +extern const std::string TENSORFLOW_ATTR_TYPE_STRING; +extern const std::string TENSORFLOW_ATTR_TYPE_INT; +extern const std::string TENSORFLOW_ATTR_TYPE_FLOAT; +extern const std::string TENSORFLOW_ATTR_TYPE_BOOL; +extern const std::string TENSORFLOW_ATTR_TYPE_TYPE; +extern const std::string TENSORFLOW_ATTR_TYPE_SHAPE; +extern const std::string TENSORFLOW_ATTR_TYPE_TENSOR; +extern const std::string TENSORFLOW_ATTR_TYPE_FUNC; + +extern const std::string TENSORFLOW_ATTR_LIST_TYPE_STRING; +extern const std::string TENSORFLOW_ATTR_LIST_TYPE_INT; +extern const std::string TENSORFLOW_ATTR_LIST_TYPE_FLOAT; +extern const std::string TENSORFLOW_ATTR_LIST_TYPE_BOOL; +extern const std::string TENSORFLOW_ATTR_LIST_TYPE_TYPE; +extern const std::string TENSORFLOW_ATTR_LIST_TYPE_SHAPE; +extern const std::string TENSORFLOW_ATTR_LIST_TYPE_TENSOR; +extern const std::string TENSORFLOW_ATTR_LIST_TYPE_FUNC; /***************************constant definition*******************************************/ -static const std::string TENSORFLOW_ATTR_OUTPUT_OP = "output_op"; - -static const std::string TENSORFLOW_ATTR_T = "T"; -static const std::string TENSORFLOW_ATTR_N = "N"; -static const std::string TENSORFLOW_ATTR_DATA_FORMAT = "data_format"; -static const std::string TENSORFLOW_ATTR_PADDING = "padding"; -static const std::string TENSORFLOW_ATTR_KSIZE = "ksize"; -static const std::string TENSORFLOW_ATTR_STRIDES = "strides"; -static const std::string TENSORFLOW_ATTR_DILATIONS = "dilations"; -static const std::string TENSORFLOW_ATTR_DTYPE = "dtype"; -static const std::string TENSORFLOW_ATTR_VALUE = "value"; -static const std::string TENSORFLOW_ATTR_TRANSINPUT = "transpose_a"; -static const std::string TENSORFLOW_ATTR_TRANSWEIGHT = "transpose_b"; -static const std::string TENSORFLOW_ATTR_SHAPE = "shape"; -static const std::string TENSORFLOW_ATTR_TIDX = "Tidx"; -static const std::string TENSORFLOW_ATTR_TPADDINGS = "Tpaddings"; -static const std::string TENSORFLOW_ATTR_TMULTIPLES = "Tmultiples"; -static const std::string TENSORFLOW_ATTR_TINDICES = "Tindices"; -static const std::string TENSORFLOW_ATTR_TPARAMS = "Tparams"; -static const std::string TENSORFLOW_ATTR_TAXIS = "Taxis"; -static const std::string TENSORFLOW_ATTR_DSTT = "DstT"; -static const std::string TENSORFLOW_ATTR_SRCT = "SrcT"; -static const std::string TENSORFLOW_ATTR_PERM = "perm"; -static const std::string TENSORFLOW_ATTR_INDEX = "Index"; -static const std::string TENSORFLOW_ATTR_TSHAPE = "Tshape"; -static const std::string TENSORFLOW_ATTR_AXIS = "Axis"; -static const std::string TENSORFLOW_ATTR_BIAS = "bias"; -static const std::string TENSORFLOW_ATTR_DEPTH_RADIUS = "depth_radius"; -static const std::string TENSORFLOW_ATTR_ALPHA = "alpha"; -static const std::string TENSORFLOW_ATTR_BETA = "beta"; -static const std::string TENSORFLOW_ATTR_MODE = "mode"; +extern const std::string TENSORFLOW_ATTR_OUTPUT_OP; + +extern const std::string TENSORFLOW_ATTR_T; +extern const std::string TENSORFLOW_ATTR_N; +extern const std::string TENSORFLOW_ATTR_DATA_FORMAT; +extern const std::string TENSORFLOW_ATTR_PADDING; +extern const std::string TENSORFLOW_ATTR_KSIZE; +extern const std::string TENSORFLOW_ATTR_STRIDES; +extern const std::string TENSORFLOW_ATTR_DILATIONS; +extern const std::string TENSORFLOW_ATTR_DTYPE; +extern const std::string TENSORFLOW_ATTR_VALUE; +extern const std::string TENSORFLOW_ATTR_TRANSINPUT; +extern const std::string TENSORFLOW_ATTR_TRANSWEIGHT; +extern const std::string TENSORFLOW_ATTR_SHAPE; +extern const std::string TENSORFLOW_ATTR_TIDX; +extern const std::string TENSORFLOW_ATTR_TPADDINGS; +extern const std::string TENSORFLOW_ATTR_TMULTIPLES; +extern const std::string TENSORFLOW_ATTR_TINDICES; +extern const std::string TENSORFLOW_ATTR_TPARAMS; +extern const std::string TENSORFLOW_ATTR_TAXIS; +extern const std::string TENSORFLOW_ATTR_DSTT; +extern const std::string TENSORFLOW_ATTR_SRCT; +extern const std::string TENSORFLOW_ATTR_PERM; +extern const std::string TENSORFLOW_ATTR_INDEX; +extern const std::string TENSORFLOW_ATTR_TSHAPE; +extern const std::string TENSORFLOW_ATTR_AXIS; +extern const std::string TENSORFLOW_ATTR_BIAS; +extern const std::string TENSORFLOW_ATTR_DEPTH_RADIUS; +extern const std::string TENSORFLOW_ATTR_ALPHA; +extern const std::string TENSORFLOW_ATTR_BETA; +extern const std::string TENSORFLOW_ATTR_MODE; // op:Const -static const std::string TENSORFLOWF_NODE_OP_CONST = "Const"; -static const std::string TENSORFLOWF_NODE_OP_IDENTITY = "Identity"; -static const std::string TENSORFLOWF_NODE_OP_SWITCH = "Switch"; -static const std::string TENSORFLOWF_NODE_OP_PLACEHOLDER = "Placeholder"; -static const std::string TENSORFLOWF_NODE_OP_ADDN = "AddN"; -static const std::string TENSORFLOWF_NODE_OP_MATMUL = "MatMul"; -static const std::string TENSORFLOWF_NODE_OP_RELU = "Relu"; -static const std::string TENSORFLOWF_NODE_OP_SHAPE = "Shape"; -static const std::string TENSORFLOWF_NODE_OP_TRANSPOSE = "Transpose"; -static const std::string TENSORFLOWF_NODE_OP_MERGE = "Merge"; +extern const std::string TENSORFLOWF_NODE_OP_CONST; +extern const std::string TENSORFLOWF_NODE_OP_IDENTITY; +extern const std::string TENSORFLOWF_NODE_OP_SWITCH; +extern const std::string TENSORFLOWF_NODE_OP_PLACEHOLDER; +extern const std::string TENSORFLOWF_NODE_OP_ADDN; +extern const std::string TENSORFLOWF_NODE_OP_MATMUL; +extern const std::string TENSORFLOWF_NODE_OP_RELU; +extern const std::string TENSORFLOWF_NODE_OP_SHAPE; +extern const std::string TENSORFLOWF_NODE_OP_TRANSPOSE; +extern const std::string TENSORFLOWF_NODE_OP_MERGE; // data_format -static const std::string TENSORFLOWF_TENSOR_NCHW = "NCHW"; -static const std::string TENSORFLOWF_TENSOR_NHWC = "NHWC"; +extern const std::string TENSORFLOWF_TENSOR_NCHW; +extern const std::string TENSORFLOWF_TENSOR_NHWC; -static const int TENSORFLOW_CONV_STRIDE_NUM = 4; -static const int TENSORFLOW_CONV_DILATION_NUM = 4; +extern const int TENSORFLOW_CONV_STRIDE_NUM; +extern const int TENSORFLOW_CONV_DILATION_NUM; // padding -static const std::string TENSORFLOWF_OP_PADDING_VALID = "VALID"; -static const std::string TENSORFLOWF_OP_PADDING_SAME = "SAME"; +extern const std::string TENSORFLOWF_OP_PADDING_VALID; +extern const std::string TENSORFLOWF_OP_PADDING_SAME; // normal input size -static const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_MATMUL = 2; -static const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_RESHAPE = 1; -static const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_POOL = 1; +extern const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_MATMUL; +extern const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_RESHAPE; +extern const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_POOL; // normal weight size -static const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_MATMUL = 1; -static const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_RESHAPE = 1; +extern const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_MATMUL; +extern const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_RESHAPE; // input or output -static const uint32_t TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG = 1; -static const uint32_t TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG = 2; +extern const uint32_t TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG; +extern const uint32_t TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG; class TensorFlowUtil { public: