Browse Source

static check clean

pull/416/head
isaacxr 3 years ago
parent
commit
b01c90a062
22 changed files with 201 additions and 114 deletions
  1. +3
    -2
      parser/caffe/caffe_parser.cc
  2. +1
    -1
      parser/common/model_saver.cc
  3. +1
    -0
      parser/common/model_saver.h
  4. +1
    -1
      parser/common/op_def/arg_op.h
  5. +1
    -1
      parser/common/op_def/constant_op.h
  6. +1
    -1
      parser/common/op_def/fill_op.h
  7. +1
    -1
      parser/common/op_def/frameworkop_op.h
  8. +1
    -1
      parser/common/op_def/no_op_op.h
  9. +1
    -1
      parser/common/op_def/ref_switch_op.h
  10. +1
    -1
      parser/common/op_def/shape_n_op.h
  11. +1
    -1
      parser/common/op_def/var_is_initialized_op_op.h
  12. +1
    -1
      parser/common/op_def/variable_op.h
  13. +1
    -1
      parser/common/parser_fp16_t.h
  14. +1
    -1
      parser/common/pass_manager.cc
  15. +1
    -1
      parser/common/pass_manager.h
  16. +3
    -3
      parser/common/pre_checker.cc
  17. +1
    -1
      parser/common/pre_checker.h
  18. +11
    -13
      parser/onnx/onnx_constant_parser.h
  19. +7
    -8
      parser/tensorflow/tensorflow_parser.cc
  20. +3
    -3
      parser/tensorflow/tensorflow_parser.h
  21. +88
    -0
      parser/tensorflow/tensorflow_util.cc
  22. +71
    -71
      parser/tensorflow/tensorflow_util.h

+ 3
- 2
parser/caffe/caffe_parser.cc View File

@@ -2180,7 +2180,7 @@ Status CaffeWeightsParser::CheckNodes(ge::ComputeGraphPtr &graph) {
ErrorManager::GetInstance().ATCReportErrMessage("E11029", {"opname"}, {node->GetName()}); ErrorManager::GetInstance().ATCReportErrMessage("E11029", {"opname"}, {node->GetName()});
GELOGE(ge::GRAPH_FAILED, "[Find][Node] Op[%s] in model file does not exist in weight file.", GELOGE(ge::GRAPH_FAILED, "[Find][Node] Op[%s] in model file does not exist in weight file.",
node->GetName().c_str()); node->GetName().c_str());
PreChecker::Instance().RefreshErrorMessageByName(node->GetName(), PreChecker::PARAM_INVALID,
PreChecker::Instance().RefreshErrorMessageByName(node->GetName(), PreChecker::ErrorCode::PARAM_INVALID,
"Node does not exist in weight file."); "Node does not exist in weight file.");
} else { } else {
REPORT_INNER_ERROR("E19999", "Op:%s(%s)'s input %d is not linked, check invalid", REPORT_INNER_ERROR("E19999", "Op:%s(%s)'s input %d is not linked, check invalid",
@@ -2188,7 +2188,8 @@ Status CaffeWeightsParser::CheckNodes(ge::ComputeGraphPtr &graph) {
GELOGE(ge::GRAPH_FAILED, "[Check][Param] Op[%s]'s input %d is not linked.", node->GetName().c_str(), GELOGE(ge::GRAPH_FAILED, "[Check][Param] Op[%s]'s input %d is not linked.", node->GetName().c_str(),
in_anchor_ptr->GetIdx()); in_anchor_ptr->GetIdx());
string check_msg = "input " + to_string(in_anchor_ptr->GetIdx()) + "is not linked in weight file"; string check_msg = "input " + to_string(in_anchor_ptr->GetIdx()) + "is not linked in weight file";
PreChecker::Instance().RefreshErrorMessageByName(node->GetName(), PreChecker::PARAM_INVALID, check_msg);
PreChecker::Instance().RefreshErrorMessageByName(node->GetName(), PreChecker::ErrorCode::PARAM_INVALID,
check_msg);
} }
return FAILED; return FAILED;
} }


+ 1
- 1
parser/common/model_saver.cc View File

