@@ -1,4 +1,4 @@ | |||||
[submodule "metadef"] | [submodule "metadef"] | ||||
path = metadef | path = metadef | ||||
url = https://gitee.com/ascend/metadef.git | url = https://gitee.com/ascend/metadef.git | ||||
branch = development | |||||
branch = r1.2.0 |
@@ -33,11 +33,11 @@ if (ENABLE_OPEN_SRC) | |||||
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | ||||
endif() | endif() | ||||
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | ||||
find_module(slog libslog.so ${GE_LIB_PATH}) | |||||
find_module(slog libalog.so ${GE_LIB_PATH}) | |||||
find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) | find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) | ||||
find_module(error_manager liberror_manager.so ${GE_LIB_PATH}) | find_module(error_manager liberror_manager.so ${GE_LIB_PATH}) | ||||
elseif(ENABLE_GE_COV OR ENABLE_GE_UT) | elseif(ENABLE_GE_COV OR ENABLE_GE_UT) | ||||
message(STATUS "Runing on llt mode, no need to depend other component") | |||||
message(STATUS "Running on llt mode, no need to depend other component.") | |||||
else() | else() | ||||
if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | ||||
set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) | set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) | ||||
@@ -47,7 +47,7 @@ if (ENABLE_OPEN_SRC) | |||||
set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) | set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) | ||||
find_module(slog libslog.so ${ASCEND_ATC_DIR}) | |||||
find_module(slog libalog.so ${ASCEND_ATC_DIR}) | |||||
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | ||||
find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | ||||
endif() | endif() | ||||
@@ -2,6 +2,9 @@ approvers: | |||||
- ji_chen | - ji_chen | ||||
- wqtshg | - wqtshg | ||||
- ljl0711 | - ljl0711 | ||||
- startzgf168 | |||||
- lbisdaddy | |||||
- andylhy | |||||
reviewers: | reviewers: | ||||
- xchu42 | - xchu42 | ||||
- sheng-nan | - sheng-nan |
@@ -7,7 +7,6 @@ function(find_module module name path) | |||||
if (TARGET ${module}) | if (TARGET ${module}) | ||||
return() | return() | ||||
endif() | endif() | ||||
add_library(${module} INTERFACE) | |||||
find_library(${module}_LIBRARY_DIR NAMES ${name} NAMES_PER_DIR PATHS ${path} | find_library(${module}_LIBRARY_DIR NAMES ${name} NAMES_PER_DIR PATHS ${path} | ||||
PATH_SUFFIXES lib | PATH_SUFFIXES lib | ||||
) | ) | ||||
@@ -16,5 +15,9 @@ function(find_module module name path) | |||||
if ("${${module}_LIBRARY_DIR}" STREQUAL "${module}_LIBRARY_DIR-NOTFOUND") | if ("${${module}_LIBRARY_DIR}" STREQUAL "${module}_LIBRARY_DIR-NOTFOUND") | ||||
message(FATAL_ERROR "${name} not found in ${path}") | message(FATAL_ERROR "${name} not found in ${path}") | ||||
endif() | endif() | ||||
target_link_libraries(${module} INTERFACE ${${module}_LIBRARY_DIR}) | |||||
add_library(${module} SHARED IMPORTED) | |||||
set_target_properties(${module} PROPERTIES | |||||
IMPORTED_LOCATION ${${module}_LIBRARY_DIR} | |||||
) | |||||
endfunction() | endfunction() |
@@ -16,6 +16,7 @@ target_compile_definitions(intf_pub INTERFACE | |||||
$<$<CONFIG:Debug>:CFG_BUILD_DEBUG> | $<$<CONFIG:Debug>:CFG_BUILD_DEBUG> | ||||
WIN64=1 | WIN64=1 | ||||
LINUX=0 | LINUX=0 | ||||
LOG_CPP | |||||
) | ) | ||||
target_link_options(intf_pub INTERFACE | target_link_options(intf_pub INTERFACE | ||||
-Wl,-z,relro | -Wl,-z,relro | ||||
@@ -17,17 +17,23 @@ | |||||
#ifndef INC_EXTERNAL_PARSER_ONNX_PARSER_H_ | #ifndef INC_EXTERNAL_PARSER_ONNX_PARSER_H_ | ||||
#define INC_EXTERNAL_PARSER_ONNX_PARSER_H_ | #define INC_EXTERNAL_PARSER_ONNX_PARSER_H_ | ||||
#include <memory> | |||||
#include <map> | |||||
#include "graph/ascend_string.h" | #include "graph/ascend_string.h" | ||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
#include "graph/graph.h" | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/graph.h" | |||||
namespace ge { | namespace ge { | ||||
graphStatus aclgrphParseONNX(const char *model_file, | graphStatus aclgrphParseONNX(const char *model_file, | ||||
const std::map<ge::AscendString, ge::AscendString> &parser_params, ge::Graph &graph); | |||||
const std::map<ge::AscendString, | |||||
ge::AscendString> &parser_params, | |||||
ge::Graph &graph); | |||||
graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, | |||||
const std::map<ge::AscendString, ge::AscendString> &parser_params, | |||||
graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t buffer_size, | |||||
const std::map<ge::AscendString, | |||||
ge::AscendString> &parser_params, | |||||
ge::Graph &graph); | ge::Graph &graph); | ||||
} // namespace ge | } // namespace ge | ||||
@@ -1 +1 @@ | |||||
Subproject commit c14d2be38171eed63416e71178774103faf1f5cd | |||||
Subproject commit af156f825aa53a24bd30ae4065e3ea356cf555ef |
@@ -193,6 +193,7 @@ const int kMaxParseDepth = 5; | |||||
const int32_t kMinLineWorldSize = 3; | const int32_t kMinLineWorldSize = 3; | ||||
const int32_t kMaxIdentifier = 536870911; // 2^29 - 1 | const int32_t kMaxIdentifier = 536870911; // 2^29 - 1 | ||||
const int32_t kBase = 10; | const int32_t kBase = 10; | ||||
const uint32_t kInteval = 2; | |||||
const char *const kPython = "Python"; | const char *const kPython = "Python"; | ||||
const char *const kProposalLayer = "ProposalLayer"; | const char *const kProposalLayer = "ProposalLayer"; | ||||
const char *const kDetectionOutput = "DetectionOutput"; | const char *const kDetectionOutput = "DetectionOutput"; | ||||
@@ -793,13 +794,22 @@ Status CaffeModelParser::ParseRepeatedField(const google::protobuf::Reflection * | |||||
CASE_FIELD_TYPE_REPEATED(STRING, String, string); | CASE_FIELD_TYPE_REPEATED(STRING, String, string); | ||||
#undef CASE_FIELD_TYPE_REPEATED | #undef CASE_FIELD_TYPE_REPEATED | ||||
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { | case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { | ||||
for (int i = 0; i < field_size; ++i) { | |||||
const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i); | |||||
if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) { | |||||
GELOGE(FAILED, "ParseOperatorAttrs of field: %s failed.", field->name().c_str()); | |||||
return FAILED; | |||||
} | |||||
nlohmann::json message_json; | |||||
Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set<string>(), | |||||
message_json[field->name()], false); | |||||
std::string repeated_message_str; | |||||
try { | |||||
repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore); | |||||
} catch (std::exception &e) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19007", {"exception"}, {e.what()}); | |||||
GELOGE(FAILED, "Failed to convert JSON to string, reason: %s.", e.what()); | |||||
return FAILED; | |||||
} catch (...) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19008"); | |||||
GELOGE(FAILED, "Failed to convert JSON to string."); | |||||
return FAILED; | |||||
} | } | ||||
(void)ops.SetAttr(field->name(), repeated_message_str); | |||||
break; | break; | ||||
} | } | ||||
default: { | default: { | ||||
@@ -56,20 +56,17 @@ class CaffeModelParser : public domi::ModelParser { | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
* @brief Parse the relevant data from memory and save it to graph | |||||
* @param [in] memory buffer of model file | |||||
* @param [in] buffer size | |||||
* @brief Parse the relevant data from the memory and save it to graph | |||||
* @param [in] file Path of model file | |||||
* @param [in|out] graph graph for saving model information | * @param [in|out] graph graph for saving model information | ||||
* @return SUCCESS parse successfully | * @return SUCCESS parse successfully | ||||
* @return FAILED parse failed | * @return FAILED parse failed | ||||
*/ | */ | ||||
Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; | Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override { | Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override { | ||||
return domi::SUCCESS; | |||||
return domi::SUCCESS; | |||||
} | } | ||||
#endif | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -30,7 +30,6 @@ enum DataType | |||||
DT_RESOURCE = 23; // resource type | DT_RESOURCE = 23; // resource type | ||||
DT_STRING_REF = 24; // string_ref type | DT_STRING_REF = 24; // string_ref type | ||||
DT_DUAL = 25; /**< dual output type */ | DT_DUAL = 25; /**< dual output type */ | ||||
DT_VARIANT = 26; // variant type | |||||
} | } | ||||
message AttrDef | message AttrDef | ||||
@@ -406,53 +406,6 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
domi::Status AclGrphParseUtil::ParseAclWeightCompressConf(const ComputeGraphPtr &graph, | |||||
const string &compress_weight_conf) { | |||||
GE_CHECK_NOTNULL(graph); | |||||
if (compress_weight_conf.empty()) { | |||||
return SUCCESS; | |||||
} | |||||
std::string real_path = ge::parser::RealPath(compress_weight_conf.c_str()); | |||||
if (real_path.empty()) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||||
{"compress_weight_conf", compress_weight_conf}); | |||||
GELOGE(PARAM_INVALID, "Can not get real path for %s.", compress_weight_conf.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
std::ifstream ifs(real_path); | |||||
if (!ifs.is_open()) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||||
{"compress_weight_conf", compress_weight_conf}); | |||||
GELOGE(FAILED, "Open file %s failed", compress_weight_conf.c_str()); | |||||
return FAILED; | |||||
} | |||||
std::string compress_nodes; | |||||
ifs >> compress_nodes; | |||||
ifs.close(); | |||||
if (compress_nodes.empty()) { | |||||
GELOGW("Compress weight of nodes info is empty"); | |||||
return SUCCESS; | |||||
} | |||||
GELOGI("Compress weight of nodes: %s", compress_nodes.c_str()); | |||||
vector<string> compress_node_vec = StringUtils::Split(compress_nodes, ';'); | |||||
for (size_t i = 0; i < compress_node_vec.size(); ++i) { | |||||
ge::NodePtr node = graph->FindNode(compress_node_vec[i]); | |||||
if (node == nullptr) { | |||||
GELOGW("Node %s is not in graph", compress_node_vec[i].c_str()); | |||||
continue; | |||||
} | |||||
auto op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (!ge::AttrUtils::SetBool(op_desc, ge::ATTR_NAME_COMPRESS_WEIGHT, true)) { | |||||
GELOGE(domi::FAILED, "Node %s SetBool failed.", compress_node_vec[i].c_str()); | |||||
return domi::FAILED; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | ||||
std::vector<std::string> &output_nodes_name) { | std::vector<std::string> &output_nodes_name) { | ||||
output_nodes_name.clear(); | output_nodes_name.clear(); | ||||
@@ -641,7 +594,7 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendStrin | |||||
domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | ||||
const std::map<AscendString, AscendString> &parser_params) { | const std::map<AscendString, AscendString> &parser_params) { | ||||
// support paragrams: input_fp16_nodes, is_input_adjust_hw_layout, compress_weight_conf, | |||||
// support paragrams: input_fp16_nodes, is_input_adjust_hw_layout | |||||
ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); | ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); | ||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
@@ -654,11 +607,6 @@ domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | |||||
ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS, | ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS, | ||||
return PARAM_INVALID, "Parse input_fp16_nodes failed"); | return PARAM_INVALID, "Parse input_fp16_nodes failed"); | ||||
string compress_weight_conf; | |||||
GetAclParams(parser_params, ge::ir_option::COMPRESS_WEIGHT_CONF, compress_weight_conf); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclWeightCompressConf(compute_graph, compress_weight_conf) != SUCCESS, | |||||
return PARAM_INVALID, "Parse compress_weight_conf failed"); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -784,7 +732,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co | |||||
google::protobuf::io::CodedInputStream coded_stream(&istream); | google::protobuf::io::CodedInputStream coded_stream(&istream); | ||||
bool ret = ReadProtoFromCodedInputStream(coded_stream, proto); | bool ret = ReadProtoFromCodedInputStream(coded_stream, proto); | ||||
fs.close(); | fs.close(); | ||||
if (!ret) { | if (!ret) { | ||||
@@ -60,7 +60,6 @@ class AclGrphParseUtil { | |||||
uint32_t index, OpDescPtr &op_desc); | uint32_t index, OpDescPtr &op_desc); | ||||
domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, | domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, | ||||
const string &is_input_adjust_hw_layout); | const string &is_input_adjust_hw_layout); | ||||
domi::Status ParseAclWeightCompressConf(const ComputeGraphPtr &graph, const string &compress_weight_conf); | |||||
domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | ||||
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | ||||
}; | }; | ||||
@@ -47,11 +47,11 @@ class Pb2Json { | |||||
static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json, | static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json, | ||||
bool enum2str = false); | bool enum2str = false); | ||||
protected: | |||||
static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | ||||
const ProtobufReflection *reflection, const std::set<std::string> &black_fields, | const ProtobufReflection *reflection, const std::set<std::string> &black_fields, | ||||
Json &json, bool enum2str); | Json &json, bool enum2str); | ||||
protected: | |||||
static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, | static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, | ||||
bool enum2str, Json &json); | bool enum2str, Json &json); | ||||
@@ -16,7 +16,6 @@ | |||||
#include "framework/omg/parser/parser_api.h" | #include "framework/omg/parser/parser_api.h" | ||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "tbe_plugin_loader.h" | #include "tbe_plugin_loader.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "parser/common/register_tbe.h" | #include "parser/common/register_tbe.h" | ||||
@@ -41,7 +40,7 @@ Status ParserInitialize(const std::map<std::string, std::string> &options) { | |||||
std::string fmk_type = std::to_string(domi::TENSORFLOW); | std::string fmk_type = std::to_string(domi::TENSORFLOW); | ||||
auto it = options.find(ge::FRAMEWORK_TYPE); | auto it = options.find(ge::FRAMEWORK_TYPE); | ||||
if (it != options.end()) { | if (it != options.end()) { | ||||
fmk_type = it->second; | |||||
fmk_type = it->second; | |||||
} | } | ||||
std::vector<OpRegistrationData> registrationDatas = domi::OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> registrationDatas = domi::OpRegistry::Instance()->registrationDatas; | ||||
GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size()); | GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size()); | ||||
@@ -28,6 +28,29 @@ | |||||
#include "register/op_registry.h" | #include "register/op_registry.h" | ||||
namespace ge { | namespace ge { | ||||
namespace { | |||||
Status HandleNewOp(const NodePtr &node, const ComputeGraphPtr &compute_graph, const NodePtr &new_node) { | |||||
GE_CHECK_NOTNULL(node); | |||||
GE_CHECK_NOTNULL(new_node); | |||||
if (new_node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Set owner graph for node:%s failed.", new_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
auto op_desc = new_node->GetOpDesc(); | |||||
static std::atomic_long new_node_index(0); | |||||
auto new_name = "PartitionedCall_" + new_node->GetName() + "_" + to_string(new_node_index++); | |||||
op_desc->SetName(new_name); | |||||
bool ret = ge::AttrUtils::SetListStr(op_desc, | |||||
ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, | |||||
std::move(std::vector<std::string>{node->GetName()})); | |||||
if (!ret) { | |||||
GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), op_desc->GetName().c_str()); | |||||
} | |||||
GELOGD("Handle new op[%s] for node[%s] success.", new_node->GetName().c_str(), node->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
} | |||||
Status ParserUtils::ExpandOneToManyGraph(Graph &graph) { | Status ParserUtils::ExpandOneToManyGraph(Graph &graph) { | ||||
GELOGD("Begin run ParserUtils::ExpandOneToManyGraph."); | GELOGD("Begin run ParserUtils::ExpandOneToManyGraph."); | ||||
for (const auto &gn : graph.GetDirectNode()) { | for (const auto &gn : graph.GetDirectNode()) { | ||||
@@ -68,17 +91,14 @@ Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &n | |||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
// add subgraph node to graph. | // add subgraph node to graph. | ||||
std::unordered_map<std::string, NodePtr> all_new_nodes; | |||||
std::vector<NodePtr> input_nodes; | std::vector<NodePtr> input_nodes; | ||||
for (const auto &n : sub_compute_graph->GetDirectNode()) { | for (const auto &n : sub_compute_graph->GetDirectNode()) { | ||||
auto new_node = compute_graph->AddNode(n); | auto new_node = compute_graph->AddNode(n); | ||||
GE_CHECK_NOTNULL(new_node); | GE_CHECK_NOTNULL(new_node); | ||||
all_new_nodes[new_node->GetName()] = new_node; | |||||
if (new_node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Set owner graph for node:%s failed.", new_node->GetName().c_str()); | |||||
if (HandleNewOp(node, compute_graph, new_node) != SUCCESS) { | |||||
GELOGE(FAILED, "Handle new op[%s] for node[%s] failed.", new_node->GetName().c_str(), node->GetName().c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
if (new_node->GetType() == ge::parser::DATA) { | if (new_node->GetType() == ge::parser::DATA) { | ||||
input_nodes.emplace_back(new_node); | input_nodes.emplace_back(new_node); | ||||
} | } | ||||
@@ -30,7 +30,6 @@ enum DataType | |||||
DT_RESOURCE = 23; // resource type | DT_RESOURCE = 23; // resource type | ||||
DT_STRING_REF = 24; // string_ref type | DT_STRING_REF = 24; // string_ref type | ||||
DT_DUAL = 25; /**< dual output type */ | DT_DUAL = 25; /**< dual output type */ | ||||
DT_VARIANT = 26; // variant type | |||||
} | } | ||||
message AttrDef | message AttrDef | ||||
@@ -115,7 +115,6 @@ target_include_directories(fmk_onnx_parser_stub PRIVATE | |||||
${PARSER_DIR}/parser | ${PARSER_DIR}/parser | ||||
${PARSER_DIR}/../inc | ${PARSER_DIR}/../inc | ||||
${METADEF_DIR}/inc | ${METADEF_DIR}/inc | ||||
${METADEF_DIR}/inc/graph | |||||
${METADEF_DIR}/inc/external | ${METADEF_DIR}/inc/external | ||||
${METADEF_DIR}/inc/external/graph | ${METADEF_DIR}/inc/external/graph | ||||
) | ) | ||||
@@ -52,7 +52,7 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
libregister \ | libregister \ | ||||
liberror_manager \ | liberror_manager \ | ||||
LOCAL_STATIC_LIBRARIES += libmmpa | |||||
LOCAL_STATIC_LIBRARIES += libmmpa | |||||
LOCAL_LDFLAGS := -lrt -ldl | LOCAL_LDFLAGS := -lrt -ldl | ||||
@@ -62,7 +62,6 @@ include $(BUILD_HOST_SHARED_LIBRARY) | |||||
include $(CLEAR_VARS) | include $(CLEAR_VARS) | ||||
LOCAL_C_INCLUDES := \ | LOCAL_C_INCLUDES := \ | ||||
$(TOPDIR)inc \ | |||||
$(TOPDIR)metadef/inc \ | $(TOPDIR)metadef/inc \ | ||||
$(TOPDIR)parser/inc \ | $(TOPDIR)parser/inc \ | ||||
$(TOPDIR)inc/external \ | $(TOPDIR)inc/external \ | ||||
@@ -88,4 +87,3 @@ LOCAL_SHARED_LIBRARIES := | |||||
LOCAL_LDFLAGS := -lrt -ldl | LOCAL_LDFLAGS := -lrt -ldl | ||||
include $(BUILD_HOST_SHARED_LIBRARY) | include $(BUILD_HOST_SHARED_LIBRARY) | ||||
@@ -19,8 +19,8 @@ | |||||
#include <iostream> | #include <iostream> | ||||
#include "common/convert/pb2json.h" | #include "common/convert/pb2json.h" | ||||
#include "common/util.h" | #include "common/util.h" | ||||
#include "common/ge_types.h" | |||||
#include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
#include "common/ge_types.h" | |||||
#include "external/graph/operator_factory.h" | #include "external/graph/operator_factory.h" | ||||
#include "external/register/register_error_codes.h" | #include "external/register/register_error_codes.h" | ||||
#include "external/parser/onnx_parser.h" | #include "external/parser/onnx_parser.h" | ||||
@@ -39,17 +39,18 @@ | |||||
#include "register/op_registry.h" | #include "register/op_registry.h" | ||||
namespace ge { | namespace ge { | ||||
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, | |||||
const std::map<AscendString, AscendString> &parser_params, | |||||
ge::Graph &graph, std::shared_ptr<domi::ModelParser> &model_parser) { | |||||
graphStatus aclgrphParseONNX(const char *model_file, | |||||
const std::map<AscendString, | |||||
AscendString> &parser_params, | |||||
ge::Graph &graph) { | |||||
GE_CHECK_NOTNULL(model_file); | |||||
GetParserContext().type = domi::ONNX; | GetParserContext().type = domi::ONNX; | ||||
std::map<string, string> options; | std::map<string, string> options; | ||||
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX))); | options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX))); | ||||
if (acl_graph_parse_util.AclParserInitialize(options) != ge::SUCCESS) { | |||||
GELOGE(ge::FAILED, "Acl parser initialize failed."); | |||||
return ge::FAILED; | |||||
} | |||||
// load custom plugin so and proto | |||||
AclGrphParseUtil acl_graph_parse_util; | |||||
(void)acl_graph_parse_util.AclParserInitialize(options); | |||||
string output_name; | string output_name; | ||||
if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { | if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { | ||||
@@ -62,40 +63,9 @@ graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, | |||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | ||||
model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); | |||||
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); | |||||
GE_CHECK_NOTNULL(model_parser); | GE_CHECK_NOTNULL(model_parser); | ||||
return ge::SUCCESS; | |||||
} | |||||
graphStatus HandleAfterParse(AclGrphParseUtil &acl_graph_parse_util, | |||||
const std::map<AscendString, AscendString> &parser_params, | |||||
ge::Graph &graph) { | |||||
if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | |||||
GELOGE(ge::FAILED, "Parser params after graph failed."); | |||||
return ge::FAILED; | |||||
} | |||||
if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||||
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||||
return ge::FAILED; | |||||
} | |||||
return ge::SUCCESS; | |||||
} | |||||
graphStatus aclgrphParseONNX(const char *model_file, | |||||
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { | |||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
GE_CHECK_NOTNULL(model_file); | |||||
// load custom plugin so and proto | |||||
AclGrphParseUtil acl_graph_parse_util; | |||||
std::shared_ptr<domi::ModelParser> model_parser; | |||||
if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { | |||||
GELOGE(ge::FAILED, "Prepare before parse failed."); | |||||
return ge::FAILED; | |||||
} | |||||
GE_CHECK_NOTNULL(model_parser); | |||||
// parse caffe model_file to GE graph | // parse caffe model_file to GE graph | ||||
ge::graphStatus ret = model_parser->Parse(model_file, graph); | ge::graphStatus ret = model_parser->Parse(model_file, graph); | ||||
if (ret != ge::SUCCESS) { | if (ret != ge::SUCCESS) { | ||||
@@ -104,44 +74,65 @@ graphStatus aclgrphParseONNX(const char *model_file, | |||||
} | } | ||||
GELOGI("Parser graph %s success.", graph.GetName().c_str()); | GELOGI("Parser graph %s success.", graph.GetName().c_str()); | ||||
if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) { | |||||
GELOGE(ge::FAILED, "Handle after parse failed."); | |||||
if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | |||||
GELOGE(ge::FAILED, "Parser params after graph failed."); | |||||
return ge::FAILED; | return ge::FAILED; | ||||
} | } | ||||
if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||||
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||||
return ge::FAILED; | |||||
} | |||||
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); | GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); | ||||
#endif | |||||
return ge::SUCCESS; | return ge::SUCCESS; | ||||
} | } | ||||
graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, | |||||
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { | |||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t buffer_size, | |||||
const std::map<AscendString, | |||||
AscendString> &parser_params, | |||||
ge::Graph &graph) { | |||||
GE_CHECK_NOTNULL(buffer); | GE_CHECK_NOTNULL(buffer); | ||||
GetParserContext().type = domi::ONNX; | |||||
std::map<string, string> options; | |||||
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX))); | |||||
// load custom plugin so and proto | // load custom plugin so and proto | ||||
AclGrphParseUtil acl_graph_parse_util; | AclGrphParseUtil acl_graph_parse_util; | ||||
std::shared_ptr<domi::ModelParser> model_parser; | |||||
(void)acl_graph_parse_util.AclParserInitialize(options); | |||||
if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { | |||||
GELOGE(ge::FAILED, "Prepare before parse failed."); | |||||
string output_name; | |||||
if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { | |||||
GELOGE(ge::FAILED, "Parser params before graph failed."); | |||||
return ge::FAILED; | return ge::FAILED; | ||||
} | } | ||||
// Create an empty computegraph | |||||
string graph_name = output_name.empty() ? "tmpGraph" : output_name; | |||||
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name); | |||||
GE_CHECK_NOTNULL(compute_graph); | |||||
// parse caffe model_file to GE graph | |||||
ge::graphStatus ret = model_parser->ParseFromMemory(buffer, (uint32_t)size, graph); | |||||
graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||||
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); | |||||
GE_CHECK_NOTNULL(model_parser); | |||||
// parse caffe model_file and weights_file to GE graph | |||||
ge::graphStatus ret = model_parser->ParseFromMemory(buffer, (uint32_t)buffer_size, graph); | |||||
if (ret != ge::SUCCESS) { | if (ret != ge::SUCCESS) { | ||||
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); | GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); | ||||
return ge::FAILED; | return ge::FAILED; | ||||
} | } | ||||
GELOGI("Parser graph %s success.", graph.GetName().c_str()); | GELOGI("Parser graph %s success.", graph.GetName().c_str()); | ||||
if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) { | |||||
GELOGE(ge::FAILED, "Handle after parse failed."); | |||||
if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | |||||
GELOGE(ge::FAILED, "Parser params after graph failed."); | |||||
return ge::FAILED; | return ge::FAILED; | ||||
} | } | ||||
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); | |||||
#endif | |||||
return ge::SUCCESS; | |||||
if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||||
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||||
return ge::FAILED; | |||||
} | |||||
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); | |||||
return ge::SUCCESS; | |||||
} | } | ||||
} // namespace ge | } // namespace ge | ||||
@@ -159,7 +150,6 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, | |||||
GELOGE(FAILED, "Onnx graph has zero input"); | GELOGE(FAILED, "Onnx graph has zero input"); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
// get input value info map | // get input value info map | ||||
std::map<std::string, ge::onnx::TensorProto> input_name_tensor; | std::map<std::string, ge::onnx::TensorProto> input_name_tensor; | ||||
for (int i = 0; i < onnx_graph.input_size(); i++) { | for (int i = 0; i < onnx_graph.input_size(); i++) { | ||||
@@ -173,7 +163,6 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, | |||||
initializer_name_tensor.erase(initializer_iter); | initializer_name_tensor.erase(initializer_iter); | ||||
continue; | continue; | ||||
} | } | ||||
ge::onnx::TensorProto tensor_tmp; | ge::onnx::TensorProto tensor_tmp; | ||||
if (value_info.has_type()) { | if (value_info.has_type()) { | ||||
const ge::onnx::TypeProto type = value_info.type(); | const ge::onnx::TypeProto type = value_info.type(); | ||||
@@ -194,7 +183,6 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, | |||||
} | } | ||||
input_name_tensor[value_info.name()] = tensor_tmp; | input_name_tensor[value_info.name()] = tensor_tmp; | ||||
} | } | ||||
// Construct node for input | // Construct node for input | ||||
int64_t index = 0; | int64_t index = 0; | ||||
for (auto it : input_name_tensor) { | for (auto it : input_name_tensor) { | ||||
@@ -350,9 +338,11 @@ Status OnnxModelParser::SetOperatorInputs() { | |||||
for (auto in_iter = inputs_map_.begin(); in_iter != inputs_map_.end(); in_iter++) { | for (auto in_iter = inputs_map_.begin(); in_iter != inputs_map_.end(); in_iter++) { | ||||
auto out_iter = outputs_map_.find(in_iter->first); | auto out_iter = outputs_map_.find(in_iter->first); | ||||
if (out_iter == outputs_map_.end()) { | if (out_iter == outputs_map_.end()) { | ||||
GELOGE(INTERNAL_ERROR, "Unknown input: %s:%d in node: %s", in_iter->first.c_str(), in_iter->second[0].second, | |||||
GELOGW("Unknown input: %s:%d for node: %s, which maybe option input.", | |||||
in_iter->first.c_str(), | |||||
in_iter->second[0].second, | |||||
in_iter->second[0].first.c_str()); | in_iter->second[0].first.c_str()); | ||||
return INTERNAL_ERROR; | |||||
continue; | |||||
} | } | ||||
std::vector<std::pair<std::string, int>> &input_node_indexs = in_iter->second; | std::vector<std::pair<std::string, int>> &input_node_indexs = in_iter->second; | ||||
@@ -511,11 +501,10 @@ Status OnnxModelParser::GetGraphInputs(std::vector<ge::Operator> &input_ops) { | |||||
input_ops.emplace_back(in_op->second); | input_ops.emplace_back(in_op->second); | ||||
GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str()); | GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str()); | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) { | |||||
Status OnnxModelParser::GetModelFromfile(const char *file, ge::onnx::ModelProto &onnx_model) { | |||||
GE_CHECK_NOTNULL(file); | GE_CHECK_NOTNULL(file); | ||||
GELOGI("File path is %s.", file); | GELOGI("File path is %s.", file); | ||||
@@ -529,20 +518,18 @@ Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
Status OnnxModelParser::GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) { | Status OnnxModelParser::GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) { | ||||
GE_CHECK_NOTNULL(data); | GE_CHECK_NOTNULL(data); | ||||
// 1. Get graph from onnx model file. | |||||
if (!ge::parser::ReadProtoFromArray(data, size, &onnx_model)) { | |||||
// 1. Get graph from memory. | |||||
if (!ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &onnx_model)) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage( | ErrorManager::GetInstance().ATCReportErrMessage( | ||||
"E19021", {"reason"}, {"Read onnx model from memory failed."}); | |||||
GELOGE(PARAM_INVALID, "Read onnx model from memory failed."); | |||||
"E19021", {"reason"}, {"Read onnx model file failed."}); | |||||
GELOGE(PARAM_INVALID, "Read onnx model file failed."); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
#endif | |||||
Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph) { | Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph) { | ||||
if (!onnx_model.has_graph()) { | if (!onnx_model.has_graph()) { | ||||
@@ -551,13 +538,11 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
ge::onnx::GraphProto onnx_graph = onnx_model.graph(); | ge::onnx::GraphProto onnx_graph = onnx_model.graph(); | ||||
auto opset_import = onnx_model.opset_import(); | auto opset_import = onnx_model.opset_import(); | ||||
for (auto it : opset_import) { | for (auto it : opset_import) { | ||||
domain_verseion_[it.domain()] = it.version(); | domain_verseion_[it.domain()] = it.version(); | ||||
GELOGI("Domain: %s, Version: %ld ", it.domain().c_str(), it.version()); | GELOGI("Domain: %s, Version: %ld ", it.domain().c_str(), it.version()); | ||||
} | } | ||||
// 2. Get all inializer. | // 2. Get all inializer. | ||||
std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor; | std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor; | ||||
for (int i = 0; i < onnx_graph.initializer_size(); i++) { | for (int i = 0; i < onnx_graph.initializer_size(); i++) { | ||||
@@ -567,7 +552,6 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||||
GELOGI("Initializer name: %s .", initializer_tensor.name().c_str()); | GELOGI("Initializer name: %s .", initializer_tensor.name().c_str()); | ||||
} | } | ||||
} | } | ||||
// 3. Parse Input from graph. | // 3. Parse Input from graph. | ||||
GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size()); | GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size()); | ||||
Status ret = ParseInput(onnx_graph, initializer_name_tensor); | Status ret = ParseInput(onnx_graph, initializer_name_tensor); | ||||
@@ -576,21 +560,18 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||||
return ret; | return ret; | ||||
} | } | ||||
GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); | GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); | ||||
// 4. Parse Constant from graph. | // 4. Parse Constant from graph. | ||||
ret = ParseInitializer(onnx_graph, initializer_name_tensor); | ret = ParseInitializer(onnx_graph, initializer_name_tensor); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Parse initializer for onnx failed."); | GELOGE(ret, "Parse initializer for onnx failed."); | ||||
return ret; | return ret; | ||||
} | } | ||||
// 5. Update node name for node do not has name. | // 5. Update node name for node do not has name. | ||||
ret = UpdateAllNodeName(onnx_graph); | ret = UpdateAllNodeName(onnx_graph); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Update all node name for onnx failed."); | GELOGE(ret, "Update all node name for onnx failed."); | ||||
return ret; | return ret; | ||||
} | } | ||||
// 6 Precheck. | // 6 Precheck. | ||||
ret = Prechecker(onnx_graph); | ret = Prechecker(onnx_graph); | ||||
bool is_precheck_failed = (ret != SUCCESS) || (ge::PreChecker::Instance().HasError()); | bool is_precheck_failed = (ret != SUCCESS) || (ge::PreChecker::Instance().HasError()); | ||||
@@ -624,7 +605,6 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||||
// 9. Construct graph. | // 9. Construct graph. | ||||
std::vector<ge::Operator> input_ops; | std::vector<ge::Operator> input_ops; | ||||
ret = GetGraphInputs(input_ops); | ret = GetGraphInputs(input_ops); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Get graph inputs failed."); | GELOGE(ret, "Get graph inputs failed."); | ||||
@@ -642,35 +622,33 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||||
Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { | Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { | ||||
ge::onnx::ModelProto onnx_model; | ge::onnx::ModelProto onnx_model; | ||||
Status ret = GetModelFromFile(file, onnx_model); | |||||
Status ret = GetModelFromfile(file, onnx_model); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(FAILED, "get model from file failed."); | |||||
return FAILED; | |||||
GELOGE(ret, "Get model from file failed."); | |||||
return ret; | |||||
} | } | ||||
ret = ModelParseToGraph(onnx_model, graph); | ret = ModelParseToGraph(onnx_model, graph); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(FAILED, "parse model failed."); | |||||
return FAILED; | |||||
GELOGE(ret, "Parse model failed."); | |||||
return ret; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
Status OnnxModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) { | Status OnnxModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) { | ||||
ge::onnx::ModelProto onnx_model; | ge::onnx::ModelProto onnx_model; | ||||
Status ret = GetModelFromMemory(data, size, onnx_model); | Status ret = GetModelFromMemory(data, size, onnx_model); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(FAILED, "get model from file failed."); | |||||
return FAILED; | |||||
GELOGE(ret, "Get model from memory failed."); | |||||
return ret; | |||||
} | } | ||||
ret = ModelParseToGraph(onnx_model, graph); | ret = ModelParseToGraph(onnx_model, graph); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(FAILED, "parse model failed."); | |||||
return FAILED; | |||||
GELOGE(ret, "Parse model failed."); | |||||
return ret; | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
#endif | |||||
Status OnnxModelParser::ToJson(const char *model_file, const char *json_file) { | Status OnnxModelParser::ToJson(const char *model_file, const char *json_file) { | ||||
if (model_file == nullptr) { | if (model_file == nullptr) { | ||||
@@ -700,4 +678,4 @@ ge::DataType OnnxModelParser::ConvertToGeDataType(const uint32_t type) { | |||||
namespace domi { | namespace domi { | ||||
REGISTER_MODEL_PARSER_CREATOR(ONNX, ge::OnnxModelParser); | REGISTER_MODEL_PARSER_CREATOR(ONNX, ge::OnnxModelParser); | ||||
REGISTER_WEIGHTS_PARSER_CREATOR(ONNX, ge::OnnxWeightsParser); | REGISTER_WEIGHTS_PARSER_CREATOR(ONNX, ge::OnnxWeightsParser); | ||||
} | |||||
} |
@@ -38,11 +38,11 @@ class OnnxModelParser : public domi::ModelParser { | |||||
ge::DataType ConvertToGeDataType(const uint32_t type) override; | ge::DataType ConvertToGeDataType(const uint32_t type) override; | ||||
Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override { return domi::SUCCESS; } | |||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override; | Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override; | ||||
#endif | |||||
Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override { | |||||
return domi::SUCCESS; | |||||
} | |||||
Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override { | Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override { | ||||
return domi::SUCCESS; | return domi::SUCCESS; | ||||
@@ -81,12 +81,10 @@ class OnnxModelParser : public domi::ModelParser { | |||||
Status GetGraphInputs(std::vector<ge::Operator> &input_ops); | Status GetGraphInputs(std::vector<ge::Operator> &input_ops); | ||||
Status Prechecker(ge::onnx::GraphProto &onnx_graph); | Status Prechecker(ge::onnx::GraphProto &onnx_graph); | ||||
Status GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model); | |||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
Status GetModelFromfile(const char *file, ge::onnx::ModelProto &onnx_model); | |||||
Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model); | Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model); | ||||
#endif | |||||
Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph); | Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph); | ||||
@@ -30,7 +30,6 @@ enum DataType | |||||
DT_RESOURCE = 23; // resource type | DT_RESOURCE = 23; // resource type | ||||
DT_STRING_REF = 24; // string_ref type | DT_STRING_REF = 24; // string_ref type | ||||
DT_DUAL = 25; /**< dual output type */ | DT_DUAL = 25; /**< dual output type */ | ||||
DT_VARIANT = 26; // variant type | |||||
} | } | ||||
message AttrDef | message AttrDef | ||||
@@ -30,7 +30,6 @@ enum DataType | |||||
DT_RESOURCE = 23; // resource type | DT_RESOURCE = 23; // resource type | ||||
DT_STRING_REF = 24; // string_ref type | DT_STRING_REF = 24; // string_ref type | ||||
DT_DUAL = 25; /**< dual output type */ | DT_DUAL = 25; /**< dual output type */ | ||||
DT_VARIANT = 26; // variant type | |||||
} | } | ||||
message AttrDef | message AttrDef | ||||
@@ -721,23 +721,15 @@ Status TensorFlowModelParser::AddEdges(ge::ComputeGraphPtr &graph) { | |||||
GELOGD("Start add contorl edge: from %s to %s.", src->GetName().c_str(), dest->GetName().c_str()); | GELOGD("Start add contorl edge: from %s to %s.", src->GetName().c_str(), dest->GetName().c_str()); | ||||
ge::InControlAnchorPtr in_archor_ptr = dest->GetInControlAnchor(); | ge::InControlAnchorPtr in_archor_ptr = dest->GetInControlAnchor(); | ||||
GE_CHECK_NOTNULL(in_archor_ptr); | GE_CHECK_NOTNULL(in_archor_ptr); | ||||
GE_IF_BOOL_EXEC(nodedef_map_[src_op_name]->op() != TENSORFLOWF_NODE_OP_SWITCH, | |||||
ge::OutControlAnchorPtr out_archor_ptr = src->GetOutControlAnchor(); | |||||
GE_CHECK_NOTNULL(out_archor_ptr); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E12014", {"opname1", "opname2"}, | |||||
{src->GetName(), dest->GetName()}); | |||||
return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), | |||||
dest->GetName().c_str());); | |||||
GE_IF_BOOL_EXEC(nodedef_map_[src_op_name]->op() == TENSORFLOWF_NODE_OP_SWITCH, | |||||
ge::OutDataAnchorPtr out_data_archor_ptr = src->GetOutDataAnchor(outputpair.first); | |||||
GE_CHECK_NOTNULL(out_data_archor_ptr); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
ge::GraphUtils::AddEdge(out_data_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E12014", {"opname1", "opname2"}, | |||||
{src->GetName(), dest->GetName()}); | |||||
return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), | |||||
dest->GetName().c_str());); | |||||
ge::OutControlAnchorPtr out_archor_ptr = src->GetOutControlAnchor(); | |||||
GE_CHECK_NOTNULL(out_archor_ptr); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E12014", {"opname1", "opname2"}, | |||||
{src->GetName(), dest->GetName()}); | |||||
return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), | |||||
dest->GetName().c_str() | |||||
); | |||||
} | } | ||||
} | } | ||||
dest_input_map.erase(input_iter); | dest_input_map.erase(input_iter); | ||||
@@ -3221,7 +3213,7 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap | |||||
} | } | ||||
// 2.4 remove the input const nodes | // 2.4 remove the input const nodes | ||||
Status ret = RemoveInputs(current_node, unused_inputs); | |||||
Status ret = RemoveInputs(graph_def, current_node, unused_inputs, all_nodedef_map); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E12006", {"opname"}, {current_op_name}); | ErrorManager::GetInstance().ATCReportErrMessage("E12006", {"opname"}, {current_op_name}); | ||||
GELOGE(INTERNAL_ERROR, "Op[%s] remove input failed.", current_op_name.c_str()); | GELOGE(INTERNAL_ERROR, "Op[%s] remove input failed.", current_op_name.c_str()); | ||||
@@ -3232,6 +3224,34 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status TensorFlowModelParser::AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def, | |||||
domi::tensorflow::NodeDef *node_def, | |||||
const map<string, NodeDef *> &all_node_map, | |||||
const vector<string> &removed_inputs_vec) { | |||||
GE_CHECK_NOTNULL(graph_def); | |||||
GE_CHECK_NOTNULL(node_def); | |||||
for (const auto &remove_input : removed_inputs_vec) { | |||||
string input_node_name = NodeNameFromInput(remove_input); | |||||
auto it = all_node_map.find(input_node_name); | |||||
if (it == all_node_map.end()) { | |||||
GELOGE(FAILED, "Can not find node name:%s in all node map.", input_node_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
NodeDef *input_node_def = it->second; | |||||
if (input_node_def->op() == SWITCH || input_node_def->op() == REFSWITCH) { | |||||
NodeDef *identity_node_def = graph_def->add_node(); | |||||
GE_CHECK_NOTNULL(identity_node_def); | |||||
input_node_name = input_node_name + "identity"; | |||||
identity_node_def->set_name(input_node_name); | |||||
identity_node_def->set_op(IDENTITY); | |||||
identity_node_def->add_input(remove_input); | |||||
} | |||||
string control_input = "^" + input_node_name; | |||||
node_def->add_input(control_input); | |||||
GELOGD("Add control input:%s for node:%s", control_input.c_str(), node_def->name().c_str()); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
* @brief Delete input from nodedef | * @brief Delete input from nodedef | ||||
@@ -3241,7 +3261,10 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap | |||||
* @return false remove failed | * @return false remove failed | ||||
* | * | ||||
*/ | */ | ||||
Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def, const set<uint32_t> &remove_index_set) { | |||||
Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::GraphDef *graph_def, | |||||
domi::tensorflow::NodeDef *node_def, | |||||
const set<uint32_t> &remove_index_set, | |||||
const map<string, NodeDef *> &all_node_map) { | |||||
GE_CHECK_NOTNULL(node_def); | GE_CHECK_NOTNULL(node_def); | ||||
if (remove_index_set.empty()) { | if (remove_index_set.empty()) { | ||||
GELOGI("The size of remove_index_set is zero."); | GELOGI("The size of remove_index_set is zero."); | ||||
@@ -3258,6 +3281,7 @@ Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def, | |||||
RemoveInputAttr(node_def, remove_inputs_map); | RemoveInputAttr(node_def, remove_inputs_map); | ||||
int index = 0; | int index = 0; | ||||
vector<string> removed_inputs_vec; | |||||
auto *inputs = node_def->mutable_input(); | auto *inputs = node_def->mutable_input(); | ||||
for (auto input_it = inputs->begin(); input_it != inputs->end(); ++index) { | for (auto input_it = inputs->begin(); input_it != inputs->end(); ++index) { | ||||
// 1.decide whether to remove the input | // 1.decide whether to remove the input | ||||
@@ -3269,6 +3293,7 @@ Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def, | |||||
std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) { | std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) { | ||||
GELOGD("Remove input:%s, index:%d", remove_input_name.c_str(), index); | GELOGD("Remove input:%s, index:%d", remove_input_name.c_str(), index); | ||||
flag = true; | flag = true; | ||||
removed_inputs_vec.emplace_back(remove_input_name); | |||||
break; | break; | ||||
} | } | ||||
} | } | ||||
@@ -3281,6 +3306,11 @@ Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def, | |||||
} | } | ||||
} | } | ||||
Status ret = AddControlEdgeAfterRemoveInputs(graph_def, node_def, all_node_map, removed_inputs_vec); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(FAILED, "Add control edges for node:%s failed.", node_def->name().c_str()); | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -86,26 +86,22 @@ class TensorFlowModelParser : public domi::ModelParser { | |||||
* @param [in|out] graph save model information after parsing | * @param [in|out] graph save model information after parsing | ||||
* @return SUCCESS parse successfully | * @return SUCCESS parse successfully | ||||
* @return FAILED parse failed | * @return FAILED parse failed | ||||
*/ | */ | ||||
Status Parse(const char *file, ge::Graph &graph) override; | Status Parse(const char *file, ge::Graph &graph) override; | ||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
* @brief Parse the relevant data from memory and save it to graph | |||||
* @param [in] memory buffer of model file | |||||
* @param [in] buffer size | |||||
* @param [in|out] graph graph for saving model information | |||||
* @brief Parse the relevant data from the memory and save it to graph | |||||
* @param [in] file Path of the model file | |||||
* @param [in|out] graph save model information after parsing | |||||
* @return SUCCESS parse successfully | * @return SUCCESS parse successfully | ||||
* @return FAILED parse failed | * @return FAILED parse failed | ||||
*/ | */ | ||||
Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; | Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; | ||||
#ifndef ONLY_COMPILE_OPEN_SRC | |||||
Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override { | Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override { | ||||
return domi::SUCCESS; | return domi::SUCCESS; | ||||
} | } | ||||
#endif | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -541,7 +537,15 @@ class TensorFlowModelParser : public domi::ModelParser { | |||||
* @return false remove failed | * @return false remove failed | ||||
* | * | ||||
*/ | */ | ||||
Status RemoveInputs(domi::tensorflow::NodeDef *node_def, const set<uint32_t> &remove_index_set); | |||||
Status RemoveInputs(domi::tensorflow::GraphDef *graph_def, | |||||
domi::tensorflow::NodeDef *node_def, | |||||
const set<uint32_t> &remove_index_set, | |||||
const map<string, NodeDef *> &all_node_map); | |||||
Status AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def, | |||||
domi::tensorflow::NodeDef *node_def, | |||||
const map<string, NodeDef *> &all_node_map, | |||||
const vector<string> &removed_inputs_vec); | |||||
void RemoveInputAttr(domi::tensorflow::NodeDef *node_def, const map<string, vector<int>> &remove_inputs_map); | void RemoveInputAttr(domi::tensorflow::NodeDef *node_def, const map<string, vector<int>> &remove_inputs_map); | ||||