Browse Source

Pre Merge pull request !730 from 刘豪/fileconstant_to_const

pull/730/MERGE
刘豪 Gitee 2 years ago
parent
commit
5fe6372ff4
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 88 additions and 0 deletions
  1. +86
    -0
      parser/tensorflow/tensorflow_parser.cc
  2. +2
    -0
      parser/tensorflow/tensorflow_parser.h

+ 86
- 0
parser/tensorflow/tensorflow_parser.cc View File

@@ -54,6 +54,8 @@
#include "register/scope/scope_pass_registry_impl.h"
#include "parser/common/auto_mapping_subgraph_io_index_func.h"
#include "graph/def_types.h"
#include "framework/common/types.h"
#include "common/util/mem_utils.h"

using ge::OpParserFactory;
using ge::Pb2Json;
@@ -230,6 +232,10 @@ const char *const kDpop = "DPOP";
const char *const kFuncDefLibraryFilePath = "graph_def_library.pbtxt";
const char *const kAttrNameIsScopeInnerNode = "_is_scope_inner_node";
const char *const kExternalModel = "_external_model";
const char *const kAttrNameFileConstantPath = "file_constant_path";
const char *const kAttrNameLocation = "location";
const char *const kAttrNameOffset = "offset";
const char *const kAttrNameLength = "length";
struct ParseArg {
const google::protobuf::Message *proto;
std::string function_name;
@@ -4025,6 +4031,7 @@ Status TensorFlowModelParser::AddExternalGraph(const ComputeGraphPtr &root_graph
return INTERNAL_ERROR;
}
Graph graph = model.GetGraph();
GE_CHK_STATUS_RET(ConvertFileConstToConst(graph), "Failed to Convert file const to const.");
GELOGD("Get subgraph[%s] from model[%s].", ParserUtils::GetGraphName(graph).c_str(), node->GetName().c_str());
Status ret = MappingAndAddSubGraph(node, graph, root_graph);
if (ret != SUCCESS) {
@@ -4037,6 +4044,85 @@ Status TensorFlowModelParser::AddExternalGraph(const ComputeGraphPtr &root_graph
}
return SUCCESS;
}

Status TensorFlowModelParser::GetFileConstantPath(const OpDescPtr &op_desc, std::string &file_path, size_t &offset,
size_t &length) {
NamedAttrs attrs;
if (!AttrUtils::GetNamedAttrs(op_desc, kAttrNameFileConstantPath, attrs)) {
return SUCCESS;
}
offset = 0U;
length = 0U;

// offset and length are optional
int64_t attr_value = 0;
(void)AttrUtils::GetInt(attrs, kAttrNameOffset, attr_value);
if (attr_value != 0) {
offset = static_cast<size_t>(attr_value);
}
int64_t attr_length = 0;
(void)AttrUtils::GetInt(attrs, kAttrNameLength, attr_length);
if (attr_length != 0) {
length = static_cast<size_t>(attr_length);
}
if (!AttrUtils::GetStr(attrs, kAttrNameLocation, file_path)) {
REPORT_INNER_ERROR("E19999", "Failed to get file path.");
GELOGE(FAILED, "[Check][Param] Failed to get file path.");
return FAILED;
}
return SUCCESS;
}

Status TensorFlowModelParser::ConvertFileConstToConst(const Graph &graph) {
ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);
for (auto &node : compute_graph->GetAllNodes()) {
if (node->GetType() == FILECONSTANT) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
const auto &tensor_desc = op_desc->GetOutputDescPtr(0U);
GE_CHECK_NOTNULL(tensor_desc);
GELOGD("File constant data type:%u", tensor_desc->GetDataType());
GELOGD("File constant GetDimNum:%zu", tensor_desc->GetShape().GetDimNum());
GELOGD("File constant data type:%ld", tensor_desc->GetShape().GetShapeSize());
int64_t weight_size = 0;
GE_CHK_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*tensor_desc, weight_size),
"Failed to get file constant weight size.");
std::string file_path;
size_t offset = 0U;
size_t length = 0U;
GE_CHK_STATUS_RET(GetFileConstantPath(op_desc, file_path, offset, length),
"Failed to get file path of %s.", op_desc->GetName().c_str());
const size_t file_length = (length == 0U ? static_cast<size_t>(weight_size) : length);
const std::string real_path = RealPath(file_path.c_str());
GELOGD("Load weight from file:%s, file length:%zu", real_path.c_str(), file_length);
GE_CHECK_NOTNULL(real_path.c_str());
std::ifstream ifs(real_path, std::ifstream::binary);
if (!ifs.is_open()) {
REPORT_CALL_ERROR("E19999", "Read file %s failed.", file_path.c_str());
GELOGE(FAILED, "[Read][File]Failed, file %s.", file_path.c_str());
return FAILED;
}
ifs.clear();
ifs.seekg(offset, ifs.beg);
const std::unique_ptr<char[]> bin_buff = std::unique_ptr<char[]>(new (std::nothrow) char[file_length]);
(void)ifs.read(static_cast<char *>(bin_buff.get()), static_cast<int64_t>(file_length));
ifs.close();
GeTensorPtr const_value = MakeShared<GeTensor>(op_desc->GetOutputDesc(0U),
reinterpret_cast<uint8_t *>(bin_buff.get()), file_length);
if (!AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, const_value)) {
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str());
return FAILED;
}
op_desc->SetType(CONSTANT);

GELOGD("Convert node:%s from file constant to const success.", op_desc->GetName().c_str());
}
}
return SUCCESS;
}
} // namespace ge

namespace domi {


+ 2
- 0
parser/tensorflow/tensorflow_parser.h View File

@@ -626,6 +626,8 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
static Status CheckAndUpdateInputDesc(const ge::ComputeGraphPtr &compute_graph);
static Status UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes);
static Status AddExternalGraph(const ComputeGraphPtr &root_graph);
static Status GetFileConstantPath(const OpDescPtr &op_desc, std::string &file_path, size_t &offset, size_t &length);
static Status ConvertFileConstToConst(const Graph &graph);

/**
* save <node_name, node_def>


Loading…
Cancel
Save