Browse Source

Pre Merge pull request !633 from lipeiyang/ge_dev

pull/633/MERGE
lipeiyang Gitee 2 years ago
parent
commit
f23b94933f
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 184 additions and 145 deletions
  1. +1
    -1
      metadef
  2. +1
    -1
      parser/tensorflow/tensorflow_constant_parser.h
  3. +1
    -1
      parser/tensorflow/tensorflow_frameworkop_parser.cc
  4. +1
    -1
      parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc
  5. +8
    -8
      parser/tensorflow/tensorflow_fusion_op_parser.cc
  6. +1
    -1
      parser/tensorflow/tensorflow_fusion_op_parser.h
  7. +1
    -1
      parser/tensorflow/tensorflow_merge_parser.cc
  8. +1
    -1
      parser/tensorflow/tensorflow_merge_parser.h
  9. +28
    -16
      parser/tensorflow/tensorflow_parser.cc
  10. +1
    -1
      parser/tensorflow/tensorflow_parser_register.h
  11. +2
    -2
      parser/tensorflow/tensorflow_ref_switch_parser.cc
  12. +8
    -5
      parser/tensorflow/tensorflow_reshape_parser.cc
  13. +4
    -3
      parser/tensorflow/tensorflow_shape_n_parser.cc
  14. +12
    -9
      parser/tensorflow/tensorflow_squeeze_parser.cc
  15. +43
    -39
      parser/tensorflow/tensorflow_util.cc
  16. +37
    -29
      tests/st/testcase/test_tensorflow_parser.cc
  17. +34
    -26
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 5894aaa3c0cb12565e5e57a0a49f3e732e608f3d
Subproject commit 5d062a35640733026457c91966a558769570b0f8

+ 1
- 1
parser/tensorflow/tensorflow_constant_parser.h View File

@@ -21,9 +21,9 @@
#include "parser/common/data_op_parser.h"
#include "parser/tensorflow/tensorflow_op_parser.h"

