@@ -1 +1 @@ | |||||
Subproject commit 5894aaa3c0cb12565e5e57a0a49f3e732e608f3d | |||||
Subproject commit 5d062a35640733026457c91966a558769570b0f8 |
@@ -21,9 +21,9 @@ | |||||
#include "parser/common/data_op_parser.h" | #include "parser/common/data_op_parser.h" | ||||
#include "parser/tensorflow/tensorflow_op_parser.h" | #include "parser/tensorflow/tensorflow_op_parser.h" | ||||
namespace ge { | |||||
using domi::tensorflow::NodeDef; | using domi::tensorflow::NodeDef; | ||||
namespace ge { | |||||
class PARSER_FUNC_VISIBILITY TensorFlowConstantParser : public TensorFlowOpParser { | class PARSER_FUNC_VISIBILITY TensorFlowConstantParser : public TensorFlowOpParser { | ||||
public: | public: | ||||
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | ||||
@@ -31,7 +31,7 @@ namespace ge { | |||||
Status ParseParams(const Message *op_src, FrameworkOpOperator *op) { | Status ParseParams(const Message *op_src, FrameworkOpOperator *op) { | ||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
const domi::tensorflow::NodeDef *node = reinterpret_cast<const domi::tensorflow::NodeDef *>(op_src); | |||||
const domi::tensorflow::NodeDef *node = static_cast<const domi::tensorflow::NodeDef *>(op_src); | |||||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | ||||
string type = node->op(); | string type = node->op(); | ||||
@@ -33,7 +33,7 @@ Status TensorFlowFusionCustomParserAdapter::ParseParams(const vector<const NodeD | |||||
std::vector<const google::protobuf::Message *> inside_nodes; | std::vector<const google::protobuf::Message *> inside_nodes; | ||||
for (auto inside_node : v_input_const) { | for (auto inside_node : v_input_const) { | ||||
GE_CHECK_NOTNULL(inside_node); | GE_CHECK_NOTNULL(inside_node); | ||||
const google::protobuf::Message *node_src = reinterpret_cast<const google::protobuf::Message *>(inside_node); | |||||
const google::protobuf::Message *node_src = dynamic_cast<const google::protobuf::Message *>(inside_node); | |||||
inside_nodes.push_back(node_src); | inside_nodes.push_back(node_src); | ||||
} | } | ||||
std::string ori_type = op_dest->GetType(); | std::string ori_type = op_dest->GetType(); | ||||
@@ -31,21 +31,21 @@ namespace ge { | |||||
do { \ | do { \ | ||||
google::protobuf::RepeatedField<FIELD> val_vec; \ | google::protobuf::RepeatedField<FIELD> val_vec; \ | ||||
int32_t val_size = 0; \ | int32_t val_size = 0; \ | ||||
val_vec = (tensor).FIELD##_val(); \ | |||||
val_vec = (tensor).FIELD##_val(); \ | |||||
val_size = val_vec.size(); \ | val_size = val_vec.size(); \ | ||||
if ((index) < val_size) { \ | |||||
(param) = val_vec.Get(index); \ | |||||
} else if ((tensor).has_tensor_shape()) { \ | |||||
const std::string tensor_content = (tensor).tensor_content(); \ | |||||
char *buf = const_cast<char *>(tensor_content.data()); \ | |||||
FIELD *buf_v = reinterpret_cast<FIELD *>(buf); \ | |||||
if ((index) < val_size) { \ | |||||
(param) = val_vec.Get(index); \ | |||||
} else if ((tensor).has_tensor_shape()) { \ | |||||
const std::string &tensor_content = (tensor).tensor_content(); \ | |||||
const char *buf = tensor_content.data(); \ | |||||
const FIELD *buf_v = reinterpret_cast<const FIELD *>(buf); \ | |||||
if (static_cast<uint32_t>(index) >= tensor_content.length() / sizeof(FIELD)) { \ | if (static_cast<uint32_t>(index) >= tensor_content.length() / sizeof(FIELD)) { \ | ||||
REPORT_INNER_ERROR("E19999", "Const data size of node:%s is smaller than index:%d, not supported!", \ | REPORT_INNER_ERROR("E19999", "Const data size of node:%s is smaller than index:%d, not supported!", \ | ||||
node_def->name().c_str(), index); \ | node_def->name().c_str(), index); \ | ||||
GELOGE(domi::PARAM_INVALID, "Const data size is smaller than index :%d,not supported!", index); \ | GELOGE(domi::PARAM_INVALID, "Const data size is smaller than index :%d,not supported!", index); \ | ||||
return domi::PARAM_INVALID; \ | return domi::PARAM_INVALID; \ | ||||
} \ | } \ | ||||
(param) = buf_v[index]; \ | |||||
(param) = buf_v[index]; \ | |||||
} else { \ | } else { \ | ||||
REPORT_INNER_ERROR("E19999", "Const data size of node:%s is smaller than index:%d, not supported!", \ | REPORT_INNER_ERROR("E19999", "Const data size of node:%s is smaller than index:%d, not supported!", \ | ||||
node_def->name().c_str(), index); \ | node_def->name().c_str(), index); \ | ||||
@@ -25,11 +25,11 @@ | |||||
#include "proto/tensorflow/graph.pb.h" | #include "proto/tensorflow/graph.pb.h" | ||||
#include "proto/tensorflow/node_def.pb.h" | #include "proto/tensorflow/node_def.pb.h" | ||||
namespace ge { | |||||
using google::protobuf::Message; | using google::protobuf::Message; | ||||
using domi::tensorflow::NodeDef; | using domi::tensorflow::NodeDef; | ||||
using domi::tensorflow::TensorProto; | using domi::tensorflow::TensorProto; | ||||
namespace ge { | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
* @brief Used to parse TensorFlow operator information | * @brief Used to parse TensorFlow operator information | ||||
@@ -39,7 +39,7 @@ Status TensorFlowMergeParser::ParseParams(const Message *op_src, ge::OpDescPtr & | |||||
int32_t input_tensor_num = attr_num.i(); | int32_t input_tensor_num = attr_num.i(); | ||||
// add dynamic input | // add dynamic input | ||||
graphStatus ret = op_desc->AddDynamicInputDesc("x", input_tensor_num); | |||||
const graphStatus ret = op_desc->AddDynamicInputDesc("x", input_tensor_num); | |||||
if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "Add Dynamic InputDesc name:x to node:%s(%s) failed", | REPORT_CALL_ERROR("E19999", "Add Dynamic InputDesc name:x to node:%s(%s) failed", | ||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
@@ -25,4 +25,4 @@ class PARSER_FUNC_VISIBILITY TensorFlowMergeParser : public TensorFlowOpParser { | |||||
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) override; | Status ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) override; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif //_DOMI_OMG_PARSER_TENSORFLOW_TENSORFLOW_MERGE_PARSER_H_ | |||||
#endif // _DOMI_OMG_PARSER_TENSORFLOW_TENSORFLOW_MERGE_PARSER_H_ |
@@ -191,21 +191,25 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map<Ascend | |||||
GELOGI("AclgrphParse graph %s success.", ParserUtils::GetGraphName(graph).c_str()); | GELOGI("AclgrphParse graph %s success.", ParserUtils::GetGraphName(graph).c_str()); | ||||
return ge::SUCCESS; | return ge::SUCCESS; | ||||
} | } | ||||
void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr parent_node, ge::NodePtr node) | |||||
{ | |||||
void AddDumpOriginName(const ge::NodePtr parent_node, const std::string& subgraph_name, ge::ComputeGraphPtr graph) { | |||||
if (parent_node == nullptr) { | |||||
return; // Root graph no need set dump origin name as parser always keep the origin node name | |||||
} | |||||
std::vector<std::string> original_names; | std::vector<std::string> original_names; | ||||
auto parend_desc = parent_node->GetOpDesc(); | |||||
(void)ge::AttrUtils::GetListStr(parend_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||||
(void)ge::AttrUtils::GetListStr(parent_node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||||
if (original_names.empty()) { | if (original_names.empty()) { | ||||
original_names.emplace_back(parent_node->GetName()); | original_names.emplace_back(parent_node->GetName()); | ||||
} | } | ||||
// for fusion node also used original_names[0] | // for fusion node also used original_names[0] | ||||
(void)original_names[0].append("/").append(subgraph_name).append("/").append(node->GetName()); | |||||
if (!ge::AttrUtils::SetListStr(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names)) { | |||||
GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), node->GetOpDesc()->GetName().c_str()); | |||||
std::string prefix = original_names[0].append("/").append(subgraph_name).append("/"); | |||||
for (const ge::NodePtr &node : graph->GetDirectNode()) { | |||||
original_names[0] = prefix + node->GetName(); | |||||
if (!ge::AttrUtils::SetListStr(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names)) { | |||||
GELOGW("Set dump origin name to %s fail.", node->GetOpDesc()->GetName().c_str()); | |||||
} | |||||
GELOGD("Add dump origin name %s for node %s.", original_names[0].c_str(), node->GetName().c_str()); | |||||
} | } | ||||
GELOGD("Add dump origin name %s for node %s.", original_names[0].c_str(), node->GetName().c_str()); | |||||
} | } | ||||
} // namespace ge | } // namespace ge | ||||
@@ -273,6 +277,7 @@ Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque | |||||
} | } | ||||
Status PostOpProcessForSubgraph(const ParseArg &arg) { | Status PostOpProcessForSubgraph(const ParseArg &arg) { | ||||
AddDumpOriginName(arg.parent_node, arg.subgraph_name, arg.graph); | |||||
if (arg.parent_node == nullptr) { | if (arg.parent_node == nullptr) { | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -297,7 +302,6 @@ Status PostOpProcessForSubgraph(const ParseArg &arg) { | |||||
if ((node->GetOpDesc() == nullptr) || (node->GetType() == "Variable") || (node->GetType() == "VariableV2")) { | if ((node->GetOpDesc() == nullptr) || (node->GetType() == "Variable") || (node->GetType() == "VariableV2")) { | ||||
continue; | continue; | ||||
} | } | ||||
AddDumpOriginName(arg.subgraph_name, arg.parent_node, node); | |||||
node->GetOpDesc()->SetName(node->GetOwnerComputeGraph()->GetName() + "/" + node->GetName()); | node->GetOpDesc()->SetName(node->GetOwnerComputeGraph()->GetName() + "/" + node->GetName()); | ||||
} | } | ||||
@@ -909,8 +913,10 @@ Status TensorFlowModelParser::CheckOpType(const domi::tensorflow::NodeDef *node_ | |||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
op_type == ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, | op_type == ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, | ||||
GE_CHK_STATUS_RET(CheckOpShapeDim(node_def, check_dims[op_type], valid), "failed to check op shape"); | GE_CHK_STATUS_RET(CheckOpShapeDim(node_def, check_dims[op_type], valid), "failed to check op shape"); | ||||
GE_IF_BOOL_EXEC(!valid, op_type = ge::parser::FRAMEWORKOP; GELOGI("Set op %s to frameworkop", node_name.c_str()); | |||||
framework_ops_[node_name] = node_def;);); | |||||
GE_IF_BOOL_EXEC(!valid, op_type = ge::parser::FRAMEWORKOP; | |||||
GELOGI("Set op %s to frameworkop", node_name.c_str()); | |||||
framework_ops_[node_name] = node_def;); | |||||
); | |||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
op_type == ge::parser::ADD || op_type == ge::parser::MULTIPLY || op_type == ge::parser::MEAN, | op_type == ge::parser::ADD || op_type == ge::parser::MULTIPLY || op_type == ge::parser::MEAN, | ||||
@@ -970,7 +976,8 @@ Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::Co | |||||
GELOGD("CCE %s parsering", node_op.c_str());); | GELOGD("CCE %s parsering", node_op.c_str());); | ||||
GE_IF_BOOL_EXEC((implyType == domi::ImplyType::HCCL) && (op_type != ge::parser::FRAMEWORKOP), | GE_IF_BOOL_EXEC((implyType == domi::ImplyType::HCCL) && (op_type != ge::parser::FRAMEWORKOP), | ||||
GELOGD("HCCL %s parsering", node_op.c_str());); | GELOGD("HCCL %s parsering", node_op.c_str());); | ||||
GE_IF_BOOL_EXEC(op_type == ge::parser::FRAMEWORKOP, GELOGD("FRAMEWORKOP %s parsering", node_op.c_str());); | |||||
GE_IF_BOOL_EXEC(op_type == ge::parser::FRAMEWORKOP, | |||||
GELOGD("FRAMEWORKOP %s parsering", node_op.c_str());); | |||||
GELOGD("TF op node name = %s, op type= %s, trans to op type %s", node_name.c_str(), node_op.c_str(), op_type.c_str()); | GELOGD("TF op node name = %s, op type= %s, trans to op type %s", node_name.c_str(), node_op.c_str(), op_type.c_str()); | ||||
// Construct operator by IR | // Construct operator by IR | ||||
@@ -2322,7 +2329,8 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, | |||||
// Loop analysis of op_nodes and map them to nodes in graph | // Loop analysis of op_nodes and map them to nodes in graph | ||||
ret = AddFmkNode(graph, scope_graph, op_node_name_list, isDatasetInit); | ret = AddFmkNode(graph, scope_graph, op_node_name_list, isDatasetInit); | ||||
PARSER_TIMESTAMP_END(AddFmkNode, "TensorFlowModelParser::AddFmkNode"); | PARSER_TIMESTAMP_END(AddFmkNode, "TensorFlowModelParser::AddFmkNode"); | ||||
GE_CHK_STATUS_EXEC(ret, DeleteFuisonNodeDef(); return ret, "AddFmkNode failed"); | |||||
GE_CHK_STATUS_EXEC(ret, DeleteFuisonNodeDef(); | |||||
return ret, "AddFmkNode failed"); | |||||
GELOGD("[TF Parser] Add framework node success"); | GELOGD("[TF Parser] Add framework node success"); | ||||
ret = AddEdges(graph); | ret = AddEdges(graph); | ||||
@@ -3004,12 +3012,16 @@ Status TensorFlowModelParser::GetFormatTranspose(const NodeDef *transpose_node, | |||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
type == domi::tensorflow::DT_INT32, | type == domi::tensorflow::DT_INT32, | ||||
const int32_t *data = reinterpret_cast<const int32_t *>(tensor.tensor_content().data()); | const int32_t *data = reinterpret_cast<const int32_t *>(tensor.tensor_content().data()); | ||||
for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) { perm_value.push_back(data[i]); }); | |||||
for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) { | |||||
perm_value.push_back(data[i]); | |||||
}); | |||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
type == domi::tensorflow::DT_INT64, | type == domi::tensorflow::DT_INT64, | ||||
const int64_t *data = reinterpret_cast<const int64_t *>(tensor.tensor_content().data()); | const int64_t *data = reinterpret_cast<const int64_t *>(tensor.tensor_content().data()); | ||||
for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) { perm_value.push_back(data[i]); }); | |||||
for (int i = 0; i < parser::DIM_DEFAULT_SIZE; i++) { | |||||
perm_value.push_back(data[i]); | |||||
}); | |||||
// 0, 1, 2, 3 present dim num. | // 0, 1, 2, 3 present dim num. | ||||
vector<int64_t> perm_to_nchw = {0, 3, 1, 2}; | vector<int64_t> perm_to_nchw = {0, 3, 1, 2}; | ||||
@@ -61,7 +61,7 @@ class PARSER_FUNC_VISIBILITY TensorflowParserBuilder : public TensorflowWeightPa | |||||
public: | public: | ||||
using ParseParamsFn = std::function<domi::Status(const domi::tensorflow::NodeDef *, Param *)>; | using ParseParamsFn = std::function<domi::Status(const domi::tensorflow::NodeDef *, Param *)>; | ||||
explicit TensorflowParserBuilder(const std::string &davinci_optype) : davinci_optype_(davinci_optype) {} | |||||
explicit TensorflowParserBuilder(const std::string &davinci_optype) noexcept : davinci_optype_(davinci_optype) {} | |||||
~TensorflowParserBuilder() override {} | ~TensorflowParserBuilder() override {} | ||||
@@ -40,8 +40,8 @@ Status TensorFlowRefSwitchParser::ParseT(const domi::tensorflow::NodeDef *node, | |||||
GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "type"), "check Attr T failed"); | GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "type"), "check Attr T failed"); | ||||
domi::tensorflow::DataType tfType = attr.type(); | |||||
ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tfType); | |||||
const domi::tensorflow::DataType tfType = attr.type(); | |||||
const ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tfType); | |||||
CHECK_FALSE_EXEC(type != ge::DataType::DT_UNDEFINED, | CHECK_FALSE_EXEC(type != ge::DataType::DT_UNDEFINED, | ||||
REPORT_CALL_ERROR("E19999", "Data type %s of node %s is not supported", | REPORT_CALL_ERROR("E19999", "Data type %s of node %s is not supported", | ||||
DataType_Name(tfType).c_str(), | DataType_Name(tfType).c_str(), | ||||
@@ -34,26 +34,29 @@ Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &att | |||||
auto a_list = attr_value.list(); | auto a_list = attr_value.list(); | ||||
GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, 0, tf_datatype), PARAM_INVALID, | GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, 0, tf_datatype), PARAM_INVALID, | ||||
"parse ge_desc failed."); | "parse ge_desc failed."); | ||||
uint32_t size_type = 1; | |||||
uint32_t size_type = 1U; | |||||
auto data_type = ge_desc.GetDataType(); | auto data_type = ge_desc.GetDataType(); | ||||
bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type); | bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type); | ||||
GE_IF_BOOL_EXEC(!type_ret, | GE_IF_BOOL_EXEC(!type_ret, | ||||
REPORT_CALL_ERROR("E19999", "Data type %s is not supported", | REPORT_CALL_ERROR("E19999", "Data type %s is not supported", | ||||
ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); | ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
GELOGE(FAILED, "Can't GetDataTypeLength of data_type: %s", | |||||
GELOGE(FAILED, "Can't GetDataTypeLength of data_type: %s.", | |||||
ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); | ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
return PARAM_INVALID); | return PARAM_INVALID); | ||||
// calculate size | // calculate size | ||||
int64_t real_size = 1; | int64_t real_size = 1; | ||||
for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | |||||
for (uint32_t j = 0U; j < ge_desc.GetShape().GetDimNum(); ++j) { | |||||
int64_t tmp_dim = ge_desc.GetShape().GetDim(j); | int64_t tmp_dim = ge_desc.GetShape().GetDim(j); | ||||
GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); | |||||
if (tmp_dim < 0) { | |||||
real_size = tmp_dim * (-1) * real_size; | |||||
continue; | |||||
} | |||||
real_size *= tmp_dim; | real_size *= tmp_dim; | ||||
} | } | ||||
PARSER_INT64_MULCHECK(real_size, size_type); | PARSER_INT64_MULCHECK(real_size, size_type); | ||||
ge::TensorUtils::SetSize(ge_desc, real_size * size_type); | ge::TensorUtils::SetSize(ge_desc, real_size * size_type); | ||||
ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); | ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); | ||||
GELOGI("after translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u", | |||||
GELOGI("After translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u", | |||||
ge::TypeUtils::DataTypeToSerialString(ge_desc.GetDataType()).c_str(), | ge::TypeUtils::DataTypeToSerialString(ge_desc.GetDataType()).c_str(), | ||||
ge::TypeUtils::FormatToSerialString(ge_desc.GetFormat()).c_str(), real_size * size_type, size_type); | ge::TypeUtils::FormatToSerialString(ge_desc.GetFormat()).c_str(), real_size * size_type, size_type); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -83,7 +83,8 @@ Status TensorFlowShapeNParser::ParseN(const domi::tensorflow::NodeDef *node, Sha | |||||
// The upper caller guarantees the input params is not empty. | // The upper caller guarantees the input params is not empty. | ||||
domi::tensorflow::AttrValue attr; | domi::tensorflow::AttrValue attr; | ||||
const int64_t attr_n = 2; | const int64_t attr_n = 2; | ||||
CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, SHAPEN_ATTR_N, attr), op->N(attr_n); return SUCCESS); | |||||
CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, SHAPEN_ATTR_N, attr), op->N(attr_n); | |||||
return SUCCESS); | |||||
GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "int"), "check Attr N failed"); | GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "int"), "check Attr N failed"); | ||||
@@ -133,7 +134,7 @@ Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
graphStatus status = op_dest->AddDynamicOutputDesc("y", dynamic_tensor_num); | |||||
const graphStatus status = op_dest->AddDynamicOutputDesc("y", dynamic_tensor_num); | |||||
if (status != GRAPH_SUCCESS) { | if (status != GRAPH_SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "Add Dynamic OuputDesc name:y to node:%s(%s) failed", | REPORT_CALL_ERROR("E19999", "Add Dynamic OuputDesc name:y to node:%s(%s) failed", | ||||
op_dest->GetName().c_str(), op_dest->GetType().c_str()); | op_dest->GetName().c_str(), op_dest->GetType().c_str()); | ||||
@@ -141,7 +142,7 @@ Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
} | } | ||||
graphStatus status = op_dest->AddDynamicInputDesc("x", dynamic_tensor_num); | |||||
const graphStatus status = op_dest->AddDynamicInputDesc("x", dynamic_tensor_num); | |||||
if (status != GRAPH_SUCCESS) { | if (status != GRAPH_SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "Add Dynamic InputDesc name:x to node:%s(%s) failed", | REPORT_CALL_ERROR("E19999", "Add Dynamic InputDesc name:x to node:%s(%s) failed", | ||||
op_dest->GetName().c_str(), op_dest->GetType().c_str()); | op_dest->GetName().c_str(), op_dest->GetType().c_str()); | ||||
@@ -38,27 +38,30 @@ Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &att | |||||
auto a_list = attr_value.list(); | auto a_list = attr_value.list(); | ||||
GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, 0, tf_datatype), domi::PARAM_INVALID, | GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, 0, tf_datatype), domi::PARAM_INVALID, | ||||
"parse ge_desc failed."); | "parse ge_desc failed."); | ||||
uint32_t size_type; | |||||
auto data_type = ge_desc.GetDataType(); | |||||
bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type); | |||||
uint32_t size_type = 0U; | |||||
const auto data_type = ge_desc.GetDataType(); | |||||
const bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type); | |||||
GE_IF_BOOL_EXEC(!type_ret, | GE_IF_BOOL_EXEC(!type_ret, | ||||
REPORT_CALL_ERROR("E19999", "Data type %s is not supported", | REPORT_CALL_ERROR("E19999", "Data type %s is not supported", | ||||
ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); | ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
GELOGE(domi::PARAM_INVALID, "Can't GetDataTypeLength of data_type: %s", | |||||
GELOGE(domi::PARAM_INVALID, "Can't GetDataTypeLength of data_type: %s.", | |||||
ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); | ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
return domi::PARAM_INVALID); | return domi::PARAM_INVALID); | ||||
// calculate size | // calculate size | ||||
int64_t real_size = 1; | int64_t real_size = 1; | ||||
for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | |||||
for (uint32_t j = 0U; j < ge_desc.GetShape().GetDimNum(); ++j) { | |||||
int64_t tmp_dim = ge_desc.GetShape().GetDim(j); | int64_t tmp_dim = ge_desc.GetShape().GetDim(j); | ||||
GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); | |||||
if (tmp_dim < 0) { | |||||
real_size = tmp_dim * (-1) * real_size; | |||||
continue; | |||||
} | |||||
PARSER_INT64_MULCHECK(real_size, tmp_dim); | PARSER_INT64_MULCHECK(real_size, tmp_dim); | ||||
real_size *= tmp_dim; | real_size *= tmp_dim; | ||||
} | } | ||||
PARSER_INT64_MULCHECK(real_size, size_type); | PARSER_INT64_MULCHECK(real_size, size_type); | ||||
ge::TensorUtils::SetSize(ge_desc, real_size * size_type); | ge::TensorUtils::SetSize(ge_desc, real_size * size_type); | ||||
ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); | ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); | ||||
GELOGD("after translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u", | |||||
GELOGD("After translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u", | |||||
ge::TypeUtils::DataTypeToSerialString(ge_desc.GetDataType()).c_str(), | ge::TypeUtils::DataTypeToSerialString(ge_desc.GetDataType()).c_str(), | ||||
ge::TypeUtils::FormatToSerialString(ge_desc.GetFormat()).c_str(), real_size * size_type, size_type); | ge::TypeUtils::FormatToSerialString(ge_desc.GetFormat()).c_str(), real_size * size_type, size_type); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -80,8 +83,8 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||||
domi::tensorflow::AttrValue axis; | domi::tensorflow::AttrValue axis; | ||||
domi::tensorflow::AttrValue dims; | domi::tensorflow::AttrValue dims; | ||||
bool has_axis = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis); | |||||
bool has_dims = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims); | |||||
const bool has_axis = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis); | |||||
const bool has_dims = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims); | |||||
if (!has_axis && !has_dims) { | if (!has_axis && !has_dims) { | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -30,6 +30,25 @@ | |||||
using domi::tensorflow::DT_INVALID; | using domi::tensorflow::DT_INVALID; | ||||
#define VALIDATE_FIELD(attr_value, type, num_set, name, type_string, oneof_case) \ | |||||
do { \ | |||||
if ((attr_value).has_list()) { \ | |||||
if ((attr_value).list().name##_size() > 0) { \ | |||||
if (type != "list(" type_string ")") { \ | |||||
GELOGE(FAILED, "GeAttrValue had value with type 'list(" type_string ")'when '%s' expected", type.c_str()); \ | |||||
return FAILED; \ | |||||
} \ | |||||
++(num_set); \ | |||||
} \ | |||||
} else if ((attr_value).value_case() == domi::tensorflow::AttrValue::oneof_case) { \ | |||||
if (type != (type_string)) { \ | |||||
GELOGE(FAILED, "GeAttrValue had value with type '" type_string "' when '%s' expected", type.c_str()); \ | |||||
return FAILED; \ | |||||
} \ | |||||
++(num_set); \ | |||||
} \ | |||||
} while (false) | |||||
namespace ge { | namespace ge { | ||||
/***************************TensorFlow attribute type, constant definition*******************************************/ | /***************************TensorFlow attribute type, constant definition*******************************************/ | ||||
const std::string TENSORFLOW_ATTR_TYPE_STRING = "string"; | const std::string TENSORFLOW_ATTR_TYPE_STRING = "string"; | ||||
@@ -126,7 +145,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrVa | |||||
GE_CHECK_NOTNULL(node_def); | GE_CHECK_NOTNULL(node_def); | ||||
const google::protobuf::Map<std::string, domi::tensorflow::AttrValue> &attr = node_def->attr(); | const google::protobuf::Map<std::string, domi::tensorflow::AttrValue> &attr = node_def->attr(); | ||||
const google::protobuf::Map<std::string, domi::tensorflow::AttrValue>::const_iterator it = attr.find(attr_name); | const google::protobuf::Map<std::string, domi::tensorflow::AttrValue>::const_iterator it = attr.find(attr_name); | ||||
if (it != attr.end()) { | |||||
if (it != attr.cend()) { | |||||
attr_value = it->second; | attr_value = it->second; | ||||
return true; | return true; | ||||
} | } | ||||
@@ -136,34 +155,15 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrVa | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::CheckAttrHasType( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::CheckAttrHasType( | ||||
const domi::tensorflow::AttrValue &attr_value, const std::string &type) { | const domi::tensorflow::AttrValue &attr_value, const std::string &type) { | ||||
uint32_t num_set = 0; | |||||
#define VALIDATE_FIELD(name, type_string, oneof_case) \ | |||||
do { \ | |||||
if (attr_value.has_list()) { \ | |||||
if (attr_value.list().name##_size() > 0) { \ | |||||
if (type != "list(" type_string ")") { \ | |||||
GELOGE(FAILED, "GeAttrValue had value with type 'list(" type_string ")'when '%s' expected", type.c_str()); \ | |||||
return FAILED; \ | |||||
} \ | |||||
++num_set; \ | |||||
} \ | |||||
} else if (attr_value.value_case() == domi::tensorflow::AttrValue::oneof_case) { \ | |||||
if (type != (type_string)) { \ | |||||
GELOGE(FAILED, "GeAttrValue had value with type '" type_string "' when '%s' expected", type.c_str()); \ | |||||
return FAILED; \ | |||||
} \ | |||||
++num_set; \ | |||||
} \ | |||||
} while (false) | |||||
VALIDATE_FIELD(s, "string", kS); | |||||
VALIDATE_FIELD(i, "int", kI); | |||||
VALIDATE_FIELD(f, "float", kF); | |||||
VALIDATE_FIELD(b, "bool", kB); | |||||
VALIDATE_FIELD(type, "type", kType); | |||||
VALIDATE_FIELD(shape, "shape", kShape); | |||||
VALIDATE_FIELD(tensor, "tensor", kTensor); | |||||
VALIDATE_FIELD(func, "func", kFunc); | |||||
uint32_t num_set = 0U; | |||||
VALIDATE_FIELD(attr_value, type, num_set, s, "string", kS); | |||||
VALIDATE_FIELD(attr_value, type, num_set, i, "int", kI); | |||||
VALIDATE_FIELD(attr_value, type, num_set, f, "float", kF); | |||||
VALIDATE_FIELD(attr_value, type, num_set, b, "bool", kB); | |||||
VALIDATE_FIELD(attr_value, type, num_set, type, "type", kType); | |||||
VALIDATE_FIELD(attr_value, type, num_set, shape, "shape", kShape); | |||||
VALIDATE_FIELD(attr_value, type, num_set, tensor, "tensor", kTensor); | |||||
VALIDATE_FIELD(attr_value, type, num_set, func, "func", kFunc); | |||||
#undef VALIDATE_FIELD | #undef VALIDATE_FIELD | ||||
@@ -173,7 +173,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Ch | |||||
} | } | ||||
// Okay to have an empty list, but not to be missing a non-list value. | // Okay to have an empty list, but not to be missing a non-list value. | ||||
if ((num_set == 0) && (!ge::StringUtils::StartWith(type, "list("))) { | |||||
if ((num_set == 0U) && (!ge::StringUtils::StartWith(type, "list("))) { | |||||
GELOGE(FAILED, "GeAttrValue missing value with expected type '%s'", type.c_str()); | GELOGE(FAILED, "GeAttrValue missing value with expected type '%s'", type.c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -210,7 +210,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Pa | |||||
const domi::tensorflow::NodeDef *node_src, const std::string &attr_src, domi::tensorflow::DataType &data_type) { | const domi::tensorflow::NodeDef *node_src, const std::string &attr_src, domi::tensorflow::DataType &data_type) { | ||||
GE_CHECK_NOTNULL(node_src); | GE_CHECK_NOTNULL(node_src); | ||||
std::string node_name = node_src->name(); | |||||
const std::string &node_name = node_src->name(); | |||||
// Find the value of attr_src from node_src | // Find the value of attr_src from node_src | ||||
domi::tensorflow::AttrValue attr_value; | domi::tensorflow::AttrValue attr_value; | ||||
@@ -236,7 +236,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::ParseFromA | |||||
ge_desc.SetOriginFormat(ge::FORMAT_ND); | ge_desc.SetOriginFormat(ge::FORMAT_ND); | ||||
tf_datatype = a_list.func(i).attr().at(SERIALIZE_DATATYPE).i(); | tf_datatype = a_list.func(i).attr().at(SERIALIZE_DATATYPE).i(); | ||||
ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_datatype); | |||||
const ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_datatype); | |||||
GE_CHK_BOOL_RET_STATUS(type != ge::DataType::DT_UNDEFINED, PARAM_INVALID, | GE_CHK_BOOL_RET_STATUS(type != ge::DataType::DT_UNDEFINED, PARAM_INVALID, | ||||
"In FrameworkOp translate datatype:%d failed, domi cann't support.", tf_datatype); | "In FrameworkOp translate datatype:%d failed, domi cann't support.", tf_datatype); | ||||
ge_desc.SetDataType(type); | ge_desc.SetDataType(type); | ||||
@@ -266,21 +266,25 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr | |||||
int32_t tf_datatype = 0; | int32_t tf_datatype = 0; | ||||
GE_CHK_BOOL_RET_STATUS(ParseFromAttrValueList(ge_desc, a_list, i, tf_datatype), PARAM_INVALID, | GE_CHK_BOOL_RET_STATUS(ParseFromAttrValueList(ge_desc, a_list, i, tf_datatype), PARAM_INVALID, | ||||
"parse ge_desc failed."); | "parse ge_desc failed."); | ||||
uint32_t size_type = 1; | |||||
uint32_t size_type = 1U; | |||||
auto data_type = ge_desc.GetDataType(); | auto data_type = ge_desc.GetDataType(); | ||||
GE_CHK_BOOL_RET_STATUS(ge::TypeUtils::GetDataTypeLength(data_type, size_type), PARAM_INVALID, | GE_CHK_BOOL_RET_STATUS(ge::TypeUtils::GetDataTypeLength(data_type, size_type), PARAM_INVALID, | ||||
"dataType no define size , parse ge_desc failed."); | "dataType no define size , parse ge_desc failed."); | ||||
// get size | // get size | ||||
for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | |||||
for (uint32_t j = 0U; j < ge_desc.GetShape().GetDimNum(); ++j) { | |||||
int64_t tmp_dim = ge_desc.GetShape().GetDim(j); | int64_t tmp_dim = ge_desc.GetShape().GetDim(j); | ||||
// The shape infered by fusedbatchnormgrad and mean calling tensorflow is not accurate. | // The shape infered by fusedbatchnormgrad and mean calling tensorflow is not accurate. | ||||
// Here, special treatment is given to the two operators. | // Here, special treatment is given to the two operators. | ||||
// Adjust shape to fit resnet50 network only. | // Adjust shape to fit resnet50 network only. | ||||
GE_IF_BOOL_EXEC((type == ge::parser::FUSEDBATCHNORMGRAD) && (tmp_dim == 0), ge_desc.SetShape(ge::GeShape()); | |||||
break;); | |||||
GE_IF_BOOL_EXEC((type == ge::parser::MEAN) && (tmp_dim == 0), std::vector<int64_t> data_dim = {tmp_dim}; | |||||
ge_desc.SetShape(ge::GeShape(data_dim)); break;); | |||||
if ((type == ge::parser::FUSEDBATCHNORMGRAD) && (tmp_dim == 0)) { | |||||
ge_desc.SetShape(ge::GeShape()); | |||||
break; | |||||
} | |||||
if ((type == ge::parser::MEAN) && (tmp_dim == 0)) { | |||||
std::vector<int64_t> data_dim = {tmp_dim}; | |||||
ge_desc.SetShape(ge::GeShape(data_dim)); | |||||
break; | |||||
} | |||||
} | } | ||||
ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); | ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); | ||||
GELOGD("IO:%d: after translate tf_desc, datatype: %s, format: %s, size_type: %u", io, | GELOGD("IO:%d: after translate tf_desc, datatype: %s, format: %s, size_type: %u", io, | ||||
@@ -158,7 +158,7 @@ void STestTensorflowParser::RegisterCustomOp() { | |||||
domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
} | } | ||||
extern void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr parent_node, ge::NodePtr node); | |||||
void AddDumpOriginName(const ge::NodePtr parent_node, const std::string& subgraph_name, ge::ComputeGraphPtr graph); | |||||
namespace { | namespace { | ||||
NodeDef* AddNode(GraphDef& graph, string type, string name) { | NodeDef* AddNode(GraphDef& graph, string type, string name) { | ||||
@@ -4291,35 +4291,43 @@ TEST_F(STestTensorflowParser, tensorflow_optimizer_fmk_fusion_op) { | |||||
TEST_F(STestTensorflowParser, AddDumpOriginName_test) | TEST_F(STestTensorflowParser, AddDumpOriginName_test) | ||||
{ | { | ||||
GeTensorDesc scalar_tensor(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT); | GeTensorDesc scalar_tensor(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT); | ||||
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default"); | |||||
ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>(); | |||||
data_op->SetType(parser::WHILE); | |||||
data_op->SetName("WHILE0"); | |||||
data_op->AddInputDesc(ge::GeTensorDesc()); | |||||
data_op->AddOutputDesc(ge::GeTensorDesc()); | |||||
ge::NodePtr while0 = graph->AddNode(data_op); | |||||
data_op = std::make_shared<ge::OpDesc>(); | |||||
data_op->SetType(parser::LOOPCOND); | |||||
data_op->SetName("COND0"); | |||||
data_op->AddInputDesc(ge::GeTensorDesc()); | |||||
data_op->AddOutputDesc(ge::GeTensorDesc()); | |||||
ge::NodePtr cond0 = graph->AddNode(data_op); | |||||
AddDumpOriginName(std::string("while"), while0, cond0); | |||||
data_op = std::make_shared<ge::OpDesc>(); | |||||
data_op->SetType(parser::DATA); | |||||
data_op->SetName("Data1"); | |||||
data_op->AddInputDesc(ge::GeTensorDesc()); | |||||
data_op->AddOutputDesc(ge::GeTensorDesc()); | |||||
ge::NodePtr data1 = graph->AddNode(data_op); | |||||
AddDumpOriginName(std::string("cond"), cond0, data1); | |||||
auto desc = data1->GetOpDesc(); | |||||
ge::ComputeGraphPtr parent_graph = std::make_shared<ge::ComputeGraph>("parent_graph"); | |||||
ge::OpDescPtr parent = std::make_shared<ge::OpDesc>(); | |||||
parent->SetType("Foo"); | |||||
parent->SetName("foo"); | |||||
ge::NodePtr foo = parent_graph->AddNode(parent); | |||||
ge::ComputeGraphPtr sub_graph = std::make_shared<ge::ComputeGraph>("sub_graph"); | |||||
auto child = std::make_shared<ge::OpDesc>(); | |||||
child->SetType("Bar"); | |||||
child->SetName("bar"); | |||||
ge::NodePtr bar = sub_graph->AddNode(child); | |||||
AddDumpOriginName(foo, "f", sub_graph); | |||||
std::vector<std::string> original_names; | std::vector<std::string> original_names; | ||||
(void)ge::AttrUtils::GetListStr(desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||||
EXPECT_EQ(original_names.empty(), false); | |||||
EXPECT_EQ(original_names[0], "WHILE0/while/COND0/cond/Data1"); | |||||
(void)ge::AttrUtils::GetListStr(bar->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||||
EXPECT_EQ(original_names.size(), 1U); | |||||
EXPECT_EQ(original_names[0], "foo/f/bar"); | |||||
(void)ge::AttrUtils::SetListStr(foo->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||||
AddDumpOriginName(foo, "f", sub_graph); | |||||
original_names.clear(); | |||||
(void)ge::AttrUtils::GetListStr(bar->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||||
EXPECT_EQ(original_names.size(), 1U); | |||||
EXPECT_EQ(original_names[0], "foo/f/bar/f/bar"); | |||||
original_names.push_back("abc"); | |||||
(void)ge::AttrUtils::SetListStr(foo->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||||
AddDumpOriginName(foo, "f", sub_graph); | |||||
original_names.clear(); | |||||
(void)ge::AttrUtils::GetListStr(bar->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||||
EXPECT_EQ(original_names.size(), 2U); | |||||
EXPECT_EQ(original_names[0], "foo/f/bar/f/bar/f/bar"); | |||||
EXPECT_EQ(original_names[1], "abc"); | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -169,7 +169,7 @@ static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_de | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
extern void AddDumpOriginName(const std::string& subgraph_name, const ge::NodePtr parent_node, ge::NodePtr node); | |||||
void AddDumpOriginName(const ge::NodePtr parent_node, const std::string& subgraph_name, ge::ComputeGraphPtr graph); | |||||
void UtestTensorflowParser::RegisterCustomOp() { | void UtestTensorflowParser::RegisterCustomOp() { | ||||
REGISTER_CUSTOM_OP("Add") | REGISTER_CUSTOM_OP("Add") | ||||
@@ -4779,35 +4779,43 @@ TEST_F(UtestTensorflowParser, tensorflow_ComputeArgRange) | |||||
TEST_F(UtestTensorflowParser, AddDumpOriginName_test) | TEST_F(UtestTensorflowParser, AddDumpOriginName_test) | ||||
{ | { | ||||
GeTensorDesc scalar_tensor(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT); | GeTensorDesc scalar_tensor(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT); | ||||
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default"); | |||||
ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>(); | |||||
data_op->SetType(parser::WHILE); | |||||
data_op->SetName("WHILE0"); | |||||
data_op->AddInputDesc(ge::GeTensorDesc()); | |||||
data_op->AddOutputDesc(ge::GeTensorDesc()); | |||||
ge::NodePtr while0 = graph->AddNode(data_op); | |||||
ge::ComputeGraphPtr parent_graph = std::make_shared<ge::ComputeGraph>("parent_graph"); | |||||
ge::OpDescPtr parent = std::make_shared<ge::OpDesc>(); | |||||
parent->SetType("Foo"); | |||||
parent->SetName("foo"); | |||||
ge::NodePtr foo = parent_graph->AddNode(parent); | |||||
data_op = std::make_shared<ge::OpDesc>(); | |||||
data_op->SetType(parser::LOOPCOND); | |||||
data_op->SetName("COND0"); | |||||
data_op->AddInputDesc(ge::GeTensorDesc()); | |||||
data_op->AddOutputDesc(ge::GeTensorDesc()); | |||||
ge::NodePtr cond0 = graph->AddNode(data_op); | |||||
AddDumpOriginName(std::string("while"), while0, cond0); | |||||
data_op = std::make_shared<ge::OpDesc>(); | |||||
data_op->SetType(parser::DATA); | |||||
data_op->SetName("Data1"); | |||||
data_op->AddInputDesc(ge::GeTensorDesc()); | |||||
data_op->AddOutputDesc(ge::GeTensorDesc()); | |||||
ge::NodePtr data1 = graph->AddNode(data_op); | |||||
AddDumpOriginName(std::string("cond"), cond0, data1); | |||||
ge::ComputeGraphPtr sub_graph = std::make_shared<ge::ComputeGraph>("sub_graph"); | |||||
auto child = std::make_shared<ge::OpDesc>(); | |||||
child->SetType("Bar"); | |||||
child->SetName("bar"); | |||||
ge::NodePtr bar = sub_graph->AddNode(child); | |||||
AddDumpOriginName(foo, "f", sub_graph); | |||||
auto desc = data1->GetOpDesc(); | |||||
std::vector<std::string> original_names; | std::vector<std::string> original_names; | ||||
(void)ge::AttrUtils::GetListStr(desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||||
EXPECT_EQ(original_names.empty(), false); | |||||
EXPECT_EQ(original_names[0], "WHILE0/while/COND0/cond/Data1"); | |||||
(void)ge::AttrUtils::GetListStr(bar->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||||
EXPECT_EQ(original_names.size(), 1U); | |||||
EXPECT_EQ(original_names[0], "foo/f/bar"); | |||||
(void)ge::AttrUtils::SetListStr(foo->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||||
AddDumpOriginName(foo, "f", sub_graph); | |||||
original_names.clear(); | |||||
(void)ge::AttrUtils::GetListStr(bar->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||||
EXPECT_EQ(original_names.size(), 1U); | |||||
EXPECT_EQ(original_names[0], "foo/f/bar/f/bar"); | |||||
original_names.push_back("abc"); | |||||
(void)ge::AttrUtils::SetListStr(foo->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||||
AddDumpOriginName(foo, "f", sub_graph); | |||||
original_names.clear(); | |||||
(void)ge::AttrUtils::GetListStr(bar->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); | |||||
EXPECT_EQ(original_names.size(), 2U); | |||||
EXPECT_EQ(original_names[0], "foo/f/bar/f/bar/f/bar"); | |||||
EXPECT_EQ(original_names[1], "abc"); | |||||
} | } | ||||
} // namespace ge | } // namespace ge |