Merge pull request !686 from 王笑天/ge_devpull/691/MERGE
@@ -35,6 +35,7 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include <map> | |||||
#include "graph/ascend_string.h" | #include "graph/ascend_string.h" | ||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
@@ -86,7 +86,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||||
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::CAFFE))); | options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::CAFFE))); | ||||
// load custom plugin so and proto | // load custom plugin so and proto | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
domi::Status status = acl_graph_parse_util.AclParserInitialize(options); | domi::Status status = acl_graph_parse_util.AclParserInitialize(options); | ||||
if (status != domi::SUCCESS) { | if (status != domi::SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "AclParserInitialize failed, ret:%d.", status); | REPORT_CALL_ERROR("E19999", "AclParserInitialize failed, ret:%d.", status); | ||||
@@ -144,7 +144,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||||
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::CAFFE))); | options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::CAFFE))); | ||||
// load custom plugin so and proto | // load custom plugin so and proto | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
domi::Status status = acl_graph_parse_util.AclParserInitialize(options); | domi::Status status = acl_graph_parse_util.AclParserInitialize(options); | ||||
if (status != domi::SUCCESS) { | if (status != domi::SUCCESS) { | ||||
REPORT_CALL_ERROR("E19999", "AclParserInitialize failed, ret:%d.", status); | REPORT_CALL_ERROR("E19999", "AclParserInitialize failed, ret:%d.", status); | ||||
@@ -429,7 +429,7 @@ Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, cons | |||||
} | } | ||||
Status CaffeModelParser::CustomProtoParse(const char *model_path, const string &custom_proto, | Status CaffeModelParser::CustomProtoParse(const char *model_path, const string &custom_proto, | ||||
const string &caffe_proto, vector<ge::Operator> &operators) { | |||||
const string &caffe_proto, vector<ge::Operator> &operators) const { | |||||
(void)caffe_proto; | (void)caffe_proto; | ||||
string custom_proto_path = ge::parser::RealPath(custom_proto.c_str()); | string custom_proto_path = ge::parser::RealPath(custom_proto.c_str()); | ||||
if (custom_proto_path.empty()) { | if (custom_proto_path.empty()) { | ||||
@@ -1904,7 +1904,7 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto | |||||
} | } | ||||
Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message &message, | Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message &message, | ||||
google::protobuf::Message *layer) { | |||||
google::protobuf::Message *layer) const { | |||||
const google::protobuf::Reflection *layer_reflection = message.GetReflection(); | const google::protobuf::Reflection *layer_reflection = message.GetReflection(); | ||||
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); | CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); | ||||
vector<const google::protobuf::FieldDescriptor *> field_desc; | vector<const google::protobuf::FieldDescriptor *> field_desc; | ||||
@@ -168,7 +168,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||||
* @return FAILED parse failed | * @return FAILED parse failed | ||||
*/ | */ | ||||
Status CustomProtoParse(const char *model_path, const string &custom_proto, const string &caffe_proto, | Status CustomProtoParse(const char *model_path, const string &custom_proto, const string &caffe_proto, | ||||
std::vector<ge::Operator> &operators); | |||||
std::vector<ge::Operator> &operators) const; | |||||
/* | /* | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -396,7 +396,7 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser { | |||||
Status CheckLayersSize(const google::protobuf::Message &message) const; | Status CheckLayersSize(const google::protobuf::Message &message) const; | ||||
Status ConvertLayerProto(const google::protobuf::Message &message, | Status ConvertLayerProto(const google::protobuf::Message &message, | ||||
google::protobuf::Message *layer); | |||||
google::protobuf::Message *layer) const; | |||||
Status ParseLayerField(const google::protobuf::Reflection &reflection, | Status ParseLayerField(const google::protobuf::Reflection &reflection, | ||||
const google::protobuf::Message &message, | const google::protobuf::Message &message, | ||||
@@ -149,7 +149,7 @@ static Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) { | |||||
return domi::SUCCESS; | return domi::SUCCESS; | ||||
} | } | ||||
domi::Status AclGrphParseUtil::LoadOpsProtoLib() { | |||||
domi::Status AclGraphParseUtil::LoadOpsProtoLib() { | |||||
string opsproto_path; | string opsproto_path; | ||||
ge::Status ret = ge::TBEPluginLoader::GetOpsProtoPath(opsproto_path); | ge::Status ret = ge::TBEPluginLoader::GetOpsProtoPath(opsproto_path); | ||||
if (ret != ge::SUCCESS) { | if (ret != ge::SUCCESS) { | ||||
@@ -170,7 +170,7 @@ domi::Status AclGrphParseUtil::LoadOpsProtoLib() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void AclGrphParseUtil::SaveCustomCaffeProtoPath() { | |||||
void AclGraphParseUtil::SaveCustomCaffeProtoPath() { | |||||
GELOGD("Enter save custom caffe proto path."); | GELOGD("Enter save custom caffe proto path."); | ||||
std::string path_base = GetSoPath(); | std::string path_base = GetSoPath(); | ||||
path_base = path_base.substr(0, path_base.rfind('/')); | path_base = path_base.substr(0, path_base.rfind('/')); | ||||
@@ -192,7 +192,7 @@ void AclGrphParseUtil::SaveCustomCaffeProtoPath() { | |||||
// Initialize PARSER, load custom op plugin | // Initialize PARSER, load custom op plugin | ||||
// options will be used later for parser decoupling | // options will be used later for parser decoupling | ||||
domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, std::string> &options) { | |||||
domi::Status AclGraphParseUtil::AclParserInitialize(const std::map<std::string, std::string> &options) { | |||||
GELOGT(TRACE_INIT, "AclParserInitialize start"); | GELOGT(TRACE_INIT, "AclParserInitialize start"); | ||||
// check init status | // check init status | ||||
if (parser_initialized) { | if (parser_initialized) { | ||||
@@ -240,7 +240,7 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, s | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void AclGrphParseUtil::SetDefaultFormat() { | |||||
void AclGraphParseUtil::SetDefaultFormat() { | |||||
if (ge::GetParserContext().type == domi::TENSORFLOW) { | if (ge::GetParserContext().type == domi::TENSORFLOW) { | ||||
ge::GetParserContext().format = domi::DOMI_TENSOR_NHWC; | ge::GetParserContext().format = domi::DOMI_TENSOR_NHWC; | ||||
} else { | } else { | ||||
@@ -248,7 +248,7 @@ void AclGrphParseUtil::SetDefaultFormat() { | |||||
} | } | ||||
} | } | ||||
domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) const { | |||||
domi::Status AclGraphParseUtil::ParseAclOutputNodes(const string &out_nodes) const { | |||||
try { | try { | ||||
ge::GetParserContext().out_nodes_map.clear(); | ge::GetParserContext().out_nodes_map.clear(); | ||||
ge::GetParserContext().user_out_nodes.clear(); | ge::GetParserContext().user_out_nodes.clear(); | ||||
@@ -323,7 +323,7 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) cons | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
domi::Status AclGrphParseUtil::ParseAclOutputFp16NodesFormat(const string &is_output_fp16) const { | |||||
domi::Status AclGraphParseUtil::ParseAclOutputFp16NodesFormat(const string &is_output_fp16) const { | |||||
if (is_output_fp16.empty()) { | if (is_output_fp16.empty()) { | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -347,7 +347,7 @@ domi::Status AclGrphParseUtil::ParseAclOutputFp16NodesFormat(const string &is_ou | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fusion_passes) const { | |||||
domi::Status AclGraphParseUtil::ParseAclEnableScope(const string &enable_scope_fusion_passes) const { | |||||
ge::GetParserContext().enable_scope_fusion_passes.clear(); | ge::GetParserContext().enable_scope_fusion_passes.clear(); | ||||
if (enable_scope_fusion_passes.empty()) { | if (enable_scope_fusion_passes.empty()) { | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -356,8 +356,8 @@ domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fu | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void AclGrphParseUtil::AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec, | |||||
const string &fp16_nodes_name, size_t index, OpDescPtr &op_desc) { | |||||
void AclGraphParseUtil::AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec, | |||||
const string &fp16_nodes_name, size_t index, OpDescPtr &op_desc) { | |||||
if (AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, TypeUtils::DataTypeToSerialString(DT_FLOAT16))) { | if (AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, TypeUtils::DataTypeToSerialString(DT_FLOAT16))) { | ||||
if ((index < adjust_fp16_format_vec.size()) && (adjust_fp16_format_vec[index] == "true")) { | if ((index < adjust_fp16_format_vec.size()) && (adjust_fp16_format_vec[index] == "true")) { | ||||
GELOGI("This node [%s] should be set NC1HWC0", fp16_nodes_name.c_str()); | GELOGI("This node [%s] should be set NC1HWC0", fp16_nodes_name.c_str()); | ||||
@@ -368,8 +368,8 @@ void AclGrphParseUtil::AddAttrsForInputNodes(const vector<string> &adjust_fp16_f | |||||
} | } | ||||
} | } | ||||
domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, | |||||
const string &is_input_adjust_hw_layout) const { | |||||
domi::Status AclGraphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, | |||||
const string &is_input_adjust_hw_layout) const { | |||||
GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
vector<string> adjust_fp16_format_vec; | vector<string> adjust_fp16_format_vec; | ||||
if (!is_input_adjust_hw_layout.empty()) { | if (!is_input_adjust_hw_layout.empty()) { | ||||
@@ -411,7 +411,7 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
domi::Status AclGrphParseUtil::SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, | |||||
domi::Status AclGraphParseUtil::SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, | |||||
const std::string &input_data_names) const { | const std::string &input_data_names) const { | ||||
std::vector<std::string> input_names = StringUtils::Split(input_data_names, ','); | std::vector<std::string> input_names = StringUtils::Split(input_data_names, ','); | ||||
std::unordered_map<std::string, size_t> name_to_index; | std::unordered_map<std::string, size_t> name_to_index; | ||||
@@ -446,8 +446,8 @@ domi::Status AclGrphParseUtil::SetSpecifyIndexAttrByInputNames(const ComputeGrap | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void AclGrphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||||
std::vector<std::string> &output_nodes_name) const { | |||||
void AclGraphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||||
std::vector<std::string> &output_nodes_name) const { | |||||
output_nodes_name.clear(); | output_nodes_name.clear(); | ||||
auto &out_tensor_names = ge::GetParserContext().out_tensor_names; | auto &out_tensor_names = ge::GetParserContext().out_tensor_names; | ||||
if (out_tensor_names.empty()) { | if (out_tensor_names.empty()) { | ||||
@@ -478,8 +478,8 @@ void AclGrphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, | |||||
} | } | ||||
} | } | ||||
domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node, | |||||
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const { | |||||
domi::Status AclGraphParseUtil::GetOutputLeaf(NodePtr node, | |||||
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const { | |||||
ge::OpDescPtr tmpDescPtr = node->GetOpDesc(); | ge::OpDescPtr tmpDescPtr = node->GetOpDesc(); | ||||
if (tmpDescPtr == nullptr) { | if (tmpDescPtr == nullptr) { | ||||
REPORT_INNER_ERROR("E19999", "param node has no opdesc."); | REPORT_INNER_ERROR("E19999", "param node has no opdesc."); | ||||
@@ -508,7 +508,7 @@ domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node, | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | |||||
domi::Status AclGraphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | |||||
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const { | std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const { | ||||
std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes; | std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes; | ||||
if (!default_out_nodes.empty()) { | if (!default_out_nodes.empty()) { | ||||
@@ -531,8 +531,8 @@ domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_gr | |||||
return domi::SUCCESS; | return domi::SUCCESS; | ||||
} | } | ||||
domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph, | |||||
const std::map<AscendString, AscendString> &parser_params) { | |||||
domi::Status AclGraphParseUtil::SetOutputNodeInfo(ge::Graph &graph, | |||||
const std::map<AscendString, AscendString> &parser_params) const { | |||||
(void)parser_params; | (void)parser_params; | ||||
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | ||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
@@ -588,7 +588,7 @@ domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph, | |||||
return domi::SUCCESS; | return domi::SUCCESS; | ||||
} | } | ||||
domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendString> &parser_params) const { | |||||
domi::Status AclGraphParseUtil::CheckOptions(const std::map<AscendString, AscendString> &parser_params) const { | |||||
for (auto &ele : parser_params) { | for (auto &ele : parser_params) { | ||||
const char *key_ascend = ele.first.GetString(); | const char *key_ascend = ele.first.GetString(); | ||||
if (key_ascend == nullptr) { | if (key_ascend == nullptr) { | ||||
@@ -609,8 +609,8 @@ domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendS | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, | |||||
string &graph_name) { | |||||
domi::Status AclGraphParseUtil::ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, | |||||
string &graph_name) const { | |||||
GELOGI("Parse graph user options start."); | GELOGI("Parse graph user options start."); | ||||
ge::GetParserContext().input_nodes_format_map.clear(); | ge::GetParserContext().input_nodes_format_map.clear(); | ||||
ge::GetParserContext().output_formats.clear(); | ge::GetParserContext().output_formats.clear(); | ||||
@@ -663,8 +663,8 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | |||||
const std::map<AscendString, AscendString> &parser_params) const { | |||||
domi::Status AclGraphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | |||||
const std::map<AscendString, AscendString> &parser_params) const { | |||||
// support paragrams: input_fp16_nodes, is_input_adjust_hw_layout, | // support paragrams: input_fp16_nodes, is_input_adjust_hw_layout, | ||||
ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); | ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); | ||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
@@ -938,12 +938,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const cha | |||||
return ret; | return ret; | ||||
} | } | ||||
/// | |||||
/// @brief get the Original Type of FrameworkOp | /// @brief get the Original Type of FrameworkOp | ||||
/// @param [in] node | /// @param [in] node | ||||
/// @param [out] type | /// @param [out] type | ||||
/// @return Status | /// @return Status | ||||
/// | |||||
Status GetOriginalType(const ge::NodePtr &node, string &type) { | Status GetOriginalType(const ge::NodePtr &node, string &type) { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
type = node->GetType(); | type = node->GetType(); | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved. | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -34,16 +34,16 @@ namespace ge { | |||||
using google::protobuf::Message; | using google::protobuf::Message; | ||||
class AclGrphParseUtil { | |||||
class AclGraphParseUtil { | |||||
public: | public: | ||||
AclGrphParseUtil() {} | |||||
virtual ~AclGrphParseUtil() {} | |||||
AclGraphParseUtil() {} | |||||
virtual ~AclGraphParseUtil() {} | |||||
static domi::Status LoadOpsProtoLib(); | static domi::Status LoadOpsProtoLib(); | ||||
static void SaveCustomCaffeProtoPath(); | static void SaveCustomCaffeProtoPath(); | ||||
domi::Status AclParserInitialize(const std::map<std::string, std::string> &options); | domi::Status AclParserInitialize(const std::map<std::string, std::string> &options); | ||||
domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params); | |||||
domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params) const; | |||||
domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, | domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, | ||||
std::string &graph_name); | |||||
std::string &graph_name) const; | |||||
domi::Status ParseParamsAfterGraph(ge::Graph &graph, const std::map<AscendString, | domi::Status ParseParamsAfterGraph(ge::Graph &graph, const std::map<AscendString, | ||||
AscendString> &parser_params) const; | AscendString> &parser_params) const; | ||||
@@ -67,31 +67,23 @@ class AclGrphParseUtil { | |||||
}; | }; | ||||
namespace parser { | namespace parser { | ||||
/// | |||||
/// @ingroup: domi_common | /// @ingroup: domi_common | ||||
/// @brief: get length of file | /// @brief: get length of file | ||||
/// @param [in] input_file: path of file | /// @param [in] input_file: path of file | ||||
/// @return long: File length. If the file length fails to be obtained, the value -1 is returned. | /// @return long: File length. If the file length fails to be obtained, the value -1 is returned. | ||||
/// | |||||
extern long GetFileLength(const std::string &input_file); | extern long GetFileLength(const std::string &input_file); | ||||
/// | |||||
/// @ingroup domi_common | /// @ingroup domi_common | ||||
/// @brief Absolute path for obtaining files. | /// @brief Absolute path for obtaining files. | ||||
/// @param [in] path of input file | /// @param [in] path of input file | ||||
/// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned | /// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned | ||||
/// | |||||
std::string RealPath(const char *path); | std::string RealPath(const char *path); | ||||
/// | |||||
/// @ingroup domi_common | /// @ingroup domi_common | ||||
/// @brief Obtains the absolute time (timestamp) of the current system. | /// @brief Obtains the absolute time (timestamp) of the current system. | ||||
/// @return Timestamp, in microseconds (US) | /// @return Timestamp, in microseconds (US) | ||||
/// | |||||
/// | |||||
uint64_t GetCurrentTimestamp(); | uint64_t GetCurrentTimestamp(); | ||||
/// | |||||
/// @ingroup domi_common | /// @ingroup domi_common | ||||
/// @brief Reads all data from a binary file. | /// @brief Reads all data from a binary file. | ||||
/// @param [in] file_name path of file | /// @param [in] file_name path of file | ||||
@@ -99,20 +91,16 @@ uint64_t GetCurrentTimestamp(); | |||||
/// @param [out] length Output memory size | /// @param [out] length Output memory size | ||||
/// @return false fail | /// @return false fail | ||||
/// @return true success | /// @return true success | ||||
/// | |||||
bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length); | bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length); | ||||
/// | |||||
/// @ingroup domi_common | /// @ingroup domi_common | ||||
/// @brief proto file in bianary format | /// @brief proto file in bianary format | ||||
/// @param [in] file path of proto file | /// @param [in] file path of proto file | ||||
/// @param [out] proto memory for storing the proto file | /// @param [out] proto memory for storing the proto file | ||||
/// @return true success | /// @return true success | ||||
/// @return false fail | /// @return false fail | ||||
/// | |||||
bool ReadProtoFromBinaryFile(const char *file, Message *proto); | bool ReadProtoFromBinaryFile(const char *file, Message *proto); | ||||
/// | |||||
/// @ingroup domi_common | /// @ingroup domi_common | ||||
/// @brief Reads the proto structure from an array. | /// @brief Reads the proto structure from an array. | ||||
/// @param [in] data proto data to be read | /// @param [in] data proto data to be read | ||||
@@ -120,42 +108,33 @@ bool ReadProtoFromBinaryFile(const char *file, Message *proto); | |||||
/// @param [out] proto Memory for storing the proto file | /// @param [out] proto Memory for storing the proto file | ||||
/// @return true success | /// @return true success | ||||
/// @return false fail | /// @return false fail | ||||
/// | |||||
bool ReadProtoFromArray(const void *data, int size, Message *proto); | bool ReadProtoFromArray(const void *data, int size, Message *proto); | ||||
/// | |||||
/// @ingroup domi_proto | /// @ingroup domi_proto | ||||
/// @brief Reads the proto file in the text format. | /// @brief Reads the proto file in the text format. | ||||
/// @param [in] file path of proto file | /// @param [in] file path of proto file | ||||
/// @param [out] message Memory for storing the proto file | /// @param [out] message Memory for storing the proto file | ||||
/// @return true success | /// @return true success | ||||
/// @return false fail | /// @return false fail | ||||
/// | |||||
bool ReadProtoFromText(const char *file, google::protobuf::Message *message); | bool ReadProtoFromText(const char *file, google::protobuf::Message *message); | ||||
bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message); | bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message); | ||||
/// | |||||
/// @brief get the Original Type of FrameworkOp | /// @brief get the Original Type of FrameworkOp | ||||
/// @param [in] node | /// @param [in] node | ||||
/// @param [out] type | /// @param [out] type | ||||
/// @return Status | /// @return Status | ||||
/// | |||||
domi::Status GetOriginalType(const ge::NodePtr &node, string &type); | domi::Status GetOriginalType(const ge::NodePtr &node, string &type); | ||||
/// | |||||
/// @ingroup domi_common | /// @ingroup domi_common | ||||
/// @brief Check whether the file path meets the whitelist verification requirements. | /// @brief Check whether the file path meets the whitelist verification requirements. | ||||
/// @param [in] filePath file path | /// @param [in] filePath file path | ||||
/// @param [out] result | /// @param [out] result | ||||
/// | |||||
bool ValidateStr(const std::string &filePath, const std::string &mode); | bool ValidateStr(const std::string &filePath, const std::string &mode); | ||||
/// | |||||
/// @ingroup domi_common | /// @ingroup domi_common | ||||
/// @brief Obtains the current time string. | /// @brief Obtains the current time string. | ||||
/// @return Time character string in the format: %Y%m%d%H%M%S, eg: 20171011083555 | /// @return Time character string in the format: %Y%m%d%H%M%S, eg: 20171011083555 | ||||
/// | |||||
std::string CurrentTimeInStr(); | std::string CurrentTimeInStr(); | ||||
template <typename T, typename... Args> | template <typename T, typename... Args> | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019~2021. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -178,7 +178,7 @@ void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFie | |||||
switch (field->type()) { | switch (field->type()) { | ||||
case ProtobufFieldDescriptor::TYPE_MESSAGE: { | case ProtobufFieldDescriptor::TYPE_MESSAGE: { | ||||
const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i); | const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i); | ||||
if (0UL != tmp_message.ByteSizeLong()) { | |||||
if (tmp_message.ByteSizeLong() != 0UL) { | |||||
Message2Json(tmp_message, black_fields, tmp_json, enum2str, depth + 1); | Message2Json(tmp_message, black_fields, tmp_json, enum2str, depth + 1); | ||||
} | } | ||||
} break; | } break; | ||||
@@ -52,7 +52,7 @@ const char *kLocation = "location"; | |||||
} | } | ||||
namespace ge { | namespace ge { | ||||
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, | |||||
graphStatus PrepareBeforeParse(AclGraphParseUtil &acl_graph_parse_util, | |||||
const std::map<AscendString, AscendString> &parser_params, | const std::map<AscendString, AscendString> &parser_params, | ||||
ge::Graph &graph, std::shared_ptr<domi::ModelParser> &model_parser) { | ge::Graph &graph, std::shared_ptr<domi::ModelParser> &model_parser) { | ||||
GetParserContext().type = domi::ONNX; | GetParserContext().type = domi::ONNX; | ||||
@@ -82,7 +82,7 @@ graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, | |||||
return ge::SUCCESS; | return ge::SUCCESS; | ||||
} | } | ||||
graphStatus HandleAfterParse(AclGrphParseUtil &acl_graph_parse_util, | |||||
graphStatus HandleAfterParse(AclGraphParseUtil &acl_graph_parse_util, | |||||
const std::map<AscendString, AscendString> &parser_params, | const std::map<AscendString, AscendString> &parser_params, | ||||
ge::Graph &graph) { | ge::Graph &graph) { | ||||
if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | ||||
@@ -104,7 +104,7 @@ graphStatus aclgrphParseONNX(const char *model_file, | |||||
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { | const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { | ||||
GE_CHECK_NOTNULL(model_file); | GE_CHECK_NOTNULL(model_file); | ||||
// load custom plugin so and proto | // load custom plugin so and proto | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
std::shared_ptr<domi::ModelParser> model_parser; | std::shared_ptr<domi::ModelParser> model_parser; | ||||
if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { | if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { | ||||
@@ -136,7 +136,7 @@ graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, | |||||
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { | const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { | ||||
GE_CHECK_NOTNULL(buffer); | GE_CHECK_NOTNULL(buffer); | ||||
// load custom plugin so and proto | // load custom plugin so and proto | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
std::shared_ptr<domi::ModelParser> model_parser; | std::shared_ptr<domi::ModelParser> model_parser; | ||||
if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { | if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { | ||||
@@ -94,7 +94,7 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) { | |||||
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::TENSORFLOW))); | options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::TENSORFLOW))); | ||||
// load custom plugin so and proto | // load custom plugin so and proto | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
if (acl_graph_parse_util.AclParserInitialize(options) != domi::SUCCESS) { | if (acl_graph_parse_util.AclParserInitialize(options) != domi::SUCCESS) { | ||||
GELOGE(GRAPH_FAILED, "Parser Initialize failed."); | GELOGE(GRAPH_FAILED, "Parser Initialize failed."); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
@@ -142,7 +142,7 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map<Ascend | |||||
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::TENSORFLOW))); | options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(domi::TENSORFLOW))); | ||||
// load custom plugin so and proto | // load custom plugin so and proto | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
domi::Status status = acl_graph_parse_util.AclParserInitialize(options); | domi::Status status = acl_graph_parse_util.AclParserInitialize(options); | ||||
if (status != domi::SUCCESS) { | if (status != domi::SUCCESS) { | ||||
GELOGE(GRAPH_FAILED, "Parser Initialize failed."); | GELOGE(GRAPH_FAILED, "Parser Initialize failed."); | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import os | import os | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import onnx | import onnx | ||||
from onnx import helper | from onnx import helper | ||||
from onnx import AttributeProto, TensorProto, GraphProto | from onnx import AttributeProto, TensorProto, GraphProto | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import os | import os | ||||
import numpy as np | import numpy as np | ||||
import onnx | import onnx | ||||
@@ -42,4 +46,4 @@ def gen_onnx(): | |||||
print(model_def) | print(model_def) | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
gen_onnx() | |||||
gen_onnx() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import os | import os | ||||
from tensorflow.python.framework import graph_util | from tensorflow.python.framework import graph_util | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import numpy as np | import numpy as np | ||||
@@ -11,4 +15,4 @@ def generate_VarIsInitializedOp_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_VarIsInitializedOp.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_VarIsInitializedOp.pb", as_text=False) | ||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_VarIsInitializedOp_pb() | |||||
generate_VarIsInitializedOp_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import os | import os | ||||
import numpy as np | import numpy as np | ||||
@@ -38,4 +42,4 @@ def generate_case_2(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="avgpool3dgrad.pb.txt", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="avgpool3dgrad.pb.txt", as_text=False) | ||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_case_2() | |||||
generate_case_2() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import os | import os | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import numpy as np | import numpy as np | ||||
from tensorflow.python.framework import graph_util | from tensorflow.python.framework import graph_util | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import os | import os | ||||
from tensorflow.python.framework import graph_util | from tensorflow.python.framework import graph_util | ||||
@@ -23,4 +27,4 @@ def generate_add_pb(): | |||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_conv2d_pb() | generate_conv2d_pb() | ||||
generate_add_pb() | |||||
generate_add_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
from tensorflow.python.ops import control_flow_ops | from tensorflow.python.ops import control_flow_ops | ||||
@@ -10,4 +14,4 @@ def generate_enter_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_enter.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_enter.pb", as_text=False) | ||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_enter_pb() | |||||
generate_enter_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import numpy as np | import numpy as np | ||||
@@ -9,4 +13,4 @@ def generate_fill_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_fill.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_fill.pb", as_text=False) | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
generate_fill_pb() | |||||
generate_fill_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
def generate_identity_pb(): | def generate_identity_pb(): | ||||
@@ -10,4 +14,4 @@ def generate_identity_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_identity.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_identity.pb", as_text=False) | ||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_identity_pb() | |||||
generate_identity_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
from tensorflow.python.framework import graph_util | from tensorflow.python.framework import graph_util | ||||
import numpy as np | import numpy as np | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import numpy as np | import numpy as np | ||||
@@ -11,4 +15,4 @@ def generate_no_op_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_no_op.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_no_op.pb", as_text=False) | ||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_no_op_pb() | |||||
generate_no_op_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
def generate_reshape_pb(): | def generate_reshape_pb(): | ||||
@@ -7,4 +11,4 @@ def generate_reshape_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_reshape.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_reshape.pb", as_text=False) | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
generate_reshape_pb() | |||||
generate_reshape_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import numpy as np | import numpy as np | ||||
@@ -10,4 +14,4 @@ def generate_sequeeze_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_sequeeze.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_sequeeze.pb", as_text=False) | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
generate_sequeeze_pb() | |||||
generate_sequeeze_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import numpy as np | import numpy as np | ||||
@@ -8,4 +12,4 @@ def generate_shape_n_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_shape_n.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_shape_n.pb", as_text=False) | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
generate_shape_n_pb() | |||||
generate_shape_n_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
from tensorflow.python.ops import control_flow_ops | from tensorflow.python.ops import control_flow_ops | ||||
@@ -10,4 +14,4 @@ def generate_switch_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_switch.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_switch.pb", as_text=False) | ||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_switch_pb() | |||||
generate_switch_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
def generate_VariableV2_pb(): | def generate_VariableV2_pb(): | ||||
@@ -10,4 +14,4 @@ def generate_VariableV2_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_VariableV2.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_VariableV2.pb", as_text=False) | ||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_VariableV2_pb() | |||||
generate_VariableV2_pb() |
@@ -191,7 +191,7 @@ TEST_F(STestCaffeParser, caffe_parser_user_output_with_default) { | |||||
ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | ||||
auto ret = model_parser->Parse(model_file.c_str(), graph); | auto ret = model_parser->Parse(model_file.c_str(), graph); | ||||
ASSERT_EQ(ret, GRAPH_SUCCESS); | ASSERT_EQ(ret, GRAPH_SUCCESS); | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
std::map<AscendString, AscendString> parser_params; | std::map<AscendString, AscendString> parser_params; | ||||
auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); | auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); | ||||
ASSERT_EQ(status, SUCCESS); | ASSERT_EQ(status, SUCCESS); | ||||
@@ -483,7 +483,7 @@ TEST_F(STestCaffeParser, CaffeWeightsParser_CreateCustomOperator_test) | |||||
TEST_F(STestCaffeParser, CaffeWeightsParser_ParseOutputNodeTopInfo_test) | TEST_F(STestCaffeParser, CaffeWeightsParser_ParseOutputNodeTopInfo_test) | ||||
{ | { | ||||
CaffeModelParser model_parser; | CaffeModelParser model_parser; | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
domi::caffe::NetParameter net; | domi::caffe::NetParameter net; | ||||
domi::caffe::LayerParameter *lay0 = net.add_layer(); | domi::caffe::LayerParameter *lay0 = net.add_layer(); | ||||
@@ -1104,7 +1104,7 @@ TEST_F(STestTensorflowParser, parser_tensorflow_model) { | |||||
// parser tensorflow model out_node_size is equal to index | // parser tensorflow model out_node_size is equal to index | ||||
string graph_name; | string graph_name; | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
std::map<AscendString, AscendString> out_nodes_with_node_and_index = { | std::map<AscendString, AscendString> out_nodes_with_node_and_index = { | ||||
{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:1")}}; | {AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:1")}}; | ||||
ParerSTestsUtils::ClearParserInnerCtx(); | ParerSTestsUtils::ClearParserInnerCtx(); | ||||
@@ -1356,7 +1356,7 @@ TEST_F(STestTensorflowParser, tensorflow_parserAllGraph_failed) | |||||
TEST_F(STestTensorflowParser, test_parse_acl_output_nodes) | TEST_F(STestTensorflowParser, test_parse_acl_output_nodes) | ||||
{ | { | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
string graph_name; | string graph_name; | ||||
// case 1: Normal with 'node and index' | // case 1: Normal with 'node and index' | ||||
ParerSTestsUtils::ClearParserInnerCtx(); | ParerSTestsUtils::ClearParserInnerCtx(); | ||||
@@ -1523,7 +1523,7 @@ TEST_F(STestTensorflowParser, parse_AddFmkNode) | |||||
std::string modelFile = caseDir + "/origin_models/tf_add.pb"; | std::string modelFile = caseDir + "/origin_models/tf_add.pb"; | ||||
ge::Graph graph; | ge::Graph graph; | ||||
string graph_name; | string graph_name; | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
std::map<ge::AscendString, ge::AscendString> parser_options = {{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}}; | std::map<ge::AscendString, ge::AscendString> parser_options = {{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}}; | ||||
ParerSTestsUtils::ClearParserInnerCtx(); | ParerSTestsUtils::ClearParserInnerCtx(); | ||||
Status ret = acl_graph_parse_util.ParseParamsBeforeGraph(parser_options, graph_name); | Status ret = acl_graph_parse_util.ParseParamsBeforeGraph(parser_options, graph_name); | ||||
@@ -3781,9 +3781,9 @@ TEST_F(STestTensorflowParser, tensorflow_ReadBytesFromBinaryFile_test) | |||||
EXPECT_EQ(realPath, ""); | EXPECT_EQ(realPath, ""); | ||||
} | } | ||||
TEST_F(STestTensorflowParser, tensorflow_AclGrphParseUtil_ParseAclInputFp16Nodes_test) | |||||
TEST_F(STestTensorflowParser, tensorflow_AclGraphParseUtil_ParseAclInputFp16Nodes_test) | |||||
{ | { | ||||
AclGrphParseUtil parserUtil; | |||||
AclGraphParseUtil parserUtil; | |||||
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME); | ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME); | ||||
std::string input_fp16_nodes = "Add"; | std::string input_fp16_nodes = "Add"; | ||||
std::string is_input_adjust_hw_layout = "is_input_adjust_hw_layout"; | std::string is_input_adjust_hw_layout = "is_input_adjust_hw_layout"; | ||||
@@ -4010,7 +4010,7 @@ TEST_F(STestTensorflowParser, tensorflow_FP16_parser_test) | |||||
TEST_F(STestTensorflowParser, tensorflow_AclParserInitialize_test) | TEST_F(STestTensorflowParser, tensorflow_AclParserInitialize_test) | ||||
{ | { | ||||
AclGrphParseUtil parseUtil; | |||||
AclGraphParseUtil parseUtil; | |||||
std::map<std::string, std::string> options; | std::map<std::string, std::string> options; | ||||
Status ret = parseUtil.AclParserInitialize(options); | Status ret = parseUtil.AclParserInitialize(options); | ||||
EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
@@ -4022,7 +4022,7 @@ TEST_F(STestTensorflowParser, tensorflow_AclParserInitialize_test) | |||||
TEST_F(STestTensorflowParser, tensorflow_GetOutputLeaf_test) | TEST_F(STestTensorflowParser, tensorflow_GetOutputLeaf_test) | ||||
{ | { | ||||
AclGrphParseUtil parseUtil; | |||||
AclGraphParseUtil parseUtil; | |||||
ge::ComputeGraphPtr compute_graph = build_graph(true); | ge::ComputeGraphPtr compute_graph = build_graph(true); | ||||
ge::NodePtr output_nodes_info = compute_graph->FindNode("Relu3"); | ge::NodePtr output_nodes_info = compute_graph->FindNode("Relu3"); | ||||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{output_nodes_info,0}}; | std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{output_nodes_info,0}}; | ||||
@@ -189,7 +189,7 @@ TEST_F(UtestCaffeParser, caffe_parser_user_output_with_name_and_index) { | |||||
ge::GetParserContext().user_out_nodes.push_back({"abs", 0}); | ge::GetParserContext().user_out_nodes.push_back({"abs", 0}); | ||||
auto ret = model_parser->Parse(model_file.c_str(), graph); | auto ret = model_parser->Parse(model_file.c_str(), graph); | ||||
ASSERT_EQ(ret, GRAPH_SUCCESS); | ASSERT_EQ(ret, GRAPH_SUCCESS); | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
std::map<AscendString, AscendString> parser_params; | std::map<AscendString, AscendString> parser_params; | ||||
auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); | auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); | ||||
ASSERT_EQ(status, SUCCESS); | ASSERT_EQ(status, SUCCESS); | ||||
@@ -216,7 +216,7 @@ TEST_F(UtestCaffeParser, caffe_parser_user_output_with_top_name) { | |||||
ge::GetParserContext().user_out_tensors.push_back("abs_out"); | ge::GetParserContext().user_out_tensors.push_back("abs_out"); | ||||
auto ret = model_parser->Parse(model_file.c_str(), graph); | auto ret = model_parser->Parse(model_file.c_str(), graph); | ||||
ASSERT_EQ(ret, GRAPH_SUCCESS); | ASSERT_EQ(ret, GRAPH_SUCCESS); | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
std::map<AscendString, AscendString> parser_params; | std::map<AscendString, AscendString> parser_params; | ||||
auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); | auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); | ||||
ASSERT_EQ(status, SUCCESS); | ASSERT_EQ(status, SUCCESS); | ||||
@@ -241,7 +241,7 @@ TEST_F(UtestCaffeParser, caffe_parser_user_output_with_default) { | |||||
ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | ||||
auto ret = model_parser->Parse(model_file.c_str(), graph); | auto ret = model_parser->Parse(model_file.c_str(), graph); | ||||
ASSERT_EQ(ret, GRAPH_SUCCESS); | ASSERT_EQ(ret, GRAPH_SUCCESS); | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
std::map<AscendString, AscendString> parser_params; | std::map<AscendString, AscendString> parser_params; | ||||
auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); | auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); | ||||
ASSERT_EQ(status, SUCCESS); | ASSERT_EQ(status, SUCCESS); | ||||
@@ -543,7 +543,7 @@ TEST_F(UtestCaffeParser, CaffeWeightsParser_CreateCustomOperator_test) | |||||
TEST_F(UtestCaffeParser, CaffeWeightsParser_ParseOutputNodeTopInfo_test) | TEST_F(UtestCaffeParser, CaffeWeightsParser_ParseOutputNodeTopInfo_test) | ||||
{ | { | ||||
CaffeModelParser model_parser; | CaffeModelParser model_parser; | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
domi::caffe::NetParameter net; | domi::caffe::NetParameter net; | ||||
domi::caffe::LayerParameter *lay0 = net.add_layer(); | domi::caffe::LayerParameter *lay0 = net.add_layer(); | ||||
@@ -53,7 +53,7 @@ class UtestAclGraphParser : public testing::Test { | |||||
}; | }; | ||||
TEST_F(UtestAclGraphParser, test_parse_acl_output_nodes) { | TEST_F(UtestAclGraphParser, test_parse_acl_output_nodes) { | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
string graph_name; | string graph_name; | ||||
// case 1: Normal with 'node and index' | // case 1: Normal with 'node and index' | ||||
ParerUTestsUtils::ClearParserInnerCtx(); | ParerUTestsUtils::ClearParserInnerCtx(); | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import onnx | import onnx | ||||
from onnx import helper | from onnx import helper | ||||
from onnx import AttributeProto, TensorProto, GraphProto | from onnx import AttributeProto, TensorProto, GraphProto | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import os | import os | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import onnx | import onnx | ||||
from onnx import helper | from onnx import helper | ||||
from onnx import AttributeProto, TensorProto, GraphProto | from onnx import AttributeProto, TensorProto, GraphProto | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import os | import os | ||||
from tensorflow.python.framework import graph_util | from tensorflow.python.framework import graph_util | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import numpy as np | import numpy as np | ||||
@@ -11,4 +15,4 @@ def generate_VarIsInitializedOp_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_VarIsInitializedOp.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_VarIsInitializedOp.pb", as_text=False) | ||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_VarIsInitializedOp_pb() | |||||
generate_VarIsInitializedOp_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import os | import os | ||||
import numpy as np | import numpy as np | ||||
@@ -38,4 +42,4 @@ def generate_case_2(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="avgpool3dgrad.pb.txt", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="avgpool3dgrad.pb.txt", as_text=False) | ||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_case_2() | |||||
generate_case_2() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import os | import os | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import numpy as np | import numpy as np | ||||
from tensorflow.python.framework import graph_util | from tensorflow.python.framework import graph_util | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import os | import os | ||||
from tensorflow.python.framework import graph_util | from tensorflow.python.framework import graph_util | ||||
@@ -23,4 +27,4 @@ def generate_add_pb(): | |||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_conv2d_pb() | generate_conv2d_pb() | ||||
generate_add_pb() | |||||
generate_add_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
from tensorflow.python.ops import control_flow_ops | from tensorflow.python.ops import control_flow_ops | ||||
@@ -10,4 +14,4 @@ def generate_enter_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_enter.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_enter.pb", as_text=False) | ||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_enter_pb() | |||||
generate_enter_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import numpy as np | import numpy as np | ||||
@@ -9,4 +13,4 @@ def generate_fill_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_fill.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_fill.pb", as_text=False) | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
generate_fill_pb() | |||||
generate_fill_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
def generate_identity_pb(): | def generate_identity_pb(): | ||||
@@ -10,4 +14,4 @@ def generate_identity_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_identity.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_identity.pb", as_text=False) | ||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_identity_pb() | |||||
generate_identity_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
from tensorflow.python.framework import graph_util | from tensorflow.python.framework import graph_util | ||||
import numpy as np | import numpy as np | ||||
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import numpy as np | import numpy as np | ||||
@@ -11,4 +15,4 @@ def generate_no_op_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_no_op.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_no_op.pb", as_text=False) | ||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_no_op_pb() | |||||
generate_no_op_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
def generate_reshape_pb(): | def generate_reshape_pb(): | ||||
@@ -7,4 +11,4 @@ def generate_reshape_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_reshape.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_reshape.pb", as_text=False) | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
generate_reshape_pb() | |||||
generate_reshape_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import numpy as np | import numpy as np | ||||
@@ -10,4 +14,4 @@ def generate_sequeeze_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_sequeeze.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_sequeeze.pb", as_text=False) | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
generate_sequeeze_pb() | |||||
generate_sequeeze_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import numpy as np | import numpy as np | ||||
@@ -8,4 +12,4 @@ def generate_shape_n_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_shape_n.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_shape_n.pb", as_text=False) | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
generate_shape_n_pb() | |||||
generate_shape_n_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
from tensorflow.python.ops import control_flow_ops | from tensorflow.python.ops import control_flow_ops | ||||
@@ -10,4 +14,4 @@ def generate_switch_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_switch.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_switch.pb", as_text=False) | ||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_switch_pb() | |||||
generate_switch_pb() |
@@ -1,3 +1,7 @@ | |||||
#!/usr/bin/env python3 | |||||
# -*- coding utf-8 -*- | |||||
# Copyright Huawei Technologies Co., Ltd 2019-2022. All rights reserved. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
def generate_VariableV2_pb(): | def generate_VariableV2_pb(): | ||||
@@ -10,4 +14,4 @@ def generate_VariableV2_pb(): | |||||
tf.io.write_graph(sess.graph, logdir="./", name="test_VariableV2.pb", as_text=False) | tf.io.write_graph(sess.graph, logdir="./", name="test_VariableV2.pb", as_text=False) | ||||
if __name__=='__main__': | if __name__=='__main__': | ||||
generate_VariableV2_pb() | |||||
generate_VariableV2_pb() |
@@ -1106,7 +1106,7 @@ TEST_F(UtestTensorflowParser, parser_tensorflow_model) { | |||||
// parser tensorflow model out_node_size is equal to index | // parser tensorflow model out_node_size is equal to index | ||||
string graph_name; | string graph_name; | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
std::map<AscendString, AscendString> out_nodes_with_node_and_index = { | std::map<AscendString, AscendString> out_nodes_with_node_and_index = { | ||||
{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:1")}}; | {AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:1")}}; | ||||
ParerUTestsUtils::ClearParserInnerCtx(); | ParerUTestsUtils::ClearParserInnerCtx(); | ||||
@@ -1452,7 +1452,7 @@ TEST_F(UtestTensorflowParser, tensorflow_parserAllGraph_failed) | |||||
TEST_F(UtestTensorflowParser, test_parse_acl_output_nodes) | TEST_F(UtestTensorflowParser, test_parse_acl_output_nodes) | ||||
{ | { | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
string graph_name; | string graph_name; | ||||
// case 1: Normal with 'node and index' | // case 1: Normal with 'node and index' | ||||
ParerUTestsUtils::ClearParserInnerCtx(); | ParerUTestsUtils::ClearParserInnerCtx(); | ||||
@@ -1621,7 +1621,7 @@ TEST_F(UtestTensorflowParser, parse_AddFmkNode) | |||||
std::string modelFile = caseDir + "/tensorflow_model/tf_add.pb"; | std::string modelFile = caseDir + "/tensorflow_model/tf_add.pb"; | ||||
ge::Graph graph; | ge::Graph graph; | ||||
string graph_name; | string graph_name; | ||||
AclGrphParseUtil acl_graph_parse_util; | |||||
AclGraphParseUtil acl_graph_parse_util; | |||||
std::map<ge::AscendString, ge::AscendString> parser_options = {{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}}; | std::map<ge::AscendString, ge::AscendString> parser_options = {{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}}; | ||||
ParerUTestsUtils::ClearParserInnerCtx(); | ParerUTestsUtils::ClearParserInnerCtx(); | ||||
Status ret = acl_graph_parse_util.ParseParamsBeforeGraph(parser_options, graph_name); | Status ret = acl_graph_parse_util.ParseParamsBeforeGraph(parser_options, graph_name); | ||||
@@ -3885,9 +3885,9 @@ TEST_F(UtestTensorflowParser, tensorflow_ReadBytesFromBinaryFile_test) | |||||
EXPECT_EQ(realPath, ""); | EXPECT_EQ(realPath, ""); | ||||
} | } | ||||
TEST_F(UtestTensorflowParser, tensorflow_AclGrphParseUtil_ParseAclInputFp16Nodes_test) | |||||
TEST_F(UtestTensorflowParser, tensorflow_AclGraphParseUtil_ParseAclInputFp16Nodes_test) | |||||
{ | { | ||||
AclGrphParseUtil parserUtil; | |||||
AclGraphParseUtil parserUtil; | |||||
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME); | ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME); | ||||
std::string input_fp16_nodes = "Add"; | std::string input_fp16_nodes = "Add"; | ||||
std::string is_input_adjust_hw_layout = "is_input_adjust_hw_layout"; | std::string is_input_adjust_hw_layout = "is_input_adjust_hw_layout"; | ||||
@@ -4094,7 +4094,7 @@ TEST_F(UtestTensorflowParser, tensorflow_FP16_parser_test) | |||||
TEST_F(UtestTensorflowParser, tensorflow_AclParserInitialize_test) | TEST_F(UtestTensorflowParser, tensorflow_AclParserInitialize_test) | ||||
{ | { | ||||
AclGrphParseUtil parseUtil; | |||||
AclGraphParseUtil parseUtil; | |||||
std::map<std::string, std::string> options; | std::map<std::string, std::string> options; | ||||
Status ret = parseUtil.AclParserInitialize(options); | Status ret = parseUtil.AclParserInitialize(options); | ||||
EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
@@ -4106,7 +4106,7 @@ TEST_F(UtestTensorflowParser, tensorflow_AclParserInitialize_test) | |||||
TEST_F(UtestTensorflowParser, tensorflow_GetOutputLeaf_test) | TEST_F(UtestTensorflowParser, tensorflow_GetOutputLeaf_test) | ||||
{ | { | ||||
AclGrphParseUtil parseUtil; | |||||
AclGraphParseUtil parseUtil; | |||||
ge::ComputeGraphPtr compute_graph = build_graph(true); | ge::ComputeGraphPtr compute_graph = build_graph(true); | ||||
ge::NodePtr output_nodes_info = compute_graph->FindNode("Relu3"); | ge::NodePtr output_nodes_info = compute_graph->FindNode("Relu3"); | ||||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{output_nodes_info,0}}; | std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{output_nodes_info,0}}; | ||||