namespace ge {
using domi::tensorflow::NodeDef;

namespace ge {
class PARSER_FUNC_VISIBILITY TensorFlowConstantParser : public TensorFlowOpParser {
public:
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override;


+ 1
- 1
parser/tensorflow/tensorflow_frameworkop_parser.cc View File

@@ -31,7 +31,7 @@ namespace ge {
Status ParseParams(const Message *op_src, FrameworkOpOperator *op) {
GE_CHECK_NOTNULL(op_src);
GE_CHECK_NOTNULL(op);
const domi::tensorflow::NodeDef *node = reinterpret_cast<const domi::tensorflow::NodeDef *>(op_src);
const domi::tensorflow::NodeDef *node = static_cast<const domi::tensorflow::NodeDef *>(op_src);
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str());
string type = node->op();



+ 1
- 1
parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc View File

@@ -33,7 +33,7 @@ Status TensorFlowFusionCustomParserAdapter::ParseParams(const vector<const NodeD
std::vector<const google::protobuf::Message *> inside_nodes;
for (auto inside_node : v_input_const) {
GE_CHECK_NOTNULL(inside_node);
const google::protobuf::Message *node_src = reinterpret_cast<const google::protobuf::Message *>(inside_node);
const google::protobuf::Message *node_src = dynamic_cast<const google::protobuf::Message *>(inside_node);
inside_nodes.push_back(node_src);
}
std::string ori_type = op_dest->GetType();


+ 8
- 8
parser/tensorflow/tensorflow_fusion_op_parser.cc View File

@@ -31,21 +31,21 @@ namespace ge {
do { \
google::protobuf::RepeatedField<FIELD> val_vec; \
int32_t val_size = 0; \
val_vec = (tensor).FIELD##_val(); \
val_vec = (tensor).FIELD##_val(); \
val_size = val_vec.size(); \
if ((index) < val_size) { \
(param) = val_vec.Get(index); \
} else if ((tensor).has_tensor_shape()) { \
const std::string tensor_content = (tensor).tensor_content(); \
char *buf = const_cast<char *>(tensor_content.data()); \
FIELD *buf_v = reinterpret_cast<FIELD *>(buf); \
if ((index) < val_size) { \
(param) = val_vec.Get(index); \
} else if ((tensor).has_tensor_shape()) { \
const std::string &tensor_content = (tensor).tensor_content(); \
const char *buf = tensor_content.data(); \
const FIELD *buf_v = reinterpret_cast<const FIELD *>(buf); \
if (static_cast<uint32_t>(index) >= tensor_content.length() / sizeof(FIELD)) { \
REPORT_INNER_ERROR("E19999", "Const data size of node:%s is smaller than index:%d, not supported!", \
node_def->name().c_str(), index); \
GELOGE(domi::PARAM_INVALID, "Const data size is smaller than index :%d,not supported!", index); \
return domi::PARAM_INVALID; \
} \
(param) = buf_v[index]; \
(param) = buf_v[index]; \
} else { \
REPORT_INNER_ERROR("E19999", "Const data size of node:%s is smaller than index:%d, not supported!", \
node_def->name().c_str(), index); \


+ 1
- 1
parser/tensorflow/tensorflow_fusion_op_parser.h View File

@@ -25,11 +25,11 @@
#include "proto/tensorflow/graph.pb.h"
#include "proto/tensorflow/node_def.pb.h"

namespace ge {
using google::protobuf::Message;
using domi::tensorflow::NodeDef;
using domi::tensorflow::TensorProto;

namespace ge {
/**
* @ingroup domi_omg
* @brief Used to parse TensorFlow operator information


+ 1
- 1
parser/tensorflow/tensorflow_merge_parser.cc View File

@@ -39,7 +39,7 @@ Status TensorFlowMergeParser::ParseParams(const Message *op_src, ge::OpDescPtr &
int32_t input_tensor_num = attr_num.i();

// add dynamic input
graphStatus ret = op_desc->AddDynamicInputDesc("x", input_tensor_num);
const graphStatus ret = op_desc->AddDynamicInputDesc("x", input_tensor_num);
if (ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Add Dynamic InputDesc name:x to node:%s(%s) failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str());


+ 1
- 1
parser/tensorflow/tensorflow_merge_parser.h View File

@@ -25,4 +25,4 @@ class PARSER_FUNC_VISIBILITY TensorFlowMergeParser : public TensorFlowOpParser {
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) override;
};
} // namespace ge
#endif //_DOMI_OMG_PARSER_TENSORFLOW_TENSORFLOW_MERGE_PARSER_H_
#endif // _DOMI_OMG_PARSER_TENSORFLOW_TENSORFLOW_MERGE_PARSER_H_

+ 28
- 16
parser/tensorflow/tensorflow_parser.cc View File

@@ -191,21 +191,25 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map<Ascend
GELOGI("AclgrphParse graph %s success.", ParserUtils::GetGraphName(graph).c_str());
return ge::SUCCESS;
}
void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr parent_node, ge::NodePtr node)
{

void AddDumpOriginName(const ge::NodePtr parent_node, const std::string& subgraph_name, ge::ComputeGraphPtr graph) {
if (parent_node == nullptr) {
return; // Root graph no need set dump origin name as parser always keep the origin node name
}
std::vector<std::string> original_names;
auto parend_desc = parent_node->GetOpDesc();
(void)ge::AttrUtils::GetListStr(parend_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
(void)ge::AttrUtils::GetListStr(parent_node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
if (original_names.empty()) {
original_names.emplace_back(parent_node->GetName());
}
// for fusion node also used original_names[0]
(void)original_names[0].append("/").append(subgraph_name).append("/").append(node->GetName());

if (!ge::AttrUtils::SetListStr(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names)) {
GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), node->GetOpDesc()->GetName().c_str());
std::string prefix = original_names[0].append("/").append(subgraph_name).append("/");
for (const ge::NodePtr &node : graph->GetDirectNode()) {
original_names[0] = prefix + node->GetName();
if (!ge::AttrUtils::SetListStr(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names)) {
GELOGW("Set dump origin name to %s fail.", node->GetOpDesc()->GetName().c_str());
}
GELOGD("Add dump origin name %s for node %s.", original_names[0].c_str(), node->GetName().c_str());
}
GELOGD("Add dump origin name %s for node %s.", original_names[0].c_str(), node->GetName().c_str());
}
} // namespace ge

@@ -273,6 +277,7 @@ Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque
}

Status PostOpProcessForSubgraph(const ParseArg &arg) {
AddDumpOriginName(arg.parent_node, arg.subgraph_name, arg.graph);
if (arg.parent_node == nullptr) {
return SUCCESS;
}
@@ -297,7 +302,6 @@ Status PostOpProcessForSubgraph(const ParseArg &arg) {
if ((node->GetOpDesc() == nullptr) || (node->GetType() == "Variable") || (node->GetType() == "VariableV2")) {
continue;
}
AddDumpOriginName(arg.subgraph_name, arg.parent_node, node);
node->GetOpDesc()->SetName(node->GetOwnerComputeGraph()->GetName() + "/" + node->GetName());
}

@@ -909,8 +913,10 @@ Status TensorFlowModelParser::CheckOpType(const domi::tensorflow::NodeDef *node_
GE_IF_BOOL_EXEC(
op_type == ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS,
GE_CHK_STATUS_RET(CheckOpShapeDim(node_def, check_dims[op_type], valid), "failed to check op shape");
GE_IF_BOOL_EXEC(!valid, op_type = ge::parser::FRAMEWORKOP; GELOGI("Set op %s to frameworkop", node_name.c_str());
framework_ops_[node_name] = node_def;););
GE_IF_BOOL_EXEC(!valid, op_type = ge::parser::FRAMEWORKOP;
GELOGI("Set op %s to frameworkop", node_name.c_str());
framework_ops_[node_name] = node_def;);
);

GE_IF_BOOL_EXEC(
op_type == ge::parser::ADD || op_type == ge::parser::MULTIPLY || op_type == ge::parser::MEAN,
@@ -970,7 +976,8 @@ Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::Co
GELOGD("CCE %s parsering", node_op.c_str()););
GE_IF_BOOL_EXEC((implyType == domi::ImplyType::HCCL) && (op_type != ge::parser::FRAMEWORKOP),
GELOGD("HCCL %s parsering", node_op.c_str()););
GE_IF_BOOL_EXEC(op_type == ge::parser::FRAMEWORKOP, GELOGD("FRAMEWORKOP %s parsering", node_op.c_str()););
GE_IF_BOOL_EXEC(op_type == ge::parser::FRAMEWORKOP,
GELOGD("FRAMEWORKOP %s parsering", node_op.c_str()););
GELOGD("TF op node name = %s, op type= %s, trans to op type %s", node_name.c_str(), node_op.c_str(), op_type.c_str());

// Construct operator by IR
@@ -2322,7 +2329,8 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto,
// Loop analysis of op_nodes and map them to nodes in graph
ret = AddFmkNode(graph, scope_graph, op_node_name_list, isDatasetInit);
PARSER_TIMESTAMP_END(AddFmkNode, "TensorFlowModelParser::AddFmkNode");
GE_CHK_STATUS_EXEC(ret, DeleteFuisonNodeDef(); return ret, "AddFmkNode failed");
GE_CHK_STATUS_EXEC(ret, DeleteFuisonNodeDef();
return ret, "AddFmkNode failed");
GELOGD("[TF Parser] Add framework node success");

ret = AddEdges(graph);
@@ -3004,12 +3012,16 @@ Status TensorFlowModelParser::GetFormatTranspose(const NodeDef *transpose_node,
GE_IF_BOOL_EXEC(
type == domi::tensorflow::DT_INT32,
const int32_t *data = reinterpret_cast<const int32_t *>(tensor.tensor_content().data());
for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) { perm_value.push_back(data[i]); });
for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) {
perm_value.push_back(data[i]);
});

GE_IF_BOOL_EXEC(
type == domi::tensorflow::DT_INT64,
const int64_t *data = reinterpret_cast<const int64_t *>(tensor.tensor_content().data());
for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) { perm_value.push_back(data[i]); });
for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) {
perm_value.push_back(data[i]);
});

// 0, 1, 2, 3 present dim num.
vector<int64_t> perm_to_nchw = {0, 3, 1, 2};


+ 1
- 1
parser/tensorflow/tensorflow_parser_register.h View File

@@ -61,7 +61,7 @@ class PARSER_FUNC_VISIBILITY TensorflowParserBuilder : public TensorflowWeightPa
public:
using ParseParamsFn = std::function<domi::Status(const domi::tensorflow::NodeDef *, Param *)>;

explicit TensorflowParserBuilder(const std::string &davinci_optype) : davinci_optype_(davinci_optype) {}
explicit TensorflowParserBuilder(const std::string &davinci_optype) noexcept : davinci_optype_(davinci_optype) {}

~TensorflowParserBuilder() override {}



+ 2
- 2
parser/tensorflow/tensorflow_ref_switch_parser.cc View File

@@ -40,8 +40,8 @@ Status TensorFlowRefSwitchParser::ParseT(const domi::tensorflow::NodeDef *node,

GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "type"), "check Attr T failed");

domi::tensorflow::DataType tfType = attr.type();
ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tfType);
const domi::tensorflow::DataType tfType = attr.type();
const ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tfType);
CHECK_FALSE_EXEC(type != ge::DataType::DT_UNDEFINED,
REPORT_CALL_ERROR("E19999", "Data type %s of node %s is not supported",
DataType_Name(tfType).c_str(),


+ 8
- 5
parser/tensorflow/tensorflow_reshape_parser.cc View File

@@ -34,26 +34,29 @@ Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &att
auto a_list = attr_value.list();
GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, 0, tf_datatype), PARAM_INVALID,
"parse ge_desc failed.");
uint32_t size_type = 1;
uint32_t size_type = 1U;
auto data_type = ge_desc.GetDataType();
bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type);
GE_IF_BOOL_EXEC(!type_ret,
REPORT_CALL_ERROR("E19999", "Data type %s is not supported",
ge::TypeUtils::DataTypeToSerialString(data_type).c_str());
GELOGE(FAILED, "Can't GetDataTypeLength of data_type: %s",
GELOGE(FAILED, "Can't GetDataTypeLength of data_type: %s.",
ge::TypeUtils::DataTypeToSerialString(data_type).c_str());
return PARAM_INVALID);
// calculate size
int64_t real_size = 1;
for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) {
for (uint32_t j = 0U; j < ge_desc.GetShape().GetDimNum(); ++j) {
int64_t tmp_dim = ge_desc.GetShape().GetDim(j);
GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;);
if (tmp_dim < 0) {
real_size = tmp_dim * (-1) * real_size;
continue;
}
real_size *= tmp_dim;
}
PARSER_INT64_MULCHECK(real_size, size_type);
ge::TensorUtils::SetSize(ge_desc, real_size * size_type);
ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum());
GELOGI("after translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u",
GELOGI("After translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u",
ge::TypeUtils::DataTypeToSerialString(ge_desc.GetDataType()).c_str(),
ge::TypeUtils::FormatToSerialString(ge_desc.GetFormat()).c_str(), real_size * size_type, size_type);
return SUCCESS;


+ 4
- 3
parser/tensorflow/tensorflow_shape_n_parser.cc View File

@@ -83,7 +83,8 @@ Status TensorFlowShapeNParser::ParseN(const domi::tensorflow::NodeDef *node, Sha
// The upper caller guarantees the input params is not empty.
domi::tensorflow::AttrValue attr;
const int64_t attr_n = 2;
CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, SHAPEN_ATTR_N, attr), op->N(attr_n); return SUCCESS);
CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, SHAPEN_ATTR_N, attr), op->N(attr_n);
return SUCCESS);

GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "int"), "check Attr N failed");

@@ -133,7 +134,7 @@ Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr
if (ret != SUCCESS) {
return ret;
}
graphStatus status = op_dest->AddDynamicOutputDesc("y", dynamic_tensor_num);
const graphStatus status = op_dest->AddDynamicOutputDesc("y", dynamic_tensor_num);
if (status != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Add Dynamic OuputDesc name:y to node:%s(%s) failed",
op_dest->GetName().c_str(), op_dest->GetType().c_str());
@@ -141,7 +142,7 @@ Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr
return FAILED;
}
}
graphStatus status = op_dest->AddDynamicInputDesc("x", dynamic_tensor_num);
const graphStatus status = op_dest->AddDynamicInputDesc("x", dynamic_tensor_num);
if (status != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Add Dynamic InputDesc name:x to node:%s(%s) failed",
op_dest->GetName().c_str(), op_dest->GetType().c_str());


+ 12
- 9
parser/tensorflow/tensorflow_squeeze_parser.cc View File

@@ -38,27 +38,30 @@ Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &att
auto a_list = attr_value.list();
GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, 0, tf_datatype), domi::PARAM_INVALID,
"parse ge_desc failed.");
uint32_t size_type;
auto data_type = ge_desc.GetDataType();
bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type);
uint32_t size_type = 0U;
const auto data_type = ge_desc.GetDataType();
const bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type);
GE_IF_BOOL_EXEC(!type_ret,
REPORT_CALL_ERROR("E19999", "Data type %s is not supported",
ge::TypeUtils::DataTypeToSerialString(data_type).c_str());
GELOGE(domi::PARAM_INVALID, "Can't GetDataTypeLength of data_type: %s",
GELOGE(domi::PARAM_INVALID, "Can't GetDataTypeLength of data_type: %s.",
ge::TypeUtils::DataTypeToSerialString(data_type).c_str());
return domi::PARAM_INVALID);
// calculate size
int64_t real_size = 1;
for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) {
for (uint32_t j = 0U; j < ge_desc.GetShape().GetDimNum(); ++j) {
int64_t tmp_dim = ge_desc.GetShape().GetDim(j);
GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;);
if (tmp_dim < 0) {
real_size = tmp_dim * (-1) * real_size;
continue;
}
PARSER_INT64_MULCHECK(real_size, tmp_dim);
real_size *= tmp_dim;
}
PARSER_INT64_MULCHECK(real_size, size_type);
ge::TensorUtils::SetSize(ge_desc, real_size * size_type);
ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum());
GELOGD("after translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u",
GELOGD("After translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u",
ge::TypeUtils::DataTypeToSerialString(ge_desc.GetDataType()).c_str(),
ge::TypeUtils::FormatToSerialString(ge_desc.GetFormat()).c_str(), real_size * size_type, size_type);
return SUCCESS;
@@ -80,8 +83,8 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr
domi::tensorflow::AttrValue axis;
domi::tensorflow::AttrValue dims;

bool has_axis = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis);
bool has_dims = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims);
const bool has_axis = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis);
const bool has_dims = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims);
if (!has_axis && !has_dims) {
return SUCCESS;
}


+ 43
- 39
parser/tensorflow/tensorflow_util.cc View File

@@ -30,6 +30,25 @@

using domi::tensorflow::DT_INVALID;

#define VALIDATE_FIELD(attr_value, type, num_set, name, type_string, oneof_case) \
do { \
if ((attr_value).has_list()) { \
if ((attr_value).list().name##_size() > 0) { \
if (type != "list(" type_string ")") { \
GELOGE(FAILED, "GeAttrValue had value with type 'list(" type_string ")'when '%s' expected", type.c_str()); \
return FAILED; \
} \
++(num_set); \
} \
} else if ((attr_value).value_case() == domi::tensorflow::AttrValue::oneof_case) { \
if (type != (type_string)) { \
GELOGE(FAILED, "GeAttrValue had value with type '" type_string "' when '%s' expected", type.c_str()); \
return FAILED; \
} \
++(num_set); \
} \
} while (false)

namespace ge {
/***************************TensorFlow attribute type, constant definition*******************************************/
const std::string TENSORFLOW_ATTR_TYPE_STRING = "string";
@@ -126,7 +145,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrVa
GE_CHECK_NOTNULL(node_def);
const google::protobuf::Map<std::string, domi::tensorflow::AttrValue> &attr = node_def->attr();
const google::protobuf::Map<std::string, domi::tensorflow::AttrValue>::const_iterator it = attr.find(attr_name);
if (it != attr.end()) {
if (it != attr.cend()) {
attr_value = it->second;
return true;
}
@@ -136,34 +155,15 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrVa

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::CheckAttrHasType(
const domi::tensorflow::AttrValue &attr_value, const std::string &type) {
uint32_t num_set = 0;
#define VALIDATE_FIELD(name, type_string, oneof_case) \
do { \
if (attr_value.has_list()) { \
if (attr_value.list().name##_size() > 0) { \
if (type != "list(" type_string ")") { \
GELOGE(FAILED, "GeAttrValue had value with type 'list(" type_string ")'when '%s' expected", type.c_str()); \
return FAILED; \
} \
++num_set; \
} \
} else if (attr_value.value_case() == domi::tensorflow::AttrValue::oneof_case) { \
if (type != (type_string)) { \
GELOGE(FAILED, "GeAttrValue had value with type '" type_string "' when '%s' expected", type.c_str()); \
return FAILED; \
} \
++num_set; \
} \
} while (false)

VALIDATE_FIELD(s, "string", kS);
VALIDATE_FIELD(i, "int", kI);
VALIDATE_FIELD(f, "float", kF);
VALIDATE_FIELD(b, "bool", kB);
VALIDATE_FIELD(type, "type", kType);
VALIDATE_FIELD(shape, "shape", kShape);
VALIDATE_FIELD(tensor, "tensor", kTensor);
VALIDATE_FIELD(func, "func", kFunc);
uint32_t num_set = 0U;
VALIDATE_FIELD(attr_value, type, num_set, s, "string", kS);
VALIDATE_FIELD(attr_value, type, num_set, i, "int", kI);
VALIDATE_FIELD(attr_value, type, num_set, f, "float", kF);
VALIDATE_FIELD(attr_value, type, num_set, b, "bool", kB);
VALIDATE_FIELD(attr_value, type, num_set, type, "type", kType);
VALIDATE_FIELD(attr_value, type, num_set, shape, "shape", kShape);
VALIDATE_FIELD(attr_value, type, num_set, tensor, "tensor", kTensor);
VALIDATE_FIELD(attr_value, type, num_set, func, "func", kFunc);

#undef VALIDATE_FIELD

@@ -173,7 +173,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Ch
}

// Okay to have an empty list, but not to be missing a non-list value.
if ((num_set == 0) && (!ge::StringUtils::StartWith(type, "list("))) {
if ((num_set == 0U) && (!ge::StringUtils::StartWith(type, "list("))) {
GELOGE(FAILED, "GeAttrValue missing value with expected type '%s'", type.c_str());
return FAILED;
}
@@ -210,7 +210,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Pa
const domi::tensorflow::NodeDef *node_src, const std::string &attr_src, domi::tensorflow::DataType &data_type) {
GE_CHECK_NOTNULL(node_src);

std::string node_name = node_src->name();
const std::string &node_name = node_src->name();

// Find the value of attr_src from node_src
domi::tensorflow::AttrValue attr_value;
@@ -236,7 +236,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::ParseFromA
ge_desc.SetOriginFormat(ge::FORMAT_ND);

tf_datatype = a_list.func(i).attr().at(SERIALIZE_DATATYPE).i();
ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_datatype);
const ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_datatype);
GE_CHK_BOOL_RET_STATUS(type != ge::DataType::DT_UNDEFINED, PARAM_INVALID,
"In FrameworkOp translate datatype:%d failed, domi cann't support.", tf_datatype);
ge_desc.SetDataType(type);
@@ -266,21 +266,25 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr
int32_t tf_datatype = 0;
GE_CHK_BOOL_RET_STATUS(ParseFromAttrValueList(ge_desc, a_list, i, tf_datatype), PARAM_INVALID,
"parse ge_desc failed.");
uint32_t size_type = 1;
uint32_t size_type = 1U;
auto data_type = ge_desc.GetDataType();
GE_CHK_BOOL_RET_STATUS(ge::TypeUtils::GetDataTypeLength(data_type, size_type), PARAM_INVALID,
"dataType no define size , parse ge_desc failed.");
// get size
for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) {
for (uint32_t j = 0U; j < ge_desc.GetShape().GetDimNum(); ++j) {
int64_t tmp_dim = ge_desc.GetShape().GetDim(j);

// The shape infered by fusedbatchnormgrad and mean calling tensorflow is not accurate.
// Here, special treatment is given to the two operators.
// Adjust shape to fit resnet50 network only.
GE_IF_BOOL_EXEC((type == ge::parser::FUSEDBATCHNORMGRAD) && (tmp_dim == 0), ge_desc.SetShape(ge::GeShape());
break;);
GE_IF_BOOL_EXEC((type == ge::parser::MEAN) && (tmp_dim == 0), std::vector<int64_t> data_dim = {tmp_dim};
ge_desc.SetShape(ge::GeShape(data_dim)); break;);
if ((type == ge::parser::FUSEDBATCHNORMGRAD) && (tmp_dim == 0)) {
ge_desc.SetShape(ge::GeShape());
break;
}
if ((type == ge::parser::MEAN) && (tmp_dim == 0)) {
std::vector<int64_t> data_dim = {tmp_dim};
ge_desc.SetShape(ge::GeShape(data_dim));
break;
}
}
ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum());
GELOGD("IO:%d: after translate tf_desc, datatype: %s, format: %s, size_type: %u", io,


+ 37
- 29
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -158,7 +158,7 @@ void STestTensorflowParser::RegisterCustomOp() {
domi::OpRegistry::Instance()->registrationDatas.clear();
}

extern void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr parent_node, ge::NodePtr node);
void AddDumpOriginName(const ge::NodePtr parent_node, const std::string& subgraph_name, ge::ComputeGraphPtr graph);

namespace {
NodeDef* AddNode(GraphDef& graph, string type, string name) {
@@ -4291,35 +4291,43 @@ TEST_F(STestTensorflowParser, tensorflow_optimizer_fmk_fusion_op) {
TEST_F(STestTensorflowParser, AddDumpOriginName_test)
{
GeTensorDesc scalar_tensor(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT);
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(parser::WHILE);
data_op->SetName("WHILE0");
data_op->AddInputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr while0 = graph->AddNode(data_op);

data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(parser::LOOPCOND);
data_op->SetName("COND0");
data_op->AddInputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr cond0 = graph->AddNode(data_op);
AddDumpOriginName(std::string("while"), while0, cond0);

data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(parser::DATA);
data_op->SetName("Data1");
data_op->AddInputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr data1 = graph->AddNode(data_op);
AddDumpOriginName(std::string("cond"), cond0, data1);

auto desc = data1->GetOpDesc();
ge::ComputeGraphPtr parent_graph = std::make_shared<ge::ComputeGraph>("parent_graph");
ge::OpDescPtr parent = std::make_shared<ge::OpDesc>();
parent->SetType("Foo");
parent->SetName("foo");
ge::NodePtr foo = parent_graph->AddNode(parent);


ge::ComputeGraphPtr sub_graph = std::make_shared<ge::ComputeGraph>("sub_graph");
auto child = std::make_shared<ge::OpDesc>();
child->SetType("Bar");
child->SetName("bar");
ge::NodePtr bar = sub_graph->AddNode(child);

AddDumpOriginName(foo, "f", sub_graph);

std::vector<std::string> original_names;
(void)ge::AttrUtils::GetListStr(desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
EXPECT_EQ(original_names.empty(), false);
EXPECT_EQ(original_names[0], "WHILE0/while/COND0/cond/Data1");
(void)ge::AttrUtils::GetListStr(bar->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
EXPECT_EQ(original_names.size(), 1U);
EXPECT_EQ(original_names[0], "foo/f/bar");

(void)ge::AttrUtils::SetListStr(foo->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
AddDumpOriginName(foo, "f", sub_graph);

original_names.clear();
(void)ge::AttrUtils::GetListStr(bar->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
EXPECT_EQ(original_names.size(), 1U);
EXPECT_EQ(original_names[0], "foo/f/bar/f/bar");

original_names.push_back("abc");
(void)ge::AttrUtils::SetListStr(foo->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
AddDumpOriginName(foo, "f", sub_graph);

original_names.clear();
(void)ge::AttrUtils::GetListStr(bar->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
EXPECT_EQ(original_names.size(), 2U);
EXPECT_EQ(original_names[0], "foo/f/bar/f/bar/f/bar");
EXPECT_EQ(original_names[1], "abc");
}

} // namespace ge

+ 34
- 26
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -169,7 +169,7 @@ static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_de
return SUCCESS;
}

extern void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr parent_node, ge::NodePtr node);
void AddDumpOriginName(const ge::NodePtr parent_node, const std::string& subgraph_name, ge::ComputeGraphPtr graph);

void UtestTensorflowParser::RegisterCustomOp() {
REGISTER_CUSTOM_OP("Add")
@@ -4779,35 +4779,43 @@ TEST_F(UtestTensorflowParser, tensorflow_ComputeArgRange)
TEST_F(UtestTensorflowParser, AddDumpOriginName_test)
{
GeTensorDesc scalar_tensor(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT);
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(parser::WHILE);
data_op->SetName("WHILE0");
data_op->AddInputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr while0 = graph->AddNode(data_op);
ge::ComputeGraphPtr parent_graph = std::make_shared<ge::ComputeGraph>("parent_graph");
ge::OpDescPtr parent = std::make_shared<ge::OpDesc>();
parent->SetType("Foo");
parent->SetName("foo");
ge::NodePtr foo = parent_graph->AddNode(parent);

data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(parser::LOOPCOND);
data_op->SetName("COND0");
data_op->AddInputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr cond0 = graph->AddNode(data_op);
AddDumpOriginName(std::string("while"), while0, cond0);

data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(parser::DATA);
data_op->SetName("Data1");
data_op->AddInputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr data1 = graph->AddNode(data_op);
AddDumpOriginName(std::string("cond"), cond0, data1);
ge::ComputeGraphPtr sub_graph = std::make_shared<ge::ComputeGraph>("sub_graph");
auto child = std::make_shared<ge::OpDesc>();
child->SetType("Bar");
child->SetName("bar");
ge::NodePtr bar = sub_graph->AddNode(child);

AddDumpOriginName(foo, "f", sub_graph);

auto desc = data1->GetOpDesc();
std::vector<std::string> original_names;
(void)ge::AttrUtils::GetListStr(desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
EXPECT_EQ(original_names.empty(), false);
EXPECT_EQ(original_names[0], "WHILE0/while/COND0/cond/Data1");
(void)ge::AttrUtils::GetListStr(bar->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
EXPECT_EQ(original_names.size(), 1U);
EXPECT_EQ(original_names[0], "foo/f/bar");

(void)ge::AttrUtils::SetListStr(foo->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
AddDumpOriginName(foo, "f", sub_graph);

original_names.clear();
(void)ge::AttrUtils::GetListStr(bar->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
EXPECT_EQ(original_names.size(), 1U);
EXPECT_EQ(original_names[0], "foo/f/bar/f/bar");

original_names.push_back("abc");
(void)ge::AttrUtils::SetListStr(foo->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
AddDumpOriginName(foo, "f", sub_graph);

original_names.clear();
(void)ge::AttrUtils::GetListStr(bar->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names);
EXPECT_EQ(original_names.size(), 2U);
EXPECT_EQ(original_names[0], "foo/f/bar/f/bar/f/bar");
EXPECT_EQ(original_names[1], "abc");
}

} // namespace ge

Loading…
Cancel
Save