diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 3af6ace..84f3ad2 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -54,6 +54,8 @@ #include "register/scope/scope_pass_registry_impl.h" #include "parser/common/auto_mapping_subgraph_io_index_func.h" #include "graph/def_types.h" +#include "framework/common/types.h" +#include "common/util/mem_utils.h" using ge::OpParserFactory; using ge::Pb2Json; @@ -230,6 +232,10 @@ const char *const kDpop = "DPOP"; const char *const kFuncDefLibraryFilePath = "graph_def_library.pbtxt"; const char *const kAttrNameIsScopeInnerNode = "_is_scope_inner_node"; const char *const kExternalModel = "_external_model"; +const char *const kAttrNameFileConstantPath = "file_constant_path"; +const char *const kAttrNameLocation = "location"; +const char *const kAttrNameOffset = "offset"; +const char *const kAttrNameLength = "length"; struct ParseArg { const google::protobuf::Message *proto; std::string function_name; @@ -4025,6 +4031,7 @@ Status TensorFlowModelParser::AddExternalGraph(const ComputeGraphPtr &root_graph return INTERNAL_ERROR; } Graph graph = model.GetGraph(); + GE_CHK_STATUS_RET(ConvertFileConstToConst(graph), "Failed to Convert file const to const."); GELOGD("Get subgraph[%s] from model[%s].", ParserUtils::GetGraphName(graph).c_str(), node->GetName().c_str()); Status ret = MappingAndAddSubGraph(node, graph, root_graph); if (ret != SUCCESS) { @@ -4037,6 +4044,85 @@ Status TensorFlowModelParser::AddExternalGraph(const ComputeGraphPtr &root_graph } return SUCCESS; } + +Status TensorFlowModelParser::GetFileConstantPath(const OpDescPtr &op_desc, std::string &file_path, size_t &offset, + size_t &length) { + NamedAttrs attrs; + if (!AttrUtils::GetNamedAttrs(op_desc, kAttrNameFileConstantPath, attrs)) { + return SUCCESS; + } + offset = 0U; + length = 0U; + + // offset and length are optional + int64_t attr_value = 0; + (void)AttrUtils::GetInt(attrs, kAttrNameOffset, attr_value); + if (attr_value != 0) { + offset = static_cast(attr_value); + } + int64_t attr_length = 0; + (void)AttrUtils::GetInt(attrs, kAttrNameLength, attr_length); + if (attr_length != 0) { + length = static_cast(attr_length); + } + if (!AttrUtils::GetStr(attrs, kAttrNameLocation, file_path)) { + REPORT_INNER_ERROR("E19999", "Failed to get file path."); + GELOGE(FAILED, "[Check][Param] Failed to get file path."); + return FAILED; + } + return SUCCESS; +} + +Status TensorFlowModelParser::ConvertFileConstToConst(const Graph &graph) { + ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); + for (auto &node : compute_graph->GetAllNodes()) { + if (node->GetType() == FILECONSTANT) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + const auto &tensor_desc = op_desc->GetOutputDescPtr(0U); + GE_CHECK_NOTNULL(tensor_desc); + GELOGD("File constant data type:%u", tensor_desc->GetDataType()); + GELOGD("File constant GetDimNum:%zu", tensor_desc->GetShape().GetDimNum()); + GELOGD("File constant data type:%ld", tensor_desc->GetShape().GetShapeSize()); + int64_t weight_size = 0; + GE_CHK_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*tensor_desc, weight_size), + "Failed to get file constant weight size."); + std::string file_path; + size_t offset = 0U; + size_t length = 0U; + GE_CHK_STATUS_RET(GetFileConstantPath(op_desc, file_path, offset, length), + "Failed to get file path of %s.", op_desc->GetName().c_str()); + const size_t file_length = (length == 0U ? static_cast(weight_size) : length); + const std::string real_path = RealPath(file_path.c_str()); + GELOGD("Load weight from file:%s, file length:%zu", real_path.c_str(), file_length); + GE_CHECK_NOTNULL(real_path.c_str()); + std::ifstream ifs(real_path, std::ifstream::binary); + if (!ifs.is_open()) { + REPORT_CALL_ERROR("E19999", "Read file %s failed.", file_path.c_str()); + GELOGE(FAILED, "[Read][File]Failed, file %s.", file_path.c_str()); + return FAILED; + } + ifs.clear(); + ifs.seekg(offset, ifs.beg); + const std::unique_ptr bin_buff = std::unique_ptr(new (std::nothrow) char[file_length]); + (void)ifs.read(static_cast(bin_buff.get()), static_cast(file_length)); + ifs.close(); + GeTensorPtr const_value = MakeShared(op_desc->GetOutputDesc(0U), + reinterpret_cast(bin_buff.get()), file_length); + if (!AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, const_value)) { + REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(), + op_desc->GetName().c_str(), op_desc->GetType().c_str()); + GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(), + op_desc->GetName().c_str(), op_desc->GetType().c_str()); + return FAILED; + } + op_desc->SetType(CONSTANT); + + GELOGD("Convert node:%s from file constant to const success.", op_desc->GetName().c_str()); + } + } + return SUCCESS; +} } // namespace ge namespace domi { diff --git a/parser/tensorflow/tensorflow_parser.h b/parser/tensorflow/tensorflow_parser.h index 59b6faa..222718f 100644 --- a/parser/tensorflow/tensorflow_parser.h +++ b/parser/tensorflow/tensorflow_parser.h @@ -626,6 +626,8 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { static Status CheckAndUpdateInputDesc(const ge::ComputeGraphPtr &compute_graph); static Status UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes); static Status AddExternalGraph(const ComputeGraphPtr &root_graph); + static Status GetFileConstantPath(const OpDescPtr &op_desc, std::string &file_path, size_t &offset, size_t &length); + static Status ConvertFileConstToConst(const Graph &graph); /** * save