Browse Source

Pre Merge pull request !631 from TangQunzhang/ge_dev

pull/631/MERGE
TangQunzhang Gitee 2 years ago
parent
commit
c982aba92a
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 98 additions and 79 deletions
  1. +1
    -1
      parser/tensorflow/tensorflow_constant_parser.h
  2. +1
    -1
      parser/tensorflow/tensorflow_frameworkop_parser.cc
  3. +1
    -1
      parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc
  4. +8
    -8
      parser/tensorflow/tensorflow_fusion_op_parser.cc
  5. +1
    -1
      parser/tensorflow/tensorflow_fusion_op_parser.h
  6. +1
    -1
      parser/tensorflow/tensorflow_merge_parser.cc
  7. +1
    -1
      parser/tensorflow/tensorflow_merge_parser.h
  8. +14
    -6
      parser/tensorflow/tensorflow_parser.cc
  9. +1
    -1
      parser/tensorflow/tensorflow_parser_register.h
  10. +2
    -2
      parser/tensorflow/tensorflow_ref_switch_parser.cc
  11. +8
    -5
      parser/tensorflow/tensorflow_reshape_parser.cc
  12. +4
    -3
      parser/tensorflow/tensorflow_shape_n_parser.cc
  13. +12
    -9
      parser/tensorflow/tensorflow_squeeze_parser.cc
  14. +43
    -39
      parser/tensorflow/tensorflow_util.cc

+ 1
- 1
parser/tensorflow/tensorflow_constant_parser.h View File

@@ -21,9 +21,9 @@
#include "parser/common/data_op_parser.h"
#include "parser/tensorflow/tensorflow_op_parser.h"