@@ -75,7 +75,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi
mmSsize_t mmpa_ret = mmWrite(fd, const_cast<void *>((const void *)model_char), len); mmSsize_t mmpa_ret = mmWrite(fd, const_cast<void *>((const void *)model_char), len);
if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) {
ErrorManager::GetInstance().ATCReportErrMessage( ErrorManager::GetInstance().ATCReportErrMessage(
"E19004", {"file", "errmsg"}, {file_path, strerror(errno)});
"E19004", {"file", "errmsg"}, {file_path, strerror(errno)});
// Need to both print the error info of mmWrite and mmClose, so return ret after mmClose // Need to both print the error info of mmWrite and mmClose, so return ret after mmClose
GELOGE(FAILED, "[WriteTo][File] %s failed. errno = %ld, %s", file_path, mmpa_ret, strerror(errno)); GELOGE(FAILED, "[WriteTo][File] %s failed. errno = %ld, %s", file_path, mmpa_ret, strerror(errno));
ret = FAILED; ret = FAILED;


+ 1
- 0
parser/common/model_saver.h View File

@@ -20,6 +20,7 @@
#include <string> #include <string>


#include "ge/ge_api_error_codes.h" #include "ge/ge_api_error_codes.h"
#include "ge/ge_api_types.h"
#include "register/register_types.h" #include "register/register_types.h"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"




+ 1
- 1
parser/common/op_def/arg_op.h View File

@@ -23,7 +23,7 @@ class ArgOpOperator : public ParserOperator {
public: public:
ArgOpOperator(); ArgOpOperator();


~ArgOpOperator();
~ArgOpOperator() override;


ArgOpOperator &Name(const std::string &name); ArgOpOperator &Name(const std::string &name);




+ 1
- 1
parser/common/op_def/constant_op.h View File

@@ -24,7 +24,7 @@ namespace ge {
class ConstantOperator : public ParserOperator { class ConstantOperator : public ParserOperator {
public: public:
ConstantOperator(); ConstantOperator();
~ConstantOperator();
~ConstantOperator() override;


ConstantOperator &Name(const std::string &name); ConstantOperator &Name(const std::string &name);
ConstantOperator &VectorAttr(std::string key, std::vector<int64_t> &value); ConstantOperator &VectorAttr(std::string key, std::vector<int64_t> &value);


+ 1
- 1
parser/common/op_def/fill_op.h View File

@@ -23,7 +23,7 @@ class FillOperator : public ParserOperator {
public: public:
FillOperator(); FillOperator();


~FillOperator();
~FillOperator() override;


FillOperator &DataType(int64_t dataType); FillOperator &DataType(int64_t dataType);




+ 1
- 1
parser/common/op_def/frameworkop_op.h View File

@@ -24,7 +24,7 @@ class FrameworkOpOperator : public ParserOperator {
public: public:
FrameworkOpOperator(); FrameworkOpOperator();


~FrameworkOpOperator();
~FrameworkOpOperator() override;


FrameworkOpOperator &Name(const std::string &name); FrameworkOpOperator &Name(const std::string &name);




+ 1
- 1
parser/common/op_def/no_op_op.h View File

@@ -24,7 +24,7 @@ namespace ge {
class NoOpOperator : public ParserOperator { class NoOpOperator : public ParserOperator {
public: public:
NoOpOperator(); NoOpOperator();
~NoOpOperator();
~NoOpOperator() override;


NoOpOperator &Name(const std::string &name); NoOpOperator &Name(const std::string &name);
}; };


+ 1
- 1
parser/common/op_def/ref_switch_op.h View File

@@ -24,7 +24,7 @@ namespace ge {
class RefSwitchOperator : public ParserOperator { class RefSwitchOperator : public ParserOperator {
public: public:
RefSwitchOperator(); RefSwitchOperator();
~RefSwitchOperator();
~RefSwitchOperator() override;


RefSwitchOperator &Name(const std::string &name); RefSwitchOperator &Name(const std::string &name);
RefSwitchOperator &T(ge::DataType t); RefSwitchOperator &T(ge::DataType t);


+ 1
- 1
parser/common/op_def/shape_n_op.h View File

@@ -24,7 +24,7 @@ namespace ge {
class ShapeNOperator : public ParserOperator { class ShapeNOperator : public ParserOperator {
public: public:
ShapeNOperator(); ShapeNOperator();
~ShapeNOperator();
~ShapeNOperator() override;


ShapeNOperator &Name(const std::string &name); ShapeNOperator &Name(const std::string &name);




+ 1
- 1
parser/common/op_def/var_is_initialized_op_op.h View File

@@ -24,7 +24,7 @@ namespace ge {
class VarIsInitializedOpOperator : public ParserOperator { class VarIsInitializedOpOperator : public ParserOperator {
public: public:
VarIsInitializedOpOperator(); VarIsInitializedOpOperator();
~VarIsInitializedOpOperator();
~VarIsInitializedOpOperator() override;


VarIsInitializedOpOperator &Name(const std::string &name); VarIsInitializedOpOperator &Name(const std::string &name);
VarIsInitializedOpOperator &VectorAttr(const std::string &key, std::vector<int64_t> &value); VarIsInitializedOpOperator &VectorAttr(const std::string &key, std::vector<int64_t> &value);


+ 1
- 1
parser/common/op_def/variable_op.h View File

@@ -25,7 +25,7 @@ namespace ge {
class VariableOperator : public ParserOperator { class VariableOperator : public ParserOperator {
public: public:
VariableOperator(); VariableOperator();
~VariableOperator();
~VariableOperator() override;


VariableOperator &Name(const std::string &name); VariableOperator &Name(const std::string &name);




+ 1
- 1
parser/common/parser_fp16_t.h View File

@@ -586,7 +586,7 @@ T MinMan(const int16_t &e_a, T &m_a, const int16_t &e_b, T &m_b) {
template<typename T> template<typename T>
T RightShift(T man, int16_t shift) { T RightShift(T man, int16_t shift) {
int bits = sizeof(T) * 8; // one byte have 8 bits int bits = sizeof(T) * 8; // one byte have 8 bits
T mask = (((T) 1u) << ((unsigned int) (bits - 1)));
T mask = static_cast<T>(1u) << static_cast<uint32_t>(bits - 1);
for (int i = 0; i < shift; i++) { for (int i = 0; i < shift; i++) {
man = ((man & mask) | (man >> 1)); man = ((man & mask) | (man >> 1));
} }


+ 1
- 1
parser/common/pass_manager.cc View File

@@ -27,7 +27,7 @@ const std::vector<std::pair<std::string, GraphPass *>> &PassManager::GraphPasses
return names_to_graph_passes_; return names_to_graph_passes_;
} }


Status PassManager::AddPass(const string &pass_name, GraphPass *pass) {
Status PassManager::AddPass(const string &pass_name, GraphPass *const pass) {
GE_CHECK_NOTNULL(pass); GE_CHECK_NOTNULL(pass);
names_to_graph_passes_.emplace_back(pass_name, pass); names_to_graph_passes_.emplace_back(pass_name, pass);
return SUCCESS; return SUCCESS;


+ 1
- 1
parser/common/pass_manager.h View File

@@ -41,7 +41,7 @@ public:
/// @param [in] pass Pass to be added, it will be destroyed when pass manager destroys. /// @param [in] pass Pass to be added, it will be destroyed when pass manager destroys.
/// @author /// @author
/// ///
Status AddPass(const string &pass_name, GraphPass *pass);
Status AddPass(const string &pass_name, GraphPass *const pass);


/// ///
/// Optimize graph with added pass /// Optimize graph with added pass


+ 3
- 3
parser/common/pre_checker.cc View File

@@ -98,7 +98,7 @@ Status PreChecker::CheckName(OpId id) {
// If the name is duplicate, an error is logged // If the name is duplicate, an error is logged
if (id != v.first && info.name == v.second.name) { if (id != v.first && info.name == v.second.name) {
Cause cause; Cause cause;
cause.code = NAME_REPEATED;
cause.code = ErrorCode::NAME_REPEATED;
cause.message = "The name is repeated."; cause.message = "The name is repeated.";


GELOGI("Name %s repeated.", info.name.c_str()); GELOGI("Name %s repeated.", info.name.c_str());
@@ -248,7 +248,7 @@ Status PreChecker::CheckTypeSupported(OpId id, const string &type, const string
std::string op_type; std::string op_type;
if (!domi::OpRegistry::Instance()->GetOmTypeByOriOpType(type, op_type)) { if (!domi::OpRegistry::Instance()->GetOmTypeByOriOpType(type, op_type)) {
Cause cause; Cause cause;
cause.code = TYPE_UNSUPPORTED;
cause.code = ErrorCode::TYPE_UNSUPPORTED;
cause.message = "The type is not supported."; cause.message = "The type is not supported.";
GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str()); GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str());
if (!is_tensorflow) { if (!is_tensorflow) {
@@ -262,7 +262,7 @@ Status PreChecker::CheckTypeSupported(OpId id, const string &type, const string
// Log error if type not found // Log error if type not found
if (fmk_op_types_->find(type) == fmk_op_types_->end()) { if (fmk_op_types_->find(type) == fmk_op_types_->end()) {
Cause cause; Cause cause;
cause.code = TYPE_UNSUPPORTED;
cause.code = ErrorCode::TYPE_UNSUPPORTED;
cause.message = "The type is not supported."; cause.message = "The type is not supported.";


GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str()); GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str());


+ 1
- 1
parser/common/pre_checker.h View File

@@ -44,7 +44,7 @@ class PreChecker {
* @ingroup domi_omg * @ingroup domi_omg
* @brief error code, 1~99:Error, 100~199:Waring。 * @brief error code, 1~99:Error, 100~199:Waring。
*/ */
enum ErrorCode {
enum class ErrorCode {
// no error // no error
OK = 0, OK = 0,




+ 11
- 13
parser/onnx/onnx_constant_parser.h View File

@@ -23,8 +23,6 @@
#include "parser/common/data_op_parser.h" #include "parser/common/data_op_parser.h"
#include "parser/onnx/onnx_op_parser.h" #include "parser/onnx/onnx_op_parser.h"


using ge::onnx::NodeProto;

namespace ge { namespace ge {
class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser {
public: public:
@@ -60,17 +58,17 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser {


DataType data_type = tensor.GetTensorDesc().GetDataType(); DataType data_type = tensor.GetTensorDesc().GetDataType();
switch (data_type) { switch (data_type) {
#define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \
case dt_type: \
{ \
unique_ptr<value_type> addr_trans(new(std::nothrow) value_type[count]()); \
GE_CHECK_NOTNULL(addr_trans); \
for (int32_t i = 0; i < count; i++) { \
*(addr_trans.get() + i) = static_cast<value_type>(*(addr.get() + i)); \
} \
tensor.SetData(reinterpret_cast<uint8_t *>(addr_trans.get()), count * sizeof(value_type)); \
break; \
} \
#define CASE_SET_DATA(dt_type, value_type, addr, count, tensor) \
case dt_type: \
{ \
unique_ptr<value_type> addr_trans(new(std::nothrow) value_type[count]()); \
GE_CHECK_NOTNULL(addr_trans); \
for (int32_t i = 0; i < (count); i++) { \
*(addr_trans.get() + i) = static_cast<value_type>(*((addr).get() + i)); \
} \
(tensor).SetData(reinterpret_cast<uint8_t *>(addr_trans.get()), (count) * sizeof(value_type)); \
break; \
} \


CASE_SET_DATA(DT_FLOAT16, uint16_t, addr, count, tensor) CASE_SET_DATA(DT_FLOAT16, uint16_t, addr, count, tensor)
CASE_SET_DATA(DT_INT16, int16_t, addr, count, tensor) CASE_SET_DATA(DT_INT16, int16_t, addr, count, tensor)


+ 7
- 8
parser/tensorflow/tensorflow_parser.cc View File

@@ -586,7 +586,7 @@ Status TensorFlowModelParser::AddNode(const domi::tensorflow::NodeDef *node_def,
} }


void TensorFlowModelParser::GetInputOutputTensorNum(const ge::OpDescPtr &op_desc, size_t &input_tensor_num, void TensorFlowModelParser::GetInputOutputTensorNum(const ge::OpDescPtr &op_desc, size_t &input_tensor_num,
size_t &output_tensor_num) {
size_t &output_tensor_num) const {
// The caller guarantees that the pointer is not null // The caller guarantees that the pointer is not null
auto iter = op_node_context_map_.find(op_desc->GetName()); auto iter = op_node_context_map_.find(op_desc->GetName());
if (iter == op_node_context_map_.end()) { if (iter == op_node_context_map_.end()) {
@@ -817,8 +817,7 @@ Status TensorFlowModelParser::AddEdges(ge::ComputeGraphPtr &graph) {
return SUCCESS; return SUCCESS;
} }


Status TensorFlowModelParser::AddFmkNodeDefToMap(const domi::tensorflow::GraphDef &graph_def,
const domi::tensorflow::NodeDef *node_def,
Status TensorFlowModelParser::AddFmkNodeDefToMap(const domi::tensorflow::NodeDef *node_def,
vector<string> &op_node_name_list) { vector<string> &op_node_name_list) {
GE_CHECK_NOTNULL(node_def); GE_CHECK_NOTNULL(node_def);
const string &node_name = node_def->name(); const string &node_name = node_def->name();
@@ -1224,7 +1223,7 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g
GELOGI("Node: %s maybe a fusion op.", node_def->name().c_str());); GELOGI("Node: %s maybe a fusion op.", node_def->name().c_str()););


// Do not exit immediately when there is an error, wait until all errors are collected before exiting // Do not exit immediately when there is an error, wait until all errors are collected before exiting
GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(graph_def, node_def, op_node_name_list), has_error = true,
GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(node_def, op_node_name_list), has_error = true,
"add node failed."); "add node failed.");
} }


@@ -1459,7 +1458,7 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro
} }


// Do not exit immediately when there is an error, wait until all errors are collected before exiting // Do not exit immediately when there is an error, wait until all errors are collected before exiting
GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(graph_def, node_def, op_node_name_list), has_error = true);
GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(node_def, op_node_name_list), has_error = true);
} }


// The fusion operator has passed the verification. // The fusion operator has passed the verification.
@@ -1545,7 +1544,7 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro
return SUCCESS; return SUCCESS;
} }


Status TensorFlowModelParser::CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) {
Status TensorFlowModelParser::CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) const {
// Number of data nodes // Number of data nodes
uint32_t data_node_count = 0; uint32_t data_node_count = 0;
for (const domi::tensorflow::NodeDef &node_def : graph_def.node()) { for (const domi::tensorflow::NodeDef &node_def : graph_def.node()) {
@@ -2275,7 +2274,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto,
} }


// Do not exit immediately when there is an error, wait until all errors are collected before exiting // Do not exit immediately when there is an error, wait until all errors are collected before exiting
Status ret = AddFmkNodeDefToMap(*graph_def, node_def, op_node_name_list);
Status ret = AddFmkNodeDefToMap(node_def, op_node_name_list);
GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed"); GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed");
} }
PARSER_TIMESTAMP_END(AddFmkNodeDefToMap, "TensorFlowModelParser::AddFmkNodeDefToMap"); PARSER_TIMESTAMP_END(AddFmkNodeDefToMap, "TensorFlowModelParser::AddFmkNodeDefToMap");
@@ -2865,7 +2864,7 @@ Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph
// mutable_node return vale is not empty // mutable_node return vale is not empty
domi::tensorflow::NodeDef *node_def = graph_def->mutable_node(i); domi::tensorflow::NodeDef *node_def = graph_def->mutable_node(i);
const string &node_name = node_def->name(); const string &node_name = node_def->name();
Status ret = AddFmkNodeDefToMap(*graph_def, node_def, op_node_name_list);
Status ret = AddFmkNodeDefToMap(node_def, op_node_name_list);
GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed"); GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed");
if (node_def->op() == ge::parser::IDENTITY || node_def->op() == ge::parser::READVARIABLEOP) { if (node_def->op() == ge::parser::IDENTITY || node_def->op() == ge::parser::READVARIABLEOP) {
identity_to_optimize.push_back(node_def); identity_to_optimize.push_back(node_def);


+ 3
- 3
parser/tensorflow/tensorflow_parser.h View File

@@ -185,7 +185,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
* @return FAILED add failed * @return FAILED add failed


*/ */
Status AddFmkNodeDefToMap(const domi::tensorflow::GraphDef &graph_def, const domi::tensorflow::NodeDef *node_def,
Status AddFmkNodeDefToMap(const domi::tensorflow::NodeDef *node_def,
vector<string> &op_node_name_list); vector<string> &op_node_name_list);


/** /**
@@ -243,7 +243,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
* @return SUCCESS check successfully * @return SUCCESS check successfully
* @return FAILED check failed * @return FAILED check failed
*/ */
Status CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def);
Status CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) const;


/** /**
* @ingroup domi_omg * @ingroup domi_omg
@@ -516,7 +516,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
Status UppdateOutputMap(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info, Status UppdateOutputMap(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info,
OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context); OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context);
void GetInputOutputTensorNum(const ge::OpDescPtr &op_desc, size_t &input_tensor_num, void GetInputOutputTensorNum(const ge::OpDescPtr &op_desc, size_t &input_tensor_num,
size_t &output_tensor_num);
size_t &output_tensor_num) const;
static Status CheckOpShapeDim(const domi::tensorflow::NodeDef *node_def, const std::set<int> &dims, bool &valid); static Status CheckOpShapeDim(const domi::tensorflow::NodeDef *node_def, const std::set<int> &dims, bool &valid);
Status CheckOpType(const domi::tensorflow::NodeDef *node_def, string &op_type); Status CheckOpType(const domi::tensorflow::NodeDef *node_def, string &op_type);




+ 88
- 0
parser/tensorflow/tensorflow_util.cc View File

@@ -31,6 +31,94 @@
using domi::tensorflow::DT_INVALID; using domi::tensorflow::DT_INVALID;


namespace ge { namespace ge {
/***************************TensorFlow attribute type, constant definition*******************************************/
const std::string TENSORFLOW_ATTR_TYPE_STRING = "string";
const std::string TENSORFLOW_ATTR_TYPE_INT = "int";
const std::string TENSORFLOW_ATTR_TYPE_FLOAT = "float";
const std::string TENSORFLOW_ATTR_TYPE_BOOL = "bool";
const std::string TENSORFLOW_ATTR_TYPE_TYPE = "type";
const std::string TENSORFLOW_ATTR_TYPE_SHAPE = "shape";
const std::string TENSORFLOW_ATTR_TYPE_TENSOR = "tensor";
const std::string TENSORFLOW_ATTR_TYPE_FUNC = "func";

const std::string TENSORFLOW_ATTR_LIST_TYPE_STRING = "list(string)";
const std::string TENSORFLOW_ATTR_LIST_TYPE_INT = "list(int)";
const std::string TENSORFLOW_ATTR_LIST_TYPE_FLOAT = "list(float)";
const std::string TENSORFLOW_ATTR_LIST_TYPE_BOOL = "list(bool)";
const std::string TENSORFLOW_ATTR_LIST_TYPE_TYPE = "list(type)";
const std::string TENSORFLOW_ATTR_LIST_TYPE_SHAPE = "list(shape)";
const std::string TENSORFLOW_ATTR_LIST_TYPE_TENSOR = "list(tensor)";
const std::string TENSORFLOW_ATTR_LIST_TYPE_FUNC = "list(func)";

/***************************constant definition*******************************************/
const std::string TENSORFLOW_ATTR_OUTPUT_OP = "output_op";

const std::string TENSORFLOW_ATTR_T = "T";
const std::string TENSORFLOW_ATTR_N = "N";
const std::string TENSORFLOW_ATTR_DATA_FORMAT = "data_format";
const std::string TENSORFLOW_ATTR_PADDING = "padding";
const std::string TENSORFLOW_ATTR_KSIZE = "ksize";
const std::string TENSORFLOW_ATTR_STRIDES = "strides";
const std::string TENSORFLOW_ATTR_DILATIONS = "dilations";
const std::string TENSORFLOW_ATTR_DTYPE = "dtype";
const std::string TENSORFLOW_ATTR_VALUE = "value";
const std::string TENSORFLOW_ATTR_TRANSINPUT = "transpose_a";
const std::string TENSORFLOW_ATTR_TRANSWEIGHT = "transpose_b";
const std::string TENSORFLOW_ATTR_SHAPE = "shape";
const std::string TENSORFLOW_ATTR_TIDX = "Tidx";
const std::string TENSORFLOW_ATTR_TPADDINGS = "Tpaddings";
const std::string TENSORFLOW_ATTR_TMULTIPLES = "Tmultiples";
const std::string TENSORFLOW_ATTR_TINDICES = "Tindices";
const std::string TENSORFLOW_ATTR_TPARAMS = "Tparams";
const std::string TENSORFLOW_ATTR_TAXIS = "Taxis";
const std::string TENSORFLOW_ATTR_DSTT = "DstT";
const std::string TENSORFLOW_ATTR_SRCT = "SrcT";
const std::string TENSORFLOW_ATTR_PERM = "perm";
const std::string TENSORFLOW_ATTR_INDEX = "Index";
const std::string TENSORFLOW_ATTR_TSHAPE = "Tshape";
const std::string TENSORFLOW_ATTR_AXIS = "Axis";
const std::string TENSORFLOW_ATTR_BIAS = "bias";
const std::string TENSORFLOW_ATTR_DEPTH_RADIUS = "depth_radius";
const std::string TENSORFLOW_ATTR_ALPHA = "alpha";
const std::string TENSORFLOW_ATTR_BETA = "beta";
const std::string TENSORFLOW_ATTR_MODE = "mode";

// op:Const
const std::string TENSORFLOWF_NODE_OP_CONST = "Const";
const std::string TENSORFLOWF_NODE_OP_IDENTITY = "Identity";
const std::string TENSORFLOWF_NODE_OP_SWITCH = "Switch";
const std::string TENSORFLOWF_NODE_OP_PLACEHOLDER = "Placeholder";
const std::string TENSORFLOWF_NODE_OP_ADDN = "AddN";
const std::string TENSORFLOWF_NODE_OP_MATMUL = "MatMul";
const std::string TENSORFLOWF_NODE_OP_RELU = "Relu";
const std::string TENSORFLOWF_NODE_OP_SHAPE = "Shape";
const std::string TENSORFLOWF_NODE_OP_TRANSPOSE = "Transpose";
const std::string TENSORFLOWF_NODE_OP_MERGE = "Merge";

// data_format
const std::string TENSORFLOWF_TENSOR_NCHW = "NCHW";
const std::string TENSORFLOWF_TENSOR_NHWC = "NHWC";

const int TENSORFLOW_CONV_STRIDE_NUM = 4;
const int TENSORFLOW_CONV_DILATION_NUM = 4;

// padding
const std::string TENSORFLOWF_OP_PADDING_VALID = "VALID";
const std::string TENSORFLOWF_OP_PADDING_SAME = "SAME";

// normal input size
const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_MATMUL = 2;
const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_RESHAPE = 1;
const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_POOL = 1;

// normal weight size
const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_MATMUL = 1;
const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_RESHAPE = 1;

// input or output
const uint32_t TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG = 1;
const uint32_t TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG = 2;

using AttrValueMap = ::google::protobuf::Map<std::string, domi::tensorflow::AttrValue>; using AttrValueMap = ::google::protobuf::Map<std::string, domi::tensorflow::AttrValue>;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrValue( FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrValue(
const domi::tensorflow::NodeDef *node_def, const std::string &attr_name, domi::tensorflow::AttrValue &attr_value) { const domi::tensorflow::NodeDef *node_def, const std::string &attr_name, domi::tensorflow::AttrValue &attr_value) {


+ 71
- 71
parser/tensorflow/tensorflow_util.h View File

@@ -44,92 +44,92 @@ using domi::tensorflow::FunctionDefLibrary;


namespace ge { namespace ge {
/***************************TensorFlow attribute type, constant definition*******************************************/ /***************************TensorFlow attribute type, constant definition*******************************************/
static const std::string TENSORFLOW_ATTR_TYPE_STRING = "string";
static const std::string TENSORFLOW_ATTR_TYPE_INT = "int";
static const std::string TENSORFLOW_ATTR_TYPE_FLOAT = "float";
static const std::string TENSORFLOW_ATTR_TYPE_BOOL = "bool";
static const std::string TENSORFLOW_ATTR_TYPE_TYPE = "type";
static const std::string TENSORFLOW_ATTR_TYPE_SHAPE = "shape";
static const std::string TENSORFLOW_ATTR_TYPE_TENSOR = "tensor";
static const std::string TENSORFLOW_ATTR_TYPE_FUNC = "func";
static const std::string TENSORFLOW_ATTR_LIST_TYPE_STRING = "list(string)";
static const std::string TENSORFLOW_ATTR_LIST_TYPE_INT = "list(int)";
static const std::string TENSORFLOW_ATTR_LIST_TYPE_FLOAT = "list(float)";
static const std::string TENSORFLOW_ATTR_LIST_TYPE_BOOL = "list(bool)";
static const std::string TENSORFLOW_ATTR_LIST_TYPE_TYPE = "list(type)";
static const std::string TENSORFLOW_ATTR_LIST_TYPE_SHAPE = "list(shape)";
static const std::string TENSORFLOW_ATTR_LIST_TYPE_TENSOR = "list(tensor)";
static const std::string TENSORFLOW_ATTR_LIST_TYPE_FUNC = "list(func)";
extern const std::string TENSORFLOW_ATTR_TYPE_STRING;
extern const std::string TENSORFLOW_ATTR_TYPE_INT;
extern const std::string TENSORFLOW_ATTR_TYPE_FLOAT;
extern const std::string TENSORFLOW_ATTR_TYPE_BOOL;
extern const std::string TENSORFLOW_ATTR_TYPE_TYPE;
extern const std::string TENSORFLOW_ATTR_TYPE_SHAPE;
extern const std::string TENSORFLOW_ATTR_TYPE_TENSOR;
extern const std::string TENSORFLOW_ATTR_TYPE_FUNC;
extern const std::string TENSORFLOW_ATTR_LIST_TYPE_STRING;
extern const std::string TENSORFLOW_ATTR_LIST_TYPE_INT;
extern const std::string TENSORFLOW_ATTR_LIST_TYPE_FLOAT;
extern const std::string TENSORFLOW_ATTR_LIST_TYPE_BOOL;
extern const std::string TENSORFLOW_ATTR_LIST_TYPE_TYPE;
extern const std::string TENSORFLOW_ATTR_LIST_TYPE_SHAPE;
extern const std::string TENSORFLOW_ATTR_LIST_TYPE_TENSOR;
extern const std::string TENSORFLOW_ATTR_LIST_TYPE_FUNC;


/***************************constant definition*******************************************/ /***************************constant definition*******************************************/
static const std::string TENSORFLOW_ATTR_OUTPUT_OP = "output_op";
static const std::string TENSORFLOW_ATTR_T = "T";
static const std::string TENSORFLOW_ATTR_N = "N";
static const std::string TENSORFLOW_ATTR_DATA_FORMAT = "data_format";
static const std::string TENSORFLOW_ATTR_PADDING = "padding";
static const std::string TENSORFLOW_ATTR_KSIZE = "ksize";
static const std::string TENSORFLOW_ATTR_STRIDES = "strides";
static const std::string TENSORFLOW_ATTR_DILATIONS = "dilations";
static const std::string TENSORFLOW_ATTR_DTYPE = "dtype";
static const std::string TENSORFLOW_ATTR_VALUE = "value";
static const std::string TENSORFLOW_ATTR_TRANSINPUT = "transpose_a";
static const std::string TENSORFLOW_ATTR_TRANSWEIGHT = "transpose_b";
static const std::string TENSORFLOW_ATTR_SHAPE = "shape";
static const std::string TENSORFLOW_ATTR_TIDX = "Tidx";
static const std::string TENSORFLOW_ATTR_TPADDINGS = "Tpaddings";
static const std::string TENSORFLOW_ATTR_TMULTIPLES = "Tmultiples";
static const std::string TENSORFLOW_ATTR_TINDICES = "Tindices";
static const std::string TENSORFLOW_ATTR_TPARAMS = "Tparams";
static const std::string TENSORFLOW_ATTR_TAXIS = "Taxis";
static const std::string TENSORFLOW_ATTR_DSTT = "DstT";
static const std::string TENSORFLOW_ATTR_SRCT = "SrcT";
static const std::string TENSORFLOW_ATTR_PERM = "perm";
static const std::string TENSORFLOW_ATTR_INDEX = "Index";
static const std::string TENSORFLOW_ATTR_TSHAPE = "Tshape";
static const std::string TENSORFLOW_ATTR_AXIS = "Axis";
static const std::string TENSORFLOW_ATTR_BIAS = "bias";
static const std::string TENSORFLOW_ATTR_DEPTH_RADIUS = "depth_radius";
static const std::string TENSORFLOW_ATTR_ALPHA = "alpha";
static const std::string TENSORFLOW_ATTR_BETA = "beta";
static const std::string TENSORFLOW_ATTR_MODE = "mode";
extern const std::string TENSORFLOW_ATTR_OUTPUT_OP;
extern const std::string TENSORFLOW_ATTR_T;
extern const std::string TENSORFLOW_ATTR_N;
extern const std::string TENSORFLOW_ATTR_DATA_FORMAT;
extern const std::string TENSORFLOW_ATTR_PADDING;
extern const std::string TENSORFLOW_ATTR_KSIZE;
extern const std::string TENSORFLOW_ATTR_STRIDES;
extern const std::string TENSORFLOW_ATTR_DILATIONS;
extern const std::string TENSORFLOW_ATTR_DTYPE;
extern const std::string TENSORFLOW_ATTR_VALUE;
extern const std::string TENSORFLOW_ATTR_TRANSINPUT;
extern const std::string TENSORFLOW_ATTR_TRANSWEIGHT;
extern const std::string TENSORFLOW_ATTR_SHAPE;
extern const std::string TENSORFLOW_ATTR_TIDX;
extern const std::string TENSORFLOW_ATTR_TPADDINGS;
extern const std::string TENSORFLOW_ATTR_TMULTIPLES;
extern const std::string TENSORFLOW_ATTR_TINDICES;
extern const std::string TENSORFLOW_ATTR_TPARAMS;
extern const std::string TENSORFLOW_ATTR_TAXIS;
extern const std::string TENSORFLOW_ATTR_DSTT;
extern const std::string TENSORFLOW_ATTR_SRCT;
extern const std::string TENSORFLOW_ATTR_PERM;
extern const std::string TENSORFLOW_ATTR_INDEX;
extern const std::string TENSORFLOW_ATTR_TSHAPE;
extern const std::string TENSORFLOW_ATTR_AXIS;
extern const std::string TENSORFLOW_ATTR_BIAS;
extern const std::string TENSORFLOW_ATTR_DEPTH_RADIUS;
extern const std::string TENSORFLOW_ATTR_ALPHA;
extern const std::string TENSORFLOW_ATTR_BETA;
extern const std::string TENSORFLOW_ATTR_MODE;


// op:Const // op:Const
static const std::string TENSORFLOWF_NODE_OP_CONST = "Const";
static const std::string TENSORFLOWF_NODE_OP_IDENTITY = "Identity";
static const std::string TENSORFLOWF_NODE_OP_SWITCH = "Switch";
static const std::string TENSORFLOWF_NODE_OP_PLACEHOLDER = "Placeholder";
static const std::string TENSORFLOWF_NODE_OP_ADDN = "AddN";
static const std::string TENSORFLOWF_NODE_OP_MATMUL = "MatMul";
static const std::string TENSORFLOWF_NODE_OP_RELU = "Relu";
static const std::string TENSORFLOWF_NODE_OP_SHAPE = "Shape";
static const std::string TENSORFLOWF_NODE_OP_TRANSPOSE = "Transpose";
static const std::string TENSORFLOWF_NODE_OP_MERGE = "Merge";
extern const std::string TENSORFLOWF_NODE_OP_CONST;
extern const std::string TENSORFLOWF_NODE_OP_IDENTITY;
extern const std::string TENSORFLOWF_NODE_OP_SWITCH;
extern const std::string TENSORFLOWF_NODE_OP_PLACEHOLDER;
extern const std::string TENSORFLOWF_NODE_OP_ADDN;
extern const std::string TENSORFLOWF_NODE_OP_MATMUL;
extern const std::string TENSORFLOWF_NODE_OP_RELU;
extern const std::string TENSORFLOWF_NODE_OP_SHAPE;
extern const std::string TENSORFLOWF_NODE_OP_TRANSPOSE;
extern const std::string TENSORFLOWF_NODE_OP_MERGE;


// data_format // data_format
static const std::string TENSORFLOWF_TENSOR_NCHW = "NCHW";
static const std::string TENSORFLOWF_TENSOR_NHWC = "NHWC";
extern const std::string TENSORFLOWF_TENSOR_NCHW;
extern const std::string TENSORFLOWF_TENSOR_NHWC;


static const int TENSORFLOW_CONV_STRIDE_NUM = 4;
static const int TENSORFLOW_CONV_DILATION_NUM = 4;
extern const int TENSORFLOW_CONV_STRIDE_NUM;
extern const int TENSORFLOW_CONV_DILATION_NUM;


// padding // padding
static const std::string TENSORFLOWF_OP_PADDING_VALID = "VALID";
static const std::string TENSORFLOWF_OP_PADDING_SAME = "SAME";
extern const std::string TENSORFLOWF_OP_PADDING_VALID;
extern const std::string TENSORFLOWF_OP_PADDING_SAME;


// normal input size // normal input size
static const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_MATMUL = 2;
static const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_RESHAPE = 1;
static const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_POOL = 1;
extern const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_MATMUL;
extern const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_RESHAPE;
extern const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_POOL;


// normal weight size // normal weight size
static const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_MATMUL = 1;
static const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_RESHAPE = 1;
extern const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_MATMUL;
extern const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_RESHAPE;


// input or output // input or output
static const uint32_t TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG = 1;
static const uint32_t TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG = 2;
extern const uint32_t TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG;
extern const uint32_t TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG;


class TensorFlowUtil { class TensorFlowUtil {
public: public:


Loading…
Cancel
Save