namespace ge {
using domi::tensorflow::NodeDef;

namespace ge {
class PARSER_FUNC_VISIBILITY TensorFlowConstantParser : public TensorFlowOpParser {
public:
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override;


+ 1
- 1
parser/tensorflow/tensorflow_frameworkop_parser.cc View File

@@ -31,7 +31,7 @@ namespace ge {
Status ParseParams(const Message *op_src, FrameworkOpOperator *op) {
GE_CHECK_NOTNULL(op_src);
GE_CHECK_NOTNULL(op);
const domi::tensorflow::NodeDef *node = reinterpret_cast<const domi::tensorflow::NodeDef *>(op_src);
const domi::tensorflow::NodeDef *node = static_cast<const domi::tensorflow::NodeDef *>(op_src);
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str());
string type = node->op();



+ 1
- 1
parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc View File

@@ -33,7 +33,7 @@ Status TensorFlowFusionCustomParserAdapter::ParseParams(const vector<const NodeD
std::vector<const google::protobuf::Message *> inside_nodes;
for (auto inside_node : v_input_const) {
GE_CHECK_NOTNULL(inside_node);
const google::protobuf::Message *node_src = reinterpret_cast<const google::protobuf::Message *>(inside_node);
const google::protobuf::Message *node_src = dynamic_cast<const google::protobuf::Message *>(inside_node);
inside_nodes.push_back(node_src);
}
std::string ori_type = op_dest->GetType();


+ 8
- 8
parser/tensorflow/tensorflow_fusion_op_parser.cc View File

@@ -31,21 +31,21 @@ namespace ge {
do { \
google::protobuf::RepeatedField<FIELD> val_vec; \
int32_t val_size = 0; \
val_vec = (tensor).FIELD##_val(); \
val_vec = (tensor).FIELD##_val(); \
val_size = val_vec.size(); \
if ((index) < val_size) { \
(param) = val_vec.Get(index); \
} else if ((tensor).has_tensor_shape()) { \
const std::string tensor_content = (tensor).tensor_content(); \
char *buf = const_cast<char *>(tensor_content.data()); \
FIELD *buf_v = reinterpret_cast<FIELD *>(buf); \
if ((index) < val_size) { \
(param) = val_vec.Get(index); \
} else if ((tensor).has_tensor_shape()) { \
const std::string &tensor_content = (tensor).tensor_content(); \
const char *buf = tensor_content.data(); \
const FIELD *buf_v = reinterpret_cast<const FIELD *>(buf); \
if (static_cast<uint32_t>(index) >= tensor_content.length() / sizeof(FIELD)) { \
REPORT_INNER_ERROR("E19999", "Const data size of node:%s is smaller than index:%d, not supported!", \
node_def->name().c_str(), index); \
GELOGE(domi::PARAM_INVALID, "Const data size is smaller than index :%d,not supported!", index); \
return domi::PARAM_INVALID; \
} \
(param) = buf_v[index]; \
(param) = buf_v[index]; \
} else { \
REPORT_INNER_ERROR("E19999", "Const data size of node:%s is smaller than index:%d, not supported!", \
node_def->name().c_str(), index); \


+ 1
- 1
parser/tensorflow/tensorflow_fusion_op_parser.h View File

@@ -25,11 +25,11 @@
#include "proto/tensorflow/graph.pb.h"
#include "proto/tensorflow/node_def.pb.h"

namespace ge {
using google::protobuf::Message;
using domi::tensorflow::NodeDef;
using domi::tensorflow::TensorProto;

namespace ge {
/**
* @ingroup domi_omg
* @brief Used to parse TensorFlow operator information


+ 1
- 1
parser/tensorflow/tensorflow_merge_parser.cc View File

@@ -39,7 +39,7 @@ Status TensorFlowMergeParser::ParseParams(const Message *op_src, ge::OpDescPtr &
int32_t input_tensor_num = attr_num.i();

// add dynamic input
graphStatus ret = op_desc->AddDynamicInputDesc("x", input_tensor_num);
const graphStatus ret = op_desc->AddDynamicInputDesc("x", input_tensor_num);
if (ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Add Dynamic InputDesc name:x to node:%s(%s) failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str());


+ 1
- 1
parser/tensorflow/tensorflow_merge_parser.h View File

@@ -25,4 +25,4 @@ class PARSER_FUNC_VISIBILITY TensorFlowMergeParser : public TensorFlowOpParser {
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) override;
};
} // namespace ge
#endif //_DOMI_OMG_PARSER_TENSORFLOW_TENSORFLOW_MERGE_PARSER_H_
#endif // _DOMI_OMG_PARSER_TENSORFLOW_TENSORFLOW_MERGE_PARSER_H_

+ 14
- 6
parser/tensorflow/tensorflow_parser.cc View File

@@ -913,8 +913,10 @@ Status TensorFlowModelParser::CheckOpType(const domi::tensorflow::NodeDef *node_
GE_IF_BOOL_EXEC(
op_type == ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS,
GE_CHK_STATUS_RET(CheckOpShapeDim(node_def, check_dims[op_type], valid), "failed to check op shape");
GE_IF_BOOL_EXEC(!valid, op_type = ge::parser::FRAMEWORKOP; GELOGI("Set op %s to frameworkop", node_name.c_str());
framework_ops_[node_name] = node_def;););
GE_IF_BOOL_EXEC(!valid, op_type = ge::parser::FRAMEWORKOP;
GELOGI("Set op %s to frameworkop", node_name.c_str());
framework_ops_[node_name] = node_def;);
);

GE_IF_BOOL_EXEC(
op_type == ge::parser::ADD || op_type == ge::parser::MULTIPLY || op_type == ge::parser::MEAN,
@@ -974,7 +976,8 @@ Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::Co
GELOGD("CCE %s parsering", node_op.c_str()););
GE_IF_BOOL_EXEC((implyType == domi::ImplyType::HCCL) && (op_type != ge::parser::FRAMEWORKOP),
GELOGD("HCCL %s parsering", node_op.c_str()););
GE_IF_BOOL_EXEC(op_type == ge::parser::FRAMEWORKOP, GELOGD("FRAMEWORKOP %s parsering", node_op.c_str()););
GE_IF_BOOL_EXEC(op_type == ge::parser::FRAMEWORKOP,
GELOGD("FRAMEWORKOP %s parsering", node_op.c_str()););
GELOGD("TF op node name = %s, op type= %s, trans to op type %s", node_name.c_str(), node_op.c_str(), op_type.c_str());

// Construct operator by IR
@@ -2326,7 +2329,8 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto,
// Loop analysis of op_nodes and map them to nodes in graph
ret = AddFmkNode(graph, scope_graph, op_node_name_list, isDatasetInit);
PARSER_TIMESTAMP_END(AddFmkNode, "TensorFlowModelParser::AddFmkNode");
GE_CHK_STATUS_EXEC(ret, DeleteFuisonNodeDef(); return ret, "AddFmkNode failed");
GE_CHK_STATUS_EXEC(ret, DeleteFuisonNodeDef();
return ret, "AddFmkNode failed");
GELOGD("[TF Parser] Add framework node success");

ret = AddEdges(graph);
@@ -3008,12 +3012,16 @@ Status TensorFlowModelParser::GetFormatTranspose(const NodeDef *transpose_node,
GE_IF_BOOL_EXEC(
type == domi::tensorflow::DT_INT32,
const int32_t *data = reinterpret_cast<const int32_t *>(tensor.tensor_content().data());
for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) { perm_value.push_back(data[i]); });
for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) {
perm_value.push_back(data[i]);
});

GE_IF_BOOL_EXEC(
type == domi::tensorflow::DT_INT64,
const int64_t *data = reinterpret_cast<const int64_t *>(tensor.tensor_content().data());
for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) { perm_value.push_back(data[i]); });
for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) {
perm_value.push_back(data[i]);
});

// 0, 1, 2, 3 present dim num.
vector<int64_t> perm_to_nchw = {0, 3, 1, 2};


+ 1
- 1
parser/tensorflow/tensorflow_parser_register.h View File

@@ -61,7 +61,7 @@ class PARSER_FUNC_VISIBILITY TensorflowParserBuilder : public TensorflowWeightPa
public:
using ParseParamsFn = std::function<domi::Status(const domi::tensorflow::NodeDef *, Param *)>;

explicit TensorflowParserBuilder(const std::string &davinci_optype) : davinci_optype_(davinci_optype) {}
explicit TensorflowParserBuilder(const std::string &davinci_optype) noexcept : davinci_optype_(davinci_optype) {}

~TensorflowParserBuilder() override {}



+ 2
- 2
parser/tensorflow/tensorflow_ref_switch_parser.cc View File

@@ -40,8 +40,8 @@ Status TensorFlowRefSwitchParser::ParseT(const domi::tensorflow::NodeDef *node,

GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "type"), "check Attr T failed");

domi::tensorflow::DataType tfType = attr.type();
ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tfType);
const domi::tensorflow::DataType tfType = attr.type();
const ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tfType);
CHECK_FALSE_EXEC(type != ge::DataType::DT_UNDEFINED,
REPORT_CALL_ERROR("E19999", "Data type %s of node %s is not supported",
DataType_Name(tfType).c_str(),


+ 8
- 5
parser/tensorflow/tensorflow_reshape_parser.cc View File

@@ -34,26 +34,29 @@ Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &att
auto a_list = attr_value.list();
GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, 0, tf_datatype), PARAM_INVALID,
"parse ge_desc failed.");
uint32_t size_type = 1;
uint32_t size_type = 1U;
auto data_type = ge_desc.GetDataType();
bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type);
GE_IF_BOOL_EXEC(!type_ret,
REPORT_CALL_ERROR("E19999", "Data type %s is not supported",
ge::TypeUtils::DataTypeToSerialString(data_type).c_str());
GELOGE(FAILED, "Can't GetDataTypeLength of data_type: %s",
GELOGE(FAILED, "Can't GetDataTypeLength of data_type: %s.",
ge::TypeUtils::DataTypeToSerialString(data_type).c_str());
return PARAM_INVALID);
// calculate size
int64_t real_size = 1;
for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) {
for (uint32_t j = 0U; j < ge_desc.GetShape().GetDimNum(); ++j) {
int64_t tmp_dim = ge_desc.GetShape().GetDim(j);
GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;);
if (tmp_dim < 0) {
real_size = tmp_dim * (-1) * real_size;
continue;
}
real_size *= tmp_dim;
}
PARSER_INT64_MULCHECK(real_size, size_type);
ge::TensorUtils::SetSize(ge_desc, real_size * size_type);
ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum());
GELOGI("after translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u",
GELOGI("After translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u",
ge::TypeUtils::DataTypeToSerialString(ge_desc.GetDataType()).c_str(),
ge::TypeUtils::FormatToSerialString(ge_desc.GetFormat()).c_str(), real_size * size_type, size_type);
return SUCCESS;


+ 4
- 3
parser/tensorflow/tensorflow_shape_n_parser.cc View File

@@ -83,7 +83,8 @@ Status TensorFlowShapeNParser::ParseN(const domi::tensorflow::NodeDef *node, Sha
// The upper caller guarantees the input params is not empty.
domi::tensorflow::AttrValue attr;
const int64_t attr_n = 2;
CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, SHAPEN_ATTR_N, attr), op->N(attr_n); return SUCCESS);
CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, SHAPEN_ATTR_N, attr), op->N(attr_n);
return SUCCESS);

GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "int"), "check Attr N failed");

@@ -133,7 +134,7 @@ Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr
if (ret != SUCCESS) {
return ret;
}
graphStatus status = op_dest->AddDynamicOutputDesc("y", dynamic_tensor_num);
const graphStatus status = op_dest->AddDynamicOutputDesc("y", dynamic_tensor_num);
if (status != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Add Dynamic OuputDesc name:y to node:%s(%s) failed",
op_dest->GetName().c_str(), op_dest->GetType().c_str());
@@ -141,7 +142,7 @@ Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr
return FAILED;
}
}
graphStatus status = op_dest->AddDynamicInputDesc("x", dynamic_tensor_num);
const graphStatus status = op_dest->AddDynamicInputDesc("x", dynamic_tensor_num);
if (status != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Add Dynamic InputDesc name:x to node:%s(%s) failed",
op_dest->GetName().c_str(), op_dest->GetType().c_str());


+ 12
- 9
parser/tensorflow/tensorflow_squeeze_parser.cc View File

@@ -38,27 +38,30 @@ Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &att
auto a_list = attr_value.list();
GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, 0, tf_datatype), domi::PARAM_INVALID,
"parse ge_desc failed.");
uint32_t size_type;
auto data_type = ge_desc.GetDataType();
bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type);
uint32_t size_type = 0U;
const auto data_type = ge_desc.GetDataType();
const bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type);
GE_IF_BOOL_EXEC(!type_ret,
REPORT_CALL_ERROR("E19999", "Data type %s is not supported",
ge::TypeUtils::DataTypeToSerialString(data_type).c_str());
GELOGE(domi::PARAM_INVALID, "Can't GetDataTypeLength of data_type: %s",
GELOGE(domi::PARAM_INVALID, "Can't GetDataTypeLength of data_type: %s.",
ge::TypeUtils::DataTypeToSerialString(data_type).c_str());
return domi::PARAM_INVALID);
// calculate size
int64_t real_size = 1;
for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) {
for (uint32_t j = 0U; j < ge_desc.GetShape().GetDimNum(); ++j) {
int64_t tmp_dim = ge_desc.GetShape().GetDim(j);
GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;);
if (tmp_dim < 0) {
real_size = tmp_dim * (-1) * real_size;
continue;
}
PARSER_INT64_MULCHECK(real_size, tmp_dim);
real_size *= tmp_dim;
}
PARSER_INT64_MULCHECK(real_size, size_type);
ge::TensorUtils::SetSize(ge_desc, real_size * size_type);
ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum());
GELOGD("after translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u",
GELOGD("After translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u",
ge::TypeUtils::DataTypeToSerialString(ge_desc.GetDataType()).c_str(),
ge::TypeUtils::FormatToSerialString(ge_desc.GetFormat()).c_str(), real_size * size_type, size_type);
return SUCCESS;
@@ -80,8 +83,8 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr
domi::tensorflow::AttrValue axis;
domi::tensorflow::AttrValue dims;

bool has_axis = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis);
bool has_dims = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims);
const bool has_axis = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis);
const bool has_dims = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims);
if (!has_axis && !has_dims) {
return SUCCESS;
}


+ 43
- 39
parser/tensorflow/tensorflow_util.cc View File

@@ -30,6 +30,25 @@

using domi::tensorflow::DT_INVALID;

#define VALIDATE_FIELD(attr_value, type, num_set, name, type_string, oneof_case) \
do { \
if ((attr_value).has_list()) { \
if ((attr_value).list().name##_size() > 0) { \
if (type != "list(" type_string ")") { \
GELOGE(FAILED, "GeAttrValue had value with type 'list(" type_string ")'when '%s' expected", type.c_str()); \
return FAILED; \
} \
++(num_set); \
} \
} else if ((attr_value).value_case() == domi::tensorflow::AttrValue::oneof_case) { \
if (type != (type_string)) { \
GELOGE(FAILED, "GeAttrValue had value with type '" type_string "' when '%s' expected", type.c_str()); \
return FAILED; \
} \
++(num_set); \
} \
} while (false)

namespace ge {
/***************************TensorFlow attribute type, constant definition*******************************************/
const std::string TENSORFLOW_ATTR_TYPE_STRING = "string";
@@ -126,7 +145,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrVa
GE_CHECK_NOTNULL(node_def);
const google::protobuf::Map<std::string, domi::tensorflow::AttrValue> &attr = node_def->attr();
const google::protobuf::Map<std::string, domi::tensorflow::AttrValue>::const_iterator it = attr.find(attr_name);
if (it != attr.end()) {
if (it != attr.cend()) {
attr_value = it->second;
return true;
}
@@ -136,34 +155,15 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrVa

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::CheckAttrHasType(
const domi::tensorflow::AttrValue &attr_value, const std::string &type) {
uint32_t num_set = 0;
#define VALIDATE_FIELD(name, type_string, oneof_case) \
do { \
if (attr_value.has_list()) { \
if (attr_value.list().name##_size() > 0) { \
if (type != "list(" type_string ")") { \
GELOGE(FAILED, "GeAttrValue had value with type 'list(" type_string ")'when '%s' expected", type.c_str()); \
return FAILED; \
} \
++num_set; \
} \
} else if (attr_value.value_case() == domi::tensorflow::AttrValue::oneof_case) { \
if (type != (type_string)) { \
GELOGE(FAILED, "GeAttrValue had value with type '" type_string "' when '%s' expected", type.c_str()); \
return FAILED; \
} \
++num_set; \
} \
} while (false)

VALIDATE_FIELD(s, "string", kS);
VALIDATE_FIELD(i, "int", kI);
VALIDATE_FIELD(f, "float", kF);
VALIDATE_FIELD(b, "bool", kB);
VALIDATE_FIELD(type, "type", kType);
VALIDATE_FIELD(shape, "shape", kShape);
VALIDATE_FIELD(tensor, "tensor", kTensor);
VALIDATE_FIELD(func, "func", kFunc);
uint32_t num_set = 0U;
VALIDATE_FIELD(attr_value, type, num_set, s, "string", kS);
VALIDATE_FIELD(attr_value, type, num_set, i, "int", kI);
VALIDATE_FIELD(attr_value, type, num_set, f, "float", kF);
VALIDATE_FIELD(attr_value, type, num_set, b, "bool", kB);
VALIDATE_FIELD(attr_value, type, num_set, type, "type", kType);
VALIDATE_FIELD(attr_value, type, num_set, shape, "shape", kShape);
VALIDATE_FIELD(attr_value, type, num_set, tensor, "tensor", kTensor);
VALIDATE_FIELD(attr_value, type, num_set, func, "func", kFunc);

#undef VALIDATE_FIELD

@@ -173,7 +173,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Ch
}

// Okay to have an empty list, but not to be missing a non-list value.
if ((num_set == 0) && (!ge::StringUtils::StartWith(type, "list("))) {
if ((num_set == 0U) && (!ge::StringUtils::StartWith(type, "list("))) {
GELOGE(FAILED, "GeAttrValue missing value with expected type '%s'", type.c_str());
return FAILED;
}
@@ -210,7 +210,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Pa
const domi::tensorflow::NodeDef *node_src, const std::string &attr_src, domi::tensorflow::DataType &data_type) {
GE_CHECK_NOTNULL(node_src);

std::string node_name = node_src->name();
const std::string &node_name = node_src->name();

// Find the value of attr_src from node_src
domi::tensorflow::AttrValue attr_value;
@@ -236,7 +236,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::ParseFromA
ge_desc.SetOriginFormat(ge::FORMAT_ND);

tf_datatype = a_list.func(i).attr().at(SERIALIZE_DATATYPE).i();
ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_datatype);
const ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_datatype);
GE_CHK_BOOL_RET_STATUS(type != ge::DataType::DT_UNDEFINED, PARAM_INVALID,
"In FrameworkOp translate datatype:%d failed, domi cann't support.", tf_datatype);
ge_desc.SetDataType(type);
@@ -266,21 +266,25 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr
int32_t tf_datatype = 0;
GE_CHK_BOOL_RET_STATUS(ParseFromAttrValueList(ge_desc, a_list, i, tf_datatype), PARAM_INVALID,
"parse ge_desc failed.");
uint32_t size_type = 1;
uint32_t size_type = 1U;
auto data_type = ge_desc.GetDataType();
GE_CHK_BOOL_RET_STATUS(ge::TypeUtils::GetDataTypeLength(data_type, size_type), PARAM_INVALID,
"dataType no define size , parse ge_desc failed.");
// get size
for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) {
for (uint32_t j = 0U; j < ge_desc.GetShape().GetDimNum(); ++j) {
int64_t tmp_dim = ge_desc.GetShape().GetDim(j);

// The shape infered by fusedbatchnormgrad and mean calling tensorflow is not accurate.
// Here, special treatment is given to the two operators.
// Adjust shape to fit resnet50 network only.
GE_IF_BOOL_EXEC((type == ge::parser::FUSEDBATCHNORMGRAD) && (tmp_dim == 0), ge_desc.SetShape(ge::GeShape());
break;);
GE_IF_BOOL_EXEC((type == ge::parser::MEAN) && (tmp_dim == 0), std::vector<int64_t> data_dim = {tmp_dim};
ge_desc.SetShape(ge::GeShape(data_dim)); break;);
if ((type == ge::parser::FUSEDBATCHNORMGRAD) && (tmp_dim == 0)) {
ge_desc.SetShape(ge::GeShape());
break;
}
if ((type == ge::parser::MEAN) && (tmp_dim == 0)) {
std::vector<int64_t> data_dim = {tmp_dim};
ge_desc.SetShape(ge::GeShape(data_dim));
break;
}
}
ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum());
GELOGD("IO:%d: after translate tf_desc, datatype: %s, format: %s, size_type: %u", io,


Loading…
Cancel
Save